hespas.mlir_parser

MLIR parser and splitter for HESPAS. Allow top-level analysis of MLIR exports.

hespas.mlir_parser.parse_and_split_mlir(file_path, output_path, split_fn=<function linear_split>, num_threads=-1, *split_args, **split_kw_args)

Parses the StableHLO MLIR file and uses the provided split_fn to split the module. Returns a list of modules and metadata for each split.

Parameters:
  • file_path – Path to the input MLIR file.

  • output_path – Path to write the output (unused here but assumed for downstream).

  • split_fn – A function that takes a parsed module and returns a list of (ops, block_type) tuples.

class hespas.mlir_parser.MLIRParser(*, mlir_string=None, mlir_path=None, mlir_module=None, input_sensitive=False)

Bases: object

A class for parsing MLIR (Multi-Level Intermediate Representation) code.

ir_context = None
conv_dim_number_re = re.compile('stablehlo\\.conv\\s*<\\s*(\\[\\s*([fb01])\\s*,\\s*([fb01])\\s*,\\s*([fb01])\\s*,\\s*([fb01])\\s*\\])\\s*x\\s*(\\[\\s*([io01])\\s*,\\s*([io01])\\s*,\\s*([io01])\\s*,\\s*([io01])\\s*\\])\\s*->\\s*(\\[\)
window_attr_re = re.compile('window\\s*=\\s*{([^}]*)\\s*}')
window_sub_attr_re = re.compile('([a-zA-Z_\\-0-9]+)\\s*=\\s*(\\[[^=]+\\])')
static get_ir_context()
static get_module(mlir_string, context=None)
__init__(*, mlir_string=None, mlir_path=None, mlir_module=None, input_sensitive=False)

Initializes the MLIRParser with an MLIR string or file path.

Args:

mlir_input (str): The MLIR code as a string or the path to an MLIR file. input_sensitive (bool): Consider the inputs as part of the operator while considering uniqueness

property mlir_string
property operations
Extracts operators from the MLIR code using regular expressions.

operation_input: True includes operator inputs in the categorization

Returns:

list: A list of extracted operators.

property main_function

Returns the main function in the MLIR module.

property main_index: int

Returns the index of the main function in the MLIR module. Raises an error if the main function is not found.

property private_functions: list[Operation]

Returns a list of private functions in the MLIR module.

get_private_functions_map(funcs=None) dict
get_private_functions_nx_tree(functions)
get_private_functions_ops(function)

Returns a list of operations in a private function, including nested calls.

parse_operation(operation)
property ops_list

Reads a mlir file or mlir_str, returns a list of <operator_name, inputs and output> tuples

get_mlir_function_inputs()

Reads an MLIR file and returns a list of function input types.

get_mlir_function_outputs()

Reads an MLIR file and returns a list of function output types.

class hespas.mlir_parser.MLIRModule(*, mlir_string: str | None = None, mlir_path: str | Path | None = None, mlir_module=None, block_type: NodeType | str | None = None, parent_module: str | None = None, **kwargs)

Bases: object

A wrapper class for MLIR module strings and their associated metadata.

This class encapsulates an MLIR module (as a string or file path) along with metadata about the module, such as block type, operation count, input/output dimensions, and other relevant information used in the HESPAS workload analysis.

Each module is automatically assigned a unique incrementing index starting from 0.

Attributes:

mlir_string (str): The MLIR module as a string representation idx (int): Unique incrementing index for this module metadata (Dict[str, Any]): Dictionary containing module metadata

Example:
>>> # Create from MLIR string
>>> module = MLIRModule(
...     mlir_string="module { ... }",
...     block_type=NodeType.COMM_COLL_NODE,
...     parent_module="input.mlir",
... )
>>> print(module.op_count)
42
>>> print(module.idx)
0
>>>
>>> # Create from file path
>>> module = MLIRModule(
...     mlir_string="/path/to/module.mlir",  # File path
...     block_type=1
... )
>>> module.save_to_files("/output/dir")
_next_idx = 0
__init__(*, mlir_string: str | None = None, mlir_path: str | Path | None = None, mlir_module=None, block_type: NodeType | str | None = None, parent_module: str | None = None, **kwargs)

Initialize an MLIRModule instance.

Args:

mlir_string (str): Either the MLIR module as a string or a file path to an MLIR file block_type (str, optional): Type of the block ( COMM or COMP) parent_module (str, optional): File name of the module from which this was split **kwargs: Additional metadata fields

classmethod reset_index_counter()

Reset the index counter to 0. Useful for testing.

property comm_bytes
property mlir_string: str

Get the MLIR module as a string.

property analyzer
property block_type: NodeType

Get the block type (COMM or COMP).

property is_communication_block: bool

Check if this is a communication block.

property is_computation_block: bool

Check if this is a computation block.

property op_count: int

Get the number of operations in this module.

property op_count_expanded: int

Get the number of operations in this module.

property ops_list
property parent_module: str | None

Get the parent module file name from which this was split.

property module_path: Path | None

Get the module file path, or None if not set. This variable is set either if module is create from file. Or if module is create from string and written to file.

property module_file: str

Get the module file name.

property collective: str | None

Get the collective operation name (for COMM blocks).

property replica_groups: list[list[int]] | None

Get the parsed replica_groups for this collective (for COMM blocks).

Returns a list of groups, where each group is a list of device ids, or None if not a collective or if the collective uses a different attribute (e.g. collective_permute with source_target_pairs).

property input_dims: List[str]

Get the input dimensions as string representations (computed dynamically).

property output_dims: List[str]

Get the output dimensions as string representations (computed dynamically).

property metadata: Dict[str, Any]

Get metadata dictionary (computed dynamically for backward compatibility).

Note: This property exists for backward compatibility. It’s recommended to use individual properties directly instead of accessing metadata.

property input_dims_mlir: List

Get the input dimensions as MLIR RankedTensorType objects.

If the objects don’t exist (e.g., after unpickling), reconstruct them from the MLIR string.

Returns:

List: List of MLIR RankedTensorType objects

property output_dims_mlir: List

Get the output dimensions as MLIR RankedTensorType objects.

If the objects don’t exist (e.g., after unpickling), reconstruct them from the MLIR string.

Returns:

List: List of MLIR RankedTensorType objects

ensure_dir(output_dir)
save_mlir(output_path)
get_json_metadata()
save_json(output_path)
get_paths(output_dir, base_name=None)
save_to_files(output_dir: str | Path, base_name: str | None = None) tuple

Save the MLIR module and metadata to separate files.

Args:

output_dir (Union[str, Path]): Directory to save files base_name (str, optional): Base name for the files (without extension). If not provided, uses “mini_module_{idx}”

Returns:

tuple: Paths to (mlir_file, json_file)

Example:
>>> module.save_to_files("/output")
... ("/output/mini_module_0.mlir", "/output/mini_module_0.json")
get_output_files(output_dir, base_name=None)
retrieve_mlir_from_file() None

Retrieve and load the MLIR module content from the associated file.

This method reads the MLIR string from the file specified in the module_file attribute and returns it. It is useful for reloading the MLIR content after it has been cleared or modified.

Returns:

str: The MLIR module content as a string.

update_metadata(**kwargs) None

Update metadata fields by setting instance variables.

Args:

**kwargs: Metadata fields to update

Example:
>>> module.update_metadata(op_count=50, custom_field="value")
get_metadata_copy() Dict[str, Any]

Get a copy of the metadata dictionary (computed dynamically).

Returns: Dict[str, Any]: A copy of the metadata

property hash
classmethod get_next_idx() int

Get the current value of the class-level index counter.

Returns:

int: Current value of _next_idx

classmethod set_next_idx(value: int) None

Set the class-level index counter to a specific value.

Args:

value (int): New value for _next_idx

static get_abspath(path)
static get_abspath_dir(path)

Modules

mlir_analyzer

mlir_common

mlir_module

MLIR Module wrapper class for handling split MLIR modules with metadata.

mlir_parser

mlir_splitter