1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import errno
6import itertools
7import os
8import sys
9
10from typing import List, Callable
11
12import numpy as np
13
14from mlir import ir
15from mlir import runtime as rt
16
17from mlir.dialects import bufferization
18from mlir.dialects import builtin
19from mlir.dialects import func
20from mlir.dialects import sparse_tensor as st
21
22_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
23sys.path.append(_SCRIPT_PATH)
24from tools import sparse_compiler
25
26# ===----------------------------------------------------------------------=== #
27
28# TODO: move this boilerplate to its own module, so it can be used by
29# other tests and programs.
30class TypeConverter:
31  """Converter between NumPy types and MLIR types."""
32
33  def __init__(self, context: ir.Context):
34    # Note 1: these are numpy "scalar types" (i.e., the values of
35    # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
36    #
37    # Note 2: we must construct the MLIR types in the same context as the
38    # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
39    # otherwise, those methods will raise a KeyError.
40    types_list = [
41      (np.float64, ir.F64Type.get(context=context)),
42      (np.float32, ir.F32Type.get(context=context)),
43      (np.int64, ir.IntegerType.get_signless(64, context=context)),
44      (np.int32, ir.IntegerType.get_signless(32, context=context)),
45      (np.int16, ir.IntegerType.get_signless(16, context=context)),
46      (np.int8, ir.IntegerType.get_signless(8, context=context)),
47    ]
48    self._sc2ir = dict(types_list)
49    self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
50
51  def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
52    """Returns the MLIR equivalent of a NumPy dtype."""
53    try:
54      return self.sctype_to_irtype(dtype.type)
55    except KeyError as e:
56      raise KeyError(f'Unknown dtype: {dtype}') from e
57
58  def sctype_to_irtype(self, sctype) -> ir.Type:
59    """Returns the MLIR equivalent of a NumPy scalar type."""
60    if sctype in self._sc2ir:
61      return self._sc2ir[sctype]
62    else:
63      raise KeyError(f'Unknown sctype: {sctype}')
64
65  def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
66    """Returns the NumPy dtype equivalent of an MLIR type."""
67    return np.dtype(self.irtype_to_sctype(tp))
68
69  def irtype_to_sctype(self, tp: ir.Type):
70    """Returns the NumPy scalar-type equivalent of an MLIR type."""
71    if tp in self._ir2sc:
72      return self._ir2sc[tp]
73    else:
74      raise KeyError(f'Unknown ir.Type: {tp}')
75
76  def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
77    """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
78    arrays can only be converted to/from dense tensors, not sparse tensors."""
79    # TODO: handle strides as well?
80    return ir.RankedTensorType.get(nparray.shape,
81                                   self.dtype_to_irtype(nparray.dtype))
82
83# ===----------------------------------------------------------------------=== #
84
85class StressTest:
86  def __init__(self, tyconv: TypeConverter):
87    self._tyconv = tyconv
88    self._roundtripTp = None
89    self._module = None
90    self._engine = None
91
92  def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
93    assert self._roundtripTp is not None, \
94        'StressTest: uninitialized roundtrip type'
95    if tp != self._roundtripTp:
96      raise AssertionError(
97          f"Type is not equal to the roundtrip type.\n"
98          f"\tExpected: {self._roundtripTp}\n"
99          f"\tFound:    {tp}\n")
100
101  def build(self, types: List[ir.Type]):
102    """Builds the ir.Module.  The module has only the @main function,
103    which will convert the input through the list of types and then back
104    to the initial type.  The roundtrip type must be a dense tensor."""
105    assert self._module is None, 'StressTest: must not call build() repeatedly'
106    self._module = ir.Module.create()
107    with ir.InsertionPoint(self._module.body):
108      tp0 = types.pop(0)
109      self._roundtripTp = tp0
110      # TODO: assert dense? assert element type is recognised by the TypeConverter?
111      types.append(tp0)
112      funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
113      funcOp = func.FuncOp(name='main', type=funcTp)
114      funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
115      with ir.InsertionPoint(funcOp.add_entry_block()):
116        arg0 = funcOp.entry_block.arguments[0]
117        self._assertEqualsRoundtripTp(arg0.type)
118        v = st.ConvertOp(types.pop(0), arg0)
119        for tp in types:
120          w = st.ConvertOp(tp, v)
121          # Release intermediate tensors before they fall out of scope.
122          bufferization.DeallocTensorOp(v.result)
123          v = w
124        self._assertEqualsRoundtripTp(v.result.type)
125        func.ReturnOp(v)
126    return self
127
128  def writeTo(self, filename):
129    """Write the ir.Module to the given file.  If the file already exists,
130    then raises an error.  If the filename is None, then is a no-op."""
131    assert self._module is not None, \
132        'StressTest: must call build() before writeTo()'
133    if filename is None:
134      # Silent no-op, for convenience.
135      return self
136    if os.path.exists(filename):
137      raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
138    with open(filename, 'w') as f:
139      f.write(str(self._module))
140    return self
141
142  def compile(self, compiler):
143    """Compile the ir.Module."""
144    assert self._module is not None, \
145        'StressTest: must call build() before compile()'
146    assert self._engine is None, \
147        'StressTest: must not call compile() repeatedly'
148    self._engine = compiler.compile_and_jit(self._module)
149    return self
150
151  def run(self, np_arg0: np.ndarray) -> np.ndarray:
152    """Runs the test on the given numpy array, and returns the resulting
153    numpy array."""
154    assert self._engine is not None, \
155        'StressTest: must call compile() before run()'
156    self._assertEqualsRoundtripTp(
157        self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
158    np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
159    self._assertEqualsRoundtripTp(
160        self._tyconv.get_RankedTensorType_of_nparray(np_out))
161    mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
162    mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
163    self._engine.invoke('main', mem_out, mem_arg0)
164    return rt.ranked_memref_to_numpy(mem_out[0])
165
166# ===----------------------------------------------------------------------=== #
167
168def main():
169  """
170  USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
171
172  The environment variable SUPPORT_LIB must be set to point to the
173  libmlir_c_runner_utils shared library.  There are two optional
174  arguments, for debugging purposes.  The first argument specifies where
175  to write out the raw/generated ir.Module.  The second argument specifies
176  where to write out the compiled version of that ir.Module.
177  """
178  support_lib = os.getenv('SUPPORT_LIB')
179  assert support_lib is not None, 'SUPPORT_LIB is undefined'
180  if not os.path.exists(support_lib):
181    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
182
183  # CHECK-LABEL: TEST: test_stress
184  print("\nTEST: test_stress")
185  with ir.Context() as ctx, ir.Location.unknown():
186    par = 0
187    vec = 0
188    vl = 1
189    e = False
190    # Disable direct sparse2sparse conversion, because it doubles the time!
191    # TODO: While direct s2s is far too slow for per-commit testing,
192    # we should have some framework ensure that we run this test with
193    # `s2s=0` on a regular basis, to ensure that it does continue to work.
194    s2s = 1
195    sparsification_options = (
196        f'parallelization-strategy={par} '
197        f'vectorization-strategy={vec} '
198        f'vl={vl} '
199        f'enable-simd-index32={e} '
200        f's2s-strategy={s2s}')
201    compiler = sparse_compiler.SparseCompiler(
202        options=sparsification_options, opt_level=0, shared_libs=[support_lib])
203    f64 = ir.F64Type.get()
204    # Be careful about increasing this because
205    #     len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
206    shape = range(2, 6)
207    rank = len(shape)
208    # All combinations.
209    levels = list(itertools.product(*itertools.repeat(
210      [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
211    # All permutations.
212    orderings = list(map(ir.AffineMap.get_permutation,
213      itertools.permutations(range(rank))))
214    bitwidths = [0]
215    # The first type must be a dense tensor for numpy conversion to work.
216    types = [ir.RankedTensorType.get(shape, f64)]
217    for level in levels:
218      for ordering in orderings:
219        for pwidth in bitwidths:
220          for iwidth in bitwidths:
221            attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
222            types.append(ir.RankedTensorType.get(shape, f64, attr))
223    #
224    # For exhaustiveness we should have one or more StressTest, such
225    # that their paths cover all 2*n*(n-1) directed pairwise combinations
226    # of the `types` set.  However, since n is already superexponential,
227    # such exhaustiveness would be prohibitive for a test that runs on
228    # every commit.  So for now we'll just pick one particular path that
229    # at least hits all n elements of the `types` set.
230    #
231    tyconv = TypeConverter(ctx)
232    size = 1
233    for d in shape:
234      size *= d
235    np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
236    np_out = (
237        StressTest(tyconv).build(types).writeTo(
238            sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
239        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
240    # CHECK: Passed
241    if np.allclose(np_out, np_arg0):
242      print('Passed')
243    else:
244      sys.exit('FAILURE')
245
246if __name__ == '__main__':
247  main()
248