hespas.mlir_parser.mlir_splitter
Functions
|
|
|
|
|
|
|
Linear sequential split algorithm for MLIR operations. |
|
Load a dependency graph with MLIRModule objects from a pickle file. |
|
|
|
|
|
Parses the StableHLO MLIR file and uses the provided split_fn to split the module. |
|
|
|
Save a dependency graph with MLIRModule objects to a pickle file. |
|
|
|
Validate the integrity of a dependency graph with MLIRModule objects. |
Classes
|
Exceptions
Custom exception for dependency graph loading failures. |
|
- exception hespas.mlir_parser.mlir_splitter.LoadDependencyException
Bases:
ExceptionCustom exception for dependency graph loading failures.
- exception hespas.mlir_parser.mlir_splitter.MLIRSplittingError
Bases:
Exception
- exception hespas.mlir_parser.mlir_splitter.DependencyGraphValidationError
Bases:
Exception
- hespas.mlir_parser.mlir_splitter.register_split_fn(func)
- hespas.mlir_parser.mlir_splitter.split_by_opregion(module, operation_regions)
- class hespas.mlir_parser.mlir_splitter.SeparatorPolicy(*values)
Bases:
Enum- ISOLATED = 1
- END_BLOCK = 2
- START_BLOCK = 3
- classmethod from_string(src)
- hespas.mlir_parser.mlir_splitter.linear_split(operations, block_lim=1024, separator=<function is_communication_op>, separator_policy=SeparatorPolicy.ISOLATED, separator_node_type=NodeType.COMM_COLL_NODE)
Linear sequential split algorithm for MLIR operations.
The algorithm scans operations in order, grouping computational ones into blocks of up to block_lim ops. Every communication op is isolated into its own singleton block. At the end, you get a sequence of COMP and COMM blocks in execution order, and a simple chain dependency graph linking them linearly.
- Args:
operations: Sequence of MLIR operations to split block_lim: Maximum number of operations per COMP block (default: 1024)
- Returns:
Tuple[List[(ops, block_type)], nx.DiGraph]: Sequential operation blocks and linear dependency graph
- hespas.mlir_parser.mlir_splitter.merge_split(elements, module_dep_graph)
- hespas.mlir_parser.mlir_splitter.individual_split(operations, merge=False)
- hespas.mlir_parser.mlir_splitter.create_digraph_with_deps(operation_blocks, operation_dependencies)
- hespas.mlir_parser.mlir_splitter.create_digraph(operation_blocks)
- hespas.mlir_parser.mlir_splitter.parse_and_split_mlir(file_path, output_path, split_fn=<function linear_split>, num_threads=-1, *split_args, **split_kw_args)
Parses the StableHLO MLIR file and uses the provided split_fn to split the module. Returns a list of modules and metadata for each split.
- Parameters:
file_path – Path to the input MLIR file.
output_path – Path to write the output (unused here but assumed for downstream).
split_fn – A function that takes a parsed module and returns a list of (ops, block_type) tuples.
- hespas.mlir_parser.mlir_splitter.save_dependency_graph(dep_graph: DiGraph, output_dir: str | Path) str
Save a dependency graph with MLIRModule objects to a pickle file.
This function serializes a NetworkX DiGraph containing MLIRModule objects in the node attributes to a pickle file for later restoration.
- Args:
dep_graph (nx.DiGraph): Dependency graph with MLIRModule objects in node[‘mlir_module’] output_dir (Union[str, Path]): Directory to save the pickle file
- Returns:
str: Path to the saved pickle file
- Example:
>>> dep_graph = parse_and_split_mlir("input.mlir", "output") >>> pickle_path = save_dependency_graph(dep_graph, "output") >>> print(f"Graph saved to {pickle_path}")
- hespas.mlir_parser.mlir_splitter.load_dependency_graph(pickle_path: str | Path) DiGraph | None
Load a dependency graph with MLIRModule objects from a pickle file.
This function deserializes a NetworkX DiGraph containing MLIRModule objects from a pickle file. It handles MLIRModule class state restoration properly.
- Args:
pickle_path (Union[str, Path]): Path to the pickle file
- Returns:
Optional[nx.DiGraph]: Loaded dependency graph, or None if loading failed
- Raises:
FileNotFoundError: If the pickle file doesn’t exist pickle.PickleError: If the pickle file is corrupted or incompatible
- Example:
>>> dep_graph = load_dependency_graph("output/dependency_graph.pkl") >>> if dep_graph: ... print(f"Loaded graph with {len(dep_graph.nodes)} nodes")
- hespas.mlir_parser.mlir_splitter.validate_dependency_graph(dep_graph: DiGraph) bool
Validate the integrity of a dependency graph with MLIRModule objects.
This function checks that the dependency graph is properly formed and all MLIRModule objects are valid.
- Args:
dep_graph (nx.DiGraph): Dependency graph to validate
- Returns:
bool: True if valid, False otherwise
- Example:
>>> dep_graph = load_dependency_graph("output/dependency_graph.pkl") >>> if validate_dependency_graph(dep_graph): ... print("Graph is valid")
- hespas.mlir_parser.mlir_splitter.main()