1# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
2# RUN:   %PYTHON %s | FileCheck %s
3
4import ctypes
5import errno
6import itertools
7import os
8import sys
9
10from typing import List, Callable
11
12import numpy as np
13
14from mlir import ir
15from mlir import runtime as rt
16
17from mlir.dialects import builtin
18from mlir.dialects import func
19from mlir.dialects import sparse_tensor as st
20
21_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
22sys.path.append(_SCRIPT_PATH)
23from tools import sparse_compiler
24
25# ===----------------------------------------------------------------------=== #
26
27# TODO: move this boilerplate to its own module, so it can be used by
28# other tests and programs.
29class TypeConverter:
30  """Converter between NumPy types and MLIR types."""
31
32  def __init__(self, context: ir.Context):
33    # Note 1: these are numpy "scalar types" (i.e., the values of
34    # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
35    #
36    # Note 2: we must construct the MLIR types in the same context as the
37    # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
38    # otherwise, those methods will raise a KeyError.
39    types_list = [
40      (np.float64, ir.F64Type.get(context=context)),
41      (np.float32, ir.F32Type.get(context=context)),
42      (np.int64, ir.IntegerType.get_signless(64, context=context)),
43      (np.int32, ir.IntegerType.get_signless(32, context=context)),
44      (np.int16, ir.IntegerType.get_signless(16, context=context)),
45      (np.int8, ir.IntegerType.get_signless(8, context=context)),
46    ]
47    self._sc2ir = dict(types_list)
48    self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
49
50  def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
51    """Returns the MLIR equivalent of a NumPy dtype."""
52    try:
53      return self.sctype_to_irtype(dtype.type)
54    except KeyError as e:
55      raise KeyError(f'Unknown dtype: {dtype}') from e
56
57  def sctype_to_irtype(self, sctype) -> ir.Type:
58    """Returns the MLIR equivalent of a NumPy scalar type."""
59    if sctype in self._sc2ir:
60      return self._sc2ir[sctype]
61    else:
62      raise KeyError(f'Unknown sctype: {sctype}')
63
64  def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
65    """Returns the NumPy dtype equivalent of an MLIR type."""
66    return np.dtype(self.irtype_to_sctype(tp))
67
68  def irtype_to_sctype(self, tp: ir.Type):
69    """Returns the NumPy scalar-type equivalent of an MLIR type."""
70    if tp in self._ir2sc:
71      return self._ir2sc[tp]
72    else:
73      raise KeyError(f'Unknown ir.Type: {tp}')
74
75  def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
76    """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
77    arrays can only be converted to/from dense tensors, not sparse tensors."""
78    # TODO: handle strides as well?
79    return ir.RankedTensorType.get(nparray.shape,
80                                   self.dtype_to_irtype(nparray.dtype))
81
82# ===----------------------------------------------------------------------=== #
83
84class StressTest:
85  def __init__(self, tyconv: TypeConverter):
86    self._tyconv = tyconv
87    self._roundtripTp = None
88    self._module = None
89    self._engine = None
90
91  def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
92    assert self._roundtripTp is not None, \
93        'StressTest: uninitialized roundtrip type'
94    if tp != self._roundtripTp:
95      raise AssertionError(
96          f"Type is not equal to the roundtrip type.\n"
97          f"\tExpected: {self._roundtripTp}\n"
98          f"\tFound:    {tp}\n")
99
100  def build(self, types: List[ir.Type]):
101    """Builds the ir.Module.  The module has only the @main function,
102    which will convert the input through the list of types and then back
103    to the initial type.  The roundtrip type must be a dense tensor."""
104    assert self._module is None, 'StressTest: must not call build() repeatedly'
105    self._module = ir.Module.create()
106    with ir.InsertionPoint(self._module.body):
107      tp0 = types.pop(0)
108      self._roundtripTp = tp0
109      # TODO: assert dense? assert element type is recognised by the TypeConverter?
110      types.append(tp0)
111      funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
112      funcOp = func.FuncOp(name='main', type=funcTp)
113      funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
114      with ir.InsertionPoint(funcOp.add_entry_block()):
115        arg0 = funcOp.entry_block.arguments[0]
116        self._assertEqualsRoundtripTp(arg0.type)
117        v = st.ConvertOp(types.pop(0), arg0)
118        for tp in types:
119          w = st.ConvertOp(tp, v)
120          # Release intermediate tensors before they fall out of scope.
121          st.ReleaseOp(v.result)
122          v = w
123        self._assertEqualsRoundtripTp(v.result.type)
124        func.ReturnOp(v)
125    return self
126
127  def writeTo(self, filename):
128    """Write the ir.Module to the given file.  If the file already exists,
129    then raises an error.  If the filename is None, then is a no-op."""
130    assert self._module is not None, \
131        'StressTest: must call build() before writeTo()'
132    if filename is None:
133      # Silent no-op, for convenience.
134      return self
135    if os.path.exists(filename):
136      raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
137    with open(filename, 'w') as f:
138      f.write(str(self._module))
139    return self
140
141  def compile(self, compiler):
142    """Compile the ir.Module."""
143    assert self._module is not None, \
144        'StressTest: must call build() before compile()'
145    assert self._engine is None, \
146        'StressTest: must not call compile() repeatedly'
147    self._engine = compiler.compile_and_jit(self._module)
148    return self
149
150  def run(self, np_arg0: np.ndarray) -> np.ndarray:
151    """Runs the test on the given numpy array, and returns the resulting
152    numpy array."""
153    assert self._engine is not None, \
154        'StressTest: must call compile() before run()'
155    self._assertEqualsRoundtripTp(
156        self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
157    np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
158    self._assertEqualsRoundtripTp(
159        self._tyconv.get_RankedTensorType_of_nparray(np_out))
160    mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
161    mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
162    self._engine.invoke('main', mem_out, mem_arg0)
163    return rt.ranked_memref_to_numpy(mem_out[0])
164
165# ===----------------------------------------------------------------------=== #
166
167def main():
168  """
169  USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
170
171  The environment variable SUPPORT_LIB must be set to point to the
172  libmlir_c_runner_utils shared library.  There are two optional
173  arguments, for debugging purposes.  The first argument specifies where
174  to write out the raw/generated ir.Module.  The second argument specifies
175  where to write out the compiled version of that ir.Module.
176  """
177  support_lib = os.getenv('SUPPORT_LIB')
178  assert support_lib is not None, 'SUPPORT_LIB is undefined'
179  if not os.path.exists(support_lib):
180    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
181
182  # CHECK-LABEL: TEST: test_stress
183  print("\nTEST: test_stress")
184  with ir.Context() as ctx, ir.Location.unknown():
185    par = 0
186    vec = 0
187    vl = 1
188    e = False
189    # Disable direct sparse2sparse conversion, because it doubles the time!
190    # TODO: While direct s2s is far too slow for per-commit testing,
191    # we should have some framework ensure that we run this test with
192    # `s2s=0` on a regular basis, to ensure that it does continue to work.
193    s2s = 1
194    sparsification_options = (
195        f'parallelization-strategy={par} '
196        f'vectorization-strategy={vec} '
197        f'vl={vl} '
198        f'enable-simd-index32={e} '
199        f's2s-strategy={s2s}')
200    compiler = sparse_compiler.SparseCompiler(
201        options=sparsification_options, opt_level=0, shared_libs=[support_lib])
202    f64 = ir.F64Type.get()
203    # Be careful about increasing this because
204    #     len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
205    shape = range(2, 6)
206    rank = len(shape)
207    # All combinations.
208    levels = list(itertools.product(*itertools.repeat(
209      [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
210    # All permutations.
211    orderings = list(map(ir.AffineMap.get_permutation,
212      itertools.permutations(range(rank))))
213    bitwidths = [0]
214    # The first type must be a dense tensor for numpy conversion to work.
215    types = [ir.RankedTensorType.get(shape, f64)]
216    for level in levels:
217      for ordering in orderings:
218        for pwidth in bitwidths:
219          for iwidth in bitwidths:
220            attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
221            types.append(ir.RankedTensorType.get(shape, f64, attr))
222    #
223    # For exhaustiveness we should have one or more StressTest, such
224    # that their paths cover all 2*n*(n-1) directed pairwise combinations
225    # of the `types` set.  However, since n is already superexponential,
226    # such exhaustiveness would be prohibitive for a test that runs on
227    # every commit.  So for now we'll just pick one particular path that
228    # at least hits all n elements of the `types` set.
229    #
230    tyconv = TypeConverter(ctx)
231    size = 1
232    for d in shape:
233      size *= d
234    np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
235    np_out = (
236        StressTest(tyconv).build(types).writeTo(
237            sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
238        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
239    # CHECK: Passed
240    if np.allclose(np_out, np_arg0):
241      print('Passed')
242    else:
243      sys.exit('FAILURE')
244
245if __name__ == '__main__':
246  main()
247