1312c5140SAart Bik# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2312c5140SAart Bik# RUN:   %PYTHON %s | FileCheck %s
3312c5140SAart Bik
4312c5140SAart Bikimport ctypes
5312c5140SAart Bikimport numpy as np
6312c5140SAart Bikimport os
78b83b8f1SAart Bikimport sys
8312c5140SAart Bik
9312c5140SAart Bikfrom mlir import ir
10312c5140SAart Bikfrom mlir import runtime as rt
11312c5140SAart Bik
12312c5140SAart Bikfrom mlir.dialects import sparse_tensor as st
13312c5140SAart Bikfrom mlir.dialects import builtin
1436550692SRiver Riddlefrom mlir.dialects import func
15312c5140SAart Bikfrom mlir.dialects.linalg.opdsl import lang as dsl
16312c5140SAart Bik
178b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
188b83b8f1SAart Biksys.path.append(_SCRIPT_PATH)
198b83b8f1SAart Bikfrom tools import sparse_compiler
20312c5140SAart Bik
21312c5140SAart Bik@dsl.linalg_structured_op
22312c5140SAart Bikdef sddmm_dsl(
23312c5140SAart Bik    A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
24312c5140SAart Bik    B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
25312c5140SAart Bik    S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N),
26312c5140SAart Bik    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
27312c5140SAart Bik    C[dsl.D.m,
28312c5140SAart Bik      dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
29312c5140SAart Bik
30312c5140SAart Bik
31312c5140SAart Bikdef build_SDDMM(attr: st.EncodingAttr):
32312c5140SAart Bik    """Build SDDMM kernel.
33312c5140SAart Bik
34312c5140SAart Bik  This method generates a linalg op with for matrix multiplication using
35312c5140SAart Bik  just the Python API. Effectively, a generic linalg op is constructed
36312c5140SAart Bik  that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
37312c5140SAart Bik  """
38312c5140SAart Bik    module = ir.Module.create()
39312c5140SAart Bik    f64 = ir.F64Type.get()
40312c5140SAart Bik    a = ir.RankedTensorType.get([8, 8], f64)
41312c5140SAart Bik    b = ir.RankedTensorType.get([8, 8], f64)
42312c5140SAart Bik    c = ir.RankedTensorType.get([8, 8], f64)
43312c5140SAart Bik    s = ir.RankedTensorType.get([8, 8], f64, attr)
44312c5140SAart Bik    arguments = [a, b, s, c]
45312c5140SAart Bik    with ir.InsertionPoint(module.body):
46312c5140SAart Bik
4736550692SRiver Riddle        @func.FuncOp.from_py_func(*arguments)
48312c5140SAart Bik        def sddmm(*args):
49312c5140SAart Bik            return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
50312c5140SAart Bik
51312c5140SAart Bik    return module
52312c5140SAart Bik
53312c5140SAart Bik
54312c5140SAart Bikdef boilerplate(attr: st.EncodingAttr):
55312c5140SAart Bik    """Returns boilerplate code for main driver."""
56312c5140SAart Bik    return f"""
572310ced8SRiver Riddlefunc.func @main(%a: tensor<8x8xf64>,
58312c5140SAart Bik           %b: tensor<8x8xf64>,
59312c5140SAart Bik           %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
60312c5140SAart Bik  %t = arith.constant sparse<[[0,0], [0,2], [4,1]], [1.0, 2.0, 3.0]> : tensor<8x8xf64>
61312c5140SAart Bik  %s = sparse_tensor.convert %t : tensor<8x8xf64> to tensor<8x8xf64, {attr}>
62312c5140SAart Bik  %0 = call @sddmm(%a, %b, %s, %c) : (tensor<8x8xf64>,
63312c5140SAart Bik                                      tensor<8x8xf64>,
64312c5140SAart Bik                                      tensor<8x8xf64, {attr}>,
65312c5140SAart Bik                                      tensor<8x8xf64>) -> tensor<8x8xf64>
66312c5140SAart Bik  return %0 : tensor<8x8xf64>
67312c5140SAart Bik}}
68312c5140SAart Bik"""
69312c5140SAart Bik
70312c5140SAart Bik
7128063a28SAart Bikdef build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
72312c5140SAart Bik    # Build.
73312c5140SAart Bik    module = build_SDDMM(attr)
74312c5140SAart Bik    func = str(module.operation.regions[0].blocks[0].operations[0].operation)
75312c5140SAart Bik    module = ir.Module.parse(func + boilerplate(attr))
76312c5140SAart Bik
77312c5140SAart Bik    # Compile.
7828063a28SAart Bik    engine = compiler.compile_and_jit(module)
79312c5140SAart Bik
80312c5140SAart Bik    # Set up numpy input and buffer for output.
81312c5140SAart Bik    a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
82312c5140SAart Bik                  [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
83312c5140SAart Bik                  [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
84312c5140SAart Bik                  [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
85312c5140SAart Bik                  [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
86312c5140SAart Bik                  [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
87312c5140SAart Bik                  [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
88312c5140SAart Bik                  [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64)
89312c5140SAart Bik    b = np.ones((8, 8), np.float64)
90312c5140SAart Bik    c = np.zeros((8, 8), np.float64)
91312c5140SAart Bik
92312c5140SAart Bik    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
93312c5140SAart Bik    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
94312c5140SAart Bik    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
95312c5140SAart Bik
96312c5140SAart Bik    # Allocate a MemRefDescriptor to receive the output tensor.
97312c5140SAart Bik    # The buffer itself is allocated inside the MLIR code generation.
98312c5140SAart Bik    ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
99312c5140SAart Bik    mem_out = ctypes.pointer(ctypes.pointer(ref_out))
100312c5140SAart Bik
101312c5140SAart Bik    # Invoke the kernel and get numpy output.
102312c5140SAart Bik    # Built-in bufferization uses in-out buffers.
103312c5140SAart Bik    # TODO: replace with inplace comprehensive bufferization.
104312c5140SAart Bik    engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
105312c5140SAart Bik
106312c5140SAart Bik    # Sanity check on computed result. Only a few elements
107312c5140SAart Bik    # are sampled from the full dense matrix multiplication.
108312c5140SAart Bik    full_matmul = np.matmul(a, b)
109312c5140SAart Bik    expected = np.zeros((8, 8), np.float64)
110312c5140SAart Bik    expected[0, 0] = 1.0 * full_matmul[0, 0]
111312c5140SAart Bik    expected[0, 2] = 2.0 * full_matmul[0, 2]
112312c5140SAart Bik    expected[4, 1] = 3.0 * full_matmul[4, 1]
113312c5140SAart Bik    c = rt.ranked_memref_to_numpy(mem_out[0])
114312c5140SAart Bik    if np.allclose(c, expected):
115312c5140SAart Bik        pass
116312c5140SAart Bik    else:
117312c5140SAart Bik        quit(f'FAILURE')
118312c5140SAart Bik
119312c5140SAart Bik
120312c5140SAart Bikdef main():
121312c5140SAart Bik    support_lib = os.getenv('SUPPORT_LIB')
122312c5140SAart Bik    assert support_lib is not None, 'SUPPORT_LIB is undefined'
123312c5140SAart Bik    if not os.path.exists(support_lib):
124312c5140SAart Bik        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
125312c5140SAart Bik                                support_lib)
126312c5140SAart Bik
127312c5140SAart Bik    # CHECK-LABEL: TEST: testSDDMMM
128312c5140SAart Bik    print('\nTEST: testSDDMMM')
129312c5140SAart Bik    with ir.Context() as ctx, ir.Location.unknown():
130312c5140SAart Bik        count = 0
131312c5140SAart Bik        # Loop over various ways to compile and annotate the SDDMM kernel with
132312c5140SAart Bik        # a *single* sparse tensor. Note that we deliberate do not exhaustively
133312c5140SAart Bik        # search the full state space to reduce runtime of the test. It is
134312c5140SAart Bik        # straightforward to adapt the code below to explore more combinations.
135312c5140SAart Bik        levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
136312c5140SAart Bik                  [st.DimLevelType.dense, st.DimLevelType.compressed],
137312c5140SAart Bik                  [st.DimLevelType.compressed, st.DimLevelType.dense],
138312c5140SAart Bik                  [st.DimLevelType.compressed, st.DimLevelType.compressed]]
139312c5140SAart Bik        orderings = [
140312c5140SAart Bik            ir.AffineMap.get_permutation([0, 1]),
141312c5140SAart Bik            ir.AffineMap.get_permutation([1, 0])
142312c5140SAart Bik        ]
143312c5140SAart Bik        for level in levels:
144312c5140SAart Bik            for ordering in orderings:
145312c5140SAart Bik                for pwidth in [32]:
146312c5140SAart Bik                    for iwidth in [32]:
147*4620032eSNick Kreeger                        for par in [0]:
148*4620032eSNick Kreeger                            for vec in [0, 1]:
149312c5140SAart Bik                                for e in [True]:
150312c5140SAart Bik                                    vl = 1 if vec == 0 else 16
151312c5140SAart Bik                                    attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
152*4620032eSNick Kreeger                                    opt = (f'parallelization-strategy={par} '
153312c5140SAart Bik                                           f'vectorization-strategy={vec} '
154312c5140SAart Bik                                           f'vl={vl} enable-simd-index32={e}')
15528063a28SAart Bik                                    compiler = sparse_compiler.SparseCompiler(
15628063a28SAart Bik                                        options=opt, opt_level=0, shared_libs=[support_lib])
15728063a28SAart Bik                                    build_compile_and_run_SDDMMM(attr, compiler)
158312c5140SAart Bik                                    count = count + 1
159312c5140SAart Bik    # CHECK: Passed 16 tests
160312c5140SAart Bik    print('Passed ', count, 'tests')
161312c5140SAart Bik
162312c5140SAart Bik
163312c5140SAart Bikif __name__ == '__main__':
164312c5140SAart Bik    main()
165