14748cc69Swren romano# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
24748cc69Swren romano# RUN:   %PYTHON %s | FileCheck %s
3286248dbSwren romano
4286248dbSwren romanoimport ctypes
5286248dbSwren romanoimport numpy as np
6286248dbSwren romanoimport os
78b83b8f1SAart Bikimport sys
8286248dbSwren romano
9286248dbSwren romanofrom mlir import ir
10286248dbSwren romanofrom mlir import runtime as rt
11286248dbSwren romano
12286248dbSwren romanofrom mlir.dialects import sparse_tensor as st
13286248dbSwren romanofrom mlir.dialects import builtin
1436550692SRiver Riddlefrom mlir.dialects import func
15286248dbSwren romanofrom mlir.dialects.linalg.opdsl import lang as dsl
16286248dbSwren romano
178b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
188b83b8f1SAart Biksys.path.append(_SCRIPT_PATH)
198b83b8f1SAart Bikfrom tools import sparse_compiler
20286248dbSwren romano
21286248dbSwren romano@dsl.linalg_structured_op
22286248dbSwren romanodef matmul_dsl(
23286248dbSwren romano    A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
24286248dbSwren romano    B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
25286248dbSwren romano    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
26286248dbSwren romano    C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
27286248dbSwren romano
28286248dbSwren romano
29286248dbSwren romanodef build_SpMM(attr: st.EncodingAttr):
30286248dbSwren romano    """Build SpMM kernel.
31286248dbSwren romano
32286248dbSwren romano  This method generates a linalg op with for matrix multiplication using
33286248dbSwren romano  just the Python API. Effectively, a generic linalg op is constructed
34286248dbSwren romano  that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
35286248dbSwren romano  """
36286248dbSwren romano    module = ir.Module.create()
37286248dbSwren romano    f64 = ir.F64Type.get()
38286248dbSwren romano    a = ir.RankedTensorType.get([3, 4], f64, attr)
39286248dbSwren romano    b = ir.RankedTensorType.get([4, 2], f64)
40286248dbSwren romano    c = ir.RankedTensorType.get([3, 2], f64)
41286248dbSwren romano    arguments = [a, b, c]
42286248dbSwren romano    with ir.InsertionPoint(module.body):
43286248dbSwren romano
4436550692SRiver Riddle        @func.FuncOp.from_py_func(*arguments)
45286248dbSwren romano        def spMxM(*args):
46286248dbSwren romano            return matmul_dsl(args[0], args[1], outs=[args[2]])
47286248dbSwren romano
48286248dbSwren romano    return module
49286248dbSwren romano
50286248dbSwren romano
51286248dbSwren romanodef boilerplate(attr: st.EncodingAttr):
52286248dbSwren romano    """Returns boilerplate main method.
53286248dbSwren romano
54286248dbSwren romano  This method sets up a boilerplate main method that takes three tensors
55286248dbSwren romano  (a, b, c), converts the first tensor a into s sparse tensor, and then
56286248dbSwren romano  calls the sparse kernel for matrix multiplication. For convenience,
57286248dbSwren romano  this part is purely done as string input.
58286248dbSwren romano  """
59286248dbSwren romano    return f"""
602310ced8SRiver Riddlefunc.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
61286248dbSwren romano  attributes {{ llvm.emit_c_interface }} {{
62286248dbSwren romano  %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
63286248dbSwren romano  %0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>,
64286248dbSwren romano                                  tensor<4x2xf64>,
65286248dbSwren romano                                  tensor<3x2xf64>) -> tensor<3x2xf64>
66286248dbSwren romano  return %0 : tensor<3x2xf64>
67286248dbSwren romano}}
68286248dbSwren romano"""
69286248dbSwren romano
70286248dbSwren romano
7128063a28SAart Bikdef build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
72286248dbSwren romano    # Build.
73286248dbSwren romano    module = build_SpMM(attr)
74286248dbSwren romano    func = str(module.operation.regions[0].blocks[0].operations[0].operation)
75286248dbSwren romano    module = ir.Module.parse(func + boilerplate(attr))
76286248dbSwren romano
77286248dbSwren romano    # Compile.
7828063a28SAart Bik    engine = compiler.compile_and_jit(module)
79286248dbSwren romano
80286248dbSwren romano    # Set up numpy input and buffer for output.
81286248dbSwren romano    a = np.array(
82286248dbSwren romano        [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
83286248dbSwren romano        np.float64)
84286248dbSwren romano    b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
85286248dbSwren romano    c = np.zeros((3, 2), np.float64)
86286248dbSwren romano
87286248dbSwren romano    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
88286248dbSwren romano    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
89286248dbSwren romano    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
9064e171c2SBixia Zheng    # Allocate a MemRefDescriptor to receive the output tensor.
9164e171c2SBixia Zheng    # The buffer itself is allocated inside the MLIR code generation.
9264e171c2SBixia Zheng    ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
9364e171c2SBixia Zheng    mem_out = ctypes.pointer(ctypes.pointer(ref_out))
94286248dbSwren romano
95286248dbSwren romano    # Invoke the kernel and get numpy output.
96286248dbSwren romano    # Built-in bufferization uses in-out buffers.
97286248dbSwren romano    # TODO: replace with inplace comprehensive bufferization.
98286248dbSwren romano    engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
99286248dbSwren romano
100286248dbSwren romano    # Sanity check on computed result.
101286248dbSwren romano    expected = np.matmul(a, b);
102286248dbSwren romano    c = rt.ranked_memref_to_numpy(mem_out[0])
103286248dbSwren romano    if np.allclose(c, expected):
104286248dbSwren romano        pass
105286248dbSwren romano    else:
106286248dbSwren romano        quit(f'FAILURE')
107286248dbSwren romano
108286248dbSwren romano
1094748cc69Swren romanodef main():
110286248dbSwren romano    support_lib = os.getenv('SUPPORT_LIB')
1114748cc69Swren romano    assert support_lib is not None, 'SUPPORT_LIB is undefined'
1124748cc69Swren romano    if not os.path.exists(support_lib):
1134748cc69Swren romano        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
114286248dbSwren romano
1154748cc69Swren romano    # CHECK-LABEL: TEST: testSpMM
1164748cc69Swren romano    print('\nTEST: testSpMM')
117286248dbSwren romano    with ir.Context() as ctx, ir.Location.unknown():
118286248dbSwren romano        count = 0
119286248dbSwren romano        # Loop over various ways to compile and annotate the SpMM kernel with
120286248dbSwren romano        # a *single* sparse tensor. Note that we deliberate do not exhaustively
121286248dbSwren romano        # search the full state space to reduce runtime of the test. It is
122286248dbSwren romano        # straightforward to adapt the code below to explore more combinations.
123*4620032eSNick Kreeger        par = 0
124*4620032eSNick Kreeger        vec = 0
125286248dbSwren romano        vl = 1
126286248dbSwren romano        e = False
127*4620032eSNick Kreeger        opt = (f'parallelization-strategy={par} '
128*4620032eSNick Kreeger               f'vectorization-strategy={vec} '
129286248dbSwren romano               f'vl={vl} enable-simd-index32={e}')
130286248dbSwren romano        levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
131286248dbSwren romano                  [st.DimLevelType.dense, st.DimLevelType.compressed],
132286248dbSwren romano                  [st.DimLevelType.compressed, st.DimLevelType.dense],
133286248dbSwren romano                  [st.DimLevelType.compressed, st.DimLevelType.compressed]]
134286248dbSwren romano        orderings = [
135286248dbSwren romano            ir.AffineMap.get_permutation([0, 1]),
136286248dbSwren romano            ir.AffineMap.get_permutation([1, 0])
137286248dbSwren romano        ]
138286248dbSwren romano        bitwidths = [0]
13928063a28SAart Bik        compiler = sparse_compiler.SparseCompiler(
14028063a28SAart Bik            options=opt, opt_level=0, shared_libs=[support_lib])
141286248dbSwren romano        for level in levels:
142286248dbSwren romano            for ordering in orderings:
143286248dbSwren romano                for pwidth in bitwidths:
144286248dbSwren romano                    for iwidth in bitwidths:
145286248dbSwren romano                        attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
14628063a28SAart Bik                        build_compile_and_run_SpMM(attr, compiler)
147286248dbSwren romano                        count = count + 1
1484748cc69Swren romano        # CHECK: Passed 8 tests
149286248dbSwren romano        print('Passed ', count, 'tests')
1504748cc69Swren romano
151312c5140SAart Bik
1524748cc69Swren romanoif __name__ == '__main__':
1534748cc69Swren romano    main()
154