1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import gc, sys
4from mlir.ir import *
5from mlir.passmanager import *
6from mlir.execution_engine import *
7from mlir.runtime import *
8
9# Log everything to stderr and flush so that we have a unified stream to match
10# errors/info emitted by MLIR to stderr.
11def log(*args):
12  print(*args, file=sys.stderr)
13  sys.stderr.flush()
14
15def run(f):
16  log("\nTEST:", f.__name__)
17  f()
18  gc.collect()
19  assert Context._get_live_count() == 0
20
21# Verify capsule interop.
22# CHECK-LABEL: TEST: testCapsule
23def testCapsule():
24  with Context():
25    module = Module.parse(r"""
26llvm.func @none() {
27  llvm.return
28}
29    """)
30    execution_engine = ExecutionEngine(module)
31    execution_engine_capsule = execution_engine._CAPIPtr
32    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
33    log(repr(execution_engine_capsule))
34    execution_engine._testing_release()
35    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
36    # CHECK: _mlirExecutionEngine.ExecutionEngine
37    log(repr(execution_engine1))
38
39run(testCapsule)
40
41# Test invalid ExecutionEngine creation
42# CHECK-LABEL: TEST: testInvalidModule
43def testInvalidModule():
44  with Context():
45    # Builtin function
46    module = Module.parse(r"""
47    func @foo() { return }
48    """)
49    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
50    try:
51      execution_engine = ExecutionEngine(module)
52    except RuntimeError as e:
53      log("Got RuntimeError: ", e)
54
55run(testInvalidModule)
56
57def lowerToLLVM(module):
58  import mlir.conversions
59  pm = PassManager.parse("convert-memref-to-llvm,convert-std-to-llvm")
60  pm.run(module)
61  return module
62
63# Test simple ExecutionEngine execution
64# CHECK-LABEL: TEST: testInvokeVoid
65def testInvokeVoid():
66  with Context():
67    module = Module.parse(r"""
68func @void() attributes { llvm.emit_c_interface } {
69  return
70}
71    """)
72    execution_engine = ExecutionEngine(lowerToLLVM(module))
73    # Nothing to check other than no exception thrown here.
74    execution_engine.invoke("void")
75
76run(testInvokeVoid)
77
78
79# Test argument passing and result with a simple float addition.
80# CHECK-LABEL: TEST: testInvokeFloatAdd
81def testInvokeFloatAdd():
82  with Context():
83    module = Module.parse(r"""
84func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
85  %add = std.addf %arg0, %arg1 : f32
86  return %add : f32
87}
88    """)
89    execution_engine = ExecutionEngine(lowerToLLVM(module))
90    # Prepare arguments: two input floats and one result.
91    # Arguments must be passed as pointers.
92    c_float_p = ctypes.c_float * 1
93    arg0 = c_float_p(42.)
94    arg1 = c_float_p(2.)
95    res = c_float_p(-1.)
96    execution_engine.invoke("add", arg0, arg1, res)
97    # CHECK: 42.0 + 2.0 = 44.0
98    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
99
100run(testInvokeFloatAdd)
101
102
103# Test callback
104# CHECK-LABEL: TEST: testBasicCallback
105def testBasicCallback():
106  # Define a callback function that takes a float and an integer and returns a float.
107  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
108  def callback(a, b):
109    return a/2 + b/2
110
111  with Context():
112    # The module just forwards to a runtime function known as "some_callback_into_python".
113    module = Module.parse(r"""
114func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
115  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
116  return %resf : f32
117}
118func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
119    """)
120    execution_engine = ExecutionEngine(lowerToLLVM(module))
121    execution_engine.register_runtime("some_callback_into_python", callback)
122
123    # Prepare arguments: two input floats and one result.
124    # Arguments must be passed as pointers.
125    c_float_p = ctypes.c_float * 1
126    c_int_p = ctypes.c_int * 1
127    arg0 = c_float_p(42.)
128    arg1 = c_int_p(2)
129    res = c_float_p(-1.)
130    execution_engine.invoke("add", arg0, arg1, res)
131    # CHECK: 42.0 + 2 = 44.0
132    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
133
134run(testBasicCallback)
135
136# Test callback with an unranked memref
137# CHECK-LABEL: TEST: testUnrankedMemRefCallback
138def testUnrankedMemRefCallback():
139    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
140    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
141    def callback(a):
142        arr = unranked_memref_to_numpy(a, np.float32)
143        log("Inside callback: ")
144        log(arr)
145
146    with Context():
147        # The module just forwards to a runtime function known as "some_callback_into_python".
148        module = Module.parse(
149            r"""
150func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
151  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
152  return
153}
154func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
155"""
156        )
157        execution_engine = ExecutionEngine(lowerToLLVM(module))
158        execution_engine.register_runtime("some_callback_into_python", callback)
159        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
160        # CHECK: Inside callback:
161        # CHECK{LITERAL}: [[1. 2.]
162        # CHECK{LITERAL}:  [3. 4.]]
163        execution_engine.invoke(
164            "callback_memref",
165            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
166        )
167        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
168        strided_arr = np.lib.stride_tricks.as_strided(
169            inp_arr_1, strides=(4, 0), shape=(3, 4)
170        )
171        # CHECK: Inside callback:
172        # CHECK{LITERAL}: [[5. 5. 5. 5.]
173        # CHECK{LITERAL}:  [6. 6. 6. 6.]
174        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
175        execution_engine.invoke(
176            "callback_memref",
177            ctypes.pointer(
178                ctypes.pointer(get_unranked_memref_descriptor(strided_arr))
179            ),
180        )
181
182run(testUnrankedMemRefCallback)
183
184# Test callback with a ranked memref.
185# CHECK-LABEL: TEST: testRankedMemRefCallback
186def testRankedMemRefCallback():
187    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
188    @ctypes.CFUNCTYPE(
189        None,
190        ctypes.POINTER(
191            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
192        ),
193    )
194    def callback(a):
195        arr = ranked_memref_to_numpy(a)
196        log("Inside Callback: ")
197        log(arr)
198
199    with Context():
200        # The module just forwards to a runtime function known as "some_callback_into_python".
201        module = Module.parse(
202            r"""
203func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
204  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
205  return
206}
207func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
208"""
209        )
210        execution_engine = ExecutionEngine(lowerToLLVM(module))
211        execution_engine.register_runtime("some_callback_into_python", callback)
212        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
213        # CHECK: Inside Callback:
214        # CHECK{LITERAL}: [[1. 5.]
215        # CHECK{LITERAL}:  [6. 7.]]
216        execution_engine.invoke(
217            "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))
218        )
219
220run(testRankedMemRefCallback)
221
222#  Test addition of two memrefs.
223# CHECK-LABEL: TEST: testMemrefAdd
224def testMemrefAdd():
225    with Context():
226        module = Module.parse(
227            """
228      module  {
229      func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
230        %0 = constant 0 : index
231        %1 = memref.load %arg0[%0] : memref<1xf32>
232        %2 = memref.load %arg1[] : memref<f32>
233        %3 = addf %1, %2 : f32
234        memref.store %3, %arg2[%0] : memref<1xf32>
235        return
236      }
237     } """
238        )
239        arg1 = np.array([32.5]).astype(np.float32)
240        arg2 = np.array(6).astype(np.float32)
241        res = np.array([0]).astype(np.float32)
242
243        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
244        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
245        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
246
247        execution_engine = ExecutionEngine(lowerToLLVM(module))
248        execution_engine.invoke(
249            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
250        )
251        # CHECK: [32.5] + 6.0 = [38.5]
252        log("{0} + {1} = {2}".format(arg1, arg2, res))
253
254run(testMemrefAdd)
255
256#  Test addition of two 2d_memref
257# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
258def testDynamicMemrefAdd2D():
259    with Context():
260        module = Module.parse(
261      """
262      module  {
263        func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
264          %c0 = constant 0 : index
265          %c2 = constant 2 : index
266          %c1 = constant 1 : index
267          br ^bb1(%c0 : index)
268        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
269          %1 = cmpi slt, %0, %c2 : index
270          cond_br %1, ^bb2, ^bb6
271        ^bb2:  // pred: ^bb1
272          %c0_0 = constant 0 : index
273          %c2_1 = constant 2 : index
274          %c1_2 = constant 1 : index
275          br ^bb3(%c0_0 : index)
276        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
277          %3 = cmpi slt, %2, %c2_1 : index
278          cond_br %3, ^bb4, ^bb5
279        ^bb4:  // pred: ^bb3
280          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
281          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
282          %6 = addf %4, %5 : f32
283          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
284          %7 = addi %2, %c1_2 : index
285          br ^bb3(%7 : index)
286        ^bb5:  // pred: ^bb3
287          %8 = addi %0, %c1 : index
288          br ^bb1(%8 : index)
289        ^bb6:  // pred: ^bb1
290          return
291        }
292      }
293        """
294        )
295        arg1 = np.random.randn(2,2).astype(np.float32)
296        arg2 = np.random.randn(2,2).astype(np.float32)
297        res = np.random.randn(2,2).astype(np.float32)
298
299        arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
300        arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
301        res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))
302
303        execution_engine = ExecutionEngine(lowerToLLVM(module))
304        execution_engine.invoke(
305            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
306        )
307        # CHECK: True
308        log(np.allclose(arg1+arg2, res))
309
310run(testDynamicMemrefAdd2D)
311
312#  Test loading of shared libraries.
313# CHECK-LABEL: TEST: testSharedLibLoad
314def testSharedLibLoad():
315    with Context():
316        module = Module.parse(
317            """
318      module  {
319      func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
320        %c0 = constant 0 : index
321        %cst42 = constant 42.0 : f32
322        memref.store %cst42, %arg0[%c0] : memref<1xf32>
323        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
324        call @print_memref_f32(%u_memref) : (memref<*xf32>) -> ()
325        return
326      }
327      func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
328     } """
329        )
330        arg0 = np.array([0.0]).astype(np.float32)
331
332        arg0_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg0)))
333
334        execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3,
335                shared_libs=["../../../../lib/libmlir_runner_utils.so",
336                    "../../../../lib/libmlir_c_runner_utils.so"])
337        execution_engine.invoke("main", arg0_memref_ptr)
338        # CHECK: Unranked Memref
339        # CHECK-NEXT: [42]
340
341run(testSharedLibLoad)
342