1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import errno
6import itertools
7import os
8import sys
9from typing import List, Callable
10
11import numpy as np
12
13import mlir.all_passes_registration
14
15from mlir import ir
16from mlir import runtime as rt
17from mlir.execution_engine import ExecutionEngine
18from mlir.passmanager import PassManager
19
20from mlir.dialects import builtin
21from mlir.dialects import std
22from mlir.dialects import sparse_tensor as st
23
24# ===----------------------------------------------------------------------=== #
25
26# TODO: move this boilerplate to its own module, so it can be used by
27# other tests and programs.
28class TypeConverter:
29  """Converter between NumPy types and MLIR types."""
30
31  def __init__(self, context: ir.Context):
32    # Note 1: these are numpy "scalar types" (i.e., the values of
33    # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
34    #
35    # Note 2: we must construct the MLIR types in the same context as the
36    # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
37    # otherwise, those methods will raise a KeyError.
38    types_list = [
39      (np.float64, ir.F64Type.get(context=context)),
40      (np.float32, ir.F32Type.get(context=context)),
41      (np.int64, ir.IntegerType.get_signless(64, context=context)),
42      (np.int32, ir.IntegerType.get_signless(32, context=context)),
43      (np.int16, ir.IntegerType.get_signless(16, context=context)),
44      (np.int8, ir.IntegerType.get_signless(8, context=context)),
45    ]
46    self._sc2ir = dict(types_list)
47    self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
48
49  def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
50    """Returns the MLIR equivalent of a NumPy dtype."""
51    try:
52      return self.sctype_to_irtype(dtype.type)
53    except KeyError as e:
54      raise KeyError(f'Unknown dtype: {dtype}') from e
55
56  def sctype_to_irtype(self, sctype) -> ir.Type:
57    """Returns the MLIR equivalent of a NumPy scalar type."""
58    if sctype in self._sc2ir:
59      return self._sc2ir[sctype]
60    else:
61      raise KeyError(f'Unknown sctype: {sctype}')
62
63  def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
64    """Returns the NumPy dtype equivalent of an MLIR type."""
65    return np.dtype(self.irtype_to_sctype(tp))
66
67  def irtype_to_sctype(self, tp: ir.Type):
68    """Returns the NumPy scalar-type equivalent of an MLIR type."""
69    if tp in self._ir2sc:
70      return self._ir2sc[tp]
71    else:
72      raise KeyError(f'Unknown ir.Type: {tp}')
73
74  def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
75    """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
76    arrays can only be converted to/from dense tensors, not sparse tensors."""
77    # TODO: handle strides as well?
78    return ir.RankedTensorType.get(nparray.shape,
79                                   self.dtype_to_irtype(nparray.dtype))
80
81# ===----------------------------------------------------------------------=== #
82
83class StressTest:
84  def __init__(self, tyconv: TypeConverter):
85    self._tyconv = tyconv
86    self._roundtripTp = None
87    self._module = None
88    self._engine = None
89
90  def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
91    assert self._roundtripTp is not None, \
92        'StressTest: uninitialized roundtrip type'
93    if tp != self._roundtripTp:
94      raise AssertionError(
95          f"Type is not equal to the roundtrip type.\n"
96          f"\tExpected: {self._roundtripTp}\n"
97          f"\tFound:    {tp}\n")
98
99  def build(self, types: List[ir.Type]):
100    """Builds the ir.Module.  The module has only the @main function,
101    which will convert the input through the list of types and then back
102    to the initial type.  The roundtrip type must be a dense tensor."""
103    assert self._module is None, 'StressTest: must not call build() repeatedly'
104    self._module = ir.Module.create()
105    with ir.InsertionPoint(self._module.body):
106      tp0 = types.pop(0)
107      self._roundtripTp = tp0
108      # TODO: assert dense? assert element type is recognised by the TypeConverter?
109      types.append(tp0)
110      funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
111      funcOp = builtin.FuncOp(name='main', type=funcTp)
112      funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
113      with ir.InsertionPoint(funcOp.add_entry_block()):
114        arg0 = funcOp.entry_block.arguments[0]
115        self._assertEqualsRoundtripTp(arg0.type)
116        v = st.ConvertOp(types.pop(0), arg0)
117        for tp in types:
118          w = st.ConvertOp(tp, v)
119          # Release intermediate tensors before they fall out of scope.
120          st.ReleaseOp(v.result)
121          v = w
122        self._assertEqualsRoundtripTp(v.result.type)
123        std.ReturnOp(v)
124    return self
125
126  def writeTo(self, filename):
127    """Write the ir.Module to the given file.  If the file already exists,
128    then raises an error.  If the filename is None, then is a no-op."""
129    assert self._module is not None, \
130        'StressTest: must call build() before writeTo()'
131    if filename is None:
132      # Silent no-op, for convenience.
133      return self
134    if os.path.exists(filename):
135      raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
136    with open(filename, 'w') as f:
137      f.write(str(self._module))
138    return self
139
140  def compile(self, compiler: Callable[[ir.Module], ExecutionEngine]):
141    """Compile the ir.Module."""
142    assert self._module is not None, \
143        'StressTest: must call build() before compile()'
144    assert self._engine is None, \
145        'StressTest: must not call compile() repeatedly'
146    self._engine = compiler(self._module)
147    return self
148
149  def run(self, np_arg0: np.ndarray) -> np.ndarray:
150    """Runs the test on the given numpy array, and returns the resulting
151    numpy array."""
152    assert self._engine is not None, \
153        'StressTest: must call compile() before run()'
154    self._assertEqualsRoundtripTp(
155        self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
156    np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
157    self._assertEqualsRoundtripTp(
158        self._tyconv.get_RankedTensorType_of_nparray(np_out))
159    mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
160    mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
161    self._engine.invoke('main', mem_out, mem_arg0)
162    return rt.ranked_memref_to_numpy(mem_out[0])
163
164# ===----------------------------------------------------------------------=== #
165
166# TODO: move this boilerplate to its own module, so it can be used by
167# other tests and programs.
168class SparseCompiler:
169  """Sparse compiler passes."""
170
171  def __init__(self, sparsification_options: str, support_lib: str):
172    self._support_lib = support_lib
173    self._pipeline = (
174        f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),'
175        f'sparsification{{{sparsification_options}}},'
176        f'sparse-tensor-conversion,'
177        f'builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),'
178        f'convert-scf-to-std,'
179        f'func-bufferize,'
180        f'tensor-constant-bufferize,'
181        f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),'
182        f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},'
183        f'lower-affine,'
184        f'convert-memref-to-llvm,'
185        f'convert-std-to-llvm,'
186        f'reconcile-unrealized-casts')
187    # Must be in the scope of a `with ir.Context():`
188    self._passmanager = PassManager.parse(self._pipeline)
189
190  def __call__(self, module: ir.Module) -> ExecutionEngine:
191    self._passmanager.run(module)
192    return ExecutionEngine(module, opt_level=0, shared_libs=[self._support_lib])
193
194# ===----------------------------------------------------------------------=== #
195
196def main():
197  """
198  USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
199
200  The environment variable SUPPORT_LIB must be set to point to the
201  libmlir_c_runner_utils shared library.  There are two optional
202  arguments, for debugging purposes.  The first argument specifies where
203  to write out the raw/generated ir.Module.  The second argument specifies
204  where to write out the compiled version of that ir.Module.
205  """
206  support_lib = os.getenv('SUPPORT_LIB')
207  assert support_lib is not None, 'SUPPORT_LIB is undefined'
208  if not os.path.exists(support_lib):
209    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
210
211  # CHECK-LABEL: TEST: test_stress
212  print("\nTEST: test_stress")
213  with ir.Context() as ctx, ir.Location.unknown():
214    par = 0
215    vec = 0
216    vl = 1
217    e = False
218    sparsification_options = (
219        f'parallelization-strategy={par} '
220        f'vectorization-strategy={vec} '
221        f'vl={vl} '
222        f'enable-simd-index32={e}')
223    compiler = SparseCompiler(sparsification_options, support_lib)
224    f64 = ir.F64Type.get()
225    # Be careful about increasing this because
226    #     len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
227    shape = range(2, 6)
228    rank = len(shape)
229    # All combinations.
230    levels = list(itertools.product(*itertools.repeat(
231      [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
232    # All permutations.
233    orderings = list(map(ir.AffineMap.get_permutation,
234      itertools.permutations(range(rank))))
235    bitwidths = [0]
236    # The first type must be a dense tensor for numpy conversion to work.
237    types = [ir.RankedTensorType.get(shape, f64)]
238    for level in levels:
239      for ordering in orderings:
240        for pwidth in bitwidths:
241          for iwidth in bitwidths:
242            attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
243            types.append(ir.RankedTensorType.get(shape, f64, attr))
244    #
245    # For exhaustiveness we should have one or more StressTest, such
246    # that their paths cover all 2*n*(n-1) directed pairwise combinations
247    # of the `types` set.  However, since n is already superexponential,
248    # such exhaustiveness would be prohibitive for a test that runs on
249    # every commit.  So for now we'll just pick one particular path that
250    # at least hits all n elements of the `types` set.
251    #
252    tyconv = TypeConverter(ctx)
253    size = 1
254    for d in shape:
255      size *= d
256    np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
257    np_out = (
258        StressTest(tyconv)
259        .build(types)
260        .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
261        .compile(compiler)
262        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
263        .run(np_arg0))
264    # CHECK: Passed
265    if np.allclose(np_out, np_arg0):
266      print('Passed')
267    else:
268      sys.exit('FAILURE')
269
270if __name__ == '__main__':
271  main()
272