hespas.mlir_parser.mlir_common
Functions
|
|
|
|
|
|
|
Extract batch and contracting dimensions from a stablehlo.dot_general op. |
|
returns a dict of argument: argument types |
|
returns a dict of argument_name: argument types |
Check if the operation is a StableHLO communication op. |
|
|
|
Check if the operation is a StableHLO convolution op. |
|
Check if the operation is a StableHLO dot_general op. |
|
|
Check if the operation is a function return op. |
|
Check if the operation is a function return op. |
|
|
|
Extract replica_groups from a StableHLO collective operation's MLIR string. |
|
Write private functions to a single file for debugging purposes. |
Classes
|
- class hespas.mlir_parser.mlir_common.NodeType(*values)
Bases:
Enum- INVALID_NODE = 0
- METADATA_NODE = 1
- MEM_LOAD_NODE = 2
- MEM_STORE_NODE = 3
- COMP_NODE = 4
- COMM_SEND_NODE = 5
- COMM_RECV_NODE = 6
- COMM_COLL_NODE = 7
- hespas.mlir_parser.mlir_common.add_classifier(func)
- hespas.mlir_parser.mlir_common.is_communication_op_str(op_name)
- hespas.mlir_parser.mlir_common.is_convolution_op(op)
Check if the operation is a StableHLO convolution op.
- hespas.mlir_parser.mlir_common.is_dot_general_op(op)
Check if the operation is a StableHLO dot_general op.
- hespas.mlir_parser.mlir_common.is_communication_op(op)
Check if the operation is a StableHLO communication op.
- hespas.mlir_parser.mlir_common.is_reduce_op(op)
Check if the operation is a function return op.
- hespas.mlir_parser.mlir_common.is_return_op(op)
Check if the operation is a function return op.
- hespas.mlir_parser.mlir_common.get_external_inputs(block)
returns a dict of argument: argument types
- hespas.mlir_parser.mlir_common.get_external_outputs(block)
returns a dict of argument_name: argument types
- hespas.mlir_parser.mlir_common.get_dot_general_dimensions(dot_op)
Extract batch and contracting dimensions from a stablehlo.dot_general op.
- hespas.mlir_parser.mlir_common.parse_replica_groups(mlir_string: str) list[list[int]] | None
Extract replica_groups from a StableHLO collective operation’s MLIR string.
Parses the
replica_groupsdense attribute that appears in collective operations such asstablehlo.all_reduce,stablehlo.all_gather, etc.- Supported formats:
dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>→ [[0, 1, 2, 3]]dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>→ [[0, 1], [2, 3]]dense<0> : tensor<1x1xi64>→ [[0]]
For
stablehlo.collective_permutewhich usessource_target_pairsinstead ofreplica_groups, returns None.- Returns:
A list of groups, where each group is a list of integer device ids, or None if no
replica_groupsattribute is found.
- hespas.mlir_parser.mlir_common.parse_replica_from_match(raw_value, num_groups, group_size)
- hespas.mlir_parser.mlir_common.store_private_functions(priv_funcs, output_path, context=None)
Write private functions to a single file for debugging purposes.
This function writes all private functions extracted by the MLIR parser to a file named private_functions.mlir in the specified output path. This is useful for debugging and analyzing the private functions in isolation.
- Args:
mlir_parser: An instance of the MLIR parser that provides access to private functions. output_path (str): The directory path where the private_functions.mlir file will be saved.
- Returns:
list: A list of private functions extracted by the MLIR parser.
- Note:
This function must be executed within an MLIR context.
- hespas.mlir_parser.mlir_common.append_private_functions(private_funcs, module=None, context=None)
- hespas.mlir_parser.mlir_common.create_new_module_with_operations(block, private_funcs=None, context=None)