hespas.mlir_parser.mlir_splitter

Functions

create_digraph(operation_blocks)

create_digraph_with_deps(operation_blocks, ...)

individual_split(operations[, merge])

linear_split(operations[, block_lim, ...])

Linear sequential split algorithm for MLIR operations.

load_dependency_graph(pickle_path)

Load a dependency graph with MLIRModule objects from a pickle file.

main()

merge_split(elements, module_dep_graph)

parse_and_split_mlir(file_path, output_path)

Parses the StableHLO MLIR file and uses the provided split_fn to split the module.

register_split_fn(func)

save_dependency_graph(dep_graph, output_dir)

Save a dependency graph with MLIRModule objects to a pickle file.

split_by_opregion(module, operation_regions)

validate_dependency_graph(dep_graph)

Validate the integrity of a dependency graph with MLIRModule objects.

Classes

SeparatorPolicy(*values)

Exceptions

DependencyGraphValidationError

LoadDependencyException

Custom exception for dependency graph loading failures.

MLIRSplittingError

exception hespas.mlir_parser.mlir_splitter.LoadDependencyException

Bases: Exception

Custom exception for dependency graph loading failures.

exception hespas.mlir_parser.mlir_splitter.MLIRSplittingError

Bases: Exception

exception hespas.mlir_parser.mlir_splitter.DependencyGraphValidationError

Bases: Exception

hespas.mlir_parser.mlir_splitter.register_split_fn(func)
hespas.mlir_parser.mlir_splitter.split_by_opregion(module, operation_regions)
class hespas.mlir_parser.mlir_splitter.SeparatorPolicy(*values)

Bases: Enum

ISOLATED = 1
END_BLOCK = 2
START_BLOCK = 3
classmethod from_string(src)
hespas.mlir_parser.mlir_splitter.linear_split(operations, block_lim=1024, separator=<function is_communication_op>, separator_policy=SeparatorPolicy.ISOLATED, separator_node_type=NodeType.COMM_COLL_NODE)

Linear sequential split algorithm for MLIR operations.

The algorithm scans operations in order, grouping computational ones into blocks of up to block_lim ops. Every communication op is isolated into its own singleton block. At the end, you get a sequence of COMP and COMM blocks in execution order, and a simple chain dependency graph linking them linearly.

Args:

operations: Sequence of MLIR operations to split block_lim: Maximum number of operations per COMP block (default: 1024)

Returns:

Tuple[List[(ops, block_type)], nx.DiGraph]: Sequential operation blocks and linear dependency graph

hespas.mlir_parser.mlir_splitter.merge_split(elements, module_dep_graph)
hespas.mlir_parser.mlir_splitter.individual_split(operations, merge=False)
hespas.mlir_parser.mlir_splitter.create_digraph_with_deps(operation_blocks, operation_dependencies)
hespas.mlir_parser.mlir_splitter.create_digraph(operation_blocks)
hespas.mlir_parser.mlir_splitter.parse_and_split_mlir(file_path, output_path, split_fn=<function linear_split>, num_threads=-1, *split_args, **split_kw_args)

Parses the StableHLO MLIR file and uses the provided split_fn to split the module. Returns a list of modules and metadata for each split.

Parameters:
  • file_path – Path to the input MLIR file.

  • output_path – Path to write the output (unused here but assumed for downstream).

  • split_fn – A function that takes a parsed module and returns a list of (ops, block_type) tuples.

hespas.mlir_parser.mlir_splitter.save_dependency_graph(dep_graph: DiGraph, output_dir: str | Path) str

Save a dependency graph with MLIRModule objects to a pickle file.

This function serializes a NetworkX DiGraph containing MLIRModule objects in the node attributes to a pickle file for later restoration.

Args:

dep_graph (nx.DiGraph): Dependency graph with MLIRModule objects in node[‘mlir_module’] output_dir (Union[str, Path]): Directory to save the pickle file

Returns:

str: Path to the saved pickle file

Example:
>>> dep_graph = parse_and_split_mlir("input.mlir", "output")
>>> pickle_path = save_dependency_graph(dep_graph, "output")
>>> print(f"Graph saved to {pickle_path}")
hespas.mlir_parser.mlir_splitter.load_dependency_graph(pickle_path: str | Path) DiGraph | None

Load a dependency graph with MLIRModule objects from a pickle file.

This function deserializes a NetworkX DiGraph containing MLIRModule objects from a pickle file. It handles MLIRModule class state restoration properly.

Args:

pickle_path (Union[str, Path]): Path to the pickle file

Returns:

Optional[nx.DiGraph]: Loaded dependency graph, or None if loading failed

Raises:

FileNotFoundError: If the pickle file doesn’t exist pickle.PickleError: If the pickle file is corrupted or incompatible

Example:
>>> dep_graph = load_dependency_graph("output/dependency_graph.pkl")
>>> if dep_graph:
...     print(f"Loaded graph with {len(dep_graph.nodes)} nodes")
hespas.mlir_parser.mlir_splitter.validate_dependency_graph(dep_graph: DiGraph) bool

Validate the integrity of a dependency graph with MLIRModule objects.

This function checks that the dependency graph is properly formed and all MLIRModule objects are valid.

Args:

dep_graph (nx.DiGraph): Dependency graph to validate

Returns:

bool: True if valid, False otherwise

Example:
>>> dep_graph = load_dependency_graph("output/dependency_graph.pkl")
>>> if validate_dependency_graph(dep_graph):
...     print("Graph is valid")
hespas.mlir_parser.mlir_splitter.main()