192c1c63dSAart Bik# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
292c1c63dSAart Bik# RUN:   %PYTHON %s | FileCheck %s
392c1c63dSAart Bik
492c1c63dSAart Bikimport ctypes
592c1c63dSAart Bikimport os
68b83b8f1SAart Bikimport sys
792c1c63dSAart Bikimport tempfile
892c1c63dSAart Bik
992c1c63dSAart Bikfrom mlir import ir
1092c1c63dSAart Bikfrom mlir import runtime as rt
1192c1c63dSAart Bik
1292c1c63dSAart Bikfrom mlir.dialects import builtin
1392c1c63dSAart Bikfrom mlir.dialects import sparse_tensor as st
1492c1c63dSAart Bik
158b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
168b83b8f1SAart Biksys.path.append(_SCRIPT_PATH)
178b83b8f1SAart Bikfrom tools import sparse_compiler
1892c1c63dSAart Bik
1992c1c63dSAart Bik# TODO: move more into actual IR building.
2092c1c63dSAart Bikdef boilerplate(attr: st.EncodingAttr):
2192c1c63dSAart Bik    """Returns boilerplate main method."""
2292c1c63dSAart Bik    return f"""
23*2310ced8SRiver Riddlefunc.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
2492c1c63dSAart Bik  %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
2592c1c63dSAart Bik                             [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
2692c1c63dSAart Bik  %a = sparse_tensor.convert %d : tensor<10x10xf64> to tensor<10x10xf64, {attr}>
2792c1c63dSAart Bik  sparse_tensor.out %a, %p : tensor<10x10xf64, {attr}>, !llvm.ptr<i8>
2892c1c63dSAart Bik  return
2992c1c63dSAart Bik}}
3092c1c63dSAart Bik"""
3192c1c63dSAart Bik
3292c1c63dSAart Bik
3392c1c63dSAart Bikdef expected():
3492c1c63dSAart Bik    """Returns expected contents of output.
3592c1c63dSAart Bik
3692c1c63dSAart Bik  Regardless of the dimension ordering, compression, and bitwidths that are
3792c1c63dSAart Bik  used in the sparse tensor, the output is always lexicographically sorted
3892c1c63dSAart Bik  by natural index order.
3992c1c63dSAart Bik  """
4092c1c63dSAart Bik    return f"""; extended FROSTT format
4192c1c63dSAart Bik2 5
4292c1c63dSAart Bik10 10
4392c1c63dSAart Bik1 1 1
4492c1c63dSAart Bik1 10 3
4592c1c63dSAart Bik2 2 2
4692c1c63dSAart Bik5 5 5
4792c1c63dSAart Bik10 1 4
4892c1c63dSAart Bik"""
4992c1c63dSAart Bik
5092c1c63dSAart Bik
5128063a28SAart Bikdef build_compile_and_run_output(attr: st.EncodingAttr, compiler):
5292c1c63dSAart Bik    # Build and Compile.
5392c1c63dSAart Bik    module = ir.Module.parse(boilerplate(attr))
5428063a28SAart Bik    engine = compiler.compile_and_jit(module)
5592c1c63dSAart Bik
5692c1c63dSAart Bik    # Invoke the kernel and compare output.
5792c1c63dSAart Bik    with tempfile.TemporaryDirectory() as test_dir:
5892c1c63dSAart Bik        out = os.path.join(test_dir, 'out.tns')
5992c1c63dSAart Bik        buf = out.encode('utf-8')
6092c1c63dSAart Bik        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
6192c1c63dSAart Bik        engine.invoke('main', mem_a)
6292c1c63dSAart Bik
6392c1c63dSAart Bik        actual = open(out).read()
6492c1c63dSAart Bik        if actual != expected():
6592c1c63dSAart Bik            quit('FAILURE')
6692c1c63dSAart Bik
6792c1c63dSAart Bik
6892c1c63dSAart Bikdef main():
6992c1c63dSAart Bik    support_lib = os.getenv('SUPPORT_LIB')
7092c1c63dSAart Bik    assert support_lib is not None, 'SUPPORT_LIB is undefined'
7192c1c63dSAart Bik    if not os.path.exists(support_lib):
7292c1c63dSAart Bik        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
7392c1c63dSAart Bik                                support_lib)
7492c1c63dSAart Bik
7592c1c63dSAart Bik    # CHECK-LABEL: TEST: test_output
7692c1c63dSAart Bik    print('\nTEST: test_output')
7792c1c63dSAart Bik    count = 0
7892c1c63dSAart Bik    with ir.Context() as ctx, ir.Location.unknown():
7992c1c63dSAart Bik        # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
8092c1c63dSAart Bik        levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
8192c1c63dSAart Bik                  [st.DimLevelType.compressed, st.DimLevelType.compressed]]
8292c1c63dSAart Bik        orderings = [
8392c1c63dSAart Bik            ir.AffineMap.get_permutation([0, 1]),
8492c1c63dSAart Bik            ir.AffineMap.get_permutation([1, 0])
8592c1c63dSAart Bik        ]
8692c1c63dSAart Bik        bitwidths = [8, 16, 32, 64]
8728063a28SAart Bik        compiler = sparse_compiler.SparseCompiler(
8828063a28SAart Bik            options='', opt_level=2, shared_libs=[support_lib])
8992c1c63dSAart Bik        for level in levels:
9092c1c63dSAart Bik            for ordering in orderings:
9192c1c63dSAart Bik                for bwidth in bitwidths:
9292c1c63dSAart Bik                    attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
9328063a28SAart Bik                    build_compile_and_run_output(attr, compiler)
9492c1c63dSAart Bik                    count = count + 1
9592c1c63dSAart Bik
9692c1c63dSAart Bik    # CHECK: Passed 16 tests
9792c1c63dSAart Bik    print('Passed', count, 'tests')
9892c1c63dSAart Bik
9992c1c63dSAart Bik
10092c1c63dSAart Bikif __name__ == '__main__':
10192c1c63dSAart Bik    main()
102