1"""Common utilities that are useful for all the benchmarks."""
2import numpy as np
3
4import mlir.all_passes_registration
5
6from mlir import ir
7from mlir.dialects import arith
8from mlir.dialects import func
9from mlir.dialects import memref
10from mlir.dialects import scf
11from mlir.passmanager import PassManager
12
13
14def setup_passes(mlir_module):
15    """Setup pass pipeline parameters for benchmark functions.
16    """
17    opt = (
18        "parallelization-strategy=0"
19        " vectorization-strategy=0 vl=1 enable-simd-index32=False"
20    )
21    pipeline = f"sparse-compiler{{{opt}}}"
22    PassManager.parse(pipeline).run(mlir_module)
23
24
25def create_sparse_np_tensor(dimensions, number_of_elements):
26    """Constructs a numpy tensor of dimensions `dimensions` that has only a
27    specific number of nonzero elements, specified by the `number_of_elements`
28    argument.
29    """
30    tensor = np.zeros(dimensions, np.float64)
31    tensor_indices_list = [
32        [np.random.randint(0, dimension) for dimension in dimensions]
33        for _ in range(number_of_elements)
34    ]
35    for tensor_indices in tensor_indices_list:
36        current_tensor = tensor
37        for tensor_index in tensor_indices[:-1]:
38            current_tensor = current_tensor[tensor_index]
39        current_tensor[tensor_indices[-1]] = np.random.uniform(1, 100)
40    return tensor
41
42
43def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
44    """Takes an mlir module object and extracts the function object out of it.
45    This function only works for a module with one region, one block, and one
46    operation.
47    """
48    assert len(module.operation.regions) == 1, \
49        "Expected kernel module to have only one region"
50    assert len(module.operation.regions[0].blocks) == 1, \
51        "Expected kernel module to have only one block"
52    assert len(module.operation.regions[0].blocks[0].operations) == 1, \
53        "Expected kernel module to have only one operation"
54    return module.operation.regions[0].blocks[0].operations[0]
55
56
57def emit_timer_func() -> func.FuncOp:
58    """Returns the declaration of nanoTime function. If nanoTime function is
59    used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
60    """
61    i64_type = ir.IntegerType.get_signless(64)
62    nanoTime = func.FuncOp(
63        "nanoTime", ([], [i64_type]), visibility="private")
64    nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
65    return nanoTime
66
67
68def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
69    """Takes a function and a timer function, both represented as FuncOp
70    objects, and returns a new function. This new function wraps the call to
71    the original function between calls to the timer_func and this wrapping
72    in turn is executed inside a loop. The loop is executed
73    len(kernel_func.type.results) times. This function can be used to
74    create a "time measuring" variant of a function.
75    """
76    i64_type = ir.IntegerType.get_signless(64)
77    memref_of_i64_type = ir.MemRefType.get([-1], i64_type)
78    wrapped_func = func.FuncOp(
79        # Same signature and an extra buffer of indices to save timings.
80        "main",
81        (kernel_func.arguments.types + [memref_of_i64_type],
82         kernel_func.type.results),
83        visibility="public"
84    )
85    wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
86
87    num_results = len(kernel_func.type.results)
88    with ir.InsertionPoint(wrapped_func.add_entry_block()):
89        timer_buffer = wrapped_func.arguments[-1]
90        zero = arith.ConstantOp.create_index(0)
91        n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
92        one = arith.ConstantOp.create_index(1)
93        iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
94        loop = scf.ForOp(zero, n_iterations, one, iter_args)
95        with ir.InsertionPoint(loop.body):
96            start = func.CallOp(timer_func, [])
97            call = func.CallOp(
98                kernel_func,
99                wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
100            )
101            end = func.CallOp(timer_func, [])
102            time_taken = arith.SubIOp(end, start)
103            memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable])
104            scf.YieldOp(list(call.results))
105        func.ReturnOp(loop)
106
107    return wrapped_func
108