14748cc69Swren romano# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 24748cc69Swren romano# RUN: %PYTHON %s | FileCheck %s 3286248dbSwren romano 4286248dbSwren romanoimport ctypes 5286248dbSwren romanoimport numpy as np 6286248dbSwren romanoimport os 78b83b8f1SAart Bikimport sys 8286248dbSwren romano 9286248dbSwren romanofrom mlir import ir 10286248dbSwren romanofrom mlir import runtime as rt 11286248dbSwren romano 12286248dbSwren romanofrom mlir.dialects import sparse_tensor as st 13286248dbSwren romanofrom mlir.dialects import builtin 1436550692SRiver Riddlefrom mlir.dialects import func 15286248dbSwren romanofrom mlir.dialects.linalg.opdsl import lang as dsl 16286248dbSwren romano 178b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 188b83b8f1SAart Biksys.path.append(_SCRIPT_PATH) 198b83b8f1SAart Bikfrom tools import sparse_compiler 20286248dbSwren romano 21286248dbSwren romano@dsl.linalg_structured_op 22286248dbSwren romanodef matmul_dsl( 23286248dbSwren romano A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), 24286248dbSwren romano B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), 25286248dbSwren romano C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)): 26286248dbSwren romano C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] 27286248dbSwren romano 28286248dbSwren romano 29286248dbSwren romanodef build_SpMM(attr: st.EncodingAttr): 30286248dbSwren romano """Build SpMM kernel. 31286248dbSwren romano 32286248dbSwren romano This method generates a linalg op with for matrix multiplication using 33286248dbSwren romano just the Python API. Effectively, a generic linalg op is constructed 34286248dbSwren romano that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A. 35286248dbSwren romano """ 36286248dbSwren romano module = ir.Module.create() 37286248dbSwren romano f64 = ir.F64Type.get() 38286248dbSwren romano a = ir.RankedTensorType.get([3, 4], f64, attr) 39286248dbSwren romano b = ir.RankedTensorType.get([4, 2], f64) 40286248dbSwren romano c = ir.RankedTensorType.get([3, 2], f64) 41286248dbSwren romano arguments = [a, b, c] 42286248dbSwren romano with ir.InsertionPoint(module.body): 43286248dbSwren romano 4436550692SRiver Riddle @func.FuncOp.from_py_func(*arguments) 45286248dbSwren romano def spMxM(*args): 46286248dbSwren romano return matmul_dsl(args[0], args[1], outs=[args[2]]) 47286248dbSwren romano 48286248dbSwren romano return module 49286248dbSwren romano 50286248dbSwren romano 51286248dbSwren romanodef boilerplate(attr: st.EncodingAttr): 52286248dbSwren romano """Returns boilerplate main method. 53286248dbSwren romano 54286248dbSwren romano This method sets up a boilerplate main method that takes three tensors 55286248dbSwren romano (a, b, c), converts the first tensor a into s sparse tensor, and then 56286248dbSwren romano calls the sparse kernel for matrix multiplication. For convenience, 57286248dbSwren romano this part is purely done as string input. 58286248dbSwren romano """ 59286248dbSwren romano return f""" 602310ced8SRiver Riddlefunc.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64> 61286248dbSwren romano attributes {{ llvm.emit_c_interface }} {{ 62286248dbSwren romano %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}> 63286248dbSwren romano %0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>, 64286248dbSwren romano tensor<4x2xf64>, 65286248dbSwren romano tensor<3x2xf64>) -> tensor<3x2xf64> 66286248dbSwren romano return %0 : tensor<3x2xf64> 67286248dbSwren romano}} 68286248dbSwren romano""" 69286248dbSwren romano 70286248dbSwren romano 7128063a28SAart Bikdef build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler): 72286248dbSwren romano # Build. 73286248dbSwren romano module = build_SpMM(attr) 74286248dbSwren romano func = str(module.operation.regions[0].blocks[0].operations[0].operation) 75286248dbSwren romano module = ir.Module.parse(func + boilerplate(attr)) 76286248dbSwren romano 77286248dbSwren romano # Compile. 7828063a28SAart Bik engine = compiler.compile_and_jit(module) 79286248dbSwren romano 80286248dbSwren romano # Set up numpy input and buffer for output. 81286248dbSwren romano a = np.array( 82286248dbSwren romano [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], 83286248dbSwren romano np.float64) 84286248dbSwren romano b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64) 85286248dbSwren romano c = np.zeros((3, 2), np.float64) 86286248dbSwren romano 87286248dbSwren romano mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) 88286248dbSwren romano mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) 89286248dbSwren romano mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) 9064e171c2SBixia Zheng # Allocate a MemRefDescriptor to receive the output tensor. 9164e171c2SBixia Zheng # The buffer itself is allocated inside the MLIR code generation. 9264e171c2SBixia Zheng ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)() 9364e171c2SBixia Zheng mem_out = ctypes.pointer(ctypes.pointer(ref_out)) 94286248dbSwren romano 95286248dbSwren romano # Invoke the kernel and get numpy output. 96286248dbSwren romano # Built-in bufferization uses in-out buffers. 97286248dbSwren romano # TODO: replace with inplace comprehensive bufferization. 98286248dbSwren romano engine.invoke('main', mem_out, mem_a, mem_b, mem_c) 99286248dbSwren romano 100286248dbSwren romano # Sanity check on computed result. 101286248dbSwren romano expected = np.matmul(a, b); 102286248dbSwren romano c = rt.ranked_memref_to_numpy(mem_out[0]) 103286248dbSwren romano if np.allclose(c, expected): 104286248dbSwren romano pass 105286248dbSwren romano else: 106286248dbSwren romano quit(f'FAILURE') 107286248dbSwren romano 108286248dbSwren romano 1094748cc69Swren romanodef main(): 110286248dbSwren romano support_lib = os.getenv('SUPPORT_LIB') 1114748cc69Swren romano assert support_lib is not None, 'SUPPORT_LIB is undefined' 1124748cc69Swren romano if not os.path.exists(support_lib): 1134748cc69Swren romano raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) 114286248dbSwren romano 1154748cc69Swren romano # CHECK-LABEL: TEST: testSpMM 1164748cc69Swren romano print('\nTEST: testSpMM') 117286248dbSwren romano with ir.Context() as ctx, ir.Location.unknown(): 118286248dbSwren romano count = 0 119286248dbSwren romano # Loop over various ways to compile and annotate the SpMM kernel with 120286248dbSwren romano # a *single* sparse tensor. Note that we deliberate do not exhaustively 121286248dbSwren romano # search the full state space to reduce runtime of the test. It is 122286248dbSwren romano # straightforward to adapt the code below to explore more combinations. 123*4620032eSNick Kreeger par = 0 124*4620032eSNick Kreeger vec = 0 125286248dbSwren romano vl = 1 126286248dbSwren romano e = False 127*4620032eSNick Kreeger opt = (f'parallelization-strategy={par} ' 128*4620032eSNick Kreeger f'vectorization-strategy={vec} ' 129286248dbSwren romano f'vl={vl} enable-simd-index32={e}') 130286248dbSwren romano levels = [[st.DimLevelType.dense, st.DimLevelType.dense], 131286248dbSwren romano [st.DimLevelType.dense, st.DimLevelType.compressed], 132286248dbSwren romano [st.DimLevelType.compressed, st.DimLevelType.dense], 133286248dbSwren romano [st.DimLevelType.compressed, st.DimLevelType.compressed]] 134286248dbSwren romano orderings = [ 135286248dbSwren romano ir.AffineMap.get_permutation([0, 1]), 136286248dbSwren romano ir.AffineMap.get_permutation([1, 0]) 137286248dbSwren romano ] 138286248dbSwren romano bitwidths = [0] 13928063a28SAart Bik compiler = sparse_compiler.SparseCompiler( 14028063a28SAart Bik options=opt, opt_level=0, shared_libs=[support_lib]) 141286248dbSwren romano for level in levels: 142286248dbSwren romano for ordering in orderings: 143286248dbSwren romano for pwidth in bitwidths: 144286248dbSwren romano for iwidth in bitwidths: 145286248dbSwren romano attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth) 14628063a28SAart Bik build_compile_and_run_SpMM(attr, compiler) 147286248dbSwren romano count = count + 1 1484748cc69Swren romano # CHECK: Passed 8 tests 149286248dbSwren romano print('Passed ', count, 'tests') 1504748cc69Swren romano 151312c5140SAart Bik 1524748cc69Swren romanoif __name__ == '__main__': 1534748cc69Swren romano main() 154