1312c5140SAart Bik# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 2312c5140SAart Bik# RUN: %PYTHON %s | FileCheck %s 3312c5140SAart Bik 4312c5140SAart Bikimport ctypes 5312c5140SAart Bikimport numpy as np 6312c5140SAart Bikimport os 78b83b8f1SAart Bikimport sys 8312c5140SAart Bik 9312c5140SAart Bikfrom mlir import ir 10312c5140SAart Bikfrom mlir import runtime as rt 11312c5140SAart Bik 12312c5140SAart Bikfrom mlir.dialects import sparse_tensor as st 13312c5140SAart Bikfrom mlir.dialects import builtin 1436550692SRiver Riddlefrom mlir.dialects import func 15312c5140SAart Bikfrom mlir.dialects.linalg.opdsl import lang as dsl 16312c5140SAart Bik 178b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 188b83b8f1SAart Biksys.path.append(_SCRIPT_PATH) 198b83b8f1SAart Bikfrom tools import sparse_compiler 20312c5140SAart Bik 21312c5140SAart Bik@dsl.linalg_structured_op 22312c5140SAart Bikdef sddmm_dsl( 23312c5140SAart Bik A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), 24312c5140SAart Bik B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), 25312c5140SAart Bik S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N), 26312c5140SAart Bik C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)): 27312c5140SAart Bik C[dsl.D.m, 28312c5140SAart Bik dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] 29312c5140SAart Bik 30312c5140SAart Bik 31312c5140SAart Bikdef build_SDDMM(attr: st.EncodingAttr): 32312c5140SAart Bik """Build SDDMM kernel. 33312c5140SAart Bik 34312c5140SAart Bik This method generates a linalg op with for matrix multiplication using 35312c5140SAart Bik just the Python API. Effectively, a generic linalg op is constructed 36312c5140SAart Bik that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S. 37312c5140SAart Bik """ 38312c5140SAart Bik module = ir.Module.create() 39312c5140SAart Bik f64 = ir.F64Type.get() 40312c5140SAart Bik a = ir.RankedTensorType.get([8, 8], f64) 41312c5140SAart Bik b = ir.RankedTensorType.get([8, 8], f64) 42312c5140SAart Bik c = ir.RankedTensorType.get([8, 8], f64) 43312c5140SAart Bik s = ir.RankedTensorType.get([8, 8], f64, attr) 44312c5140SAart Bik arguments = [a, b, s, c] 45312c5140SAart Bik with ir.InsertionPoint(module.body): 46312c5140SAart Bik 4736550692SRiver Riddle @func.FuncOp.from_py_func(*arguments) 48312c5140SAart Bik def sddmm(*args): 49312c5140SAart Bik return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]]) 50312c5140SAart Bik 51312c5140SAart Bik return module 52312c5140SAart Bik 53312c5140SAart Bik 54312c5140SAart Bikdef boilerplate(attr: st.EncodingAttr): 55312c5140SAart Bik """Returns boilerplate code for main driver.""" 56312c5140SAart Bik return f""" 572310ced8SRiver Riddlefunc.func @main(%a: tensor<8x8xf64>, 58312c5140SAart Bik %b: tensor<8x8xf64>, 59312c5140SAart Bik %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{ 60312c5140SAart Bik %t = arith.constant sparse<[[0,0], [0,2], [4,1]], [1.0, 2.0, 3.0]> : tensor<8x8xf64> 61312c5140SAart Bik %s = sparse_tensor.convert %t : tensor<8x8xf64> to tensor<8x8xf64, {attr}> 62312c5140SAart Bik %0 = call @sddmm(%a, %b, %s, %c) : (tensor<8x8xf64>, 63312c5140SAart Bik tensor<8x8xf64>, 64312c5140SAart Bik tensor<8x8xf64, {attr}>, 65312c5140SAart Bik tensor<8x8xf64>) -> tensor<8x8xf64> 66312c5140SAart Bik return %0 : tensor<8x8xf64> 67312c5140SAart Bik}} 68312c5140SAart Bik""" 69312c5140SAart Bik 70312c5140SAart Bik 7128063a28SAart Bikdef build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler): 72312c5140SAart Bik # Build. 73312c5140SAart Bik module = build_SDDMM(attr) 74312c5140SAart Bik func = str(module.operation.regions[0].blocks[0].operations[0].operation) 75312c5140SAart Bik module = ir.Module.parse(func + boilerplate(attr)) 76312c5140SAart Bik 77312c5140SAart Bik # Compile. 7828063a28SAart Bik engine = compiler.compile_and_jit(module) 79312c5140SAart Bik 80312c5140SAart Bik # Set up numpy input and buffer for output. 81312c5140SAart Bik a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1], 82312c5140SAart Bik [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2], 83312c5140SAart Bik [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3], 84312c5140SAart Bik [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4], 85312c5140SAart Bik [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5], 86312c5140SAart Bik [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6], 87312c5140SAart Bik [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7], 88312c5140SAart Bik [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64) 89312c5140SAart Bik b = np.ones((8, 8), np.float64) 90312c5140SAart Bik c = np.zeros((8, 8), np.float64) 91312c5140SAart Bik 92312c5140SAart Bik mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) 93312c5140SAart Bik mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) 94312c5140SAart Bik mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) 95312c5140SAart Bik 96312c5140SAart Bik # Allocate a MemRefDescriptor to receive the output tensor. 97312c5140SAart Bik # The buffer itself is allocated inside the MLIR code generation. 98312c5140SAart Bik ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)() 99312c5140SAart Bik mem_out = ctypes.pointer(ctypes.pointer(ref_out)) 100312c5140SAart Bik 101312c5140SAart Bik # Invoke the kernel and get numpy output. 102312c5140SAart Bik # Built-in bufferization uses in-out buffers. 103312c5140SAart Bik # TODO: replace with inplace comprehensive bufferization. 104312c5140SAart Bik engine.invoke('main', mem_out, mem_a, mem_b, mem_c) 105312c5140SAart Bik 106312c5140SAart Bik # Sanity check on computed result. Only a few elements 107312c5140SAart Bik # are sampled from the full dense matrix multiplication. 108312c5140SAart Bik full_matmul = np.matmul(a, b) 109312c5140SAart Bik expected = np.zeros((8, 8), np.float64) 110312c5140SAart Bik expected[0, 0] = 1.0 * full_matmul[0, 0] 111312c5140SAart Bik expected[0, 2] = 2.0 * full_matmul[0, 2] 112312c5140SAart Bik expected[4, 1] = 3.0 * full_matmul[4, 1] 113312c5140SAart Bik c = rt.ranked_memref_to_numpy(mem_out[0]) 114312c5140SAart Bik if np.allclose(c, expected): 115312c5140SAart Bik pass 116312c5140SAart Bik else: 117312c5140SAart Bik quit(f'FAILURE') 118312c5140SAart Bik 119312c5140SAart Bik 120312c5140SAart Bikdef main(): 121312c5140SAart Bik support_lib = os.getenv('SUPPORT_LIB') 122312c5140SAart Bik assert support_lib is not None, 'SUPPORT_LIB is undefined' 123312c5140SAart Bik if not os.path.exists(support_lib): 124312c5140SAart Bik raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), 125312c5140SAart Bik support_lib) 126312c5140SAart Bik 127312c5140SAart Bik # CHECK-LABEL: TEST: testSDDMMM 128312c5140SAart Bik print('\nTEST: testSDDMMM') 129312c5140SAart Bik with ir.Context() as ctx, ir.Location.unknown(): 130312c5140SAart Bik count = 0 131312c5140SAart Bik # Loop over various ways to compile and annotate the SDDMM kernel with 132312c5140SAart Bik # a *single* sparse tensor. Note that we deliberate do not exhaustively 133312c5140SAart Bik # search the full state space to reduce runtime of the test. It is 134312c5140SAart Bik # straightforward to adapt the code below to explore more combinations. 135312c5140SAart Bik levels = [[st.DimLevelType.dense, st.DimLevelType.dense], 136312c5140SAart Bik [st.DimLevelType.dense, st.DimLevelType.compressed], 137312c5140SAart Bik [st.DimLevelType.compressed, st.DimLevelType.dense], 138312c5140SAart Bik [st.DimLevelType.compressed, st.DimLevelType.compressed]] 139312c5140SAart Bik orderings = [ 140312c5140SAart Bik ir.AffineMap.get_permutation([0, 1]), 141312c5140SAart Bik ir.AffineMap.get_permutation([1, 0]) 142312c5140SAart Bik ] 143312c5140SAart Bik for level in levels: 144312c5140SAart Bik for ordering in orderings: 145312c5140SAart Bik for pwidth in [32]: 146312c5140SAart Bik for iwidth in [32]: 147*4620032eSNick Kreeger for par in [0]: 148*4620032eSNick Kreeger for vec in [0, 1]: 149312c5140SAart Bik for e in [True]: 150312c5140SAart Bik vl = 1 if vec == 0 else 16 151312c5140SAart Bik attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth) 152*4620032eSNick Kreeger opt = (f'parallelization-strategy={par} ' 153312c5140SAart Bik f'vectorization-strategy={vec} ' 154312c5140SAart Bik f'vl={vl} enable-simd-index32={e}') 15528063a28SAart Bik compiler = sparse_compiler.SparseCompiler( 15628063a28SAart Bik options=opt, opt_level=0, shared_libs=[support_lib]) 15728063a28SAart Bik build_compile_and_run_SDDMMM(attr, compiler) 158312c5140SAart Bik count = count + 1 159312c5140SAart Bik # CHECK: Passed 16 tests 160312c5140SAart Bik print('Passed ', count, 'tests') 161312c5140SAart Bik 162312c5140SAart Bik 163312c5140SAart Bikif __name__ == '__main__': 164312c5140SAart Bik main() 165