hespas.estimator.roofline_estimator

Classes

RooflineEstimator([hw_config])

Exceptions

RooflineMissingDatatypeError

exception hespas.estimator.roofline_estimator.RooflineMissingDatatypeError

Bases: Exception

class hespas.estimator.roofline_estimator.RooflineEstimator(hw_config=None, **kwargs)

Bases: Estimator

allow_multiprocess = True
peak_flops = <hespas.estimator.config_option.ConfigOption object>
memory_bandwidth = <hespas.estimator.config_option.ConfigOption object>
tdp_W = <hespas.estimator.config_option.ConfigOption object>
hbm_power_ratio = <hespas.estimator.config_option.ConfigOption object>
per_datatype_flops = <hespas.estimator.config_option.ConfigOption object>
warn_on_unknown_type = <hespas.estimator.config_option.ConfigOption object>
error_on_unknown_type = <hespas.estimator.config_option.ConfigOption object>
kernel_launch_overhead_s = <hespas.estimator.config_option.ConfigOption object>
flops_per_element = <hespas.estimator.config_option.ConfigOption object>
memory_compute_parallelism = <hespas.estimator.config_option.ConfigOption object>
DEFAULT_FLOPS_PER_ELEMENT = {'stablehlo.abs': 1, 'stablehlo.add': 1, 'stablehlo.and': 1, 'stablehlo.clamp': 1, 'stablehlo.compare': 1, 'stablehlo.convert': 1, 'stablehlo.cosine': 50, 'stablehlo.divide': 4, 'stablehlo.exponential': 50, 'stablehlo.is_finite': 1, 'stablehlo.log': 50, 'stablehlo.logistic': 50, 'stablehlo.maximum': 1, 'stablehlo.minimum': 1, 'stablehlo.multiply': 1, 'stablehlo.negate': 1, 'stablehlo.not': 1, 'stablehlo.or': 1, 'stablehlo.power': 50, 'stablehlo.round_nearest_even': 1, 'stablehlo.rsqrt': 4, 'stablehlo.select': 1, 'stablehlo.sign': 1, 'stablehlo.sine': 50, 'stablehlo.sqrt': 4, 'stablehlo.subtract': 1, 'stablehlo.tanh': 50, 'stablehlo.xor': 1}
TENSOR_CORE_OPS = ('stablehlo.dot', 'stablehlo.dot_general', 'stablehlo.convolution')
TENSOR_CORE_PROMOTIONS = {'f32': 'tf32'}
__get_datatype_str(datatype)
__get_datatype_str_by_op(op_info)
__get_flops_by_datatype_str(datatype_str)
__get_flops_by_datatype(datatype)
__get_flops_by_op(op_info)
__get_flops(op_info)
__get_opname_flops_mult(op_name)
__get_total_flops_by_opname(op_name, flops)
generate_op_result(op_info, compute_time, mem_time, flops, bytes_accessed, datatype_str)
compute_runtime(op_info, flops, bytes_accessed)
__add_roofline_stats(stats_tree)
__setup_per_datatype_flops()
__setup_roofline_stats()
__setup_flops_per_element_map()
__setup_seen_unknown_custom_kernel()
__setup_module_roofline_stats(module)
__setup_per_op_roofline_stats(op_info)
__get_per_op_stats(op_info, result)
__add_per_datatype_tree(stats_tree, datatype)
__merge_lower_stats_tree(upper_stats_tree, lower_stats_tree)
__get_bytes_flops(stats_tree)
__get_module_bytes_flops(module, result)
__add_kernel_launch_overhead(module, result)
__get_total_bytes_flops(module, result)
handle_elementwise_binary(op_info)
handle_clamp(op_info)
handle_free_ops(op_info)
handle_noflop_ops(op_info)
handle_concatenate(op_info)
handle_gather(op_info)
handle_scatter(op_info)
handle_convolution(op_info)
handle_unary_elemwise(op_info)
handle_select(op_info)
handle_reduce(op_info)
handle_reduce_window(op_info)
handle_select_and_scatter(op_info)
handle_sort(op_info)
handle_dot_general(op_info)

Calculate FLOPs for StableHLO dot_general operation.

FLOP Calculation: - For each output element, we perform a dot product across contracting dimensions - Each dot product involves: product(contracting_dims) multiply-add operations - Each multiply-add = 2 FLOPs (1 multiply + 1 add) - Total FLOPs = 2 x product(output_shape) xx product(contracting_dimension_sizes)

Example: Matrix multiplication A[M,K] x B[K,N] = C[M,N] - Output elements: M x N - Contracting dimension size: K - FLOPs = 2 x M x N x K

handle_ragged_dot(op_info)

