19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s 2>&1 | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc, sys
49f3f6d7bSStella Laurenzofrom mlir.ir import *
59f3f6d7bSStella Laurenzofrom mlir.passmanager import *
69f3f6d7bSStella Laurenzofrom mlir.execution_engine import *
79f3f6d7bSStella Laurenzofrom mlir.runtime import *
89f3f6d7bSStella Laurenzo
99f3f6d7bSStella Laurenzo# Log everything to stderr and flush so that we have a unified stream to match
109f3f6d7bSStella Laurenzo# errors/info emitted by MLIR to stderr.
119f3f6d7bSStella Laurenzodef log(*args):
129f3f6d7bSStella Laurenzo  print(*args, file=sys.stderr)
139f3f6d7bSStella Laurenzo  sys.stderr.flush()
149f3f6d7bSStella Laurenzo
159f3f6d7bSStella Laurenzodef run(f):
169f3f6d7bSStella Laurenzo  log("\nTEST:", f.__name__)
179f3f6d7bSStella Laurenzo  f()
189f3f6d7bSStella Laurenzo  gc.collect()
199f3f6d7bSStella Laurenzo  assert Context._get_live_count() == 0
209f3f6d7bSStella Laurenzo
219f3f6d7bSStella Laurenzo# Verify capsule interop.
229f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCapsule
239f3f6d7bSStella Laurenzodef testCapsule():
249f3f6d7bSStella Laurenzo  with Context():
259f3f6d7bSStella Laurenzo    module = Module.parse(r"""
269f3f6d7bSStella Laurenzollvm.func @none() {
279f3f6d7bSStella Laurenzo  llvm.return
289f3f6d7bSStella Laurenzo}
299f3f6d7bSStella Laurenzo    """)
309f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(module)
319f3f6d7bSStella Laurenzo    execution_engine_capsule = execution_engine._CAPIPtr
329f3f6d7bSStella Laurenzo    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
339f3f6d7bSStella Laurenzo    log(repr(execution_engine_capsule))
349f3f6d7bSStella Laurenzo    execution_engine._testing_release()
359f3f6d7bSStella Laurenzo    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
36*0cdf4915SStella Laurenzo    # CHECK: _mlirExecutionEngine.ExecutionEngine
379f3f6d7bSStella Laurenzo    log(repr(execution_engine1))
389f3f6d7bSStella Laurenzo
399f3f6d7bSStella Laurenzorun(testCapsule)
409f3f6d7bSStella Laurenzo
419f3f6d7bSStella Laurenzo# Test invalid ExecutionEngine creation
429f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModule
439f3f6d7bSStella Laurenzodef testInvalidModule():
449f3f6d7bSStella Laurenzo  with Context():
459f3f6d7bSStella Laurenzo    # Builtin function
469f3f6d7bSStella Laurenzo    module = Module.parse(r"""
479f3f6d7bSStella Laurenzo    func @foo() { return }
489f3f6d7bSStella Laurenzo    """)
499f3f6d7bSStella Laurenzo    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
509f3f6d7bSStella Laurenzo    try:
519f3f6d7bSStella Laurenzo      execution_engine = ExecutionEngine(module)
529f3f6d7bSStella Laurenzo    except RuntimeError as e:
539f3f6d7bSStella Laurenzo      log("Got RuntimeError: ", e)
549f3f6d7bSStella Laurenzo
559f3f6d7bSStella Laurenzorun(testInvalidModule)
569f3f6d7bSStella Laurenzo
579f3f6d7bSStella Laurenzodef lowerToLLVM(module):
589f3f6d7bSStella Laurenzo  import mlir.conversions
5975e5f0aaSAlex Zinenko  pm = PassManager.parse("convert-memref-to-llvm,convert-std-to-llvm")
609f3f6d7bSStella Laurenzo  pm.run(module)
619f3f6d7bSStella Laurenzo  return module
629f3f6d7bSStella Laurenzo
639f3f6d7bSStella Laurenzo# Test simple ExecutionEngine execution
649f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeVoid
659f3f6d7bSStella Laurenzodef testInvokeVoid():
669f3f6d7bSStella Laurenzo  with Context():
679f3f6d7bSStella Laurenzo    module = Module.parse(r"""
689f3f6d7bSStella Laurenzofunc @void() attributes { llvm.emit_c_interface } {
699f3f6d7bSStella Laurenzo  return
709f3f6d7bSStella Laurenzo}
719f3f6d7bSStella Laurenzo    """)
729f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
739f3f6d7bSStella Laurenzo    # Nothing to check other than no exception thrown here.
749f3f6d7bSStella Laurenzo    execution_engine.invoke("void")
759f3f6d7bSStella Laurenzo
769f3f6d7bSStella Laurenzorun(testInvokeVoid)
779f3f6d7bSStella Laurenzo
789f3f6d7bSStella Laurenzo
799f3f6d7bSStella Laurenzo# Test argument passing and result with a simple float addition.
809f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeFloatAdd
819f3f6d7bSStella Laurenzodef testInvokeFloatAdd():
829f3f6d7bSStella Laurenzo  with Context():
839f3f6d7bSStella Laurenzo    module = Module.parse(r"""
849f3f6d7bSStella Laurenzofunc @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
859f3f6d7bSStella Laurenzo  %add = std.addf %arg0, %arg1 : f32
869f3f6d7bSStella Laurenzo  return %add : f32
879f3f6d7bSStella Laurenzo}
889f3f6d7bSStella Laurenzo    """)
899f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
909f3f6d7bSStella Laurenzo    # Prepare arguments: two input floats and one result.
919f3f6d7bSStella Laurenzo    # Arguments must be passed as pointers.
929f3f6d7bSStella Laurenzo    c_float_p = ctypes.c_float * 1
939f3f6d7bSStella Laurenzo    arg0 = c_float_p(42.)
949f3f6d7bSStella Laurenzo    arg1 = c_float_p(2.)
959f3f6d7bSStella Laurenzo    res = c_float_p(-1.)
969f3f6d7bSStella Laurenzo    execution_engine.invoke("add", arg0, arg1, res)
979f3f6d7bSStella Laurenzo    # CHECK: 42.0 + 2.0 = 44.0
989f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
999f3f6d7bSStella Laurenzo
1009f3f6d7bSStella Laurenzorun(testInvokeFloatAdd)
1019f3f6d7bSStella Laurenzo
1029f3f6d7bSStella Laurenzo
1039f3f6d7bSStella Laurenzo# Test callback
1049f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBasicCallback
1059f3f6d7bSStella Laurenzodef testBasicCallback():
1069f3f6d7bSStella Laurenzo  # Define a callback function that takes a float and an integer and returns a float.
1079f3f6d7bSStella Laurenzo  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
1089f3f6d7bSStella Laurenzo  def callback(a, b):
1099f3f6d7bSStella Laurenzo    return a/2 + b/2
1109f3f6d7bSStella Laurenzo
1119f3f6d7bSStella Laurenzo  with Context():
1129f3f6d7bSStella Laurenzo    # The module just forwards to a runtime function known as "some_callback_into_python".
1139f3f6d7bSStella Laurenzo    module = Module.parse(r"""
1149f3f6d7bSStella Laurenzofunc @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
1159f3f6d7bSStella Laurenzo  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
1169f3f6d7bSStella Laurenzo  return %resf : f32
1179f3f6d7bSStella Laurenzo}
1189f3f6d7bSStella Laurenzofunc private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
1199f3f6d7bSStella Laurenzo    """)
1209f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
1219f3f6d7bSStella Laurenzo    execution_engine.register_runtime("some_callback_into_python", callback)
1229f3f6d7bSStella Laurenzo
1239f3f6d7bSStella Laurenzo    # Prepare arguments: two input floats and one result.
1249f3f6d7bSStella Laurenzo    # Arguments must be passed as pointers.
1259f3f6d7bSStella Laurenzo    c_float_p = ctypes.c_float * 1
1269f3f6d7bSStella Laurenzo    c_int_p = ctypes.c_int * 1
1279f3f6d7bSStella Laurenzo    arg0 = c_float_p(42.)
1289f3f6d7bSStella Laurenzo    arg1 = c_int_p(2)
1299f3f6d7bSStella Laurenzo    res = c_float_p(-1.)
1309f3f6d7bSStella Laurenzo    execution_engine.invoke("add", arg0, arg1, res)
1319f3f6d7bSStella Laurenzo    # CHECK: 42.0 + 2 = 44.0
1329f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
1339f3f6d7bSStella Laurenzo
1349f3f6d7bSStella Laurenzorun(testBasicCallback)
1359f3f6d7bSStella Laurenzo
1369f3f6d7bSStella Laurenzo# Test callback with an unranked memref
1379f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUnrankedMemRefCallback
1389f3f6d7bSStella Laurenzodef testUnrankedMemRefCallback():
1399f3f6d7bSStella Laurenzo    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
1409f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
1419f3f6d7bSStella Laurenzo    def callback(a):
1429f3f6d7bSStella Laurenzo        arr = unranked_memref_to_numpy(a, np.float32)
1439f3f6d7bSStella Laurenzo        log("Inside callback: ")
1449f3f6d7bSStella Laurenzo        log(arr)
1459f3f6d7bSStella Laurenzo
1469f3f6d7bSStella Laurenzo    with Context():
1479f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
1489f3f6d7bSStella Laurenzo        module = Module.parse(
1499f3f6d7bSStella Laurenzo            r"""
1509f3f6d7bSStella Laurenzofunc @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
1519f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
1529f3f6d7bSStella Laurenzo  return
1539f3f6d7bSStella Laurenzo}
1549f3f6d7bSStella Laurenzofunc private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
1559f3f6d7bSStella Laurenzo"""
1569f3f6d7bSStella Laurenzo        )
1579f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
1589f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
1599f3f6d7bSStella Laurenzo        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
1609f3f6d7bSStella Laurenzo        # CHECK: Inside callback:
1619f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[1. 2.]
1629f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [3. 4.]]
1639f3f6d7bSStella Laurenzo        execution_engine.invoke(
1649f3f6d7bSStella Laurenzo            "callback_memref",
1659f3f6d7bSStella Laurenzo            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
1669f3f6d7bSStella Laurenzo        )
1679f3f6d7bSStella Laurenzo        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
1689f3f6d7bSStella Laurenzo        strided_arr = np.lib.stride_tricks.as_strided(
1699f3f6d7bSStella Laurenzo            inp_arr_1, strides=(4, 0), shape=(3, 4)
1709f3f6d7bSStella Laurenzo        )
1719f3f6d7bSStella Laurenzo        # CHECK: Inside callback:
1729f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[5. 5. 5. 5.]
1739f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [6. 6. 6. 6.]
1749f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
1759f3f6d7bSStella Laurenzo        execution_engine.invoke(
1769f3f6d7bSStella Laurenzo            "callback_memref",
1779f3f6d7bSStella Laurenzo            ctypes.pointer(
1789f3f6d7bSStella Laurenzo                ctypes.pointer(get_unranked_memref_descriptor(strided_arr))
1799f3f6d7bSStella Laurenzo            ),
1809f3f6d7bSStella Laurenzo        )
1819f3f6d7bSStella Laurenzo
1829f3f6d7bSStella Laurenzorun(testUnrankedMemRefCallback)
1839f3f6d7bSStella Laurenzo
1849f3f6d7bSStella Laurenzo# Test callback with a ranked memref.
1859f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRankedMemRefCallback
1869f3f6d7bSStella Laurenzodef testRankedMemRefCallback():
1879f3f6d7bSStella Laurenzo    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
1889f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(
1899f3f6d7bSStella Laurenzo        None,
1909f3f6d7bSStella Laurenzo        ctypes.POINTER(
1919f3f6d7bSStella Laurenzo            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
1929f3f6d7bSStella Laurenzo        ),
1939f3f6d7bSStella Laurenzo    )
1949f3f6d7bSStella Laurenzo    def callback(a):
1959f3f6d7bSStella Laurenzo        arr = ranked_memref_to_numpy(a)
1969f3f6d7bSStella Laurenzo        log("Inside Callback: ")
1979f3f6d7bSStella Laurenzo        log(arr)
1989f3f6d7bSStella Laurenzo
1999f3f6d7bSStella Laurenzo    with Context():
2009f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
2019f3f6d7bSStella Laurenzo        module = Module.parse(
2029f3f6d7bSStella Laurenzo            r"""
2039f3f6d7bSStella Laurenzofunc @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
2049f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
2059f3f6d7bSStella Laurenzo  return
2069f3f6d7bSStella Laurenzo}
2079f3f6d7bSStella Laurenzofunc private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
2089f3f6d7bSStella Laurenzo"""
2099f3f6d7bSStella Laurenzo        )
2109f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
2119f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
2129f3f6d7bSStella Laurenzo        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
2139f3f6d7bSStella Laurenzo        # CHECK: Inside Callback:
2149f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[1. 5.]
2159f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [6. 7.]]
2169f3f6d7bSStella Laurenzo        execution_engine.invoke(
2179f3f6d7bSStella Laurenzo            "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))
2189f3f6d7bSStella Laurenzo        )
2199f3f6d7bSStella Laurenzo
2209f3f6d7bSStella Laurenzorun(testRankedMemRefCallback)
2219f3f6d7bSStella Laurenzo
222c8b8e8e0SUday Bondhugula#  Test addition of two memrefs.
2239f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testMemrefAdd
2249f3f6d7bSStella Laurenzodef testMemrefAdd():
2259f3f6d7bSStella Laurenzo    with Context():
2269f3f6d7bSStella Laurenzo        module = Module.parse(
2279f3f6d7bSStella Laurenzo            """
2289f3f6d7bSStella Laurenzo      module  {
2299f3f6d7bSStella Laurenzo      func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
2309f3f6d7bSStella Laurenzo        %0 = constant 0 : index
2319f3f6d7bSStella Laurenzo        %1 = memref.load %arg0[%0] : memref<1xf32>
2329f3f6d7bSStella Laurenzo        %2 = memref.load %arg1[] : memref<f32>
2339f3f6d7bSStella Laurenzo        %3 = addf %1, %2 : f32
2349f3f6d7bSStella Laurenzo        memref.store %3, %arg2[%0] : memref<1xf32>
2359f3f6d7bSStella Laurenzo        return
2369f3f6d7bSStella Laurenzo      }
2379f3f6d7bSStella Laurenzo     } """
2389f3f6d7bSStella Laurenzo        )
2399f3f6d7bSStella Laurenzo        arg1 = np.array([32.5]).astype(np.float32)
2409f3f6d7bSStella Laurenzo        arg2 = np.array(6).astype(np.float32)
2419f3f6d7bSStella Laurenzo        res = np.array([0]).astype(np.float32)
2429f3f6d7bSStella Laurenzo
2439f3f6d7bSStella Laurenzo        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
2449f3f6d7bSStella Laurenzo        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
2459f3f6d7bSStella Laurenzo        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
2469f3f6d7bSStella Laurenzo
2479f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
2489f3f6d7bSStella Laurenzo        execution_engine.invoke(
2499f3f6d7bSStella Laurenzo            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
2509f3f6d7bSStella Laurenzo        )
2519f3f6d7bSStella Laurenzo        # CHECK: [32.5] + 6.0 = [38.5]
2529f3f6d7bSStella Laurenzo        log("{0} + {1} = {2}".format(arg1, arg2, res))
2539f3f6d7bSStella Laurenzo
2549f3f6d7bSStella Laurenzorun(testMemrefAdd)
2559f3f6d7bSStella Laurenzo
2569f3f6d7bSStella Laurenzo#  Test addition of two 2d_memref
2579f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
2589f3f6d7bSStella Laurenzodef testDynamicMemrefAdd2D():
2599f3f6d7bSStella Laurenzo    with Context():
2609f3f6d7bSStella Laurenzo        module = Module.parse(
2619f3f6d7bSStella Laurenzo      """
2629f3f6d7bSStella Laurenzo      module  {
2639f3f6d7bSStella Laurenzo        func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
2649f3f6d7bSStella Laurenzo          %c0 = constant 0 : index
2659f3f6d7bSStella Laurenzo          %c2 = constant 2 : index
2669f3f6d7bSStella Laurenzo          %c1 = constant 1 : index
2679f3f6d7bSStella Laurenzo          br ^bb1(%c0 : index)
2689f3f6d7bSStella Laurenzo        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
2699f3f6d7bSStella Laurenzo          %1 = cmpi slt, %0, %c2 : index
2709f3f6d7bSStella Laurenzo          cond_br %1, ^bb2, ^bb6
2719f3f6d7bSStella Laurenzo        ^bb2:  // pred: ^bb1
2729f3f6d7bSStella Laurenzo          %c0_0 = constant 0 : index
2739f3f6d7bSStella Laurenzo          %c2_1 = constant 2 : index
2749f3f6d7bSStella Laurenzo          %c1_2 = constant 1 : index
2759f3f6d7bSStella Laurenzo          br ^bb3(%c0_0 : index)
2769f3f6d7bSStella Laurenzo        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
2779f3f6d7bSStella Laurenzo          %3 = cmpi slt, %2, %c2_1 : index
2789f3f6d7bSStella Laurenzo          cond_br %3, ^bb4, ^bb5
2799f3f6d7bSStella Laurenzo        ^bb4:  // pred: ^bb3
2809f3f6d7bSStella Laurenzo          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
2819f3f6d7bSStella Laurenzo          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
2829f3f6d7bSStella Laurenzo          %6 = addf %4, %5 : f32
2839f3f6d7bSStella Laurenzo          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
2849f3f6d7bSStella Laurenzo          %7 = addi %2, %c1_2 : index
2859f3f6d7bSStella Laurenzo          br ^bb3(%7 : index)
2869f3f6d7bSStella Laurenzo        ^bb5:  // pred: ^bb3
2879f3f6d7bSStella Laurenzo          %8 = addi %0, %c1 : index
2889f3f6d7bSStella Laurenzo          br ^bb1(%8 : index)
2899f3f6d7bSStella Laurenzo        ^bb6:  // pred: ^bb1
2909f3f6d7bSStella Laurenzo          return
2919f3f6d7bSStella Laurenzo        }
2929f3f6d7bSStella Laurenzo      }
2939f3f6d7bSStella Laurenzo        """
2949f3f6d7bSStella Laurenzo        )
2959f3f6d7bSStella Laurenzo        arg1 = np.random.randn(2,2).astype(np.float32)
2969f3f6d7bSStella Laurenzo        arg2 = np.random.randn(2,2).astype(np.float32)
2979f3f6d7bSStella Laurenzo        res = np.random.randn(2,2).astype(np.float32)
2989f3f6d7bSStella Laurenzo
2999f3f6d7bSStella Laurenzo        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
3009f3f6d7bSStella Laurenzo        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
3019f3f6d7bSStella Laurenzo        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
3029f3f6d7bSStella Laurenzo
3039f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
3049f3f6d7bSStella Laurenzo        execution_engine.invoke(
3059f3f6d7bSStella Laurenzo            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
3069f3f6d7bSStella Laurenzo        )
3079f3f6d7bSStella Laurenzo        # CHECK: True
3089f3f6d7bSStella Laurenzo        log(np.allclose(arg1+arg2, res))
3099f3f6d7bSStella Laurenzo
3109f3f6d7bSStella Laurenzorun(testDynamicMemrefAdd2D)
311c8b8e8e0SUday Bondhugula
312c8b8e8e0SUday Bondhugula#  Test loading of shared libraries.
313c8b8e8e0SUday Bondhugula# CHECK-LABEL: TEST: testSharedLibLoad
314c8b8e8e0SUday Bondhuguladef testSharedLibLoad():
315c8b8e8e0SUday Bondhugula    with Context():
316c8b8e8e0SUday Bondhugula        module = Module.parse(
317c8b8e8e0SUday Bondhugula            """
318c8b8e8e0SUday Bondhugula      module  {
319c8b8e8e0SUday Bondhugula      func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
320c8b8e8e0SUday Bondhugula        %c0 = constant 0 : index
321c8b8e8e0SUday Bondhugula        %cst42 = constant 42.0 : f32
322c8b8e8e0SUday Bondhugula        memref.store %cst42, %arg0[%c0] : memref<1xf32>
323c8b8e8e0SUday Bondhugula        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
324c8b8e8e0SUday Bondhugula        call @print_memref_f32(%u_memref) : (memref<*xf32>) -> ()
325c8b8e8e0SUday Bondhugula        return
326c8b8e8e0SUday Bondhugula      }
327c8b8e8e0SUday Bondhugula      func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
328c8b8e8e0SUday Bondhugula     } """
329c8b8e8e0SUday Bondhugula        )
330c8b8e8e0SUday Bondhugula        arg0 = np.array([0.0]).astype(np.float32)
331c8b8e8e0SUday Bondhugula
332c8b8e8e0SUday Bondhugula        arg0_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg0)))
333c8b8e8e0SUday Bondhugula
334c8b8e8e0SUday Bondhugula        execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3,
335c8b8e8e0SUday Bondhugula                shared_libs=["../../../../lib/libmlir_runner_utils.so",
336c8b8e8e0SUday Bondhugula                    "../../../../lib/libmlir_c_runner_utils.so"])
337c8b8e8e0SUday Bondhugula        execution_engine.invoke("main", arg0_memref_ptr)
338c8b8e8e0SUday Bondhugula        # CHECK: Unranked Memref
339c8b8e8e0SUday Bondhugula        # CHECK-NEXT: [42]
340c8b8e8e0SUday Bondhugula
341c8b8e8e0SUday Bondhugularun(testSharedLibLoad)
342