1"""Common utilities that are useful for all the benchmarks.""" 2import numpy as np 3 4import mlir.all_passes_registration 5 6from mlir import ir 7from mlir.dialects import arith 8from mlir.dialects import func 9from mlir.dialects import memref 10from mlir.dialects import scf 11from mlir.passmanager import PassManager 12 13 14def setup_passes(mlir_module): 15 """Setup pass pipeline parameters for benchmark functions. 16 """ 17 opt = ( 18 "parallelization-strategy=0" 19 " vectorization-strategy=0 vl=1 enable-simd-index32=False" 20 ) 21 pipeline = f"sparse-compiler{{{opt}}}" 22 PassManager.parse(pipeline).run(mlir_module) 23 24 25def create_sparse_np_tensor(dimensions, number_of_elements): 26 """Constructs a numpy tensor of dimensions `dimensions` that has only a 27 specific number of nonzero elements, specified by the `number_of_elements` 28 argument. 29 """ 30 tensor = np.zeros(dimensions, np.float64) 31 tensor_indices_list = [ 32 [np.random.randint(0, dimension) for dimension in dimensions] 33 for _ in range(number_of_elements) 34 ] 35 for tensor_indices in tensor_indices_list: 36 current_tensor = tensor 37 for tensor_index in tensor_indices[:-1]: 38 current_tensor = current_tensor[tensor_index] 39 current_tensor[tensor_indices[-1]] = np.random.uniform(1, 100) 40 return tensor 41 42 43def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp: 44 """Takes an mlir module object and extracts the function object out of it. 45 This function only works for a module with one region, one block, and one 46 operation. 47 """ 48 assert len(module.operation.regions) == 1, \ 49 "Expected kernel module to have only one region" 50 assert len(module.operation.regions[0].blocks) == 1, \ 51 "Expected kernel module to have only one block" 52 assert len(module.operation.regions[0].blocks[0].operations) == 1, \ 53 "Expected kernel module to have only one operation" 54 return module.operation.regions[0].blocks[0].operations[0] 55 56 57def emit_timer_func() -> func.FuncOp: 58 """Returns the declaration of nanoTime function. If nanoTime function is 59 used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included. 60 """ 61 i64_type = ir.IntegerType.get_signless(64) 62 nanoTime = func.FuncOp( 63 "nanoTime", ([], [i64_type]), visibility="private") 64 nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 65 return nanoTime 66 67 68def emit_benchmark_wrapped_main_func(kernel_func, timer_func): 69 """Takes a function and a timer function, both represented as FuncOp 70 objects, and returns a new function. This new function wraps the call to 71 the original function between calls to the timer_func and this wrapping 72 in turn is executed inside a loop. The loop is executed 73 len(kernel_func.type.results) times. This function can be used to 74 create a "time measuring" variant of a function. 75 """ 76 i64_type = ir.IntegerType.get_signless(64) 77 memref_of_i64_type = ir.MemRefType.get([-1], i64_type) 78 wrapped_func = func.FuncOp( 79 # Same signature and an extra buffer of indices to save timings. 80 "main", 81 (kernel_func.arguments.types + [memref_of_i64_type], 82 kernel_func.type.results), 83 visibility="public" 84 ) 85 wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 86 87 num_results = len(kernel_func.type.results) 88 with ir.InsertionPoint(wrapped_func.add_entry_block()): 89 timer_buffer = wrapped_func.arguments[-1] 90 zero = arith.ConstantOp.create_index(0) 91 n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero) 92 one = arith.ConstantOp.create_index(1) 93 iter_args = list(wrapped_func.arguments[-num_results - 1:-1]) 94 loop = scf.ForOp(zero, n_iterations, one, iter_args) 95 with ir.InsertionPoint(loop.body): 96 start = func.CallOp(timer_func, []) 97 call = func.CallOp( 98 kernel_func, 99 wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args 100 ) 101 end = func.CallOp(timer_func, []) 102 time_taken = arith.SubIOp(end, start) 103 memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable]) 104 scf.YieldOp(list(call.results)) 105 func.ReturnOp(loop) 106 107 return wrapped_func 108