Ragged dot uses a new group dimension for the ragged dimension, for example in mode 1 the ragged dimension is M, it will be split into G groups each value of the group dimension tensor corresponding to how many rows of M go to each group. It has 3 modes each mode considers a different dimension as the ragged one (m, k, b) respectively. - if mode 1 apply dot_general roofline - if mode 2 assume average of k = K/G (this is the only case where ragged can reduce FLOPS) - if mode 3 assume all batches are used, so apply dot_general roofline Signatures for modes:

  • 1 [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]

  • 2 [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]

  • 3 [b,m,k], [b,k,n], [g] -> [b,m,n]

handle_slice_ops(op_info)
handle_dynamic_update_slice(op_info)
handle_fusion(op_info)
handle_custom_call(op_info)
_cache_hit_hooks = [<function Estimator.__count_cache_hits>, <function Estimator.__get_cached_module_times>, <function Estimator.__print_cached_runtime>]
_cache_miss_hooks = []
_default_op_handler(op_info: OpInfo) OpResult

This is a default for descended classes that an exception is thrown if the operator isn’t known. This can be overriden through the @register_default_op_handler decorator, but not directly. This should not be directly overidden or called manually.

Parameters:

op_info – The operator to estimate the time of. This method will not estimate, and just throw an exception

Returns:

The result of the estimator for this operator (to match the type of operator estimators - will not return)

Raises:

InvalidOpError – Raises an InvalidOpError for any unknown operator

