1"""This file contains benchmarks for sparse tensors. In particular, it 2contains benchmarks for both mlir sparse tensor dialect and numpy so that they 3can be compared against each other. 4""" 5import ctypes 6import numpy as np 7import os 8import re 9import time 10 11from mlir import ir 12from mlir import runtime as rt 13from mlir.dialects import func 14from mlir.dialects.linalg.opdsl import lang as dsl 15from mlir.execution_engine import ExecutionEngine 16 17from common import create_sparse_np_tensor 18from common import emit_timer_func 19from common import emit_benchmark_wrapped_main_func 20from common import get_kernel_func_from_module 21from common import setup_passes 22 23 24@dsl.linalg_structured_op 25def matmul_dsl( 26 A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), 27 B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), 28 C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True) 29): 30 """Helper function for mlir sparse matrix multiplication benchmark.""" 31 C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] 32 33 34def benchmark_sparse_mlir_multiplication(): 35 """Benchmark for mlir sparse matrix multiplication. Because its an 36 MLIR benchmark we need to return both a `compiler` function and a `runner` 37 function. 38 """ 39 with ir.Context(), ir.Location.unknown(): 40 module = ir.Module.create() 41 f64 = ir.F64Type.get() 42 param1_type = ir.RankedTensorType.get([1000, 1500], f64) 43 param2_type = ir.RankedTensorType.get([1500, 2000], f64) 44 result_type = ir.RankedTensorType.get([1000, 2000], f64) 45 with ir.InsertionPoint(module.body): 46 @func.FuncOp.from_py_func(param1_type, param2_type, result_type) 47 def sparse_kernel(x, y, z): 48 return matmul_dsl(x, y, outs=[z]) 49 50 def compiler(): 51 with ir.Context(), ir.Location.unknown(): 52 kernel_func = get_kernel_func_from_module(module) 53 timer_func = emit_timer_func() 54 wrapped_func = emit_benchmark_wrapped_main_func( 55 kernel_func, 56 timer_func 57 ) 58 main_module_with_benchmark = ir.Module.parse( 59 str(timer_func) + str(wrapped_func) + str(kernel_func) 60 ) 61 setup_passes(main_module_with_benchmark) 62 c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "") 63 assert os.path.exists(c_runner_utils),\ 64 f"{c_runner_utils} does not exist." \ 65 f" Please pass a valid value for" \ 66 f" MLIR_C_RUNNER_UTILS environment variable." 67 runner_utils = os.getenv("MLIR_RUNNER_UTILS", "") 68 assert os.path.exists(runner_utils),\ 69 f"{runner_utils} does not exist." \ 70 f" Please pass a valid value for MLIR_RUNNER_UTILS" \ 71 f" environment variable." 72 73 engine = ExecutionEngine( 74 main_module_with_benchmark, 75 3, 76 shared_libs=[c_runner_utils, runner_utils] 77 ) 78 return engine.invoke 79 80 def runner(engine_invoke): 81 compiled_program_args = [] 82 for argument_type in [ 83 result_type, param1_type, param2_type, result_type 84 ]: 85 argument_type_str = str(argument_type) 86 dimensions_str = re.sub("<|>|tensor", "", argument_type_str) 87 dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] 88 if argument_type == result_type: 89 argument = np.zeros(dimensions, np.float64) 90 else: 91 argument = create_sparse_np_tensor(dimensions, 1000) 92 compiled_program_args.append( 93 ctypes.pointer( 94 ctypes.pointer(rt.get_ranked_memref_descriptor(argument)) 95 ) 96 ) 97 np_timers_ns = np.array([0], dtype=np.int64) 98 compiled_program_args.append( 99 ctypes.pointer( 100 ctypes.pointer(rt.get_ranked_memref_descriptor(np_timers_ns)) 101 ) 102 ) 103 engine_invoke("main", *compiled_program_args) 104 return int(np_timers_ns[0]) 105 106 return compiler, runner 107 108 109def benchmark_np_matrix_multiplication(): 110 """Benchmark for numpy matrix multiplication. Because its a python 111 benchmark, we don't have any `compiler` function returned. We just return 112 the `runner` function. 113 """ 114 def runner(): 115 argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500)) 116 argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000)) 117 start_time = time.time_ns() 118 np.matmul(argument1, argument2) 119 return time.time_ns() - start_time 120 121 return None, runner 122