1# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s 2 3from string import Template 4 5import numpy as np 6import os 7import sys 8import tempfile 9 10_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 11sys.path.append(_SCRIPT_PATH) 12from tools import mlir_pytaco 13from tools import mlir_pytaco_io 14from tools import mlir_pytaco_utils as pytaco_utils 15from tools import testing_utils as testing_utils 16 17 18# Define the aliases to shorten the code. 19_COMPRESSED = mlir_pytaco.ModeFormat.COMPRESSED 20_DENSE = mlir_pytaco.ModeFormat.DENSE 21 22 23_FORMAT = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED]) 24_MTX_DATA_TEMPLATE = Template( 25 """%%MatrixMarket matrix coordinate real $general_or_symmetry 263 3 3 273 1 3 281 2 2 293 2 4 30""") 31 32 33def _get_mtx_data(value): 34 mtx_data = _MTX_DATA_TEMPLATE 35 return mtx_data.substitute(general_or_symmetry=value) 36 37 38# CHECK-LABEL: test_read_mtx_matrix_general 39@testing_utils.run_test 40def test_read_mtx_matrix_general(): 41 with tempfile.TemporaryDirectory() as test_dir: 42 file_name = os.path.join(test_dir, "data.mtx") 43 with open(file_name, "w") as file: 44 file.write(_get_mtx_data("general")) 45 a = mlir_pytaco_io.read(file_name, _FORMAT) 46 passed = 0 47 # The value of a is stored as an MLIR sparse tensor. 48 passed += (not a.is_unpacked()) 49 a.unpack() 50 passed += (a.is_unpacked()) 51 coords, values = a.get_coordinates_and_values() 52 passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]]) 53 passed += np.allclose(values, [2.0, 3.0, 4.0]) 54 # CHECK: 4 55 print(passed) 56 57 58# CHECK-LABEL: test_read_mtx_matrix_symmetry 59@testing_utils.run_test 60def test_read_mtx_matrix_symmetry(): 61 with tempfile.TemporaryDirectory() as test_dir: 62 file_name = os.path.join(test_dir, "data.mtx") 63 with open(file_name, "w") as file: 64 file.write(_get_mtx_data("symmetric")) 65 a = mlir_pytaco_io.read(file_name, _FORMAT) 66 passed = 0 67 # The value of a is stored as an MLIR sparse tensor. 68 passed += (not a.is_unpacked()) 69 a.unpack() 70 passed += (a.is_unpacked()) 71 coords, values = a.get_coordinates_and_values() 72 print(coords) 73 print(values) 74 passed += np.array_equal(coords, 75 [[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]) 76 passed += np.allclose(values, [2.0, 3.0, 2.0, 4.0, 3.0, 4.0]) 77 # CHECK: 4 78 print(passed) 79 80 81_TNS_DATA = """2 3 823 2 833 1 3 841 2 2 853 2 4 86""" 87 88 89# CHECK-LABEL: test_read_tns 90@testing_utils.run_test 91def test_read_tns(): 92 with tempfile.TemporaryDirectory() as test_dir: 93 file_name = os.path.join(test_dir, "data.tns") 94 with open(file_name, "w") as file: 95 file.write(_TNS_DATA) 96 a = mlir_pytaco_io.read(file_name, _FORMAT) 97 passed = 0 98 # The value of a is stored as an MLIR sparse tensor. 99 passed += (not a.is_unpacked()) 100 a.unpack() 101 passed += (a.is_unpacked()) 102 coords, values = a.get_coordinates_and_values() 103 passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]]) 104 passed += np.allclose(values, [2.0, 3.0, 4.0]) 105 # CHECK: 4 106 print(passed) 107 108 109# CHECK-LABEL: test_write_unpacked_tns 110@testing_utils.run_test 111def test_write_unpacked_tns(): 112 a = mlir_pytaco.Tensor([2, 3]) 113 a.insert([0, 1], 10) 114 a.insert([1, 2], 40) 115 a.insert([0, 0], 20) 116 with tempfile.TemporaryDirectory() as test_dir: 117 file_name = os.path.join(test_dir, "data.tns") 118 try: 119 mlir_pytaco_io.write(file_name, a) 120 except ValueError as e: 121 # CHECK: Writing unpacked sparse tensors to file is not supported 122 print(e) 123 124 125# CHECK-LABEL: test_write_packed_tns 126@testing_utils.run_test 127def test_write_packed_tns(): 128 a = mlir_pytaco.Tensor([2, 3]) 129 a.insert([0, 1], 10) 130 a.insert([1, 2], 40) 131 a.insert([0, 0], 20) 132 b = mlir_pytaco.Tensor([2, 3]) 133 i, j = mlir_pytaco.get_index_vars(2) 134 b[i, j] = a[i, j] + a[i, j] 135 with tempfile.TemporaryDirectory() as test_dir: 136 file_name = os.path.join(test_dir, "data.tns") 137 mlir_pytaco_io.write(file_name, b) 138 with open(file_name, "r") as file: 139 lines = file.readlines() 140 passed = 0 141 # Skip the comment line in the output. 142 if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]: 143 passed = 1 144 # CHECK: 1 145 print(passed) 146