_init_hooks = [<function RooflineEstimator.__setup_per_datatype_flops>, <function RooflineEstimator.__setup_roofline_stats>, <function RooflineEstimator.__setup_flops_per_element_map>, <function RooflineEstimator.__setup_seen_unknown_custom_kernel>]
_metadata_hooks = []
_module_metadata_hooks = []
_op_handlers = {'func.call': <function RooflineEstimator.handle_free_ops>, 'func.return': <function RooflineEstimator.handle_free_ops>, 'mhlo.bitcast': <function RooflineEstimator.handle_free_ops>, 'mhlo.copy': <function RooflineEstimator.handle_noflop_ops>, 'mhlo.fusion': <function RooflineEstimator.handle_fusion>, 'mhlo.ragged_dot': <function RooflineEstimator.handle_ragged_dot>, 'mhlo.return': <function RooflineEstimator.handle_free_ops>, 'stablehlo.abs': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.add': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.and': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.broadcast_in_dim': <function RooflineEstimator.handle_noflop_ops>, 'stablehlo.clamp': <function RooflineEstimator.handle_clamp>, 'stablehlo.compare': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.complex': <function RooflineEstimator.handle_free_ops>, 'stablehlo.concatenate': <function RooflineEstimator.handle_concatenate>, 'stablehlo.constant': <function RooflineEstimator.handle_free_ops>, 'stablehlo.convert': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.convolution': <function RooflineEstimator.handle_convolution>, 'stablehlo.cosine': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.custom_call': <function RooflineEstimator.handle_custom_call>, 'stablehlo.divide': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.dot': <function RooflineEstimator.handle_dot_general>, 'stablehlo.dot_general': <function RooflineEstimator.handle_dot_general>, 'stablehlo.dynamic_slice': <function RooflineEstimator.handle_slice_ops>, 'stablehlo.dynamic_update_slice': <function RooflineEstimator.handle_dynamic_update_slice>, 'stablehlo.exponential': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.gather': <function RooflineEstimator.handle_gather>, 'stablehlo.get_tuple_element': <function RooflineEstimator.handle_free_ops>, 'stablehlo.imag': <function RooflineEstimator.handle_free_ops>, 'stablehlo.iota': <function RooflineEstimator.handle_free_ops>, 'stablehlo.is_finite': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.log': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.logistic': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.maximum': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.minimum': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.multiply': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.negate': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.not': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.optimization_barrier': <function RooflineEstimator.handle_free_ops>, 'stablehlo.or': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.pad': <function RooflineEstimator.handle_noflop_ops>, 'stablehlo.partition_id': <function RooflineEstimator.handle_free_ops>, 'stablehlo.power': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.real': <function RooflineEstimator.handle_free_ops>, 'stablehlo.reduce': <function RooflineEstimator.handle_reduce>, 'stablehlo.reduce_precision': <function RooflineEstimator.handle_noflop_ops>, 'stablehlo.reduce_window': <function RooflineEstimator.handle_reduce_window>, 'stablehlo.replica_id': <function RooflineEstimator.handle_free_ops>, 'stablehlo.reshape': <function RooflineEstimator.handle_free_ops>, 'stablehlo.return': <function RooflineEstimator.handle_free_ops>, 'stablehlo.reverse': <function RooflineEstimator.handle_noflop_ops>, 'stablehlo.round_nearest_even': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.rsqrt': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.scatter': <function RooflineEstimator.handle_scatter>, 'stablehlo.select': <function RooflineEstimator.handle_select>, 'stablehlo.select_and_scatter': <function RooflineEstimator.handle_select_and_scatter>, 'stablehlo.sign': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.sine': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.slice': <function RooflineEstimator.handle_slice_ops>, 'stablehlo.sort': <function RooflineEstimator.handle_sort>, 'stablehlo.sqrt': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.subtract': <function RooflineEstimator.handle_elementwise_binary>, 'stablehlo.tanh': <function RooflineEstimator.handle_unary_elemwise>, 'stablehlo.transpose': <function RooflineEstimator.handle_free_ops>, 'stablehlo.xor': <function RooflineEstimator.handle_elementwise_binary>}
_post_estimate_hooks = [<function Estimator.__setup_per_op_tree>, <function Estimator.__get_total_estimate_time>, <function Estimator.__get_total_runtime>, <function Estimator.__count_processed>, <function Estimator.__get_per_op_runtime>, <function RooflineEstimator.__get_total_bytes_flops>]
_post_op_hooks = [<function Estimator.__get__module_op_times>, <function RooflineEstimator.__get_per_op_stats>]
_post_run_hooks = [<function Estimator.__get_module_runtime>, <function Estimator.__module_run_end_time>, <function Estimator.__print_run_runtime>, <function RooflineEstimator.__get_module_bytes_flops>, <function RooflineEstimator.__add_kernel_launch_overhead>]
_pre_estimate_hooks = [<function Estimator.__setup_per_module_stat_tree>, <function Estimator.__total_estimate_start_time>, <function Estimator.__print_start_line>, <function RooflineEstimator.__setup_module_roofline_stats>]
_pre_op_hooks = [<function Estimator.__setup_per_op_stat_tree>, <function RooflineEstimator.__setup_per_op_roofline_stats>]
_pre_run_hooks = [<function Estimator.__module_run_start_time>]
bases_order = {'': 0}
config_arguments = {'cache_dir': ('cache_dir', <hespas.estimator.config_option.ConfigOption object>), 'disable_cache': ('disable_cache', <hespas.estimator.config_option.ConfigOption object>), 'error_on_unknown_type': ('error_on_unknown_type', <hespas.estimator.config_option.ConfigOption object>), 'flops_per_element': ('flops_per_element', <hespas.estimator.config_option.ConfigOption object>), 'hbm_power_ratio': ('hbm_power_ratio', <hespas.estimator.config_option.ConfigOption object>), 'in_memory_only_cache': ('in_memory_only_cache', <hespas.estimator.config_option.ConfigOption object>), 'kernel_launch_overhead_s': ('kernel_launch_overhead_s', <hespas.estimator.config_option.ConfigOption object>), 'memory_bandwidth': ('memory_bandwidth', <hespas.estimator.config_option.ConfigOption object>), 'memory_compute_parallelism': ('memory_compute_parallelism', <hespas.estimator.config_option.ConfigOption object>), 'num_npus': ('num_npus', <hespas.estimator.config_option.ConfigOption object>), 'peak_flops': ('peak_flops', <hespas.estimator.config_option.ConfigOption object>), 'per_datatype_flops': ('per_datatype_flops', <hespas.estimator.config_option.ConfigOption object>), 'tdp_W': ('tdp_W', <hespas.estimator.config_option.ConfigOption object>), 'type': ('type', <hespas.estimator.config_option.ConfigOption object>), 'warn_on_unknown_type': ('warn_on_unknown_type', <hespas.estimator.config_option.ConfigOption object>)}
config_options = {'cache_dir': <hespas.estimator.config_option.ConfigOption object>, 'disable_cache': <hespas.estimator.config_option.ConfigOption object>, 'error_on_unknown_type': <hespas.estimator.config_option.ConfigOption object>, 'flops_per_element': <hespas.estimator.config_option.ConfigOption object>, 'hbm_power_ratio': <hespas.estimator.config_option.ConfigOption object>, 'in_memory_only_cache': <hespas.estimator.config_option.ConfigOption object>, 'kernel_launch_overhead_s': <hespas.estimator.config_option.ConfigOption object>, 'memory_bandwidth': <hespas.estimator.config_option.ConfigOption object>, 'memory_compute_parallelism': <hespas.estimator.config_option.ConfigOption object>, 'num_npus': <hespas.estimator.config_option.ConfigOption object>, 'peak_flops': <hespas.estimator.config_option.ConfigOption object>, 'per_datatype_flops': <hespas.estimator.config_option.ConfigOption object>, 'tdp_W': <hespas.estimator.config_option.ConfigOption object>, 'type': <hespas.estimator.config_option.ConfigOption object>, 'warn_on_unknown_type': <hespas.estimator.config_option.ConfigOption object>}
display_name = 'roofline'
display_name_map = {'': <class 'hespas.estimator.estimator.Estimator'>}