1# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
2
3from string import Template
4
5import numpy as np
6import os
7import sys
8import tempfile
9
10_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
11sys.path.append(_SCRIPT_PATH)
12from tools import mlir_pytaco
13from tools import mlir_pytaco_io
14from tools import mlir_pytaco_utils as pytaco_utils
15from tools import testing_utils as testing_utils
16
17
18# Define the aliases to shorten the code.
19_COMPRESSED = mlir_pytaco.ModeFormat.COMPRESSED
20_DENSE = mlir_pytaco.ModeFormat.DENSE
21
22
23_FORMAT = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
24_MTX_DATA_TEMPLATE = Template(
25    """%%MatrixMarket matrix coordinate real $general_or_symmetry
263 3 3
273 1 3
281 2 2
293 2 4
30""")
31
32
33def _get_mtx_data(value):
34  mtx_data = _MTX_DATA_TEMPLATE
35  return mtx_data.substitute(general_or_symmetry=value)
36
37
38# CHECK-LABEL: test_read_mtx_matrix_general
39@testing_utils.run_test
40def test_read_mtx_matrix_general():
41  with tempfile.TemporaryDirectory() as test_dir:
42    file_name = os.path.join(test_dir, "data.mtx")
43    with open(file_name, "w") as file:
44      file.write(_get_mtx_data("general"))
45    a = mlir_pytaco_io.read(file_name, _FORMAT)
46  passed = 0
47  # The value of a is stored as an MLIR sparse tensor.
48  passed += (not a.is_unpacked())
49  a.unpack()
50  passed += (a.is_unpacked())
51  coords, values = a.get_coordinates_and_values()
52  passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
53  passed += np.allclose(values, [2.0, 3.0, 4.0])
54  # CHECK: 4
55  print(passed)
56
57
58# CHECK-LABEL: test_read_mtx_matrix_symmetry
59@testing_utils.run_test
60def test_read_mtx_matrix_symmetry():
61  with tempfile.TemporaryDirectory() as test_dir:
62    file_name = os.path.join(test_dir, "data.mtx")
63    with open(file_name, "w") as file:
64      file.write(_get_mtx_data("symmetric"))
65    a = mlir_pytaco_io.read(file_name, _FORMAT)
66  passed = 0
67  # The value of a is stored as an MLIR sparse tensor.
68  passed += (not a.is_unpacked())
69  a.unpack()
70  passed += (a.is_unpacked())
71  coords, values = a.get_coordinates_and_values()
72  print(coords)
73  print(values)
74  passed += np.array_equal(coords,
75                           [[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]])
76  passed += np.allclose(values, [2.0, 3.0, 2.0, 4.0, 3.0, 4.0])
77  # CHECK: 4
78  print(passed)
79
80
81_TNS_DATA = """2 3
823 2
833 1 3
841 2 2
853 2 4
86"""
87
88
89# CHECK-LABEL: test_read_tns
90@testing_utils.run_test
91def test_read_tns():
92  with tempfile.TemporaryDirectory() as test_dir:
93    file_name = os.path.join(test_dir, "data.tns")
94    with open(file_name, "w") as file:
95      file.write(_TNS_DATA)
96    a = mlir_pytaco_io.read(file_name, _FORMAT)
97  passed = 0
98  # The value of a is stored as an MLIR sparse tensor.
99  passed += (not a.is_unpacked())
100  a.unpack()
101  passed += (a.is_unpacked())
102  coords, values = a.get_coordinates_and_values()
103  passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
104  passed += np.allclose(values, [2.0, 3.0, 4.0])
105  # CHECK: 4
106  print(passed)
107
108
109# CHECK-LABEL: test_write_unpacked_tns
110@testing_utils.run_test
111def test_write_unpacked_tns():
112  a = mlir_pytaco.Tensor([2, 3])
113  a.insert([0, 1], 10)
114  a.insert([1, 2], 40)
115  a.insert([0, 0], 20)
116  with tempfile.TemporaryDirectory() as test_dir:
117    file_name = os.path.join(test_dir, "data.tns")
118    try:
119      mlir_pytaco_io.write(file_name, a)
120    except ValueError as e:
121      # CHECK: Writing unpacked sparse tensors to file is not supported
122      print(e)
123
124
125# CHECK-LABEL: test_write_packed_tns
126@testing_utils.run_test
127def test_write_packed_tns():
128  a = mlir_pytaco.Tensor([2, 3])
129  a.insert([0, 1], 10)
130  a.insert([1, 2], 40)
131  a.insert([0, 0], 20)
132  b = mlir_pytaco.Tensor([2, 3])
133  i, j = mlir_pytaco.get_index_vars(2)
134  b[i, j] = a[i, j] + a[i, j]
135  with tempfile.TemporaryDirectory() as test_dir:
136    file_name = os.path.join(test_dir, "data.tns")
137    mlir_pytaco_io.write(file_name, b)
138    with open(file_name, "r") as file:
139      lines = file.readlines()
140  passed = 0
141  # Skip the comment line in the output.
142  if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
143    passed = 1
144  # CHECK: 1
145  print(passed)
146