1*9f3f6d7bSStella Laurenzo# RUN: %PYTHON %s 2>&1 | FileCheck %s
2*9f3f6d7bSStella Laurenzo
3*9f3f6d7bSStella Laurenzoimport gc, sys
4*9f3f6d7bSStella Laurenzofrom mlir.ir import *
5*9f3f6d7bSStella Laurenzofrom mlir.passmanager import *
6*9f3f6d7bSStella Laurenzofrom mlir.execution_engine import *
7*9f3f6d7bSStella Laurenzofrom mlir.runtime import *
8*9f3f6d7bSStella Laurenzo
9*9f3f6d7bSStella Laurenzo# Log everything to stderr and flush so that we have a unified stream to match
10*9f3f6d7bSStella Laurenzo# errors/info emitted by MLIR to stderr.
11*9f3f6d7bSStella Laurenzodef log(*args):
12*9f3f6d7bSStella Laurenzo  print(*args, file=sys.stderr)
13*9f3f6d7bSStella Laurenzo  sys.stderr.flush()
14*9f3f6d7bSStella Laurenzo
15*9f3f6d7bSStella Laurenzodef run(f):
16*9f3f6d7bSStella Laurenzo  log("\nTEST:", f.__name__)
17*9f3f6d7bSStella Laurenzo  f()
18*9f3f6d7bSStella Laurenzo  gc.collect()
19*9f3f6d7bSStella Laurenzo  assert Context._get_live_count() == 0
20*9f3f6d7bSStella Laurenzo
21*9f3f6d7bSStella Laurenzo# Verify capsule interop.
22*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCapsule
23*9f3f6d7bSStella Laurenzodef testCapsule():
24*9f3f6d7bSStella Laurenzo  with Context():
25*9f3f6d7bSStella Laurenzo    module = Module.parse(r"""
26*9f3f6d7bSStella Laurenzollvm.func @none() {
27*9f3f6d7bSStella Laurenzo  llvm.return
28*9f3f6d7bSStella Laurenzo}
29*9f3f6d7bSStella Laurenzo    """)
30*9f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(module)
31*9f3f6d7bSStella Laurenzo    execution_engine_capsule = execution_engine._CAPIPtr
32*9f3f6d7bSStella Laurenzo    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
33*9f3f6d7bSStella Laurenzo    log(repr(execution_engine_capsule))
34*9f3f6d7bSStella Laurenzo    execution_engine._testing_release()
35*9f3f6d7bSStella Laurenzo    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
36*9f3f6d7bSStella Laurenzo    # CHECK: _mlir.execution_engine.ExecutionEngine
37*9f3f6d7bSStella Laurenzo    log(repr(execution_engine1))
38*9f3f6d7bSStella Laurenzo
39*9f3f6d7bSStella Laurenzorun(testCapsule)
40*9f3f6d7bSStella Laurenzo
41*9f3f6d7bSStella Laurenzo# Test invalid ExecutionEngine creation
42*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModule
43*9f3f6d7bSStella Laurenzodef testInvalidModule():
44*9f3f6d7bSStella Laurenzo  with Context():
45*9f3f6d7bSStella Laurenzo    # Builtin function
46*9f3f6d7bSStella Laurenzo    module = Module.parse(r"""
47*9f3f6d7bSStella Laurenzo    func @foo() { return }
48*9f3f6d7bSStella Laurenzo    """)
49*9f3f6d7bSStella Laurenzo    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
50*9f3f6d7bSStella Laurenzo    try:
51*9f3f6d7bSStella Laurenzo      execution_engine = ExecutionEngine(module)
52*9f3f6d7bSStella Laurenzo    except RuntimeError as e:
53*9f3f6d7bSStella Laurenzo      log("Got RuntimeError: ", e)
54*9f3f6d7bSStella Laurenzo
55*9f3f6d7bSStella Laurenzorun(testInvalidModule)
56*9f3f6d7bSStella Laurenzo
57*9f3f6d7bSStella Laurenzodef lowerToLLVM(module):
58*9f3f6d7bSStella Laurenzo  import mlir.conversions
59*9f3f6d7bSStella Laurenzo  pm = PassManager.parse("convert-std-to-llvm")
60*9f3f6d7bSStella Laurenzo  pm.run(module)
61*9f3f6d7bSStella Laurenzo  return module
62*9f3f6d7bSStella Laurenzo
63*9f3f6d7bSStella Laurenzo# Test simple ExecutionEngine execution
64*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeVoid
65*9f3f6d7bSStella Laurenzodef testInvokeVoid():
66*9f3f6d7bSStella Laurenzo  with Context():
67*9f3f6d7bSStella Laurenzo    module = Module.parse(r"""
68*9f3f6d7bSStella Laurenzofunc @void() attributes { llvm.emit_c_interface } {
69*9f3f6d7bSStella Laurenzo  return
70*9f3f6d7bSStella Laurenzo}
71*9f3f6d7bSStella Laurenzo    """)
72*9f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
73*9f3f6d7bSStella Laurenzo    # Nothing to check other than no exception thrown here.
74*9f3f6d7bSStella Laurenzo    execution_engine.invoke("void")
75*9f3f6d7bSStella Laurenzo
76*9f3f6d7bSStella Laurenzorun(testInvokeVoid)
77*9f3f6d7bSStella Laurenzo
78*9f3f6d7bSStella Laurenzo
79*9f3f6d7bSStella Laurenzo# Test argument passing and result with a simple float addition.
80*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeFloatAdd
81*9f3f6d7bSStella Laurenzodef testInvokeFloatAdd():
82*9f3f6d7bSStella Laurenzo  with Context():
83*9f3f6d7bSStella Laurenzo    module = Module.parse(r"""
84*9f3f6d7bSStella Laurenzofunc @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
85*9f3f6d7bSStella Laurenzo  %add = std.addf %arg0, %arg1 : f32
86*9f3f6d7bSStella Laurenzo  return %add : f32
87*9f3f6d7bSStella Laurenzo}
88*9f3f6d7bSStella Laurenzo    """)
89*9f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
90*9f3f6d7bSStella Laurenzo    # Prepare arguments: two input floats and one result.
91*9f3f6d7bSStella Laurenzo    # Arguments must be passed as pointers.
92*9f3f6d7bSStella Laurenzo    c_float_p = ctypes.c_float * 1
93*9f3f6d7bSStella Laurenzo    arg0 = c_float_p(42.)
94*9f3f6d7bSStella Laurenzo    arg1 = c_float_p(2.)
95*9f3f6d7bSStella Laurenzo    res = c_float_p(-1.)
96*9f3f6d7bSStella Laurenzo    execution_engine.invoke("add", arg0, arg1, res)
97*9f3f6d7bSStella Laurenzo    # CHECK: 42.0 + 2.0 = 44.0
98*9f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
99*9f3f6d7bSStella Laurenzo
100*9f3f6d7bSStella Laurenzorun(testInvokeFloatAdd)
101*9f3f6d7bSStella Laurenzo
102*9f3f6d7bSStella Laurenzo
103*9f3f6d7bSStella Laurenzo# Test callback
104*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBasicCallback
105*9f3f6d7bSStella Laurenzodef testBasicCallback():
106*9f3f6d7bSStella Laurenzo  # Define a callback function that takes a float and an integer and returns a float.
107*9f3f6d7bSStella Laurenzo  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
108*9f3f6d7bSStella Laurenzo  def callback(a, b):
109*9f3f6d7bSStella Laurenzo    return a/2 + b/2
110*9f3f6d7bSStella Laurenzo
111*9f3f6d7bSStella Laurenzo  with Context():
112*9f3f6d7bSStella Laurenzo    # The module just forwards to a runtime function known as "some_callback_into_python".
113*9f3f6d7bSStella Laurenzo    module = Module.parse(r"""
114*9f3f6d7bSStella Laurenzofunc @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
115*9f3f6d7bSStella Laurenzo  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
116*9f3f6d7bSStella Laurenzo  return %resf : f32
117*9f3f6d7bSStella Laurenzo}
118*9f3f6d7bSStella Laurenzofunc private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
119*9f3f6d7bSStella Laurenzo    """)
120*9f3f6d7bSStella Laurenzo    execution_engine = ExecutionEngine(lowerToLLVM(module))
121*9f3f6d7bSStella Laurenzo    execution_engine.register_runtime("some_callback_into_python", callback)
122*9f3f6d7bSStella Laurenzo
123*9f3f6d7bSStella Laurenzo    # Prepare arguments: two input floats and one result.
124*9f3f6d7bSStella Laurenzo    # Arguments must be passed as pointers.
125*9f3f6d7bSStella Laurenzo    c_float_p = ctypes.c_float * 1
126*9f3f6d7bSStella Laurenzo    c_int_p = ctypes.c_int * 1
127*9f3f6d7bSStella Laurenzo    arg0 = c_float_p(42.)
128*9f3f6d7bSStella Laurenzo    arg1 = c_int_p(2)
129*9f3f6d7bSStella Laurenzo    res = c_float_p(-1.)
130*9f3f6d7bSStella Laurenzo    execution_engine.invoke("add", arg0, arg1, res)
131*9f3f6d7bSStella Laurenzo    # CHECK: 42.0 + 2 = 44.0
132*9f3f6d7bSStella Laurenzo    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
133*9f3f6d7bSStella Laurenzo
134*9f3f6d7bSStella Laurenzorun(testBasicCallback)
135*9f3f6d7bSStella Laurenzo
136*9f3f6d7bSStella Laurenzo# Test callback with an unranked memref
137*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUnrankedMemRefCallback
138*9f3f6d7bSStella Laurenzodef testUnrankedMemRefCallback():
139*9f3f6d7bSStella Laurenzo    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
140*9f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
141*9f3f6d7bSStella Laurenzo    def callback(a):
142*9f3f6d7bSStella Laurenzo        arr = unranked_memref_to_numpy(a, np.float32)
143*9f3f6d7bSStella Laurenzo        log("Inside callback: ")
144*9f3f6d7bSStella Laurenzo        log(arr)
145*9f3f6d7bSStella Laurenzo
146*9f3f6d7bSStella Laurenzo    with Context():
147*9f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
148*9f3f6d7bSStella Laurenzo        module = Module.parse(
149*9f3f6d7bSStella Laurenzo            r"""
150*9f3f6d7bSStella Laurenzofunc @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
151*9f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
152*9f3f6d7bSStella Laurenzo  return
153*9f3f6d7bSStella Laurenzo}
154*9f3f6d7bSStella Laurenzofunc private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
155*9f3f6d7bSStella Laurenzo"""
156*9f3f6d7bSStella Laurenzo        )
157*9f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
158*9f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
159*9f3f6d7bSStella Laurenzo        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
160*9f3f6d7bSStella Laurenzo        # CHECK: Inside callback:
161*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[1. 2.]
162*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [3. 4.]]
163*9f3f6d7bSStella Laurenzo        execution_engine.invoke(
164*9f3f6d7bSStella Laurenzo            "callback_memref",
165*9f3f6d7bSStella Laurenzo            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
166*9f3f6d7bSStella Laurenzo        )
167*9f3f6d7bSStella Laurenzo        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
168*9f3f6d7bSStella Laurenzo        strided_arr = np.lib.stride_tricks.as_strided(
169*9f3f6d7bSStella Laurenzo            inp_arr_1, strides=(4, 0), shape=(3, 4)
170*9f3f6d7bSStella Laurenzo        )
171*9f3f6d7bSStella Laurenzo        # CHECK: Inside callback:
172*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[5. 5. 5. 5.]
173*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [6. 6. 6. 6.]
174*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
175*9f3f6d7bSStella Laurenzo        execution_engine.invoke(
176*9f3f6d7bSStella Laurenzo            "callback_memref",
177*9f3f6d7bSStella Laurenzo            ctypes.pointer(
178*9f3f6d7bSStella Laurenzo                ctypes.pointer(get_unranked_memref_descriptor(strided_arr))
179*9f3f6d7bSStella Laurenzo            ),
180*9f3f6d7bSStella Laurenzo        )
181*9f3f6d7bSStella Laurenzo
182*9f3f6d7bSStella Laurenzorun(testUnrankedMemRefCallback)
183*9f3f6d7bSStella Laurenzo
184*9f3f6d7bSStella Laurenzo# Test callback with a ranked memref.
185*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRankedMemRefCallback
186*9f3f6d7bSStella Laurenzodef testRankedMemRefCallback():
187*9f3f6d7bSStella Laurenzo    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
188*9f3f6d7bSStella Laurenzo    @ctypes.CFUNCTYPE(
189*9f3f6d7bSStella Laurenzo        None,
190*9f3f6d7bSStella Laurenzo        ctypes.POINTER(
191*9f3f6d7bSStella Laurenzo            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
192*9f3f6d7bSStella Laurenzo        ),
193*9f3f6d7bSStella Laurenzo    )
194*9f3f6d7bSStella Laurenzo    def callback(a):
195*9f3f6d7bSStella Laurenzo        arr = ranked_memref_to_numpy(a)
196*9f3f6d7bSStella Laurenzo        log("Inside Callback: ")
197*9f3f6d7bSStella Laurenzo        log(arr)
198*9f3f6d7bSStella Laurenzo
199*9f3f6d7bSStella Laurenzo    with Context():
200*9f3f6d7bSStella Laurenzo        # The module just forwards to a runtime function known as "some_callback_into_python".
201*9f3f6d7bSStella Laurenzo        module = Module.parse(
202*9f3f6d7bSStella Laurenzo            r"""
203*9f3f6d7bSStella Laurenzofunc @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
204*9f3f6d7bSStella Laurenzo  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
205*9f3f6d7bSStella Laurenzo  return
206*9f3f6d7bSStella Laurenzo}
207*9f3f6d7bSStella Laurenzofunc private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
208*9f3f6d7bSStella Laurenzo"""
209*9f3f6d7bSStella Laurenzo        )
210*9f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
211*9f3f6d7bSStella Laurenzo        execution_engine.register_runtime("some_callback_into_python", callback)
212*9f3f6d7bSStella Laurenzo        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
213*9f3f6d7bSStella Laurenzo        # CHECK: Inside Callback:
214*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}: [[1. 5.]
215*9f3f6d7bSStella Laurenzo        # CHECK{LITERAL}:  [6. 7.]]
216*9f3f6d7bSStella Laurenzo        execution_engine.invoke(
217*9f3f6d7bSStella Laurenzo            "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))
218*9f3f6d7bSStella Laurenzo        )
219*9f3f6d7bSStella Laurenzo
220*9f3f6d7bSStella Laurenzorun(testRankedMemRefCallback)
221*9f3f6d7bSStella Laurenzo
222*9f3f6d7bSStella Laurenzo#  Test addition of two memref
223*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testMemrefAdd
224*9f3f6d7bSStella Laurenzodef testMemrefAdd():
225*9f3f6d7bSStella Laurenzo    with Context():
226*9f3f6d7bSStella Laurenzo        module = Module.parse(
227*9f3f6d7bSStella Laurenzo            """
228*9f3f6d7bSStella Laurenzo      module  {
229*9f3f6d7bSStella Laurenzo      func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
230*9f3f6d7bSStella Laurenzo        %0 = constant 0 : index
231*9f3f6d7bSStella Laurenzo        %1 = memref.load %arg0[%0] : memref<1xf32>
232*9f3f6d7bSStella Laurenzo        %2 = memref.load %arg1[] : memref<f32>
233*9f3f6d7bSStella Laurenzo        %3 = addf %1, %2 : f32
234*9f3f6d7bSStella Laurenzo        memref.store %3, %arg2[%0] : memref<1xf32>
235*9f3f6d7bSStella Laurenzo        return
236*9f3f6d7bSStella Laurenzo      }
237*9f3f6d7bSStella Laurenzo     } """
238*9f3f6d7bSStella Laurenzo        )
239*9f3f6d7bSStella Laurenzo        arg1 = np.array([32.5]).astype(np.float32)
240*9f3f6d7bSStella Laurenzo        arg2 = np.array(6).astype(np.float32)
241*9f3f6d7bSStella Laurenzo        res = np.array([0]).astype(np.float32)
242*9f3f6d7bSStella Laurenzo
243*9f3f6d7bSStella Laurenzo        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
244*9f3f6d7bSStella Laurenzo        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
245*9f3f6d7bSStella Laurenzo        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
246*9f3f6d7bSStella Laurenzo
247*9f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
248*9f3f6d7bSStella Laurenzo        execution_engine.invoke(
249*9f3f6d7bSStella Laurenzo            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
250*9f3f6d7bSStella Laurenzo        )
251*9f3f6d7bSStella Laurenzo        # CHECK: [32.5] + 6.0 = [38.5]
252*9f3f6d7bSStella Laurenzo        log("{0} + {1} = {2}".format(arg1, arg2, res))
253*9f3f6d7bSStella Laurenzo
254*9f3f6d7bSStella Laurenzorun(testMemrefAdd)
255*9f3f6d7bSStella Laurenzo
256*9f3f6d7bSStella Laurenzo#  Test addition of two 2d_memref
257*9f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
258*9f3f6d7bSStella Laurenzodef testDynamicMemrefAdd2D():
259*9f3f6d7bSStella Laurenzo    with Context():
260*9f3f6d7bSStella Laurenzo        module = Module.parse(
261*9f3f6d7bSStella Laurenzo      """
262*9f3f6d7bSStella Laurenzo      module  {
263*9f3f6d7bSStella Laurenzo        func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
264*9f3f6d7bSStella Laurenzo          %c0 = constant 0 : index
265*9f3f6d7bSStella Laurenzo          %c2 = constant 2 : index
266*9f3f6d7bSStella Laurenzo          %c1 = constant 1 : index
267*9f3f6d7bSStella Laurenzo          br ^bb1(%c0 : index)
268*9f3f6d7bSStella Laurenzo        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
269*9f3f6d7bSStella Laurenzo          %1 = cmpi slt, %0, %c2 : index
270*9f3f6d7bSStella Laurenzo          cond_br %1, ^bb2, ^bb6
271*9f3f6d7bSStella Laurenzo        ^bb2:  // pred: ^bb1
272*9f3f6d7bSStella Laurenzo          %c0_0 = constant 0 : index
273*9f3f6d7bSStella Laurenzo          %c2_1 = constant 2 : index
274*9f3f6d7bSStella Laurenzo          %c1_2 = constant 1 : index
275*9f3f6d7bSStella Laurenzo          br ^bb3(%c0_0 : index)
276*9f3f6d7bSStella Laurenzo        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
277*9f3f6d7bSStella Laurenzo          %3 = cmpi slt, %2, %c2_1 : index
278*9f3f6d7bSStella Laurenzo          cond_br %3, ^bb4, ^bb5
279*9f3f6d7bSStella Laurenzo        ^bb4:  // pred: ^bb3
280*9f3f6d7bSStella Laurenzo          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
281*9f3f6d7bSStella Laurenzo          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
282*9f3f6d7bSStella Laurenzo          %6 = addf %4, %5 : f32
283*9f3f6d7bSStella Laurenzo          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
284*9f3f6d7bSStella Laurenzo          %7 = addi %2, %c1_2 : index
285*9f3f6d7bSStella Laurenzo          br ^bb3(%7 : index)
286*9f3f6d7bSStella Laurenzo        ^bb5:  // pred: ^bb3
287*9f3f6d7bSStella Laurenzo          %8 = addi %0, %c1 : index
288*9f3f6d7bSStella Laurenzo          br ^bb1(%8 : index)
289*9f3f6d7bSStella Laurenzo        ^bb6:  // pred: ^bb1
290*9f3f6d7bSStella Laurenzo          return
291*9f3f6d7bSStella Laurenzo        }
292*9f3f6d7bSStella Laurenzo      }
293*9f3f6d7bSStella Laurenzo        """
294*9f3f6d7bSStella Laurenzo        )
295*9f3f6d7bSStella Laurenzo        arg1 = np.random.randn(2,2).astype(np.float32)
296*9f3f6d7bSStella Laurenzo        arg2 = np.random.randn(2,2).astype(np.float32)
297*9f3f6d7bSStella Laurenzo        res = np.random.randn(2,2).astype(np.float32)
298*9f3f6d7bSStella Laurenzo
299*9f3f6d7bSStella Laurenzo        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
300*9f3f6d7bSStella Laurenzo        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
301*9f3f6d7bSStella Laurenzo        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
302*9f3f6d7bSStella Laurenzo
303*9f3f6d7bSStella Laurenzo        execution_engine = ExecutionEngine(lowerToLLVM(module))
304*9f3f6d7bSStella Laurenzo        execution_engine.invoke(
305*9f3f6d7bSStella Laurenzo            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
306*9f3f6d7bSStella Laurenzo        )
307*9f3f6d7bSStella Laurenzo        # CHECK: True
308*9f3f6d7bSStella Laurenzo        log(np.allclose(arg1+arg2, res))
309*9f3f6d7bSStella Laurenzo
310*9f3f6d7bSStella Laurenzorun(testDynamicMemrefAdd2D)
311