192c1c63dSAart Bik# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 292c1c63dSAart Bik# RUN: %PYTHON %s | FileCheck %s 392c1c63dSAart Bik 492c1c63dSAart Bikimport ctypes 592c1c63dSAart Bikimport os 68b83b8f1SAart Bikimport sys 792c1c63dSAart Bikimport tempfile 892c1c63dSAart Bik 992c1c63dSAart Bikfrom mlir import ir 1092c1c63dSAart Bikfrom mlir import runtime as rt 1192c1c63dSAart Bik 1292c1c63dSAart Bikfrom mlir.dialects import builtin 1392c1c63dSAart Bikfrom mlir.dialects import sparse_tensor as st 1492c1c63dSAart Bik 158b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 168b83b8f1SAart Biksys.path.append(_SCRIPT_PATH) 178b83b8f1SAart Bikfrom tools import sparse_compiler 1892c1c63dSAart Bik 1992c1c63dSAart Bik# TODO: move more into actual IR building. 2092c1c63dSAart Bikdef boilerplate(attr: st.EncodingAttr): 2192c1c63dSAart Bik """Returns boilerplate main method.""" 2292c1c63dSAart Bik return f""" 23*2310ced8SRiver Riddlefunc.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{ 2492c1c63dSAart Bik %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 2592c1c63dSAart Bik [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 2692c1c63dSAart Bik %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 2792c1c63dSAart Bik sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8> 2892c1c63dSAart Bik return 2992c1c63dSAart Bik}} 3092c1c63dSAart Bik""" 3192c1c63dSAart Bik 3292c1c63dSAart Bik 3392c1c63dSAart Bikdef expected(): 3492c1c63dSAart Bik """Returns expected contents of output. 3592c1c63dSAart Bik 3692c1c63dSAart Bik Regardless of the dimension ordering, compression, and bitwidths that are 3792c1c63dSAart Bik used in the sparse tensor, the output is always lexicographically sorted 3892c1c63dSAart Bik by natural index order. 3992c1c63dSAart Bik """ 4092c1c63dSAart Bik return f"""; extended FROSTT format 4192c1c63dSAart Bik2 5 4292c1c63dSAart Bik10 10 4392c1c63dSAart Bik1 1 1 4492c1c63dSAart Bik1 10 3 4592c1c63dSAart Bik2 2 2 4692c1c63dSAart Bik5 5 5 4792c1c63dSAart Bik10 1 4 4892c1c63dSAart Bik""" 4992c1c63dSAart Bik 5092c1c63dSAart Bik 5128063a28SAart Bikdef build_compile_and_run_output(attr: st.EncodingAttr, compiler): 5292c1c63dSAart Bik # Build and Compile. 5392c1c63dSAart Bik module = ir.Module.parse(boilerplate(attr)) 5428063a28SAart Bik engine = compiler.compile_and_jit(module) 5592c1c63dSAart Bik 5692c1c63dSAart Bik # Invoke the kernel and compare output. 5792c1c63dSAart Bik with tempfile.TemporaryDirectory() as test_dir: 5892c1c63dSAart Bik out = os.path.join(test_dir, 'out.tns') 5992c1c63dSAart Bik buf = out.encode('utf-8') 6092c1c63dSAart Bik mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 6192c1c63dSAart Bik engine.invoke('main', mem_a) 6292c1c63dSAart Bik 6392c1c63dSAart Bik actual = open(out).read() 6492c1c63dSAart Bik if actual != expected(): 6592c1c63dSAart Bik quit('FAILURE') 6692c1c63dSAart Bik 6792c1c63dSAart Bik 6892c1c63dSAart Bikdef main(): 6992c1c63dSAart Bik support_lib = os.getenv('SUPPORT_LIB') 7092c1c63dSAart Bik assert support_lib is not None, 'SUPPORT_LIB is undefined' 7192c1c63dSAart Bik if not os.path.exists(support_lib): 7292c1c63dSAart Bik raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), 7392c1c63dSAart Bik support_lib) 7492c1c63dSAart Bik 7592c1c63dSAart Bik # CHECK-LABEL: TEST: test_output 7692c1c63dSAart Bik print('\nTEST: test_output') 7792c1c63dSAart Bik count = 0 7892c1c63dSAart Bik with ir.Context() as ctx, ir.Location.unknown(): 7992c1c63dSAart Bik # Loop over various sparse types: CSR, DCSR, CSC, DCSC. 8092c1c63dSAart Bik levels = [[st.DimLevelType.dense, st.DimLevelType.compressed], 8192c1c63dSAart Bik [st.DimLevelType.compressed, st.DimLevelType.compressed]] 8292c1c63dSAart Bik orderings = [ 8392c1c63dSAart Bik ir.AffineMap.get_permutation([0, 1]), 8492c1c63dSAart Bik ir.AffineMap.get_permutation([1, 0]) 8592c1c63dSAart Bik ] 8692c1c63dSAart Bik bitwidths = [8, 16, 32, 64] 8728063a28SAart Bik compiler = sparse_compiler.SparseCompiler( 8828063a28SAart Bik options='', opt_level=2, shared_libs=[support_lib]) 8992c1c63dSAart Bik for level in levels: 9092c1c63dSAart Bik for ordering in orderings: 9192c1c63dSAart Bik for bwidth in bitwidths: 9292c1c63dSAart Bik attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth) 9328063a28SAart Bik build_compile_and_run_output(attr, compiler) 9492c1c63dSAart Bik count = count + 1 9592c1c63dSAart Bik 9692c1c63dSAart Bik # CHECK: Passed 16 tests 9792c1c63dSAart Bik print('Passed', count, 'tests') 9892c1c63dSAart Bik 9992c1c63dSAart Bik 10092c1c63dSAart Bikif __name__ == '__main__': 10192c1c63dSAart Bik main() 102