1fa90c9d5SSaurabh Jha"""Common utilities that are useful for all the benchmarks."""
2fa90c9d5SSaurabh Jhaimport numpy as np
3fa90c9d5SSaurabh Jha
4fa90c9d5SSaurabh Jhafrom mlir import ir
5fa90c9d5SSaurabh Jhafrom mlir.dialects import arith
623aa5a74SRiver Riddlefrom mlir.dialects import func
7fa90c9d5SSaurabh Jhafrom mlir.dialects import memref
8fa90c9d5SSaurabh Jhafrom mlir.dialects import scf
9fa90c9d5SSaurabh Jhafrom mlir.passmanager import PassManager
10fa90c9d5SSaurabh Jha
11fa90c9d5SSaurabh Jha
12fa90c9d5SSaurabh Jhadef setup_passes(mlir_module):
13fa90c9d5SSaurabh Jha    """Setup pass pipeline parameters for benchmark functions.
14fa90c9d5SSaurabh Jha    """
15fa90c9d5SSaurabh Jha    opt = (
164620032eSNick Kreeger        "parallelization-strategy=0"
17fa90c9d5SSaurabh Jha        " vectorization-strategy=0 vl=1 enable-simd-index32=False"
18fa90c9d5SSaurabh Jha    )
190ed0a8e2SSaurabh Jha    pipeline = f"sparse-compiler{{{opt}}}"
20fa90c9d5SSaurabh Jha    PassManager.parse(pipeline).run(mlir_module)
21fa90c9d5SSaurabh Jha
22fa90c9d5SSaurabh Jha
23fa90c9d5SSaurabh Jhadef create_sparse_np_tensor(dimensions, number_of_elements):
24fa90c9d5SSaurabh Jha    """Constructs a numpy tensor of dimensions `dimensions` that has only a
25fa90c9d5SSaurabh Jha    specific number of nonzero elements, specified by the `number_of_elements`
26fa90c9d5SSaurabh Jha    argument.
27fa90c9d5SSaurabh Jha    """
28fa90c9d5SSaurabh Jha    tensor = np.zeros(dimensions, np.float64)
29fa90c9d5SSaurabh Jha    tensor_indices_list = [
30fa90c9d5SSaurabh Jha        [np.random.randint(0, dimension) for dimension in dimensions]
31fa90c9d5SSaurabh Jha        for _ in range(number_of_elements)
32fa90c9d5SSaurabh Jha    ]
33fa90c9d5SSaurabh Jha    for tensor_indices in tensor_indices_list:
34fa90c9d5SSaurabh Jha        current_tensor = tensor
35fa90c9d5SSaurabh Jha        for tensor_index in tensor_indices[:-1]:
36fa90c9d5SSaurabh Jha            current_tensor = current_tensor[tensor_index]
37fa90c9d5SSaurabh Jha        current_tensor[tensor_indices[-1]] = np.random.uniform(1, 100)
38fa90c9d5SSaurabh Jha    return tensor
39fa90c9d5SSaurabh Jha
40fa90c9d5SSaurabh Jha
4136550692SRiver Riddledef get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
42fa90c9d5SSaurabh Jha    """Takes an mlir module object and extracts the function object out of it.
43fa90c9d5SSaurabh Jha    This function only works for a module with one region, one block, and one
44fa90c9d5SSaurabh Jha    operation.
45fa90c9d5SSaurabh Jha    """
46fa90c9d5SSaurabh Jha    assert len(module.operation.regions) == 1, \
47fa90c9d5SSaurabh Jha        "Expected kernel module to have only one region"
48fa90c9d5SSaurabh Jha    assert len(module.operation.regions[0].blocks) == 1, \
49fa90c9d5SSaurabh Jha        "Expected kernel module to have only one block"
50fa90c9d5SSaurabh Jha    assert len(module.operation.regions[0].blocks[0].operations) == 1, \
51fa90c9d5SSaurabh Jha        "Expected kernel module to have only one operation"
52fa90c9d5SSaurabh Jha    return module.operation.regions[0].blocks[0].operations[0]
53fa90c9d5SSaurabh Jha
54fa90c9d5SSaurabh Jha
5536550692SRiver Riddledef emit_timer_func() -> func.FuncOp:
5689d49045SDenys Shabalin    """Returns the declaration of nanoTime function. If nanoTime function is
57fa90c9d5SSaurabh Jha    used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
58fa90c9d5SSaurabh Jha    """
59fa90c9d5SSaurabh Jha    i64_type = ir.IntegerType.get_signless(64)
6089d49045SDenys Shabalin    nanoTime = func.FuncOp(
6189d49045SDenys Shabalin        "nanoTime", ([], [i64_type]), visibility="private")
6289d49045SDenys Shabalin    nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
6389d49045SDenys Shabalin    return nanoTime
64fa90c9d5SSaurabh Jha
65fa90c9d5SSaurabh Jha
66*5da5483fSIngo Müllerdef emit_benchmark_wrapped_main_func(kernel_func, timer_func):
67fa90c9d5SSaurabh Jha    """Takes a function and a timer function, both represented as FuncOp
68fa90c9d5SSaurabh Jha    objects, and returns a new function. This new function wraps the call to
69fa90c9d5SSaurabh Jha    the original function between calls to the timer_func and this wrapping
70fa90c9d5SSaurabh Jha    in turn is executed inside a loop. The loop is executed
71*5da5483fSIngo Müller    len(kernel_func.type.results) times. This function can be used to
72*5da5483fSIngo Müller    create a "time measuring" variant of a function.
73fa90c9d5SSaurabh Jha    """
74fa90c9d5SSaurabh Jha    i64_type = ir.IntegerType.get_signless(64)
75fa90c9d5SSaurabh Jha    memref_of_i64_type = ir.MemRefType.get([-1], i64_type)
7636550692SRiver Riddle    wrapped_func = func.FuncOp(
77fa90c9d5SSaurabh Jha        # Same signature and an extra buffer of indices to save timings.
78fa90c9d5SSaurabh Jha        "main",
79*5da5483fSIngo Müller        (kernel_func.arguments.types + [memref_of_i64_type],
80*5da5483fSIngo Müller         kernel_func.type.results),
81fa90c9d5SSaurabh Jha        visibility="public"
82fa90c9d5SSaurabh Jha    )
83fa90c9d5SSaurabh Jha    wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
84fa90c9d5SSaurabh Jha
85*5da5483fSIngo Müller    num_results = len(kernel_func.type.results)
86fa90c9d5SSaurabh Jha    with ir.InsertionPoint(wrapped_func.add_entry_block()):
87fa90c9d5SSaurabh Jha        timer_buffer = wrapped_func.arguments[-1]
88fa90c9d5SSaurabh Jha        zero = arith.ConstantOp.create_index(0)
89fa90c9d5SSaurabh Jha        n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
90fa90c9d5SSaurabh Jha        one = arith.ConstantOp.create_index(1)
91fa90c9d5SSaurabh Jha        iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
92fa90c9d5SSaurabh Jha        loop = scf.ForOp(zero, n_iterations, one, iter_args)
93fa90c9d5SSaurabh Jha        with ir.InsertionPoint(loop.body):
9423aa5a74SRiver Riddle            start = func.CallOp(timer_func, [])
9523aa5a74SRiver Riddle            call = func.CallOp(
96*5da5483fSIngo Müller                kernel_func,
97fa90c9d5SSaurabh Jha                wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
98fa90c9d5SSaurabh Jha            )
9923aa5a74SRiver Riddle            end = func.CallOp(timer_func, [])
100fa90c9d5SSaurabh Jha            time_taken = arith.SubIOp(end, start)
101fa90c9d5SSaurabh Jha            memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable])
102fa90c9d5SSaurabh Jha            scf.YieldOp(list(call.results))
10323aa5a74SRiver Riddle        func.ReturnOp(loop)
104fa90c9d5SSaurabh Jha
105fa90c9d5SSaurabh Jha    return wrapped_func
106