1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import os 6import tempfile 7 8import mlir.all_passes_registration 9 10from mlir import execution_engine 11from mlir import ir 12from mlir import passmanager 13from mlir import runtime as rt 14 15from mlir.dialects import builtin 16from mlir.dialects import sparse_tensor as st 17 18 19# TODO: move more into actual IR building. 20def boilerplate(attr: st.EncodingAttr): 21 """Returns boilerplate main method.""" 22 return f""" 23func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{ 24 %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], 25 [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> 26 %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}> 27 sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8> 28 return 29}} 30""" 31 32 33def expected(): 34 """Returns expected contents of output. 35 36 Regardless of the dimension ordering, compression, and bitwidths that are 37 used in the sparse tensor, the output is always lexicographically sorted 38 by natural index order. 39 """ 40 return f"""; extended FROSTT format 412 5 4210 10 431 1 1 441 10 3 452 2 2 465 5 5 4710 1 4 48""" 49 50 51def build_compile_and_run_output(attr: st.EncodingAttr, support_lib: str, 52 compiler): 53 # Build and Compile. 54 module = ir.Module.parse(boilerplate(attr)) 55 compiler(module) 56 engine = execution_engine.ExecutionEngine( 57 module, opt_level=0, shared_libs=[support_lib]) 58 59 # Invoke the kernel and compare output. 60 with tempfile.TemporaryDirectory() as test_dir: 61 out = os.path.join(test_dir, 'out.tns') 62 buf = out.encode('utf-8') 63 mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) 64 engine.invoke('main', mem_a) 65 66 actual = open(out).read() 67 if actual != expected(): 68 quit('FAILURE') 69 70 71class SparseCompiler: 72 """Sparse compiler passes.""" 73 74 def __init__(self): 75 pipeline = ( 76 f'sparse-compiler{{reassociate-fp-reductions=1 enable-index-optimizations=1}}') 77 self.pipeline = pipeline 78 79 def __call__(self, module: ir.Module): 80 passmanager.PassManager.parse(self.pipeline).run(module) 81 82 83def main(): 84 support_lib = os.getenv('SUPPORT_LIB') 85 assert support_lib is not None, 'SUPPORT_LIB is undefined' 86 if not os.path.exists(support_lib): 87 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), 88 support_lib) 89 90 # CHECK-LABEL: TEST: test_output 91 print('\nTEST: test_output') 92 count = 0 93 with ir.Context() as ctx, ir.Location.unknown(): 94 # Loop over various sparse types: CSR, DCSR, CSC, DCSC. 95 levels = [[st.DimLevelType.dense, st.DimLevelType.compressed], 96 [st.DimLevelType.compressed, st.DimLevelType.compressed]] 97 orderings = [ 98 ir.AffineMap.get_permutation([0, 1]), 99 ir.AffineMap.get_permutation([1, 0]) 100 ] 101 bitwidths = [8, 16, 32, 64] 102 for level in levels: 103 for ordering in orderings: 104 for bwidth in bitwidths: 105 attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth) 106 compiler = SparseCompiler() 107 build_compile_and_run_output(attr, support_lib, compiler) 108 count = count + 1 109 110 # CHECK: Passed 16 tests 111 print('Passed', count, 'tests') 112 113 114if __name__ == '__main__': 115 main() 116