1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ 2# RUN: %PYTHON %s | FileCheck %s 3 4import ctypes 5import errno 6import itertools 7import os 8import sys 9from typing import List, Callable 10 11import numpy as np 12 13import mlir.all_passes_registration 14 15from mlir import ir 16from mlir import runtime as rt 17from mlir.execution_engine import ExecutionEngine 18from mlir.passmanager import PassManager 19 20from mlir.dialects import builtin 21from mlir.dialects import std 22from mlir.dialects import sparse_tensor as st 23 24# ===----------------------------------------------------------------------=== # 25 26# TODO: move this boilerplate to its own module, so it can be used by 27# other tests and programs. 28class TypeConverter: 29 """Converter between NumPy types and MLIR types.""" 30 31 def __init__(self, context: ir.Context): 32 # Note 1: these are numpy "scalar types" (i.e., the values of 33 # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class). 34 # 35 # Note 2: we must construct the MLIR types in the same context as the 36 # types that'll be passed to irtype_to_sctype() or irtype_to_dtype(); 37 # otherwise, those methods will raise a KeyError. 38 types_list = [ 39 (np.float64, ir.F64Type.get(context=context)), 40 (np.float32, ir.F32Type.get(context=context)), 41 (np.int64, ir.IntegerType.get_signless(64, context=context)), 42 (np.int32, ir.IntegerType.get_signless(32, context=context)), 43 (np.int16, ir.IntegerType.get_signless(16, context=context)), 44 (np.int8, ir.IntegerType.get_signless(8, context=context)), 45 ] 46 self._sc2ir = dict(types_list) 47 self._ir2sc = dict(( (ir,sc) for sc,ir in types_list )) 48 49 def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type: 50 """Returns the MLIR equivalent of a NumPy dtype.""" 51 try: 52 return self.sctype_to_irtype(dtype.type) 53 except KeyError as e: 54 raise KeyError(f'Unknown dtype: {dtype}') from e 55 56 def sctype_to_irtype(self, sctype) -> ir.Type: 57 """Returns the MLIR equivalent of a NumPy scalar type.""" 58 if sctype in self._sc2ir: 59 return self._sc2ir[sctype] 60 else: 61 raise KeyError(f'Unknown sctype: {sctype}') 62 63 def irtype_to_dtype(self, tp: ir.Type) -> np.dtype: 64 """Returns the NumPy dtype equivalent of an MLIR type.""" 65 return np.dtype(self.irtype_to_sctype(tp)) 66 67 def irtype_to_sctype(self, tp: ir.Type): 68 """Returns the NumPy scalar-type equivalent of an MLIR type.""" 69 if tp in self._ir2sc: 70 return self._ir2sc[tp] 71 else: 72 raise KeyError(f'Unknown ir.Type: {tp}') 73 74 def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType: 75 """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy 76 arrays can only be converted to/from dense tensors, not sparse tensors.""" 77 # TODO: handle strides as well? 78 return ir.RankedTensorType.get(nparray.shape, 79 self.dtype_to_irtype(nparray.dtype)) 80 81# ===----------------------------------------------------------------------=== # 82 83class StressTest: 84 def __init__(self, tyconv: TypeConverter): 85 self._tyconv = tyconv 86 self._roundtripTp = None 87 self._module = None 88 self._engine = None 89 90 def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType): 91 assert self._roundtripTp is not None, \ 92 'StressTest: uninitialized roundtrip type' 93 if tp != self._roundtripTp: 94 raise AssertionError( 95 f"Type is not equal to the roundtrip type.\n" 96 f"\tExpected: {self._roundtripTp}\n" 97 f"\tFound: {tp}\n") 98 99 def build(self, types: List[ir.Type]): 100 """Builds the ir.Module. The module has only the @main function, 101 which will convert the input through the list of types and then back 102 to the initial type. The roundtrip type must be a dense tensor.""" 103 assert self._module is None, 'StressTest: must not call build() repeatedly' 104 self._module = ir.Module.create() 105 with ir.InsertionPoint(self._module.body): 106 tp0 = types.pop(0) 107 self._roundtripTp = tp0 108 # TODO: assert dense? assert element type is recognised by the TypeConverter? 109 types.append(tp0) 110 funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0]) 111 funcOp = builtin.FuncOp(name='main', type=funcTp) 112 funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get() 113 with ir.InsertionPoint(funcOp.add_entry_block()): 114 arg0 = funcOp.entry_block.arguments[0] 115 self._assertEqualsRoundtripTp(arg0.type) 116 v = st.ConvertOp(types.pop(0), arg0) 117 for tp in types: 118 w = st.ConvertOp(tp, v) 119 # Release intermediate tensors before they fall out of scope. 120 st.ReleaseOp(v.result) 121 v = w 122 self._assertEqualsRoundtripTp(v.result.type) 123 std.ReturnOp(v) 124 return self 125 126 def writeTo(self, filename): 127 """Write the ir.Module to the given file. If the file already exists, 128 then raises an error. If the filename is None, then is a no-op.""" 129 assert self._module is not None, \ 130 'StressTest: must call build() before writeTo()' 131 if filename is None: 132 # Silent no-op, for convenience. 133 return self 134 if os.path.exists(filename): 135 raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename) 136 with open(filename, 'w') as f: 137 f.write(str(self._module)) 138 return self 139 140 def compile(self, compiler: Callable[[ir.Module], ExecutionEngine]): 141 """Compile the ir.Module.""" 142 assert self._module is not None, \ 143 'StressTest: must call build() before compile()' 144 assert self._engine is None, \ 145 'StressTest: must not call compile() repeatedly' 146 self._engine = compiler(self._module) 147 return self 148 149 def run(self, np_arg0: np.ndarray) -> np.ndarray: 150 """Runs the test on the given numpy array, and returns the resulting 151 numpy array.""" 152 assert self._engine is not None, \ 153 'StressTest: must call compile() before run()' 154 self._assertEqualsRoundtripTp( 155 self._tyconv.get_RankedTensorType_of_nparray(np_arg0)) 156 np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype) 157 self._assertEqualsRoundtripTp( 158 self._tyconv.get_RankedTensorType_of_nparray(np_out)) 159 mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))) 160 mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))) 161 self._engine.invoke('main', mem_out, mem_arg0) 162 return rt.ranked_memref_to_numpy(mem_out[0]) 163 164# ===----------------------------------------------------------------------=== # 165 166# TODO: move this boilerplate to its own module, so it can be used by 167# other tests and programs. 168class SparseCompiler: 169 """Sparse compiler passes.""" 170 171 def __init__(self, sparsification_options: str, support_lib: str): 172 self._support_lib = support_lib 173 self._pipeline = ( 174 f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),' 175 f'sparsification{{{sparsification_options}}},' 176 f'sparse-tensor-conversion,' 177 f'builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),' 178 f'convert-scf-to-std,' 179 f'func-bufferize,' 180 f'tensor-constant-bufferize,' 181 f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),' 182 f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},' 183 f'lower-affine,' 184 f'convert-memref-to-llvm,' 185 f'convert-std-to-llvm,' 186 f'reconcile-unrealized-casts') 187 # Must be in the scope of a `with ir.Context():` 188 self._passmanager = PassManager.parse(self._pipeline) 189 190 def __call__(self, module: ir.Module) -> ExecutionEngine: 191 self._passmanager.run(module) 192 return ExecutionEngine(module, opt_level=0, shared_libs=[self._support_lib]) 193 194# ===----------------------------------------------------------------------=== # 195 196def main(): 197 """ 198 USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]] 199 200 The environment variable SUPPORT_LIB must be set to point to the 201 libmlir_c_runner_utils shared library. There are two optional 202 arguments, for debugging purposes. The first argument specifies where 203 to write out the raw/generated ir.Module. The second argument specifies 204 where to write out the compiled version of that ir.Module. 205 """ 206 support_lib = os.getenv('SUPPORT_LIB') 207 assert support_lib is not None, 'SUPPORT_LIB is undefined' 208 if not os.path.exists(support_lib): 209 raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) 210 211 # CHECK-LABEL: TEST: test_stress 212 print("\nTEST: test_stress") 213 with ir.Context() as ctx, ir.Location.unknown(): 214 par = 0 215 vec = 0 216 vl = 1 217 e = False 218 sparsification_options = ( 219 f'parallelization-strategy={par} ' 220 f'vectorization-strategy={vec} ' 221 f'vl={vl} ' 222 f'enable-simd-index32={e}') 223 compiler = SparseCompiler(sparsification_options, support_lib) 224 f64 = ir.F64Type.get() 225 # Be careful about increasing this because 226 # len(types) = 1 + 2^rank * rank! * len(bitwidths)^2 227 shape = range(2, 6) 228 rank = len(shape) 229 # All combinations. 230 levels = list(itertools.product(*itertools.repeat( 231 [st.DimLevelType.dense, st.DimLevelType.compressed], rank))) 232 # All permutations. 233 orderings = list(map(ir.AffineMap.get_permutation, 234 itertools.permutations(range(rank)))) 235 bitwidths = [0] 236 # The first type must be a dense tensor for numpy conversion to work. 237 types = [ir.RankedTensorType.get(shape, f64)] 238 for level in levels: 239 for ordering in orderings: 240 for pwidth in bitwidths: 241 for iwidth in bitwidths: 242 attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth) 243 types.append(ir.RankedTensorType.get(shape, f64, attr)) 244 # 245 # For exhaustiveness we should have one or more StressTest, such 246 # that their paths cover all 2*n*(n-1) directed pairwise combinations 247 # of the `types` set. However, since n is already superexponential, 248 # such exhaustiveness would be prohibitive for a test that runs on 249 # every commit. So for now we'll just pick one particular path that 250 # at least hits all n elements of the `types` set. 251 # 252 tyconv = TypeConverter(ctx) 253 size = 1 254 for d in shape: 255 size *= d 256 np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape) 257 np_out = ( 258 StressTest(tyconv) 259 .build(types) 260 .writeTo(sys.argv[1] if len(sys.argv) > 1 else None) 261 .compile(compiler) 262 .writeTo(sys.argv[2] if len(sys.argv) > 2 else None) 263 .run(np_arg0)) 264 # CHECK: Passed 265 if np.allclose(np_out, np_arg0): 266 print('Passed') 267 else: 268 sys.exit('FAILURE') 269 270if __name__ == '__main__': 271 main() 272