1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5# This file contains the utilities to support testing.
6
7import numpy as np
8
9
10def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
11  """Compares sparse tensor actual output file with expected output file.
12
13  This routine assumes the input files are in FROSTT format. See
14  http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
15
16  It also assumes the first line in the output file is a comment line.
17
18  """
19  with open(actual, "r") as actual_f:
20    with open(expected, "r") as expected_f:
21      # Skip the first comment line.
22      _ = actual_f.readline()
23      _ = expected_f.readline()
24
25      # Compare the two lines of meta data
26      if (actual_f.readline() != expected_f.readline() or
27          actual_f.readline() != expected_f.readline()):
28        return FALSE
29
30  actual_data = np.loadtxt(actual, np.float64, skiprows=3)
31  expected_data = np.loadtxt(expected, np.float64, skiprows=3)
32  return np.allclose(actual_data, expected_data, rtol=rtol)
33
34
35def file_as_string(file: str) -> str:
36  """Returns contents of file as string."""
37  with open(file, "r") as f:
38    return f.read()
39
40
41def run_test(f):
42  """Prints the test name and runs the test."""
43  print(f.__name__)
44  f()
45  return f
46