hespas.mlir_parser.mlir_common

Functions

add_classifier(func)

append_private_functions(private_funcs[, ...])

create_new_module_with_operations(block[, ...])

get_dot_general_dimensions(dot_op)

Extract batch and contracting dimensions from a stablehlo.dot_general op.

get_external_inputs(block)

returns a dict of argument: argument types

get_external_outputs(block)

returns a dict of argument_name: argument types

is_communication_op(op)

Check if the operation is a StableHLO communication op.

is_communication_op_str(op_name)

is_convolution_op(op)

Check if the operation is a StableHLO convolution op.

is_dot_general_op(op)

Check if the operation is a StableHLO dot_general op.

is_reduce_op(op)

Check if the operation is a function return op.

is_return_op(op)

Check if the operation is a function return op.

parse_replica_from_match(raw_value, ...)

parse_replica_groups(mlir_string)

Extract replica_groups from a StableHLO collective operation's MLIR string.

store_private_functions(priv_funcs, output_path)

Write private functions to a single file for debugging purposes.

Classes

NodeType(*values)

class hespas.mlir_parser.mlir_common.NodeType(*values)

Bases: Enum

INVALID_NODE = 0
METADATA_NODE = 1
MEM_LOAD_NODE = 2
MEM_STORE_NODE = 3
COMP_NODE = 4
COMM_SEND_NODE = 5
COMM_RECV_NODE = 6
COMM_COLL_NODE = 7
hespas.mlir_parser.mlir_common.add_classifier(func)
hespas.mlir_parser.mlir_common.is_communication_op_str(op_name)
hespas.mlir_parser.mlir_common.is_convolution_op(op)

Check if the operation is a StableHLO convolution op.

hespas.mlir_parser.mlir_common.is_dot_general_op(op)

Check if the operation is a StableHLO dot_general op.

hespas.mlir_parser.mlir_common.is_communication_op(op)

Check if the operation is a StableHLO communication op.

hespas.mlir_parser.mlir_common.is_reduce_op(op)

Check if the operation is a function return op.

hespas.mlir_parser.mlir_common.is_return_op(op)

Check if the operation is a function return op.

hespas.mlir_parser.mlir_common.get_external_inputs(block)

returns a dict of argument: argument types

hespas.mlir_parser.mlir_common.get_external_outputs(block)

returns a dict of argument_name: argument types

hespas.mlir_parser.mlir_common.get_dot_general_dimensions(dot_op)

Extract batch and contracting dimensions from a stablehlo.dot_general op.

hespas.mlir_parser.mlir_common.parse_replica_groups(mlir_string: str) list[list[int]] | None

Extract replica_groups from a StableHLO collective operation’s MLIR string.

Parses the replica_groups dense attribute that appears in collective operations such as stablehlo.all_reduce, stablehlo.all_gather, etc.

Supported formats:
  • dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> → [[0, 1, 2, 3]]

  • dense<[[0, 1], [2, 3]]> : tensor<2x2xi64> → [[0, 1], [2, 3]]

  • dense<0> : tensor<1x1xi64> → [[0]]

For stablehlo.collective_permute which uses source_target_pairs instead of replica_groups, returns None.

Returns:

A list of groups, where each group is a list of integer device ids, or None if no replica_groups attribute is found.

hespas.mlir_parser.mlir_common.parse_replica_from_match(raw_value, num_groups, group_size)
hespas.mlir_parser.mlir_common.store_private_functions(priv_funcs, output_path, context=None)

Write private functions to a single file for debugging purposes.

This function writes all private functions extracted by the MLIR parser to a file named private_functions.mlir in the specified output path. This is useful for debugging and analyzing the private functions in isolation.

Args:

mlir_parser: An instance of the MLIR parser that provides access to private functions. output_path (str): The directory path where the private_functions.mlir file will be saved.

Returns:

list: A list of private functions extracted by the MLIR parser.

Note:

This function must be executed within an MLIR context.

hespas.mlir_parser.mlir_common.append_private_functions(private_funcs, module=None, context=None)
hespas.mlir_parser.mlir_common.create_new_module_with_operations(block, private_funcs=None, context=None)