1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import os
6import sys
7import tempfile
8
9from mlir import execution_engine
10from mlir import ir
11from mlir import runtime as rt
12
13from mlir.dialects import builtin
14from mlir.dialects import sparse_tensor as st
15
16_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
17sys.path.append(_SCRIPT_PATH)
18from tools import sparse_compiler
19
20# TODO: move more into actual IR building.
21def boilerplate(attr: st.EncodingAttr):
22  """Returns boilerplate main method."""
23  return f"""
24func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
25  %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
26                             [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
27  %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}>
28  sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8>
29  return
30}}
31"""
32
33
34def expected():
35  """Returns expected contents of output.
36
37  Regardless of the dimension ordering, compression, and bitwidths that are
38  used in the sparse tensor, the output is always lexicographically sorted
39  by natural index order.
40  """
41  return f"""; extended FROSTT format
422 5
4310 10
441 1 1
451 10 3
462 2 2
475 5 5
4810 1 4
49"""
50
51
52def build_compile_and_run_output(attr: st.EncodingAttr, support_lib: str,
53                                 compiler):
54  # Build and Compile.
55  module = ir.Module.parse(boilerplate(attr))
56  compiler(module)
57  engine = execution_engine.ExecutionEngine(
58      module, opt_level=0, shared_libs=[support_lib])
59
60  # Invoke the kernel and compare output.
61  with tempfile.TemporaryDirectory() as test_dir:
62    out = os.path.join(test_dir, 'out.tns')
63    buf = out.encode('utf-8')
64    mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
65    engine.invoke('main', mem_a)
66
67    actual = open(out).read()
68    if actual != expected():
69      quit('FAILURE')
70
71
72def main():
73  support_lib = os.getenv('SUPPORT_LIB')
74  assert support_lib is not None, 'SUPPORT_LIB is undefined'
75  if not os.path.exists(support_lib):
76    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
77                            support_lib)
78
79  # CHECK-LABEL: TEST: test_output
80  print('\nTEST: test_output')
81  count = 0
82  with ir.Context() as ctx, ir.Location.unknown():
83    # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
84    levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
85              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
86    orderings = [
87        ir.AffineMap.get_permutation([0, 1]),
88        ir.AffineMap.get_permutation([1, 0])
89    ]
90    bitwidths = [8, 16, 32, 64]
91    for level in levels:
92      for ordering in orderings:
93        for bwidth in bitwidths:
94          attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
95          compiler = sparse_compiler.SparseCompiler(options='')
96          build_compile_and_run_output(attr, support_lib, compiler)
97          count = count + 1
98
99  # CHECK: Passed 16 tests
100  print('Passed', count, 'tests')
101
102
103if __name__ == '__main__':
104  main()
105