19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s 2>&1 | FileCheck %s
2*b630bafbSRainer Orth# REQUIRES: host-supports-jit
39f3f6d7bSStella Laurenzoimport gc, sys
49f3f6d7bSStella Laurenzofrom mlir.ir import *
59f3f6d7bSStella Laurenzofrom mlir.passmanager import *
69f3f6d7bSStella Laurenzofrom mlir.execution_engine import *
79f3f6d7bSStella Laurenzofrom mlir.runtime import *
89f3f6d7bSStella Laurenzo
9a54f4eaeSMogball
109f3f6d7bSStella Laurenzo# Log everything to stderr and flush so that we have a unified stream to match
119f3f6d7bSStella Laurenzo# errors/info emitted by MLIR to stderr.
129f3f6d7bSStella Laurenzodef log(*args):
139f3f6d7bSStella Laurenzo  print(*args, file=sys.stderr)
149f3f6d7bSStella Laurenzo  sys.stderr.flush()
159f3f6d7bSStella Laurenzo
16a54f4eaeSMogball
179f3f6d7bSStella Laurenzodef run(f):
189f3f6d7bSStella Laurenzo  log("\nTEST:", f.__name__)
199f3f6d7bSStella Laurenzo  f()
209f3f6d7bSStella Laurenzo  gc.collect()
219f3f6d7bSStella Laurenzo  assert Context._get_live_count() == 0
229f3f6d7bSStella Laurenzo
23a54f4eaeSMogball
249f3f6d7bSStella Laurenzo# Verify capsule interop.
259f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCapsule
269f3f6d7bSStella Laurenzodef testCapsule():
279f3f6d7bSStella Laurenzo  with Context():
289f3f6d7bSStella Laurenzo    module = Module.parse(r"""
299f3f6d7bSStella Laurenzollvm.func @none() {
309f3f6d7bSStella Laurenzo  llvm.return
319f3f6d7bSStella Laurenzo}
329f3f6d7bSStella Laurenzo    """)
339f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(module)
349f3f6d7bSStella Laurenzo    execution_engine_capsule = execution_engine._CAPIPtr
359f3f6d7bSStella Laurenzo    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
369f3f6d7bSStella Laurenzo    log(repr(execution_engine_capsule))
379f3f6d7bSStella Laurenzo    execution_engine._testing_release()
389f3f6d7bSStella Laurenzo    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
390cdf4915SStella Laurenzo    # CHECK: _mlirExecutionEngine.ExecutionEngine
409f3f6d7bSStella Laurenzo    log(repr(execution_engine1))
419f3f6d7bSStella Laurenzo
42a54f4eaeSMogball
439f3f6d7bSStella Laurenzorun(testCapsule)
449f3f6d7bSStella Laurenzo
45a54f4eaeSMogball
469f3f6d7bSStella Laurenzo# Test invalid ExecutionEngine creation
479f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModule
489f3f6d7bSStella Laurenzodef testInvalidModule():
499f3f6d7bSStella Laurenzo  with Context():
509f3f6d7bSStella Laurenzo    # Builtin function
519f3f6d7bSStella Laurenzo    module = Module.parse(r"""
522310ced8SRiver Riddle    func.func @foo() { return }
539f3f6d7bSStella Laurenzo    """)
549f3f6d7bSStella Laurenzo    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
559f3f6d7bSStella Laurenzo    try:
569f3f6d7bSStella Laurenzo      execution_engine = ExecutionEngine(module)
579f3f6d7bSStella Laurenzo    except RuntimeError as e:
589f3f6d7bSStella Laurenzo      log("Got RuntimeError: ", e)
599f3f6d7bSStella Laurenzo
60a54f4eaeSMogball
619f3f6d7bSStella Laurenzorun(testInvalidModule)
629f3f6d7bSStella Laurenzo
63a54f4eaeSMogball
649f3f6d7bSStella Laurenzodef lowerToLLVM(module):
65a54f4eaeSMogball  pm = PassManager.parse(
66d6682189SAart Bik      "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
679f3f6d7bSStella Laurenzo  pm.run(module)
689f3f6d7bSStella Laurenzo  return module
699f3f6d7bSStella Laurenzo
70a54f4eaeSMogball
719f3f6d7bSStella Laurenzo# Test simple ExecutionEngine execution
729f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeVoid
739f3f6d7bSStella Laurenzodef testInvokeVoid():
749f3f6d7bSStella Laurenzo  with Context():
759f3f6d7bSStella Laurenzo    module = Module.parse(r"""
762310ced8SRiver Riddlefunc.func @void() attributes { llvm.emit_c_interface } {
779f3f6d7bSStella Laurenzo  return
789f3f6d7bSStella Laurenzo}
799f3f6d7bSStella Laurenzo    """)
809f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
819f3f6d7bSStella Laurenzo    # Nothing to check other than no exception thrown here.
829f3f6d7bSStella Laurenzo    execution_engine.invoke("void")
839f3f6d7bSStella Laurenzo
84a54f4eaeSMogball
859f3f6d7bSStella Laurenzorun(testInvokeVoid)
869f3f6d7bSStella Laurenzo
879f3f6d7bSStella Laurenzo
889f3f6d7bSStella Laurenzo# Test argument passing and result with a simple float addition.
899f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeFloatAdd
909f3f6d7bSStella Laurenzodef testInvokeFloatAdd():
919f3f6d7bSStella Laurenzo  with Context():
929f3f6d7bSStella Laurenzo    module = Module.parse(r"""
932310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
94a54f4eaeSMogball  %add = arith.addf %arg0, %arg1 : f32
959f3f6d7bSStella Laurenzo  return %add : f32
969f3f6d7bSStella Laurenzo}
979f3f6d7bSStella Laurenzo    """)
989f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
999f3f6d7bSStella Laurenzo    # Prepare arguments: two input floats and one result.
1009f3f6d7bSStella Laurenzo    # Arguments must be passed as pointers.
1019f3f6d7bSStella Laurenzo    c_float_p = ctypes.c_float * 1
1029f3f6d7bSStella Laurenzo    arg0 = c_float_p(42.)
1039f3f6d7bSStella Laurenzo    arg1 = c_float_p(2.)
1049f3f6d7bSStella Laurenzo    res = c_float_p(-1.)
1059f3f6d7bSStella Laurenzo    execution_engine.invoke("add", arg0, arg1, res)
1069f3f6d7bSStella Laurenzo    # CHECK: 42.0 + 2.0 = 44.0
1079f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
1089f3f6d7bSStella Laurenzo
109a54f4eaeSMogball
1109f3f6d7bSStella Laurenzorun(testInvokeFloatAdd)
1119f3f6d7bSStella Laurenzo
1129f3f6d7bSStella Laurenzo
1139f3f6d7bSStella Laurenzo# Test callback
1149f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBasicCallback
1159f3f6d7bSStella Laurenzodef testBasicCallback():
1169f3f6d7bSStella Laurenzo  # Define a callback function that takes a float and an integer and returns a float.
1179f3f6d7bSStella Laurenzo  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
1189f3f6d7bSStella Laurenzo  def callback(a, b):
1199f3f6d7bSStella Laurenzo    return a / 2 + b / 2
1209f3f6d7bSStella Laurenzo
1219f3f6d7bSStella Laurenzo  with Context():
1229f3f6d7bSStella Laurenzo    # The module just forwards to a runtime function known as "some_callback_into_python".
1239f3f6d7bSStella Laurenzo    module = Module.parse(r"""
1242310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
1259f3f6d7bSStella Laurenzo  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
1269f3f6d7bSStella Laurenzo  return %resf : f32
1279f3f6d7bSStella Laurenzo}
1282310ced8SRiver Riddlefunc.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
1299f3f6d7bSStella Laurenzo    """)
1309f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
1319f3f6d7bSStella Laurenzo    execution_engine.register_runtime("some_callback_into_python", callback)
1329f3f6d7bSStella Laurenzo
1339f3f6d7bSStella Laurenzo    # Prepare arguments: two input floats and one result.
1349f3f6d7bSStella Laurenzo    # Arguments must be passed as pointers.
1359f3f6d7bSStella Laurenzo    c_float_p = ctypes.c_float * 1
1369f3f6d7bSStella Laurenzo    c_int_p = ctypes.c_int * 1
1379f3f6d7bSStella Laurenzo    arg0 = c_float_p(42.)
1389f3f6d7bSStella Laurenzo    arg1 = c_int_p(2)
1399f3f6d7bSStella Laurenzo    res = c_float_p(-1.)
1409f3f6d7bSStella Laurenzo    execution_engine.invoke("add", arg0, arg1, res)
1419f3f6d7bSStella Laurenzo    # CHECK: 42.0 + 2 = 44.0
1429f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
1439f3f6d7bSStella Laurenzo
144a54f4eaeSMogball
1459f3f6d7bSStella Laurenzorun(testBasicCallback)
1469f3f6d7bSStella Laurenzo
147a54f4eaeSMogball
1489f3f6d7bSStella Laurenzo# Test callback with an unranked memref
1499f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUnrankedMemRefCallback
1509f3f6d7bSStella Laurenzodef testUnrankedMemRefCallback():
1519f3f6d7bSStella Laurenzo  # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
1529f3f6d7bSStella Laurenzo  @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
1539f3f6d7bSStella Laurenzo  def callback(a):
1549f3f6d7bSStella Laurenzo    arr = unranked_memref_to_numpy(a, np.float32)
1559f3f6d7bSStella Laurenzo    log("Inside callback: ")
1569f3f6d7bSStella Laurenzo    log(arr)
1579f3f6d7bSStella Laurenzo
1589f3f6d7bSStella Laurenzo  with Context():
1599f3f6d7bSStella Laurenzo    # The module just forwards to a runtime function known as "some_callback_into_python".
160a54f4eaeSMogball    module = Module.parse(r"""
1612310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
1629f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
1639f3f6d7bSStella Laurenzo  return
1649f3f6d7bSStella Laurenzo}
1652310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
166a54f4eaeSMogball""")
1679f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
1689f3f6d7bSStella Laurenzo    execution_engine.register_runtime("some_callback_into_python", callback)
1699f3f6d7bSStella Laurenzo    inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
1709f3f6d7bSStella Laurenzo    # CHECK: Inside callback:
1719f3f6d7bSStella Laurenzo    # CHECK{LITERAL}: [[1. 2.]
1729f3f6d7bSStella Laurenzo    # CHECK{LITERAL}:  [3. 4.]]
1739f3f6d7bSStella Laurenzo    execution_engine.invoke(
1749f3f6d7bSStella Laurenzo        "callback_memref",
1759f3f6d7bSStella Laurenzo        ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
1769f3f6d7bSStella Laurenzo    )
1779f3f6d7bSStella Laurenzo    inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
1789f3f6d7bSStella Laurenzo    strided_arr = np.lib.stride_tricks.as_strided(
179a54f4eaeSMogball        inp_arr_1, strides=(4, 0), shape=(3, 4))
1809f3f6d7bSStella Laurenzo    # CHECK: Inside callback:
1819f3f6d7bSStella Laurenzo    # CHECK{LITERAL}: [[5. 5. 5. 5.]
1829f3f6d7bSStella Laurenzo    # CHECK{LITERAL}:  [6. 6. 6. 6.]
1839f3f6d7bSStella Laurenzo    # CHECK{LITERAL}:  [7. 7. 7. 7.]]
1849f3f6d7bSStella Laurenzo    execution_engine.invoke(
1859f3f6d7bSStella Laurenzo        "callback_memref",
1869f3f6d7bSStella Laurenzo        ctypes.pointer(
187a54f4eaeSMogball            ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
1889f3f6d7bSStella Laurenzo    )
1899f3f6d7bSStella Laurenzo
190a54f4eaeSMogball
1919f3f6d7bSStella Laurenzorun(testUnrankedMemRefCallback)
1929f3f6d7bSStella Laurenzo
193a54f4eaeSMogball
1949f3f6d7bSStella Laurenzo# Test callback with a ranked memref.
1959f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRankedMemRefCallback
1969f3f6d7bSStella Laurenzodef testRankedMemRefCallback():
1979f3f6d7bSStella Laurenzo  # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
1989f3f6d7bSStella Laurenzo  @ctypes.CFUNCTYPE(
1999f3f6d7bSStella Laurenzo      None,
2009f3f6d7bSStella Laurenzo      ctypes.POINTER(
201a54f4eaeSMogball          make_nd_memref_descriptor(2,
202a54f4eaeSMogball                                    np.ctypeslib.as_ctypes_type(np.float32))),
2039f3f6d7bSStella Laurenzo  )
2049f3f6d7bSStella Laurenzo  def callback(a):
2059f3f6d7bSStella Laurenzo    arr = ranked_memref_to_numpy(a)
2069f3f6d7bSStella Laurenzo    log("Inside Callback: ")
2079f3f6d7bSStella Laurenzo    log(arr)
2089f3f6d7bSStella Laurenzo
2099f3f6d7bSStella Laurenzo  with Context():
2109f3f6d7bSStella Laurenzo    # The module just forwards to a runtime function known as "some_callback_into_python".
211a54f4eaeSMogball    module = Module.parse(r"""
2122310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
2139f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
2149f3f6d7bSStella Laurenzo  return
2159f3f6d7bSStella Laurenzo}
2162310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
217a54f4eaeSMogball""")
2189f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
2199f3f6d7bSStella Laurenzo    execution_engine.register_runtime("some_callback_into_python", callback)
2209f3f6d7bSStella Laurenzo    inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
2219f3f6d7bSStella Laurenzo    # CHECK: Inside Callback:
2229f3f6d7bSStella Laurenzo    # CHECK{LITERAL}: [[1. 5.]
2239f3f6d7bSStella Laurenzo    # CHECK{LITERAL}:  [6. 7.]]
2249f3f6d7bSStella Laurenzo    execution_engine.invoke(
225a54f4eaeSMogball        "callback_memref",
226a54f4eaeSMogball        ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))))
227a54f4eaeSMogball
2289f3f6d7bSStella Laurenzo
2299f3f6d7bSStella Laurenzorun(testRankedMemRefCallback)
2309f3f6d7bSStella Laurenzo
231a54f4eaeSMogball
232c8b8e8e0SUday Bondhugula#  Test addition of two memrefs.
2339f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testMemrefAdd
2349f3f6d7bSStella Laurenzodef testMemrefAdd():
2359f3f6d7bSStella Laurenzo  with Context():
236a54f4eaeSMogball    module = Module.parse("""
2379f3f6d7bSStella Laurenzo    module  {
2382310ced8SRiver Riddle      func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
239a54f4eaeSMogball        %0 = arith.constant 0 : index
2409f3f6d7bSStella Laurenzo        %1 = memref.load %arg0[%0] : memref<1xf32>
2419f3f6d7bSStella Laurenzo        %2 = memref.load %arg1[] : memref<f32>
242a54f4eaeSMogball        %3 = arith.addf %1, %2 : f32
2439f3f6d7bSStella Laurenzo        memref.store %3, %arg2[%0] : memref<1xf32>
2449f3f6d7bSStella Laurenzo        return
2459f3f6d7bSStella Laurenzo      }
246a54f4eaeSMogball    } """)
2479f3f6d7bSStella Laurenzo    arg1 = np.array([32.5]).astype(np.float32)
2489f3f6d7bSStella Laurenzo    arg2 = np.array(6).astype(np.float32)
2499f3f6d7bSStella Laurenzo    res = np.array([0]).astype(np.float32)
2509f3f6d7bSStella Laurenzo
251a54f4eaeSMogball    arg1_memref_ptr = ctypes.pointer(
252a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
253a54f4eaeSMogball    arg2_memref_ptr = ctypes.pointer(
254a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
255a54f4eaeSMogball    res_memref_ptr = ctypes.pointer(
256a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(res)))
2579f3f6d7bSStella Laurenzo
2589f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
259a54f4eaeSMogball    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
260a54f4eaeSMogball                            res_memref_ptr)
2619f3f6d7bSStella Laurenzo    # CHECK: [32.5] + 6.0 = [38.5]
2629f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg1, arg2, res))
2639f3f6d7bSStella Laurenzo
264a54f4eaeSMogball
2659f3f6d7bSStella Laurenzorun(testMemrefAdd)
2669f3f6d7bSStella Laurenzo
267a54f4eaeSMogball
268f8b692ddSAart Bik# Test addition of two f16 memrefs
269f8b692ddSAart Bik# CHECK-LABEL: TEST: testF16MemrefAdd
270f8b692ddSAart Bikdef testF16MemrefAdd():
271f8b692ddSAart Bik  with Context():
272f8b692ddSAart Bik    module = Module.parse("""
273f8b692ddSAart Bik    module  {
274f8b692ddSAart Bik      func.func @main(%arg0: memref<1xf16>,
275f8b692ddSAart Bik                      %arg1: memref<1xf16>,
276f8b692ddSAart Bik                      %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
277f8b692ddSAart Bik        %0 = arith.constant 0 : index
278f8b692ddSAart Bik        %1 = memref.load %arg0[%0] : memref<1xf16>
279f8b692ddSAart Bik        %2 = memref.load %arg1[%0] : memref<1xf16>
280f8b692ddSAart Bik        %3 = arith.addf %1, %2 : f16
281f8b692ddSAart Bik        memref.store %3, %arg2[%0] : memref<1xf16>
282f8b692ddSAart Bik        return
283f8b692ddSAart Bik      }
284f8b692ddSAart Bik    } """)
285f8b692ddSAart Bik
286f8b692ddSAart Bik    arg1 = np.array([11.]).astype(np.float16)
287f8b692ddSAart Bik    arg2 = np.array([22.]).astype(np.float16)
288f8b692ddSAart Bik    arg3 = np.array([0.]).astype(np.float16)
289f8b692ddSAart Bik
290f8b692ddSAart Bik    arg1_memref_ptr = ctypes.pointer(
291f8b692ddSAart Bik        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
292f8b692ddSAart Bik    arg2_memref_ptr = ctypes.pointer(
293f8b692ddSAart Bik        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
294f8b692ddSAart Bik    arg3_memref_ptr = ctypes.pointer(
295f8b692ddSAart Bik        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
296f8b692ddSAart Bik
297f8b692ddSAart Bik    execution_engine = ExecutionEngine(lowerToLLVM(module))
298f8b692ddSAart Bik    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
299f8b692ddSAart Bik                            arg3_memref_ptr)
300f8b692ddSAart Bik    # CHECK: [11.] + [22.] = [33.]
301f8b692ddSAart Bik    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
302f8b692ddSAart Bik
303f8b692ddSAart Bik    # test to-numpy utility
304f8b692ddSAart Bik    # CHECK: [33.]
305f8b692ddSAart Bik    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
306f8b692ddSAart Bik    log(npout)
307f8b692ddSAart Bik
308f8b692ddSAart Bik
309f8b692ddSAart Bikrun(testF16MemrefAdd)
310f8b692ddSAart Bik
311f8b692ddSAart Bik
312d6682189SAart Bik# Test addition of two complex memrefs
313d6682189SAart Bik# CHECK-LABEL: TEST: testComplexMemrefAdd
314d6682189SAart Bikdef testComplexMemrefAdd():
315d6682189SAart Bik  with Context():
316d6682189SAart Bik    module = Module.parse("""
317d6682189SAart Bik    module  {
318d6682189SAart Bik      func.func @main(%arg0: memref<1xcomplex<f64>>,
319d6682189SAart Bik                      %arg1: memref<1xcomplex<f64>>,
320d6682189SAart Bik                      %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
321d6682189SAart Bik        %0 = arith.constant 0 : index
322d6682189SAart Bik        %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
323d6682189SAart Bik        %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
324d6682189SAart Bik        %3 = complex.add %1, %2 : complex<f64>
325d6682189SAart Bik        memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
326d6682189SAart Bik        return
327d6682189SAart Bik      }
328d6682189SAart Bik    } """)
329d6682189SAart Bik
330d6682189SAart Bik    arg1 = np.array([1.+2.j]).astype(np.complex128)
331d6682189SAart Bik    arg2 = np.array([3.+4.j]).astype(np.complex128)
332d6682189SAart Bik    arg3  = np.array([0.+0.j]).astype(np.complex128)
333d6682189SAart Bik
334d6682189SAart Bik    arg1_memref_ptr = ctypes.pointer(
335d6682189SAart Bik        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
336d6682189SAart Bik    arg2_memref_ptr = ctypes.pointer(
337d6682189SAart Bik        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
338d6682189SAart Bik    arg3_memref_ptr = ctypes.pointer(
339d6682189SAart Bik        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
340d6682189SAart Bik
341d6682189SAart Bik    execution_engine = ExecutionEngine(lowerToLLVM(module))
342d6682189SAart Bik    execution_engine.invoke("main",
343d6682189SAart Bik                            arg1_memref_ptr,
344d6682189SAart Bik                            arg2_memref_ptr,
345d6682189SAart Bik                            arg3_memref_ptr)
346d6682189SAart Bik    # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
347d6682189SAart Bik    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
348d6682189SAart Bik
349d6682189SAart Bik    # test to-numpy utility
350d6682189SAart Bik    # CHECK: [4.+6.j]
351d6682189SAart Bik    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
352d6682189SAart Bik    log(npout)
353d6682189SAart Bik
354d6682189SAart Bik
355d6682189SAart Bikrun(testComplexMemrefAdd)
356d6682189SAart Bik
357d6682189SAart Bik
358d6682189SAart Bik# Test addition of two complex unranked memrefs
359d6682189SAart Bik# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
360d6682189SAart Bikdef testComplexUnrankedMemrefAdd():
361d6682189SAart Bik  with Context():
362d6682189SAart Bik    module = Module.parse("""
363d6682189SAart Bik    module  {
364d6682189SAart Bik      func.func @main(%arg0: memref<*xcomplex<f32>>,
365d6682189SAart Bik                      %arg1: memref<*xcomplex<f32>>,
366d6682189SAart Bik                      %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
367d6682189SAart Bik        %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
368d6682189SAart Bik        %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
369d6682189SAart Bik        %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
370d6682189SAart Bik        %0 = arith.constant 0 : index
371d6682189SAart Bik        %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
372d6682189SAart Bik        %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
373d6682189SAart Bik        %3 = complex.add %1, %2 : complex<f32>
374d6682189SAart Bik        memref.store %3, %C[%0] : memref<1xcomplex<f32>>
375d6682189SAart Bik        return
376d6682189SAart Bik      }
377d6682189SAart Bik    } """)
378d6682189SAart Bik
379d6682189SAart Bik    arg1 = np.array([5.+6.j]).astype(np.complex64)
380d6682189SAart Bik    arg2 = np.array([7.+8.j]).astype(np.complex64)
381d6682189SAart Bik    arg3  = np.array([0.+0.j]).astype(np.complex64)
382d6682189SAart Bik
383d6682189SAart Bik    arg1_memref_ptr = ctypes.pointer(
384d6682189SAart Bik        ctypes.pointer(get_unranked_memref_descriptor(arg1)))
385d6682189SAart Bik    arg2_memref_ptr = ctypes.pointer(
386d6682189SAart Bik        ctypes.pointer(get_unranked_memref_descriptor(arg2)))
387d6682189SAart Bik    arg3_memref_ptr = ctypes.pointer(
388d6682189SAart Bik        ctypes.pointer(get_unranked_memref_descriptor(arg3)))
389d6682189SAart Bik
390d6682189SAart Bik    execution_engine = ExecutionEngine(lowerToLLVM(module))
391d6682189SAart Bik    execution_engine.invoke("main",
392d6682189SAart Bik                            arg1_memref_ptr,
393d6682189SAart Bik                            arg2_memref_ptr,
394d6682189SAart Bik                            arg3_memref_ptr)
395d6682189SAart Bik    # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
396d6682189SAart Bik    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
397d6682189SAart Bik
398d6682189SAart Bik    # test to-numpy utility
399d6682189SAart Bik    # CHECK: [12.+14.j]
400d6682189SAart Bik    npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
401d6682189SAart Bik                                     np.dtype(np.complex64))
402d6682189SAart Bik    log(npout)
403d6682189SAart Bik
404d6682189SAart Bik
405d6682189SAart Bikrun(testComplexUnrankedMemrefAdd)
406d6682189SAart Bik
407d6682189SAart Bik
4089f3f6d7bSStella Laurenzo#  Test addition of two 2d_memref
4099f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
4109f3f6d7bSStella Laurenzodef testDynamicMemrefAdd2D():
4119f3f6d7bSStella Laurenzo  with Context():
412a54f4eaeSMogball    module = Module.parse("""
4139f3f6d7bSStella Laurenzo      module  {
4142310ced8SRiver Riddle        func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
415a54f4eaeSMogball          %c0 = arith.constant 0 : index
416a54f4eaeSMogball          %c2 = arith.constant 2 : index
417a54f4eaeSMogball          %c1 = arith.constant 1 : index
418ace01605SRiver Riddle          cf.br ^bb1(%c0 : index)
4199f3f6d7bSStella Laurenzo        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
420a54f4eaeSMogball          %1 = arith.cmpi slt, %0, %c2 : index
421ace01605SRiver Riddle          cf.cond_br %1, ^bb2, ^bb6
4229f3f6d7bSStella Laurenzo        ^bb2:  // pred: ^bb1
423a54f4eaeSMogball          %c0_0 = arith.constant 0 : index
424a54f4eaeSMogball          %c2_1 = arith.constant 2 : index
425a54f4eaeSMogball          %c1_2 = arith.constant 1 : index
426ace01605SRiver Riddle          cf.br ^bb3(%c0_0 : index)
4279f3f6d7bSStella Laurenzo        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
428a54f4eaeSMogball          %3 = arith.cmpi slt, %2, %c2_1 : index
429ace01605SRiver Riddle          cf.cond_br %3, ^bb4, ^bb5
4309f3f6d7bSStella Laurenzo        ^bb4:  // pred: ^bb3
4319f3f6d7bSStella Laurenzo          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
4329f3f6d7bSStella Laurenzo          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
433a54f4eaeSMogball          %6 = arith.addf %4, %5 : f32
4349f3f6d7bSStella Laurenzo          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
435a54f4eaeSMogball          %7 = arith.addi %2, %c1_2 : index
436ace01605SRiver Riddle          cf.br ^bb3(%7 : index)
4379f3f6d7bSStella Laurenzo        ^bb5:  // pred: ^bb3
438a54f4eaeSMogball          %8 = arith.addi %0, %c1 : index
439ace01605SRiver Riddle          cf.br ^bb1(%8 : index)
4409f3f6d7bSStella Laurenzo        ^bb6:  // pred: ^bb1
4419f3f6d7bSStella Laurenzo          return
4429f3f6d7bSStella Laurenzo        }
4439f3f6d7bSStella Laurenzo      }
444a54f4eaeSMogball        """)
4459f3f6d7bSStella Laurenzo    arg1 = np.random.randn(2, 2).astype(np.float32)
4469f3f6d7bSStella Laurenzo    arg2 = np.random.randn(2, 2).astype(np.float32)
4479f3f6d7bSStella Laurenzo    res = np.random.randn(2, 2).astype(np.float32)
4489f3f6d7bSStella Laurenzo
449a54f4eaeSMogball    arg1_memref_ptr = ctypes.pointer(
450a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
451a54f4eaeSMogball    arg2_memref_ptr = ctypes.pointer(
452a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
453a54f4eaeSMogball    res_memref_ptr = ctypes.pointer(
454a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(res)))
4559f3f6d7bSStella Laurenzo
4569f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
457a54f4eaeSMogball    execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr,
458a54f4eaeSMogball                            res_memref_ptr)
4599f3f6d7bSStella Laurenzo    # CHECK: True
4609f3f6d7bSStella Laurenzo    log(np.allclose(arg1 + arg2, res))
4619f3f6d7bSStella Laurenzo
462a54f4eaeSMogball
4639f3f6d7bSStella Laurenzorun(testDynamicMemrefAdd2D)
464c8b8e8e0SUday Bondhugula
465a54f4eaeSMogball
466c8b8e8e0SUday Bondhugula#  Test loading of shared libraries.
467c8b8e8e0SUday Bondhugula# CHECK-LABEL: TEST: testSharedLibLoad
468c8b8e8e0SUday Bondhuguladef testSharedLibLoad():
469c8b8e8e0SUday Bondhugula  with Context():
470a54f4eaeSMogball    module = Module.parse("""
471c8b8e8e0SUday Bondhugula      module  {
4722310ced8SRiver Riddle      func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
473a54f4eaeSMogball        %c0 = arith.constant 0 : index
474a54f4eaeSMogball        %cst42 = arith.constant 42.0 : f32
475c8b8e8e0SUday Bondhugula        memref.store %cst42, %arg0[%c0] : memref<1xf32>
476c8b8e8e0SUday Bondhugula        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
477d4555698SStella Stamenova        call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
478c8b8e8e0SUday Bondhugula        return
479c8b8e8e0SUday Bondhugula      }
480d4555698SStella Stamenova      func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
481a54f4eaeSMogball     } """)
482c8b8e8e0SUday Bondhugula    arg0 = np.array([0.0]).astype(np.float32)
483c8b8e8e0SUday Bondhugula
484a54f4eaeSMogball    arg0_memref_ptr = ctypes.pointer(
485a54f4eaeSMogball        ctypes.pointer(get_ranked_memref_descriptor(arg0)))
486c8b8e8e0SUday Bondhugula
487057863a9SStella Stamenova    if sys.platform == 'win32':
488057863a9SStella Stamenova      shared_libs = [
489057863a9SStella Stamenova          "../../../../bin/mlir_runner_utils.dll",
490057863a9SStella Stamenova          "../../../../bin/mlir_c_runner_utils.dll"
491057863a9SStella Stamenova      ]
492f9676d2dSAnush Elangovan    elif sys.platform == 'darwin':
493f9676d2dSAnush Elangovan      shared_libs = [
494f9676d2dSAnush Elangovan          "../../../../lib/libmlir_runner_utils.dylib",
495f9676d2dSAnush Elangovan          "../../../../lib/libmlir_c_runner_utils.dylib"
496f9676d2dSAnush Elangovan      ]
497057863a9SStella Stamenova    else:
498a54f4eaeSMogball      shared_libs = [
499a54f4eaeSMogball          "../../../../lib/libmlir_runner_utils.so",
500a54f4eaeSMogball          "../../../../lib/libmlir_c_runner_utils.so"
501057863a9SStella Stamenova      ]
502057863a9SStella Stamenova
503057863a9SStella Stamenova    execution_engine = ExecutionEngine(
504057863a9SStella Stamenova        lowerToLLVM(module),
505057863a9SStella Stamenova        opt_level=3,
506057863a9SStella Stamenova        shared_libs=shared_libs)
507c8b8e8e0SUday Bondhugula    execution_engine.invoke("main", arg0_memref_ptr)
508c8b8e8e0SUday Bondhugula    # CHECK: Unranked Memref
509c8b8e8e0SUday Bondhugula    # CHECK-NEXT: [42]
510c8b8e8e0SUday Bondhugula
511a54f4eaeSMogball
512c8b8e8e0SUday Bondhugularun(testSharedLibLoad)
513aaea92e1SDenys Shabalin
514aaea92e1SDenys Shabalin
515aaea92e1SDenys Shabalin#  Test that nano time clock is available.
516aaea92e1SDenys Shabalin# CHECK-LABEL: TEST: testNanoTime
517aaea92e1SDenys Shabalindef testNanoTime():
518aaea92e1SDenys Shabalin  with Context():
519aaea92e1SDenys Shabalin    module = Module.parse("""
520aaea92e1SDenys Shabalin      module {
5212310ced8SRiver Riddle      func.func @main() attributes { llvm.emit_c_interface } {
522d4555698SStella Stamenova        %now = call @nanoTime() : () -> i64
523aaea92e1SDenys Shabalin        %memref = memref.alloca() : memref<1xi64>
524aaea92e1SDenys Shabalin        %c0 = arith.constant 0 : index
525aaea92e1SDenys Shabalin        memref.store %now, %memref[%c0] : memref<1xi64>
526aaea92e1SDenys Shabalin        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
527d4555698SStella Stamenova        call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
528aaea92e1SDenys Shabalin        return
529aaea92e1SDenys Shabalin      }
530d4555698SStella Stamenova      func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
531d4555698SStella Stamenova      func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
532aaea92e1SDenys Shabalin    }""")
533aaea92e1SDenys Shabalin
534057863a9SStella Stamenova    if sys.platform == 'win32':
535057863a9SStella Stamenova      shared_libs = [
536057863a9SStella Stamenova          "../../../../bin/mlir_runner_utils.dll",
537057863a9SStella Stamenova          "../../../../bin/mlir_c_runner_utils.dll"
538057863a9SStella Stamenova      ]
539057863a9SStella Stamenova    else:
540aaea92e1SDenys Shabalin      shared_libs = [
541aaea92e1SDenys Shabalin          "../../../../lib/libmlir_runner_utils.so",
542aaea92e1SDenys Shabalin          "../../../../lib/libmlir_c_runner_utils.so"
543057863a9SStella Stamenova      ]
544057863a9SStella Stamenova
545057863a9SStella Stamenova    execution_engine = ExecutionEngine(
546057863a9SStella Stamenova        lowerToLLVM(module),
547057863a9SStella Stamenova        opt_level=3,
548057863a9SStella Stamenova        shared_libs=shared_libs)
549aaea92e1SDenys Shabalin    execution_engine.invoke("main")
550aaea92e1SDenys Shabalin    # CHECK: Unranked Memref
551aaea92e1SDenys Shabalin    # CHECK: [{{.*}}]
552aaea92e1SDenys Shabalin
553aaea92e1SDenys Shabalin
554aaea92e1SDenys Shabalinrun(testNanoTime)
555