14748cc69Swren romano# RUN: SUPPORT_LIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
24748cc69Swren romano# RUN:   %PYTHON %s | FileCheck %s
34748cc69Swren romano
44748cc69Swren romanoimport ctypes
54748cc69Swren romanoimport errno
64748cc69Swren romanoimport itertools
74748cc69Swren romanoimport os
84748cc69Swren romanoimport sys
98b83b8f1SAart Bik
104748cc69Swren romanofrom typing import List, Callable
114748cc69Swren romano
124748cc69Swren romanoimport numpy as np
134748cc69Swren romano
144748cc69Swren romanofrom mlir import ir
154748cc69Swren romanofrom mlir import runtime as rt
164748cc69Swren romano
17*27a431f5SMatthias Springerfrom mlir.dialects import bufferization
184748cc69Swren romanofrom mlir.dialects import builtin
1923aa5a74SRiver Riddlefrom mlir.dialects import func
204748cc69Swren romanofrom mlir.dialects import sparse_tensor as st
214748cc69Swren romano
228b83b8f1SAart Bik_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
238b83b8f1SAart Biksys.path.append(_SCRIPT_PATH)
248b83b8f1SAart Bikfrom tools import sparse_compiler
258b83b8f1SAart Bik
264748cc69Swren romano# ===----------------------------------------------------------------------=== #
274748cc69Swren romano
284748cc69Swren romano# TODO: move this boilerplate to its own module, so it can be used by
294748cc69Swren romano# other tests and programs.
304748cc69Swren romanoclass TypeConverter:
314748cc69Swren romano  """Converter between NumPy types and MLIR types."""
324748cc69Swren romano
334748cc69Swren romano  def __init__(self, context: ir.Context):
344748cc69Swren romano    # Note 1: these are numpy "scalar types" (i.e., the values of
354748cc69Swren romano    # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
364748cc69Swren romano    #
374748cc69Swren romano    # Note 2: we must construct the MLIR types in the same context as the
384748cc69Swren romano    # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
394748cc69Swren romano    # otherwise, those methods will raise a KeyError.
404748cc69Swren romano    types_list = [
414748cc69Swren romano      (np.float64, ir.F64Type.get(context=context)),
424748cc69Swren romano      (np.float32, ir.F32Type.get(context=context)),
434748cc69Swren romano      (np.int64, ir.IntegerType.get_signless(64, context=context)),
444748cc69Swren romano      (np.int32, ir.IntegerType.get_signless(32, context=context)),
454748cc69Swren romano      (np.int16, ir.IntegerType.get_signless(16, context=context)),
464748cc69Swren romano      (np.int8, ir.IntegerType.get_signless(8, context=context)),
474748cc69Swren romano    ]
484748cc69Swren romano    self._sc2ir = dict(types_list)
494748cc69Swren romano    self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
504748cc69Swren romano
514748cc69Swren romano  def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
524748cc69Swren romano    """Returns the MLIR equivalent of a NumPy dtype."""
534748cc69Swren romano    try:
544748cc69Swren romano      return self.sctype_to_irtype(dtype.type)
554748cc69Swren romano    except KeyError as e:
564748cc69Swren romano      raise KeyError(f'Unknown dtype: {dtype}') from e
574748cc69Swren romano
584748cc69Swren romano  def sctype_to_irtype(self, sctype) -> ir.Type:
594748cc69Swren romano    """Returns the MLIR equivalent of a NumPy scalar type."""
604748cc69Swren romano    if sctype in self._sc2ir:
614748cc69Swren romano      return self._sc2ir[sctype]
624748cc69Swren romano    else:
634748cc69Swren romano      raise KeyError(f'Unknown sctype: {sctype}')
644748cc69Swren romano
654748cc69Swren romano  def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
664748cc69Swren romano    """Returns the NumPy dtype equivalent of an MLIR type."""
674748cc69Swren romano    return np.dtype(self.irtype_to_sctype(tp))
684748cc69Swren romano
694748cc69Swren romano  def irtype_to_sctype(self, tp: ir.Type):
704748cc69Swren romano    """Returns the NumPy scalar-type equivalent of an MLIR type."""
714748cc69Swren romano    if tp in self._ir2sc:
724748cc69Swren romano      return self._ir2sc[tp]
734748cc69Swren romano    else:
744748cc69Swren romano      raise KeyError(f'Unknown ir.Type: {tp}')
754748cc69Swren romano
764748cc69Swren romano  def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
774748cc69Swren romano    """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
784748cc69Swren romano    arrays can only be converted to/from dense tensors, not sparse tensors."""
794748cc69Swren romano    # TODO: handle strides as well?
804748cc69Swren romano    return ir.RankedTensorType.get(nparray.shape,
814748cc69Swren romano                                   self.dtype_to_irtype(nparray.dtype))
824748cc69Swren romano
834748cc69Swren romano# ===----------------------------------------------------------------------=== #
844748cc69Swren romano
854748cc69Swren romanoclass StressTest:
864748cc69Swren romano  def __init__(self, tyconv: TypeConverter):
874748cc69Swren romano    self._tyconv = tyconv
884748cc69Swren romano    self._roundtripTp = None
894748cc69Swren romano    self._module = None
904748cc69Swren romano    self._engine = None
914748cc69Swren romano
924748cc69Swren romano  def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
934748cc69Swren romano    assert self._roundtripTp is not None, \
944748cc69Swren romano        'StressTest: uninitialized roundtrip type'
954748cc69Swren romano    if tp != self._roundtripTp:
964748cc69Swren romano      raise AssertionError(
974748cc69Swren romano          f"Type is not equal to the roundtrip type.\n"
984748cc69Swren romano          f"\tExpected: {self._roundtripTp}\n"
994748cc69Swren romano          f"\tFound:    {tp}\n")
1004748cc69Swren romano
1014748cc69Swren romano  def build(self, types: List[ir.Type]):
1024748cc69Swren romano    """Builds the ir.Module.  The module has only the @main function,
1034748cc69Swren romano    which will convert the input through the list of types and then back
1044748cc69Swren romano    to the initial type.  The roundtrip type must be a dense tensor."""
1054748cc69Swren romano    assert self._module is None, 'StressTest: must not call build() repeatedly'
1064748cc69Swren romano    self._module = ir.Module.create()
1074748cc69Swren romano    with ir.InsertionPoint(self._module.body):
1084748cc69Swren romano      tp0 = types.pop(0)
1094748cc69Swren romano      self._roundtripTp = tp0
1104748cc69Swren romano      # TODO: assert dense? assert element type is recognised by the TypeConverter?
1114748cc69Swren romano      types.append(tp0)
1124748cc69Swren romano      funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
11336550692SRiver Riddle      funcOp = func.FuncOp(name='main', type=funcTp)
1144748cc69Swren romano      funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
1154748cc69Swren romano      with ir.InsertionPoint(funcOp.add_entry_block()):
1164748cc69Swren romano        arg0 = funcOp.entry_block.arguments[0]
1174748cc69Swren romano        self._assertEqualsRoundtripTp(arg0.type)
1184748cc69Swren romano        v = st.ConvertOp(types.pop(0), arg0)
1194748cc69Swren romano        for tp in types:
1204748cc69Swren romano          w = st.ConvertOp(tp, v)
1214748cc69Swren romano          # Release intermediate tensors before they fall out of scope.
122*27a431f5SMatthias Springer          bufferization.DeallocTensorOp(v.result)
1234748cc69Swren romano          v = w
1244748cc69Swren romano        self._assertEqualsRoundtripTp(v.result.type)
12523aa5a74SRiver Riddle        func.ReturnOp(v)
1264748cc69Swren romano    return self
1274748cc69Swren romano
1284748cc69Swren romano  def writeTo(self, filename):
1294748cc69Swren romano    """Write the ir.Module to the given file.  If the file already exists,
1304748cc69Swren romano    then raises an error.  If the filename is None, then is a no-op."""
1314748cc69Swren romano    assert self._module is not None, \
1324748cc69Swren romano        'StressTest: must call build() before writeTo()'
1334748cc69Swren romano    if filename is None:
1344748cc69Swren romano      # Silent no-op, for convenience.
1354748cc69Swren romano      return self
1364748cc69Swren romano    if os.path.exists(filename):
1374748cc69Swren romano      raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
1384748cc69Swren romano    with open(filename, 'w') as f:
1394748cc69Swren romano      f.write(str(self._module))
1404748cc69Swren romano    return self
1414748cc69Swren romano
14228063a28SAart Bik  def compile(self, compiler):
1434748cc69Swren romano    """Compile the ir.Module."""
1444748cc69Swren romano    assert self._module is not None, \
1454748cc69Swren romano        'StressTest: must call build() before compile()'
1464748cc69Swren romano    assert self._engine is None, \
1474748cc69Swren romano        'StressTest: must not call compile() repeatedly'
14828063a28SAart Bik    self._engine = compiler.compile_and_jit(self._module)
1494748cc69Swren romano    return self
1504748cc69Swren romano
1514748cc69Swren romano  def run(self, np_arg0: np.ndarray) -> np.ndarray:
1524748cc69Swren romano    """Runs the test on the given numpy array, and returns the resulting
1534748cc69Swren romano    numpy array."""
1544748cc69Swren romano    assert self._engine is not None, \
1554748cc69Swren romano        'StressTest: must call compile() before run()'
1564748cc69Swren romano    self._assertEqualsRoundtripTp(
1574748cc69Swren romano        self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
1584748cc69Swren romano    np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
1594748cc69Swren romano    self._assertEqualsRoundtripTp(
1604748cc69Swren romano        self._tyconv.get_RankedTensorType_of_nparray(np_out))
1614748cc69Swren romano    mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
1624748cc69Swren romano    mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
1634748cc69Swren romano    self._engine.invoke('main', mem_out, mem_arg0)
1644748cc69Swren romano    return rt.ranked_memref_to_numpy(mem_out[0])
1654748cc69Swren romano
1664748cc69Swren romano# ===----------------------------------------------------------------------=== #
1674748cc69Swren romano
1684748cc69Swren romanodef main():
1694748cc69Swren romano  """
1704748cc69Swren romano  USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
1714748cc69Swren romano
1724748cc69Swren romano  The environment variable SUPPORT_LIB must be set to point to the
1734748cc69Swren romano  libmlir_c_runner_utils shared library.  There are two optional
1744748cc69Swren romano  arguments, for debugging purposes.  The first argument specifies where
1754748cc69Swren romano  to write out the raw/generated ir.Module.  The second argument specifies
1764748cc69Swren romano  where to write out the compiled version of that ir.Module.
1774748cc69Swren romano  """
1784748cc69Swren romano  support_lib = os.getenv('SUPPORT_LIB')
1794748cc69Swren romano  assert support_lib is not None, 'SUPPORT_LIB is undefined'
1804748cc69Swren romano  if not os.path.exists(support_lib):
1814748cc69Swren romano    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
1824748cc69Swren romano
1834748cc69Swren romano  # CHECK-LABEL: TEST: test_stress
1844748cc69Swren romano  print("\nTEST: test_stress")
1854748cc69Swren romano  with ir.Context() as ctx, ir.Location.unknown():
1864620032eSNick Kreeger    par = 0
1874620032eSNick Kreeger    vec = 0
1884748cc69Swren romano    vl = 1
1894748cc69Swren romano    e = False
1908cb33240Swren romano    # Disable direct sparse2sparse conversion, because it doubles the time!
1918cb33240Swren romano    # TODO: While direct s2s is far too slow for per-commit testing,
1928cb33240Swren romano    # we should have some framework ensure that we run this test with
1938cb33240Swren romano    # `s2s=0` on a regular basis, to ensure that it does continue to work.
1948cb33240Swren romano    s2s = 1
1954748cc69Swren romano    sparsification_options = (
1964620032eSNick Kreeger        f'parallelization-strategy={par} '
1974620032eSNick Kreeger        f'vectorization-strategy={vec} '
1984748cc69Swren romano        f'vl={vl} '
1998cb33240Swren romano        f'enable-simd-index32={e} '
2008cb33240Swren romano        f's2s-strategy={s2s}')
20128063a28SAart Bik    compiler = sparse_compiler.SparseCompiler(
20228063a28SAart Bik        options=sparsification_options, opt_level=0, shared_libs=[support_lib])
2034748cc69Swren romano    f64 = ir.F64Type.get()
2044748cc69Swren romano    # Be careful about increasing this because
2054748cc69Swren romano    #     len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
2064748cc69Swren romano    shape = range(2, 6)
2074748cc69Swren romano    rank = len(shape)
2084748cc69Swren romano    # All combinations.
2094748cc69Swren romano    levels = list(itertools.product(*itertools.repeat(
2104748cc69Swren romano      [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
2114748cc69Swren romano    # All permutations.
2124748cc69Swren romano    orderings = list(map(ir.AffineMap.get_permutation,
2134748cc69Swren romano      itertools.permutations(range(rank))))
2144748cc69Swren romano    bitwidths = [0]
2154748cc69Swren romano    # The first type must be a dense tensor for numpy conversion to work.
2164748cc69Swren romano    types = [ir.RankedTensorType.get(shape, f64)]
2174748cc69Swren romano    for level in levels:
2184748cc69Swren romano      for ordering in orderings:
2194748cc69Swren romano        for pwidth in bitwidths:
2204748cc69Swren romano          for iwidth in bitwidths:
2214748cc69Swren romano            attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
2224748cc69Swren romano            types.append(ir.RankedTensorType.get(shape, f64, attr))
2234748cc69Swren romano    #
2244748cc69Swren romano    # For exhaustiveness we should have one or more StressTest, such
2254748cc69Swren romano    # that their paths cover all 2*n*(n-1) directed pairwise combinations
2264748cc69Swren romano    # of the `types` set.  However, since n is already superexponential,
2274748cc69Swren romano    # such exhaustiveness would be prohibitive for a test that runs on
2284748cc69Swren romano    # every commit.  So for now we'll just pick one particular path that
2294748cc69Swren romano    # at least hits all n elements of the `types` set.
2304748cc69Swren romano    #
2314748cc69Swren romano    tyconv = TypeConverter(ctx)
2324748cc69Swren romano    size = 1
2334748cc69Swren romano    for d in shape:
2344748cc69Swren romano      size *= d
2354748cc69Swren romano    np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
2364748cc69Swren romano    np_out = (
2378b83b8f1SAart Bik        StressTest(tyconv).build(types).writeTo(
23828063a28SAart Bik            sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
23928063a28SAart Bik        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
2404748cc69Swren romano    # CHECK: Passed
2414748cc69Swren romano    if np.allclose(np_out, np_arg0):
2424748cc69Swren romano      print('Passed')
2434748cc69Swren romano    else:
2444748cc69Swren romano      sys.exit('FAILURE')
2454748cc69Swren romano
2464748cc69Swren romanoif __name__ == '__main__':
2474748cc69Swren romano  main()
248