1"""This file contains benchmarks for sparse tensors. In particular, it
2contains benchmarks for both mlir sparse tensor dialect and numpy so that they
3can be compared against each other.
4"""
5import ctypes
6import numpy as np
7import os
8import re
9import time
10
11from mlir import ir
12from mlir import runtime as rt
13from mlir.dialects import func
14from mlir.dialects.linalg.opdsl import lang as dsl
15from mlir.execution_engine import ExecutionEngine
16
17from common import create_sparse_np_tensor
18from common import emit_timer_func
19from common import emit_benchmark_wrapped_main_func
20from common import get_kernel_func_from_module
21from common import setup_passes
22
23
24@dsl.linalg_structured_op
25def matmul_dsl(
26    A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
27    B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
28    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
29):
30    """Helper function for mlir sparse matrix multiplication benchmark."""
31    C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
32
33
34def benchmark_sparse_mlir_multiplication():
35    """Benchmark for mlir sparse matrix multiplication. Because its an
36    MLIR benchmark we need to return both a `compiler` function and a `runner`
37    function.
38    """
39    with ir.Context(), ir.Location.unknown():
40        module = ir.Module.create()
41        f64 = ir.F64Type.get()
42        param1_type = ir.RankedTensorType.get([1000, 1500], f64)
43        param2_type = ir.RankedTensorType.get([1500, 2000], f64)
44        result_type = ir.RankedTensorType.get([1000, 2000], f64)
45        with ir.InsertionPoint(module.body):
46            @func.FuncOp.from_py_func(param1_type, param2_type, result_type)
47            def sparse_kernel(x, y, z):
48                return matmul_dsl(x, y, outs=[z])
49
50    def compiler():
51        with ir.Context(), ir.Location.unknown():
52            kernel_func = get_kernel_func_from_module(module)
53            timer_func = emit_timer_func()
54            wrapped_func = emit_benchmark_wrapped_main_func(
55                kernel_func,
56                timer_func
57            )
58            main_module_with_benchmark = ir.Module.parse(
59                str(timer_func) + str(wrapped_func) + str(kernel_func)
60            )
61            setup_passes(main_module_with_benchmark)
62            c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
63            assert os.path.exists(c_runner_utils),\
64                f"{c_runner_utils} does not exist." \
65                f" Please pass a valid value for" \
66                f" MLIR_C_RUNNER_UTILS environment variable."
67            runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
68            assert os.path.exists(runner_utils),\
69                f"{runner_utils} does not exist." \
70                f" Please pass a valid value for MLIR_RUNNER_UTILS" \
71                f" environment variable."
72
73            engine = ExecutionEngine(
74                main_module_with_benchmark,
75                3,
76                shared_libs=[c_runner_utils, runner_utils]
77            )
78            return engine.invoke
79
80    def runner(engine_invoke):
81        compiled_program_args = []
82        for argument_type in [
83            result_type, param1_type, param2_type, result_type
84        ]:
85            argument_type_str = str(argument_type)
86            dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
87            dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
88            if argument_type == result_type:
89                argument = np.zeros(dimensions, np.float64)
90            else:
91                argument = create_sparse_np_tensor(dimensions, 1000)
92            compiled_program_args.append(
93                ctypes.pointer(
94                    ctypes.pointer(rt.get_ranked_memref_descriptor(argument))
95                )
96            )
97        np_timers_ns = np.array([0], dtype=np.int64)
98        compiled_program_args.append(
99            ctypes.pointer(
100                ctypes.pointer(rt.get_ranked_memref_descriptor(np_timers_ns))
101            )
102        )
103        engine_invoke("main", *compiled_program_args)
104        return int(np_timers_ns[0])
105
106    return compiler, runner
107
108
109def benchmark_np_matrix_multiplication():
110    """Benchmark for numpy matrix multiplication. Because its a python
111    benchmark, we don't have any `compiler` function returned. We just return
112    the `runner` function.
113    """
114    def runner():
115        argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
116        argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))
117        start_time = time.time_ns()
118        np.matmul(argument1, argument2)
119        return time.time_ns() - start_time
120
121    return None, runner
122