1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains the utilities to support testing. 6 7import numpy as np 8 9 10def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool: 11 """Compares sparse tensor actual output file with expected output file. 12 13 This routine assumes the input files are in FROSTT format. See 14 http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format. 15 16 It also assumes the first line in the output file is a comment line. 17 18 """ 19 with open(actual, "r") as actual_f: 20 with open(expected, "r") as expected_f: 21 # Skip the first comment line. 22 _ = actual_f.readline() 23 _ = expected_f.readline() 24 25 # Compare the two lines of meta data 26 if (actual_f.readline() != expected_f.readline() or 27 actual_f.readline() != expected_f.readline()): 28 return FALSE 29 30 actual_data = np.loadtxt(actual, np.float64, skiprows=3) 31 expected_data = np.loadtxt(expected, np.float64, skiprows=3) 32 return np.allclose(actual_data, expected_data, rtol=rtol) 33 34 35def file_as_string(file: str) -> str: 36 """Returns contents of file as string.""" 37 with open(file, "r") as f: 38 return f.read() 39 40 41def run_test(f): 42 """Prints the test name and runs the test.""" 43 print(f.__name__) 44 f() 45 return f 46