1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2# REQUIRES: host-supports-jit
3import gc, sys
4from mlir.ir import *
5from mlir.passmanager import *
6from mlir.execution_engine import *
7from mlir.runtime import *
8
9
10# Log everything to stderr and flush so that we have a unified stream to match
11# errors/info emitted by MLIR to stderr.
12def log(*args):
13  print(*args, file=sys.stderr)
14  sys.stderr.flush()
15
16
17def run(f):
18  log("\nTEST:", f.__name__)
19  f()
20  gc.collect()
21  assert Context._get_live_count() == 0
22
23
24# Verify capsule interop.
25# CHECK-LABEL: TEST: testCapsule
26def testCapsule():
27  with Context():
28    module = Module.parse(r"""
29llvm.func @none() {
30  llvm.return
31}
32    """)
33    execution_engine = ExecutionEngine(module)
34    execution_engine_capsule = execution_engine._CAPIPtr
35    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
36    log(repr(execution_engine_capsule))
37    execution_engine._testing_release()
38    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
39    # CHECK: _mlirExecutionEngine.ExecutionEngine
40    log(repr(execution_engine1))
41
42
43run(testCapsule)
44
45
46# Test invalid ExecutionEngine creation
47# CHECK-LABEL: TEST: testInvalidModule
48def testInvalidModule():
49  with Context():
50    # Builtin function
51    module = Module.parse(r"""
52    func.func @foo() { return }
53    """)
54    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
55    try:
56      execution_engine = ExecutionEngine(module)
57    except RuntimeError as e:
58      log("Got RuntimeError: ", e)
59
60
61run(testInvalidModule)
62
63
64def lowerToLLVM(module):
65  pm = PassManager.parse(
66      "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts")
67  pm.run(module)
68  return module
69
70
71# Test simple ExecutionEngine execution
72# CHECK-LABEL: TEST: testInvokeVoid
73def testInvokeVoid():
74  with Context():
75    module = Module.parse(r"""
76func.func @void() attributes { llvm.emit_c_interface } {
77  return
78}
79    """)
80    execution_engine = ExecutionEngine(lowerToLLVM(module))
81    # Nothing to check other than no exception thrown here.
82    execution_engine.invoke("void")
83
84
85run(testInvokeVoid)
86
87
88# Test argument passing and result with a simple float addition.
89# CHECK-LABEL: TEST: testInvokeFloatAdd
90def testInvokeFloatAdd():
91  with Context():
92    module = Module.parse(r"""
93func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
94  %add = arith.addf %arg0, %arg1 : f32
95  return %add : f32
96}
97    """)
98    execution_engine = ExecutionEngine(lowerToLLVM(module))
99    # Prepare arguments: two input floats and one result.
100    # Arguments must be passed as pointers.
101    c_float_p = ctypes.c_float * 1
102    arg0 = c_float_p(42.)
103    arg1 = c_float_p(2.)
104    res = c_float_p(-1.)
105    execution_engine.invoke("add", arg0, arg1, res)
106    # CHECK: 42.0 + 2.0 = 44.0
107    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
108
109
110run(testInvokeFloatAdd)
111
112
113# Test callback
114# CHECK-LABEL: TEST: testBasicCallback
115def testBasicCallback():
116  # Define a callback function that takes a float and an integer and returns a float.
117  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
118  def callback(a, b):
119    return a / 2 + b / 2
120
121  with Context():
122    # The module just forwards to a runtime function known as "some_callback_into_python".
123    module = Module.parse(r"""
124func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
125  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
126  return %resf : f32
127}
128func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
129    """)
130    execution_engine = ExecutionEngine(lowerToLLVM(module))
131    execution_engine.register_runtime("some_callback_into_python", callback)
132
133    # Prepare arguments: two input floats and one result.
134    # Arguments must be passed as pointers.
135    c_float_p = ctypes.c_float * 1
136    c_int_p = ctypes.c_int * 1
137    arg0 = c_float_p(42.)
138    arg1 = c_int_p(2)
139    res = c_float_p(-1.)
140    execution_engine.invoke("add", arg0, arg1, res)
141    # CHECK: 42.0 + 2 = 44.0
142    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
143
144
145run(testBasicCallback)
146
147
148# Test callback with an unranked memref
149# CHECK-LABEL: TEST: testUnrankedMemRefCallback
150def testUnrankedMemRefCallback():
151  # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
152  @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
153  def callback(a):
154    arr = unranked_memref_to_numpy(a, np.float32)
155    log("Inside callback: ")
156    log(arr)
157
158  with Context():
159    # The module just forwards to a runtime function known as "some_callback_into_python".
160    module = Module.parse(r"""
161func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
162  call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
163  return
164}
165func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
166""")
167    execution_engine = ExecutionEngine(lowerToLLVM(module))
168    execution_engine.register_runtime("some_callback_into_python", callback)
169    inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
170    # CHECK: Inside callback:
171    # CHECK{LITERAL}: [[1. 2.]
172    # CHECK{LITERAL}:  [3. 4.]]
173    execution_engine.invoke(
174        "callback_memref",
175        ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
176    )
177    inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
178    strided_arr = np.lib.stride_tricks.as_strided(
179        inp_arr_1, strides=(4, 0), shape=(3, 4))
180    # CHECK: Inside callback:
181    # CHECK{LITERAL}: [[5. 5. 5. 5.]
182    # CHECK{LITERAL}:  [6. 6. 6. 6.]
183    # CHECK{LITERAL}:  [7. 7. 7. 7.]]
184    execution_engine.invoke(
185        "callback_memref",
186        ctypes.pointer(
187            ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
188    )
189
190
191run(testUnrankedMemRefCallback)
192
193
194# Test callback with a ranked memref.
195# CHECK-LABEL: TEST: testRankedMemRefCallback
196def testRankedMemRefCallback():
197  # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
198  @ctypes.CFUNCTYPE(
199      None,
200      ctypes.POINTER(
201          make_nd_memref_descriptor(2,
202                                    np.ctypeslib.as_ctypes_type(np.float32))),
203  )
204  def callback(a):
205    arr = ranked_memref_to_numpy(a)
206    log("Inside Callback: ")
207    log(arr)
208
209  with Context():
210    # The module just forwards to a runtime function known as "some_callback_into_python".
211    module = Module.parse(r"""
212func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
213  call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
214  return
215}
216func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
217""")
218    execution_engine = ExecutionEngine(lowerToLLVM(module))
219    execution_engine.register_runtime("some_callback_into_python", callback)
220    inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
221    # CHECK: Inside Callback:
222    # CHECK{LITERAL}: [[1. 5.]
223    # CHECK{LITERAL}:  [6. 7.]]
224    execution_engine.invoke(
225        "callback_memref",
226        ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))))
227
228
229run(testRankedMemRefCallback)
230
231
232#  Test addition of two memrefs.
233# CHECK-LABEL: TEST: testMemrefAdd
234def testMemrefAdd():
235  with Context():
236    module = Module.parse("""
237    module  {
238      func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
239        %0 = arith.constant 0 : index
240        %1 = memref.load %arg0[%0] : memref<1xf32>
241        %2 = memref.load %arg1[] : memref<f32>
242        %3 = arith.addf %1, %2 : f32
243        memref.store %3, %arg2[%0] : memref<1xf32>
244        return
245      }
246    } """)
247    arg1 = np.array([32.5]).astype(np.float32)
248    arg2 = np.array(6).astype(np.float32)
249    res = np.array([0]).astype(np.float32)
250
251    arg1_memref_ptr = ctypes.pointer(
252        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
253    arg2_memref_ptr = ctypes.pointer(
254        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
255    res_memref_ptr = ctypes.pointer(
256        ctypes.pointer(get_ranked_memref_descriptor(res)))
257
258    execution_engine = ExecutionEngine(lowerToLLVM(module))
259    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
260                            res_memref_ptr)
261    # CHECK: [32.5] + 6.0 = [38.5]
262    log("{0} + {1} = {2}".format(arg1, arg2, res))
263
264
265run(testMemrefAdd)
266
267
268# Test addition of two f16 memrefs
269# CHECK-LABEL: TEST: testF16MemrefAdd
270def testF16MemrefAdd():
271  with Context():
272    module = Module.parse("""
273    module  {
274      func.func @main(%arg0: memref<1xf16>,
275                      %arg1: memref<1xf16>,
276                      %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
277        %0 = arith.constant 0 : index
278        %1 = memref.load %arg0[%0] : memref<1xf16>
279        %2 = memref.load %arg1[%0] : memref<1xf16>
280        %3 = arith.addf %1, %2 : f16
281        memref.store %3, %arg2[%0] : memref<1xf16>
282        return
283      }
284    } """)
285
286    arg1 = np.array([11.]).astype(np.float16)
287    arg2 = np.array([22.]).astype(np.float16)
288    arg3 = np.array([0.]).astype(np.float16)
289
290    arg1_memref_ptr = ctypes.pointer(
291        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
292    arg2_memref_ptr = ctypes.pointer(
293        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
294    arg3_memref_ptr = ctypes.pointer(
295        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
296
297    execution_engine = ExecutionEngine(lowerToLLVM(module))
298    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
299                            arg3_memref_ptr)
300    # CHECK: [11.] + [22.] = [33.]
301    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
302
303    # test to-numpy utility
304    # CHECK: [33.]
305    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
306    log(npout)
307
308
309run(testF16MemrefAdd)
310
311
312# Test addition of two complex memrefs
313# CHECK-LABEL: TEST: testComplexMemrefAdd
314def testComplexMemrefAdd():
315  with Context():
316    module = Module.parse("""
317    module  {
318      func.func @main(%arg0: memref<1xcomplex<f64>>,
319                      %arg1: memref<1xcomplex<f64>>,
320                      %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
321        %0 = arith.constant 0 : index
322        %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
323        %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
324        %3 = complex.add %1, %2 : complex<f64>
325        memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
326        return
327      }
328    } """)
329
330    arg1 = np.array([1.+2.j]).astype(np.complex128)
331    arg2 = np.array([3.+4.j]).astype(np.complex128)
332    arg3  = np.array([0.+0.j]).astype(np.complex128)
333
334    arg1_memref_ptr = ctypes.pointer(
335        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
336    arg2_memref_ptr = ctypes.pointer(
337        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
338    arg3_memref_ptr = ctypes.pointer(
339        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
340
341    execution_engine = ExecutionEngine(lowerToLLVM(module))
342    execution_engine.invoke("main",
343                            arg1_memref_ptr,
344                            arg2_memref_ptr,
345                            arg3_memref_ptr)
346    # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
347    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
348
349    # test to-numpy utility
350    # CHECK: [4.+6.j]
351    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
352    log(npout)
353
354
355run(testComplexMemrefAdd)
356
357
358# Test addition of two complex unranked memrefs
359# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
360def testComplexUnrankedMemrefAdd():
361  with Context():
362    module = Module.parse("""
363    module  {
364      func.func @main(%arg0: memref<*xcomplex<f32>>,
365                      %arg1: memref<*xcomplex<f32>>,
366                      %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
367        %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
368        %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
369        %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
370        %0 = arith.constant 0 : index
371        %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
372        %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
373        %3 = complex.add %1, %2 : complex<f32>
374        memref.store %3, %C[%0] : memref<1xcomplex<f32>>
375        return
376      }
377    } """)
378
379    arg1 = np.array([5.+6.j]).astype(np.complex64)
380    arg2 = np.array([7.+8.j]).astype(np.complex64)
381    arg3  = np.array([0.+0.j]).astype(np.complex64)
382
383    arg1_memref_ptr = ctypes.pointer(
384        ctypes.pointer(get_unranked_memref_descriptor(arg1)))
385    arg2_memref_ptr = ctypes.pointer(
386        ctypes.pointer(get_unranked_memref_descriptor(arg2)))
387    arg3_memref_ptr = ctypes.pointer(
388        ctypes.pointer(get_unranked_memref_descriptor(arg3)))
389
390    execution_engine = ExecutionEngine(lowerToLLVM(module))
391    execution_engine.invoke("main",
392                            arg1_memref_ptr,
393                            arg2_memref_ptr,
394                            arg3_memref_ptr)
395    # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
396    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
397
398    # test to-numpy utility
399    # CHECK: [12.+14.j]
400    npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
401                                     np.dtype(np.complex64))
402    log(npout)
403
404
405run(testComplexUnrankedMemrefAdd)
406
407
408#  Test addition of two 2d_memref
409# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
410def testDynamicMemrefAdd2D():
411  with Context():
412    module = Module.parse("""
413      module  {
414        func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
415          %c0 = arith.constant 0 : index
416          %c2 = arith.constant 2 : index
417          %c1 = arith.constant 1 : index
418          cf.br ^bb1(%c0 : index)
419        ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
420          %1 = arith.cmpi slt, %0, %c2 : index
421          cf.cond_br %1, ^bb2, ^bb6
422        ^bb2:  // pred: ^bb1
423          %c0_0 = arith.constant 0 : index
424          %c2_1 = arith.constant 2 : index
425          %c1_2 = arith.constant 1 : index
426          cf.br ^bb3(%c0_0 : index)
427        ^bb3(%2: index):  // 2 preds: ^bb2, ^bb4
428          %3 = arith.cmpi slt, %2, %c2_1 : index
429          cf.cond_br %3, ^bb4, ^bb5
430        ^bb4:  // pred: ^bb3
431          %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
432          %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
433          %6 = arith.addf %4, %5 : f32
434          memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
435          %7 = arith.addi %2, %c1_2 : index
436          cf.br ^bb3(%7 : index)
437        ^bb5:  // pred: ^bb3
438          %8 = arith.addi %0, %c1 : index
439          cf.br ^bb1(%8 : index)
440        ^bb6:  // pred: ^bb1
441          return
442        }
443      }
444        """)
445    arg1 = np.random.randn(2, 2).astype(np.float32)
446    arg2 = np.random.randn(2, 2).astype(np.float32)
447    res = np.random.randn(2, 2).astype(np.float32)
448
449    arg1_memref_ptr = ctypes.pointer(
450        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
451    arg2_memref_ptr = ctypes.pointer(
452        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
453    res_memref_ptr = ctypes.pointer(
454        ctypes.pointer(get_ranked_memref_descriptor(res)))
455
456    execution_engine = ExecutionEngine(lowerToLLVM(module))
457    execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr,
458                            res_memref_ptr)
459    # CHECK: True
460    log(np.allclose(arg1 + arg2, res))
461
462
463run(testDynamicMemrefAdd2D)
464
465
466#  Test loading of shared libraries.
467# CHECK-LABEL: TEST: testSharedLibLoad
468def testSharedLibLoad():
469  with Context():
470    module = Module.parse("""
471      module  {
472      func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
473        %c0 = arith.constant 0 : index
474        %cst42 = arith.constant 42.0 : f32
475        memref.store %cst42, %arg0[%c0] : memref<1xf32>
476        %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
477        call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
478        return
479      }
480      func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
481     } """)
482    arg0 = np.array([0.0]).astype(np.float32)
483
484    arg0_memref_ptr = ctypes.pointer(
485        ctypes.pointer(get_ranked_memref_descriptor(arg0)))
486
487    if sys.platform == 'win32':
488      shared_libs = [
489          "../../../../bin/mlir_runner_utils.dll",
490          "../../../../bin/mlir_c_runner_utils.dll"
491      ]
492    elif sys.platform == 'darwin':
493      shared_libs = [
494          "../../../../lib/libmlir_runner_utils.dylib",
495          "../../../../lib/libmlir_c_runner_utils.dylib"
496      ]
497    else:
498      shared_libs = [
499          "../../../../lib/libmlir_runner_utils.so",
500          "../../../../lib/libmlir_c_runner_utils.so"
501      ]
502
503    execution_engine = ExecutionEngine(
504        lowerToLLVM(module),
505        opt_level=3,
506        shared_libs=shared_libs)
507    execution_engine.invoke("main", arg0_memref_ptr)
508    # CHECK: Unranked Memref
509    # CHECK-NEXT: [42]
510
511
512run(testSharedLibLoad)
513
514
515#  Test that nano time clock is available.
516# CHECK-LABEL: TEST: testNanoTime
517def testNanoTime():
518  with Context():
519    module = Module.parse("""
520      module {
521      func.func @main() attributes { llvm.emit_c_interface } {
522        %now = call @nanoTime() : () -> i64
523        %memref = memref.alloca() : memref<1xi64>
524        %c0 = arith.constant 0 : index
525        memref.store %now, %memref[%c0] : memref<1xi64>
526        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
527        call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
528        return
529      }
530      func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
531      func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
532    }""")
533
534    if sys.platform == 'win32':
535      shared_libs = [
536          "../../../../bin/mlir_runner_utils.dll",
537          "../../../../bin/mlir_c_runner_utils.dll"
538      ]
539    else:
540      shared_libs = [
541          "../../../../lib/libmlir_runner_utils.so",
542          "../../../../lib/libmlir_c_runner_utils.so"
543      ]
544
545    execution_engine = ExecutionEngine(
546        lowerToLLVM(module),
547        opt_level=3,
548        shared_libs=shared_libs)
549    execution_engine.invoke("main")
550    # CHECK: Unranked Memref
551    # CHECK: [{{.*}}]
552
553
554run(testNanoTime)
555