14748cc69Swren romano# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 24748cc69Swren romano# RUN: %PYTHON %s | FileCheck %s 34748cc69Swren romano 44748cc69Swren romanoimport ctypes 54748cc69Swren romanoimport errno 64748cc69Swren romanoimport itertools 74748cc69Swren romanoimport os 84748cc69Swren romanoimport sys 98b83b8f1SAart Bik 104748cc69Swren romanofrom typing import List, Callable 114748cc69Swren romano 124748cc69Swren romanoimport numpy as np 134748cc69Swren romano 144748cc69Swren romanofrom mlir import ir 154748cc69Swren romanofrom mlir import runtime as rt 164748cc69Swren romano 17*27a431f5SMatthias Springerfrom mlir.dialects import bufferization 184748cc69Swren romanofrom mlir.dialects import builtin 1923aa5a74SRiver Riddlefrom mlir.dialects import func 204748cc69Swren romanofrom mlir.dialects import sparse_tensor as st 214748cc69Swren romano 228b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 238b83b8f1SAart Biksys.path.append(_SCRIPT_PATH) 248b83b8f1SAart Bikfrom tools import sparse_compiler 258b83b8f1SAart Bik 264748cc69Swren romano# ===----------------------------------------------------------------------=== # 274748cc69Swren romano 284748cc69Swren romano# TODO: move this boilerplate to its own module, so it can be used by 294748cc69Swren romano# other tests and programs. 304748cc69Swren romanoclass TypeConverter: 314748cc69Swren romano """Converter between NumPy types and MLIR types.""" 324748cc69Swren romano 334748cc69Swren romano def __init__(self, context: ir.Context): 344748cc69Swren romano # Note 1: these are numpy "scalar types" (i.e., the values of 354748cc69Swren romano # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class). 364748cc69Swren romano # 374748cc69Swren romano # Note 2: we must construct the MLIR types in the same context as the 384748cc69Swren romano # types that'll be passed to irtype_to_sctype() or irtype_to_dtype(); 394748cc69Swren romano # otherwise, those methods will raise a KeyError. 404748cc69Swren romano types_list = [ 414748cc69Swren romano (np.float64, ir.F64Type.get(context=context)), 424748cc69Swren romano (np.float32, ir.F32Type.get(context=context)), 434748cc69Swren romano (np.int64, ir.IntegerType.get_signless(64, context=context)), 444748cc69Swren romano (np.int32, ir.IntegerType.get_signless(32, context=context)), 454748cc69Swren romano (np.int16, ir.IntegerType.get_signless(16, context=context)), 464748cc69Swren romano (np.int8, ir.IntegerType.get_signless(8, context=context)), 474748cc69Swren romano ] 484748cc69Swren romano self._sc2ir = dict(types_list) 494748cc69Swren romano self._ir2sc = dict(( (ir,sc) for sc,ir in types_list )) 504748cc69Swren romano 514748cc69Swren romano def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type: 524748cc69Swren romano """Returns the MLIR equivalent of a NumPy dtype.""" 534748cc69Swren romano try: 544748cc69Swren romano return self.sctype_to_irtype(dtype.type) 554748cc69Swren romano except KeyError as e: 564748cc69Swren romano raise KeyError(f'Unknown dtype: {dtype}') from e 574748cc69Swren romano 584748cc69Swren romano def sctype_to_irtype(self, sctype) -> ir.Type: 594748cc69Swren romano """Returns the MLIR equivalent of a NumPy scalar type.""" 604748cc69Swren romano if sctype in self._sc2ir: 614748cc69Swren romano return self._sc2ir[sctype] 624748cc69Swren romano else: 634748cc69Swren romano raise KeyError(f'Unknown sctype: {sctype}') 644748cc69Swren romano 654748cc69Swren romano def irtype_to_dtype(self, tp: ir.Type) -> np.dtype: 664748cc69Swren romano """Returns the NumPy dtype equivalent of an MLIR type.""" 674748cc69Swren romano return np.dtype(self.irtype_to_sctype(tp)) 684748cc69Swren romano 694748cc69Swren romano def irtype_to_sctype(self, tp: ir.Type): 704748cc69Swren romano """Returns the NumPy scalar-type equivalent of an MLIR type.""" 714748cc69Swren romano if tp in self._ir2sc: 724748cc69Swren romano return self._ir2sc[tp] 734748cc69Swren romano else: 744748cc69Swren romano raise KeyError(f'Unknown ir.Type: {tp}') 754748cc69Swren romano 764748cc69Swren romano def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType: 774748cc69Swren romano """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy 784748cc69Swren romano arrays can only be converted to/from dense tensors, not sparse tensors.""" 794748cc69Swren romano # TODO: handle strides as well? 804748cc69Swren romano return ir.RankedTensorType.get(nparray.shape, 814748cc69Swren romano self.dtype_to_irtype(nparray.dtype)) 824748cc69Swren romano 834748cc69Swren romano# ===----------------------------------------------------------------------=== # 844748cc69Swren romano 854748cc69Swren romanoclass StressTest: 864748cc69Swren romano def __init__(self, tyconv: TypeConverter): 874748cc69Swren romano self._tyconv = tyconv 884748cc69Swren romano self._roundtripTp = None 894748cc69Swren romano self._module = None 904748cc69Swren romano self._engine = None 914748cc69Swren romano 924748cc69Swren romano def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType): 934748cc69Swren romano assert self._roundtripTp is not None, \ 944748cc69Swren romano 'StressTest: uninitialized roundtrip type' 954748cc69Swren romano if tp != self._roundtripTp: 964748cc69Swren romano raise AssertionError( 974748cc69Swren romano f"Type is not equal to the roundtrip type.\n" 984748cc69Swren romano f"\tExpected: {self._roundtripTp}\n" 994748cc69Swren romano f"\tFound: {tp}\n") 1004748cc69Swren romano 1014748cc69Swren romano def build(self, types: List[ir.Type]): 1024748cc69Swren romano """Builds the ir.Module. The module has only the @main function, 1034748cc69Swren romano which will convert the input through the list of types and then back 1044748cc69Swren romano to the initial type. The roundtrip type must be a dense tensor.""" 1054748cc69Swren romano assert self._module is None, 'StressTest: must not call build() repeatedly' 1064748cc69Swren romano self._module = ir.Module.create() 1074748cc69Swren romano with ir.InsertionPoint(self._module.body): 1084748cc69Swren romano tp0 = types.pop(0) 1094748cc69Swren romano self._roundtripTp = tp0 1104748cc69Swren romano # TODO: assert dense? assert element type is recognised by the TypeConverter? 1114748cc69Swren romano types.append(tp0) 1124748cc69Swren romano funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0]) 11336550692SRiver Riddle funcOp = func.FuncOp(name='main', type=funcTp) 1144748cc69Swren romano funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get() 1154748cc69Swren romano with ir.InsertionPoint(funcOp.add_entry_block()): 1164748cc69Swren romano arg0 = funcOp.entry_block.arguments[0] 1174748cc69Swren romano self._assertEqualsRoundtripTp(arg0.type) 1184748cc69Swren romano v = st.ConvertOp(types.pop(0), arg0) 1194748cc69Swren romano for tp in types: 1204748cc69Swren romano w = st.ConvertOp(tp, v) 1214748cc69Swren romano # Release intermediate tensors before they fall out of scope. 122*27a431f5SMatthias Springer bufferization.DeallocTensorOp(v.result) 1234748cc69Swren romano v = w 1244748cc69Swren romano self._assertEqualsRoundtripTp(v.result.type) 12523aa5a74SRiver Riddle func.ReturnOp(v) 1264748cc69Swren romano return self 1274748cc69Swren romano 1284748cc69Swren romano def writeTo(self, filename): 1294748cc69Swren romano """Write the ir.Module to the given file. If the file already exists, 1304748cc69Swren romano then raises an error. If the filename is None, then is a no-op.""" 1314748cc69Swren romano assert self._module is not None, \ 1324748cc69Swren romano 'StressTest: must call build() before writeTo()' 1334748cc69Swren romano if filename is None: 1344748cc69Swren romano # Silent no-op, for convenience. 1354748cc69Swren romano return self 1364748cc69Swren romano if os.path.exists(filename): 1374748cc69Swren romano raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename) 1384748cc69Swren romano with open(filename, 'w') as f: 1394748cc69Swren romano f.write(str(self._module)) 1404748cc69Swren romano return self 1414748cc69Swren romano 14228063a28SAart Bik def compile(self, compiler): 1434748cc69Swren romano """Compile the ir.Module.""" 1444748cc69Swren romano assert self._module is not None, \ 1454748cc69Swren romano 'StressTest: must call build() before compile()' 1464748cc69Swren romano assert self._engine is None, \ 1474748cc69Swren romano 'StressTest: must not call compile() repeatedly' 14828063a28SAart Bik self._engine = compiler.compile_and_jit(self._module) 1494748cc69Swren romano return self 1504748cc69Swren romano 1514748cc69Swren romano def run(self, np_arg0: np.ndarray) -> np.ndarray: 1524748cc69Swren romano """Runs the test on the given numpy array, and returns the resulting 1534748cc69Swren romano numpy array.""" 1544748cc69Swren romano assert self._engine is not None, \ 1554748cc69Swren romano 'StressTest: must call compile() before run()' 1564748cc69Swren romano self._assertEqualsRoundtripTp( 1574748cc69Swren romano self._tyconv.get_RankedTensorType_of_nparray(np_arg0)) 1584748cc69Swren romano np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype) 1594748cc69Swren romano self._assertEqualsRoundtripTp( 1604748cc69Swren romano self._tyconv.get_RankedTensorType_of_nparray(np_out)) 1614748cc69Swren romano mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))) 1624748cc69Swren romano mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))) 1634748cc69Swren romano self._engine.invoke('main', mem_out, mem_arg0) 1644748cc69Swren romano return rt.ranked_memref_to_numpy(mem_out[0]) 1654748cc69Swren romano 1664748cc69Swren romano# ===----------------------------------------------------------------------=== # 1674748cc69Swren romano 1684748cc69Swren romanodef main(): 1694748cc69Swren romano """ 1704748cc69Swren romano USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]] 1714748cc69Swren romano 1724748cc69Swren romano The environment variable SUPPORT_LIB must be set to point to the 1734748cc69Swren romano libmlir_c_runner_utils shared library. There are two optional 1744748cc69Swren romano arguments, for debugging purposes. The first argument specifies where 1754748cc69Swren romano to write out the raw/generated ir.Module. The second argument specifies 1764748cc69Swren romano where to write out the compiled version of that ir.Module. 1774748cc69Swren romano """ 1784748cc69Swren romano support_lib = os.getenv('SUPPORT_LIB') 1794748cc69Swren romano assert support_lib is not None, 'SUPPORT_LIB is undefined' 1804748cc69Swren romano if not os.path.exists(support_lib): 1814748cc69Swren romano raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) 1824748cc69Swren romano 1834748cc69Swren romano # CHECK-LABEL: TEST: test_stress 1844748cc69Swren romano print("\nTEST: test_stress") 1854748cc69Swren romano with ir.Context() as ctx, ir.Location.unknown(): 1864620032eSNick Kreeger par = 0 1874620032eSNick Kreeger vec = 0 1884748cc69Swren romano vl = 1 1894748cc69Swren romano e = False 1908cb33240Swren romano # Disable direct sparse2sparse conversion, because it doubles the time! 1918cb33240Swren romano # TODO: While direct s2s is far too slow for per-commit testing, 1928cb33240Swren romano # we should have some framework ensure that we run this test with 1938cb33240Swren romano # `s2s=0` on a regular basis, to ensure that it does continue to work. 1948cb33240Swren romano s2s = 1 1954748cc69Swren romano sparsification_options = ( 1964620032eSNick Kreeger f'parallelization-strategy={par} ' 1974620032eSNick Kreeger f'vectorization-strategy={vec} ' 1984748cc69Swren romano f'vl={vl} ' 1998cb33240Swren romano f'enable-simd-index32={e} ' 2008cb33240Swren romano f's2s-strategy={s2s}') 20128063a28SAart Bik compiler = sparse_compiler.SparseCompiler( 20228063a28SAart Bik options=sparsification_options, opt_level=0, shared_libs=[support_lib]) 2034748cc69Swren romano f64 = ir.F64Type.get() 2044748cc69Swren romano # Be careful about increasing this because 2054748cc69Swren romano # len(types) = 1 + 2^rank * rank! * len(bitwidths)^2 2064748cc69Swren romano shape = range(2, 6) 2074748cc69Swren romano rank = len(shape) 2084748cc69Swren romano # All combinations. 2094748cc69Swren romano levels = list(itertools.product(*itertools.repeat( 2104748cc69Swren romano [st.DimLevelType.dense, st.DimLevelType.compressed], rank))) 2114748cc69Swren romano # All permutations. 2124748cc69Swren romano orderings = list(map(ir.AffineMap.get_permutation, 2134748cc69Swren romano itertools.permutations(range(rank)))) 2144748cc69Swren romano bitwidths = [0] 2154748cc69Swren romano # The first type must be a dense tensor for numpy conversion to work. 2164748cc69Swren romano types = [ir.RankedTensorType.get(shape, f64)] 2174748cc69Swren romano for level in levels: 2184748cc69Swren romano for ordering in orderings: 2194748cc69Swren romano for pwidth in bitwidths: 2204748cc69Swren romano for iwidth in bitwidths: 2214748cc69Swren romano attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth) 2224748cc69Swren romano types.append(ir.RankedTensorType.get(shape, f64, attr)) 2234748cc69Swren romano # 2244748cc69Swren romano # For exhaustiveness we should have one or more StressTest, such 2254748cc69Swren romano # that their paths cover all 2*n*(n-1) directed pairwise combinations 2264748cc69Swren romano # of the `types` set. However, since n is already superexponential, 2274748cc69Swren romano # such exhaustiveness would be prohibitive for a test that runs on 2284748cc69Swren romano # every commit. So for now we'll just pick one particular path that 2294748cc69Swren romano # at least hits all n elements of the `types` set. 2304748cc69Swren romano # 2314748cc69Swren romano tyconv = TypeConverter(ctx) 2324748cc69Swren romano size = 1 2334748cc69Swren romano for d in shape: 2344748cc69Swren romano size *= d 2354748cc69Swren romano np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape) 2364748cc69Swren romano np_out = ( 2378b83b8f1SAart Bik StressTest(tyconv).build(types).writeTo( 23828063a28SAart Bik sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler) 23928063a28SAart Bik .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0)) 2404748cc69Swren romano # CHECK: Passed 2414748cc69Swren romano if np.allclose(np_out, np_arg0): 2424748cc69Swren romano print('Passed') 2434748cc69Swren romano else: 2444748cc69Swren romano sys.exit('FAILURE') 2454748cc69Swren romano 2464748cc69Swren romanoif __name__ == '__main__': 2474748cc69Swren romano main() 248