1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import os 6import sys 7import tempfile 8 9from mlir import execution_engine 10from mlir import ir 11from mlir import runtime as rt 12 13from mlir.dialects import builtin 14from mlir.dialects import sparse_tensor as st 15 16_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 17sys.path.append(_SCRIPT_PATH) 18from tools import sparse_compiler 19 20# TODO: move more into actual IR building. 21def boilerplate(attr: st.EncodingAttr): 22 """Returns boilerplate main method.""" 23 return f""" 24func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{ 25 %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 26 [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 27 %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 28 sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8> 29 return 30}} 31""" 32 33 34def expected(): 35 """Returns expected contents of output. 36 37 Regardless of the dimension ordering, compression, and bitwidths that are 38 used in the sparse tensor, the output is always lexicographically sorted 39 by natural index order. 40 """ 41 return f"""; extended FROSTT format 422 5 4310 10 441 1 1 451 10 3 462 2 2 475 5 5 4810 1 4 49""" 50 51 52def build_compile_and_run_output(attr: st.EncodingAttr, support_lib: str, 53 compiler): 54 # Build and Compile. 55 module = ir.Module.parse(boilerplate(attr)) 56 compiler(module) 57 engine = execution_engine.ExecutionEngine( 58 module, opt_level=0, shared_libs=[support_lib]) 59 60 # Invoke the kernel and compare output. 61 with tempfile.TemporaryDirectory() as test_dir: 62 out = os.path.join(test_dir, 'out.tns') 63 buf = out.encode('utf-8') 64 mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 65 engine.invoke('main', mem_a) 66 67 actual = open(out).read() 68 if actual != expected(): 69 quit('FAILURE') 70 71 72def main(): 73 support_lib = os.getenv('SUPPORT_LIB') 74 assert support_lib is not None, 'SUPPORT_LIB is undefined' 75 if not os.path.exists(support_lib): 76 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), 77 support_lib) 78 79 # CHECK-LABEL: TEST: test_output 80 print('\nTEST: test_output') 81 count = 0 82 with ir.Context() as ctx, ir.Location.unknown(): 83 # Loop over various sparse types: CSR, DCSR, CSC, DCSC. 84 levels = [[st.DimLevelType.dense, st.DimLevelType.compressed], 85 [st.DimLevelType.compressed, st.DimLevelType.compressed]] 86 orderings = [ 87 ir.AffineMap.get_permutation([0, 1]), 88 ir.AffineMap.get_permutation([1, 0]) 89 ] 90 bitwidths = [8, 16, 32, 64] 91 for level in levels: 92 for ordering in orderings: 93 for bwidth in bitwidths: 94 attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth) 95 compiler = sparse_compiler.SparseCompiler(options='') 96 build_compile_and_run_output(attr, support_lib, compiler) 97 count = count + 1 98 99 # CHECK: Passed 16 tests 100 print('Passed', count, 'tests') 101 102 103if __name__ == '__main__': 104 main() 105