1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2# REQUIRES: native
3import gc, sys
4from mlir.ir import *
5from mlir.passmanager import *
6from mlir.execution_engine import *
7from mlir.runtime import *
8
9
10# Log everything to stderr and flush so that we have a unified stream to match
11# errors/info emitted by MLIR to stderr.
12def log(*args):
13  print(*args, file=sys.stderr)
14  sys.stderr.flush()
15
16
17def run(f):
18  log("\nTEST:", f.__name__)
19  f()
20  gc.collect()
21  assert Context._get_live_count() == 0
22
23
24# Verify capsule interop.
25# CHECK-LABEL: TEST: testCapsule
26def testCapsule():
27  with Context():
28    module = Module.parse(r"""
29llvm.func @none() {
30  llvm.return
31}
32    """)
33    execution_engine = ExecutionEngine(module)
34    execution_engine_capsule = execution_engine._CAPIPtr
35    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
36    log(repr(execution_engine_capsule))
37    execution_engine._testing_release()
38    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
39    # CHECK: _mlirExecutionEngine.ExecutionEngine
40    log(repr(execution_engine1))
41
42
43run(testCapsule)
44
45
46# Test invalid ExecutionEngine creation
47# CHECK-LABEL: TEST: testInvalidModule
48def testInvalidModule():
49  with Context():
50    # Builtin function
51    module = Module.parse(r"""
52    func.func @foo() { return }
53    """)
54    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
55    try:
56      execution_engine = ExecutionEngine(module)
57    except RuntimeError as e:
58      log("Got RuntimeError: ", e)
59
60
61run(testInvalidModule)
62
63
64def lowerToLLVM(module):
65  import mlir.conversions
66  pm = PassManager.parse(
67      "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
68  pm.run(module)
69  return module
70
71
72# Test simple ExecutionEngine execution
73# CHECK-LABEL: TEST: testInvokeVoid
74def testInvokeVoid():
75  with Context():
76    module = Module.parse(r"""
77func.func @void() attributes { llvm.emit_c_interface } {
78  return
79}
80    """)
81    execution_engine = ExecutionEngine(lowerToLLVM(module))
82    # Nothing to check other than no exception thrown here.
83    execution_engine.invoke("void")
84
85
86run(testInvokeVoid)
87
88
89# Test argument passing and result with a simple float addition.
90# CHECK-LABEL: TEST: testInvokeFloatAdd
91def testInvokeFloatAdd():
92  with Context():
93    module = Module.parse(r"""
94func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
95  %add = arith.addf %arg0, %arg1 : f32
96  return %add : f32
97}
98    """)
99    execution_engine = ExecutionEngine(lowerToLLVM(module))
100    # Prepare arguments: two input floats and one result.
101    # Arguments must be passed as pointers.
102    c_float_p = ctypes.c_float * 1
103    arg0 = c_float_p(42.)
104    arg1 = c_float_p(2.)
105    res = c_float_p(-1.)
106    execution_engine.invoke("add", arg0, arg1, res)
107    # CHECK: 42.0 + 2.0 = 44.0
108    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
109
110
111run(testInvokeFloatAdd)
112
113
114# Test callback
115# CHECK-LABEL: TEST: testBasicCallback
116def testBasicCallback():
117  # Define a callback function that takes a float and an integer and returns a float.
118  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
119  def callback(a, b):
120    return a / 2 + b / 2
121
122  with Context():
123    # The module just forwards to a runtime function known as "some_callback_into_python".
124    module = Module.parse(r"""
125func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
126  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
127  return %resf : f32
128}
129func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
130    """)
131    execution_engine = ExecutionEngine(lowerToLLVM(module))
132    execution_engine.register_runtime("some_callback_into_python", callback)
133
134    # Prepare arguments: two input floats and one result.
135    # Arguments must be passed as pointers.
136    c_float_p = ctypes.c_float * 1
137    c_int_p = ctypes.c_int * 1
138    arg0 = c_float_p(42.)
139    arg1 = c_int_p(2)
140    res = c_float_p(-1.)
141    execution_engine.invoke("add", arg0, arg1, res)
142    # CHECK: 42.0 + 2 = 44.0
143    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
144
145
146run(testBasicCallback)
147
148
149# Test callback with an unranked memref
150# CHECK-LABEL: TEST: testUnrankedMemRefCallback
151def testUnrankedMemRefCallback():
152  # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
153  @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
154  def callback(a):
155    arr = unranked_memref_to_numpy(a, np.float32)
156    log("Inside callback: ")
157    log(arr)
158
159  with Context():
160    # The module just forwards to a runtime function known as "some_callback_into_python".
161    module = Module.parse(r"""
162func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
163  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
164  return
165}
166func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
167""")
168    execution_engine = ExecutionEngine(lowerToLLVM(module))
169    execution_engine.register_runtime("some_callback_into_python", callback)
170    inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
171    # CHECK: Inside callback:
172    # CHECK{LITERAL}: [[1. 2.]
173    # CHECK{LITERAL}:  [3. 4.]]
174    execution_engine.invoke(
175        "callback_memref",
176        ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
177    )
178    inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
179    strided_arr = np.lib.stride_tricks.as_strided(
180        inp_arr_1, strides=(4, 0), shape=(3, 4))
181    # CHECK: Inside callback:
182    # CHECK{LITERAL}: [[5. 5. 5. 5.]
183    # CHECK{LITERAL}:  [6. 6. 6. 6.]
184    # CHECK{LITERAL}:  [7. 7. 7. 7.]]
185    execution_engine.invoke(
186        "callback_memref",
187        ctypes.pointer(
188            ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
189    )
190
191
192run(testUnrankedMemRefCallback)
193
194
195# Test callback with a ranked memref.
196# CHECK-LABEL: TEST: testRankedMemRefCallback
197def testRankedMemRefCallback():
198  # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
199  @ctypes.CFUNCTYPE(
200      None,
201      ctypes.POINTER(
202          make_nd_memref_descriptor(2,
203                                    np.ctypeslib.as_ctypes_type(np.float32))),
204  )
205  def callback(a):
206    arr = ranked_memref_to_numpy(a)
207    log("Inside Callback: ")
208    log(arr)
209
210  with Context():
211    # The module just forwards to a runtime function known as "some_callback_into_python".
212    module = Module.parse(r"""
213func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
214  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
215  return
216}
217func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
218""")
219    execution_engine = ExecutionEngine(lowerToLLVM(module))
220    execution_engine.register_runtime("some_callback_into_python", callback)
221    inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
222    # CHECK: Inside Callback:
223    # CHECK{LITERAL}: [[1. 5.]
224    # CHECK{LITERAL}:  [6. 7.]]
225    execution_engine.invoke(
226        "callback_memref",
227        ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))))
228
229
230run(testRankedMemRefCallback)
231
232
233#  Test addition of two memrefs.
234# CHECK-LABEL: TEST: testMemrefAdd
235def testMemrefAdd():
236  with Context():
237    module = Module.parse("""
238    module  {
239      func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
240        %0 = arith.constant 0 : index
241        %1 = memref.load %arg0[%0] : memref<1xf32>
242        %2 = memref.load %arg1[] : memref<f32>
243        %3 = arith.addf %1, %2 : f32
244        memref.store %3, %arg2[%0] : memref<1xf32>
245        return
246      }
247    } """)
248    arg1 = np.array([32.5]).astype(np.float32)
249    arg2 = np.array(6).astype(np.float32)
250    res = np.array([0]).astype(np.float32)
251
252    arg1_memref_ptr = ctypes.pointer(
253        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
254    arg2_memref_ptr = ctypes.pointer(
255        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
256    res_memref_ptr = ctypes.pointer(
257        ctypes.pointer(get_ranked_memref_descriptor(res)))
258
259    execution_engine = ExecutionEngine(lowerToLLVM(module))
260    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
261                            res_memref_ptr)
262    # CHECK: [32.5] + 6.0 = [38.5]
263    log("{0} + {1} = {2}".format(arg1, arg2, res))
264
265
266run(testMemrefAdd)
267
268
269# Test addition of two f16 memrefs
270# CHECK-LABEL: TEST: testF16MemrefAdd
271def testF16MemrefAdd():
272  with Context():
273    module = Module.parse("""
274    module  {
275      func.func @main(%arg0: memref<1xf16>,
276                      %arg1: memref<1xf16>,
277                      %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
278        %0 = arith.constant 0 : index
279        %1 = memref.load %arg0[%0] : memref<1xf16>
280        %2 = memref.load %arg1[%0] : memref<1xf16>
281        %3 = arith.addf %1, %2 : f16
282        memref.store %3, %arg2[%0] : memref<1xf16>
283        return
284      }
285    } """)
286
287    arg1 = np.array([11.]).astype(np.float16)
288    arg2 = np.array([22.]).astype(np.float16)
289    arg3 = np.array([0.]).astype(np.float16)
290
291    arg1_memref_ptr = ctypes.pointer(
292        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
293    arg2_memref_ptr = ctypes.pointer(
294        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
295    arg3_memref_ptr = ctypes.pointer(
296        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
297
298    execution_engine = ExecutionEngine(lowerToLLVM(module))
299    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
300                            arg3_memref_ptr)
301    # CHECK: [11.] + [22.] = [33.]
302    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
303
304    # test to-numpy utility
305    # CHECK: [33.]
306    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
307    log(npout)
308
309
310run(testF16MemrefAdd)
311
312
313# Test addition of two complex memrefs
314# CHECK-LABEL: TEST: testComplexMemrefAdd
315def testComplexMemrefAdd():
316  with Context():
317    module = Module.parse("""
318    module  {
319      func.func @main(%arg0: memref<1xcomplex<f64>>,
320                      %arg1: memref<1xcomplex<f64>>,
321                      %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
322        %0 = arith.constant 0 : index
323        %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
324        %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
325        %3 = complex.add %1, %2 : complex<f64>
326        memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
327        return
328      }
329    } """)
330
331    arg1 = np.array([1.+2.j]).astype(np.complex128)
332    arg2 = np.array([3.+4.j]).astype(np.complex128)
333    arg3  = np.array([0.+0.j]).astype(np.complex128)
334
335    arg1_memref_ptr = ctypes.pointer(
336        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
337    arg2_memref_ptr = ctypes.pointer(
338        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
339    arg3_memref_ptr = ctypes.pointer(
340        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
341
342    execution_engine = ExecutionEngine(lowerToLLVM(module))
343    execution_engine.invoke("main",
344                            arg1_memref_ptr,
345                            arg2_memref_ptr,
346                            arg3_memref_ptr)
347    # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
348    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
349
350    # test to-numpy utility
351    # CHECK: [4.+6.j]
352    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
353    log(npout)
354
355
356run(testComplexMemrefAdd)
357
358
359# Test addition of two complex unranked memrefs
360# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
361def testComplexUnrankedMemrefAdd():
362  with Context():
363    module = Module.parse("""
364    module  {
365      func.func @main(%arg0: memref<*xcomplex<f32>>,
366                      %arg1: memref<*xcomplex<f32>>,
367                      %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
368        %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
369        %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
370        %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
371        %0 = arith.constant 0 : index
372        %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
373        %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
374        %3 = complex.add %1, %2 : complex<f32>
375        memref.store %3, %C[%0] : memref<1xcomplex<f32>>
376        return
377      }
378    } """)
379
380    arg1 = np.array([5.+6.j]).astype(np.complex64)
381    arg2 = np.array([7.+8.j]).astype(np.complex64)
382    arg3  = np.array([0.+0.j]).astype(np.complex64)
383
384    arg1_memref_ptr = ctypes.pointer(
385        ctypes.pointer(get_unranked_memref_descriptor(arg1)))
386    arg2_memref_ptr = ctypes.pointer(
387        ctypes.pointer(get_unranked_memref_descriptor(arg2)))
388    arg3_memref_ptr = ctypes.pointer(
389        ctypes.pointer(get_unranked_memref_descriptor(arg3)))
390
391    execution_engine = ExecutionEngine(lowerToLLVM(module))
392    execution_engine.invoke("main",
393                            arg1_memref_ptr,
394                            arg2_memref_ptr,
395                            arg3_memref_ptr)
396    # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
397    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
398
399    # test to-numpy utility
400    # CHECK: [12.+14.j]
401    npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
402                                     np.dtype(np.complex64))
403    log(npout)
404
405
406run(testComplexUnrankedMemrefAdd)
407
408
409#  Test addition of two 2d_memref
410# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
411def testDynamicMemrefAdd2D():
412  with Context():
413    module = Module.parse("""
414      module  {
415        func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
416          %c0 = arith.constant 0 : index
417          %c2 = arith.constant 2 : index
418          %c1 = arith.constant 1 : index
419          cf.br ^bb1(%c0 : index)
420        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
421          %1 = arith.cmpi slt, %0, %c2 : index
422          cf.cond_br %1, ^bb2, ^bb6
423        ^bb2:  // pred: ^bb1
424          %c0_0 = arith.constant 0 : index
425          %c2_1 = arith.constant 2 : index
426          %c1_2 = arith.constant 1 : index
427          cf.br ^bb3(%c0_0 : index)
428        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
429          %3 = arith.cmpi slt, %2, %c2_1 : index
430          cf.cond_br %3, ^bb4, ^bb5
431        ^bb4:  // pred: ^bb3
432          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
433          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
434          %6 = arith.addf %4, %5 : f32
435          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
436          %7 = arith.addi %2, %c1_2 : index
437          cf.br ^bb3(%7 : index)
438        ^bb5:  // pred: ^bb3
439          %8 = arith.addi %0, %c1 : index
440          cf.br ^bb1(%8 : index)
441        ^bb6:  // pred: ^bb1
442          return
443        }
444      }
445        """)
446    arg1 = np.random.randn(2, 2).astype(np.float32)
447    arg2 = np.random.randn(2, 2).astype(np.float32)
448    res = np.random.randn(2, 2).astype(np.float32)
449
450    arg1_memref_ptr = ctypes.pointer(
451        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
452    arg2_memref_ptr = ctypes.pointer(
453        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
454    res_memref_ptr = ctypes.pointer(
455        ctypes.pointer(get_ranked_memref_descriptor(res)))
456
457    execution_engine = ExecutionEngine(lowerToLLVM(module))
458    execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr,
459                            res_memref_ptr)
460    # CHECK: True
461    log(np.allclose(arg1 + arg2, res))
462
463
464run(testDynamicMemrefAdd2D)
465
466
467#  Test loading of shared libraries.
468# CHECK-LABEL: TEST: testSharedLibLoad
469def testSharedLibLoad():
470  with Context():
471    module = Module.parse("""
472      module  {
473      func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
474        %c0 = arith.constant 0 : index
475        %cst42 = arith.constant 42.0 : f32
476        memref.store %cst42, %arg0[%c0] : memref<1xf32>
477        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
478        call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
479        return
480      }
481      func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
482     } """)
483    arg0 = np.array([0.0]).astype(np.float32)
484
485    arg0_memref_ptr = ctypes.pointer(
486        ctypes.pointer(get_ranked_memref_descriptor(arg0)))
487
488    if sys.platform == 'win32':
489      shared_libs = [
490          "../../../../bin/mlir_runner_utils.dll",
491          "../../../../bin/mlir_c_runner_utils.dll"
492      ]
493    else:
494      shared_libs = [
495          "../../../../lib/libmlir_runner_utils.so",
496          "../../../../lib/libmlir_c_runner_utils.so"
497      ]
498
499    execution_engine = ExecutionEngine(
500        lowerToLLVM(module),
501        opt_level=3,
502        shared_libs=shared_libs)
503    execution_engine.invoke("main", arg0_memref_ptr)
504    # CHECK: Unranked Memref
505    # CHECK-NEXT: [42]
506
507
508run(testSharedLibLoad)
509
510
511#  Test that nano time clock is available.
512# CHECK-LABEL: TEST: testNanoTime
513def testNanoTime():
514  with Context():
515    module = Module.parse("""
516      module {
517      func.func @main() attributes { llvm.emit_c_interface } {
518        %now = call @nanoTime() : () -> i64
519        %memref = memref.alloca() : memref<1xi64>
520        %c0 = arith.constant 0 : index
521        memref.store %now, %memref[%c0] : memref<1xi64>
522        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
523        call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
524        return
525      }
526      func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
527      func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
528    }""")
529
530    if sys.platform == 'win32':
531      shared_libs = [
532          "../../../../bin/mlir_runner_utils.dll",
533          "../../../../bin/mlir_c_runner_utils.dll"
534      ]
535    else:
536      shared_libs = [
537          "../../../../lib/libmlir_runner_utils.so",
538          "../../../../lib/libmlir_c_runner_utils.so"
539      ]
540
541    execution_engine = ExecutionEngine(
542        lowerToLLVM(module),
543        opt_level=3,
544        shared_libs=shared_libs)
545    execution_engine.invoke("main")
546    # CHECK: Unranked Memref
547    # CHECK: [{{.*}}]
548
549
550run(testNanoTime)
551