1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2# REQUIRES: native 3import gc, sys 4from mlir.ir import * 5from mlir.passmanager import * 6from mlir.execution_engine import * 7from mlir.runtime import * 8 9 10# Log everything to stderr and flush so that we have a unified stream to match 11# errors/info emitted by MLIR to stderr. 12def log(*args): 13 print(*args, file=sys.stderr) 14 sys.stderr.flush() 15 16 17def run(f): 18 log("\nTEST:", f.__name__) 19 f() 20 gc.collect() 21 assert Context._get_live_count() == 0 22 23 24# Verify capsule interop. 25# CHECK-LABEL: TEST: testCapsule 26def testCapsule(): 27 with Context(): 28 module = Module.parse(r""" 29llvm.func @none() { 30 llvm.return 31} 32 """) 33 execution_engine = ExecutionEngine(module) 34 execution_engine_capsule = execution_engine._CAPIPtr 35 # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr 36 log(repr(execution_engine_capsule)) 37 execution_engine._testing_release() 38 execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) 39 # CHECK: _mlirExecutionEngine.ExecutionEngine 40 log(repr(execution_engine1)) 41 42 43run(testCapsule) 44 45 46# Test invalid ExecutionEngine creation 47# CHECK-LABEL: TEST: testInvalidModule 48def testInvalidModule(): 49 with Context(): 50 # Builtin function 51 module = Module.parse(r""" 52 func.func @foo() { return } 53 """) 54 # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. 55 try: 56 execution_engine = ExecutionEngine(module) 57 except RuntimeError as e: 58 log("Got RuntimeError: ", e) 59 60 61run(testInvalidModule) 62 63 64def lowerToLLVM(module): 65 import mlir.conversions 66 pm = PassManager.parse( 67 "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") 68 pm.run(module) 69 return module 70 71 72# Test simple ExecutionEngine execution 73# CHECK-LABEL: TEST: testInvokeVoid 74def testInvokeVoid(): 75 with Context(): 76 module = Module.parse(r""" 77func.func @void() attributes { llvm.emit_c_interface } { 78 return 79} 80 """) 81 execution_engine = ExecutionEngine(lowerToLLVM(module)) 82 # Nothing to check other than no exception thrown here. 83 execution_engine.invoke("void") 84 85 86run(testInvokeVoid) 87 88 89# Test argument passing and result with a simple float addition. 90# CHECK-LABEL: TEST: testInvokeFloatAdd 91def testInvokeFloatAdd(): 92 with Context(): 93 module = Module.parse(r""" 94func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { 95 %add = arith.addf %arg0, %arg1 : f32 96 return %add : f32 97} 98 """) 99 execution_engine = ExecutionEngine(lowerToLLVM(module)) 100 # Prepare arguments: two input floats and one result. 101 # Arguments must be passed as pointers. 102 c_float_p = ctypes.c_float * 1 103 arg0 = c_float_p(42.) 104 arg1 = c_float_p(2.) 105 res = c_float_p(-1.) 106 execution_engine.invoke("add", arg0, arg1, res) 107 # CHECK: 42.0 + 2.0 = 44.0 108 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) 109 110 111run(testInvokeFloatAdd) 112 113 114# Test callback 115# CHECK-LABEL: TEST: testBasicCallback 116def testBasicCallback(): 117 # Define a callback function that takes a float and an integer and returns a float. 118 @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) 119 def callback(a, b): 120 return a / 2 + b / 2 121 122 with Context(): 123 # The module just forwards to a runtime function known as "some_callback_into_python". 124 module = Module.parse(r""" 125func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { 126 %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) 127 return %resf : f32 128} 129func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } 130 """) 131 execution_engine = ExecutionEngine(lowerToLLVM(module)) 132 execution_engine.register_runtime("some_callback_into_python", callback) 133 134 # Prepare arguments: two input floats and one result. 135 # Arguments must be passed as pointers. 136 c_float_p = ctypes.c_float * 1 137 c_int_p = ctypes.c_int * 1 138 arg0 = c_float_p(42.) 139 arg1 = c_int_p(2) 140 res = c_float_p(-1.) 141 execution_engine.invoke("add", arg0, arg1, res) 142 # CHECK: 42.0 + 2 = 44.0 143 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) 144 145 146run(testBasicCallback) 147 148 149# Test callback with an unranked memref 150# CHECK-LABEL: TEST: testUnrankedMemRefCallback 151def testUnrankedMemRefCallback(): 152 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 153 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 154 def callback(a): 155 arr = unranked_memref_to_numpy(a, np.float32) 156 log("Inside callback: ") 157 log(arr) 158 159 with Context(): 160 # The module just forwards to a runtime function known as "some_callback_into_python". 161 module = Module.parse(r""" 162func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { 163 call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () 164 return 165} 166func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } 167""") 168 execution_engine = ExecutionEngine(lowerToLLVM(module)) 169 execution_engine.register_runtime("some_callback_into_python", callback) 170 inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) 171 # CHECK: Inside callback: 172 # CHECK{LITERAL}: [[1. 2.] 173 # CHECK{LITERAL}: [3. 4.]] 174 execution_engine.invoke( 175 "callback_memref", 176 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), 177 ) 178 inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) 179 strided_arr = np.lib.stride_tricks.as_strided( 180 inp_arr_1, strides=(4, 0), shape=(3, 4)) 181 # CHECK: Inside callback: 182 # CHECK{LITERAL}: [[5. 5. 5. 5.] 183 # CHECK{LITERAL}: [6. 6. 6. 6.] 184 # CHECK{LITERAL}: [7. 7. 7. 7.]] 185 execution_engine.invoke( 186 "callback_memref", 187 ctypes.pointer( 188 ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), 189 ) 190 191 192run(testUnrankedMemRefCallback) 193 194 195# Test callback with a ranked memref. 196# CHECK-LABEL: TEST: testRankedMemRefCallback 197def testRankedMemRefCallback(): 198 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 199 @ctypes.CFUNCTYPE( 200 None, 201 ctypes.POINTER( 202 make_nd_memref_descriptor(2, 203 np.ctypeslib.as_ctypes_type(np.float32))), 204 ) 205 def callback(a): 206 arr = ranked_memref_to_numpy(a) 207 log("Inside Callback: ") 208 log(arr) 209 210 with Context(): 211 # The module just forwards to a runtime function known as "some_callback_into_python". 212 module = Module.parse(r""" 213func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { 214 call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () 215 return 216} 217func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } 218""") 219 execution_engine = ExecutionEngine(lowerToLLVM(module)) 220 execution_engine.register_runtime("some_callback_into_python", callback) 221 inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) 222 # CHECK: Inside Callback: 223 # CHECK{LITERAL}: [[1. 5.] 224 # CHECK{LITERAL}: [6. 7.]] 225 execution_engine.invoke( 226 "callback_memref", 227 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))) 228 229 230run(testRankedMemRefCallback) 231 232 233# Test addition of two memrefs. 234# CHECK-LABEL: TEST: testMemrefAdd 235def testMemrefAdd(): 236 with Context(): 237 module = Module.parse(""" 238 module { 239 func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { 240 %0 = arith.constant 0 : index 241 %1 = memref.load %arg0[%0] : memref<1xf32> 242 %2 = memref.load %arg1[] : memref<f32> 243 %3 = arith.addf %1, %2 : f32 244 memref.store %3, %arg2[%0] : memref<1xf32> 245 return 246 } 247 } """) 248 arg1 = np.array([32.5]).astype(np.float32) 249 arg2 = np.array(6).astype(np.float32) 250 res = np.array([0]).astype(np.float32) 251 252 arg1_memref_ptr = ctypes.pointer( 253 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 254 arg2_memref_ptr = ctypes.pointer( 255 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 256 res_memref_ptr = ctypes.pointer( 257 ctypes.pointer(get_ranked_memref_descriptor(res))) 258 259 execution_engine = ExecutionEngine(lowerToLLVM(module)) 260 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, 261 res_memref_ptr) 262 # CHECK: [32.5] + 6.0 = [38.5] 263 log("{0} + {1} = {2}".format(arg1, arg2, res)) 264 265 266run(testMemrefAdd) 267 268 269# Test addition of two f16 memrefs 270# CHECK-LABEL: TEST: testF16MemrefAdd 271def testF16MemrefAdd(): 272 with Context(): 273 module = Module.parse(""" 274 module { 275 func.func @main(%arg0: memref<1xf16>, 276 %arg1: memref<1xf16>, 277 %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { 278 %0 = arith.constant 0 : index 279 %1 = memref.load %arg0[%0] : memref<1xf16> 280 %2 = memref.load %arg1[%0] : memref<1xf16> 281 %3 = arith.addf %1, %2 : f16 282 memref.store %3, %arg2[%0] : memref<1xf16> 283 return 284 } 285 } """) 286 287 arg1 = np.array([11.]).astype(np.float16) 288 arg2 = np.array([22.]).astype(np.float16) 289 arg3 = np.array([0.]).astype(np.float16) 290 291 arg1_memref_ptr = ctypes.pointer( 292 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 293 arg2_memref_ptr = ctypes.pointer( 294 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 295 arg3_memref_ptr = ctypes.pointer( 296 ctypes.pointer(get_ranked_memref_descriptor(arg3))) 297 298 execution_engine = ExecutionEngine(lowerToLLVM(module)) 299 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, 300 arg3_memref_ptr) 301 # CHECK: [11.] + [22.] = [33.] 302 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 303 304 # test to-numpy utility 305 # CHECK: [33.] 306 npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 307 log(npout) 308 309 310run(testF16MemrefAdd) 311 312 313# Test addition of two complex memrefs 314# CHECK-LABEL: TEST: testComplexMemrefAdd 315def testComplexMemrefAdd(): 316 with Context(): 317 module = Module.parse(""" 318 module { 319 func.func @main(%arg0: memref<1xcomplex<f64>>, 320 %arg1: memref<1xcomplex<f64>>, 321 %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } { 322 %0 = arith.constant 0 : index 323 %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>> 324 %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>> 325 %3 = complex.add %1, %2 : complex<f64> 326 memref.store %3, %arg2[%0] : memref<1xcomplex<f64>> 327 return 328 } 329 } """) 330 331 arg1 = np.array([1.+2.j]).astype(np.complex128) 332 arg2 = np.array([3.+4.j]).astype(np.complex128) 333 arg3 = np.array([0.+0.j]).astype(np.complex128) 334 335 arg1_memref_ptr = ctypes.pointer( 336 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 337 arg2_memref_ptr = ctypes.pointer( 338 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 339 arg3_memref_ptr = ctypes.pointer( 340 ctypes.pointer(get_ranked_memref_descriptor(arg3))) 341 342 execution_engine = ExecutionEngine(lowerToLLVM(module)) 343 execution_engine.invoke("main", 344 arg1_memref_ptr, 345 arg2_memref_ptr, 346 arg3_memref_ptr) 347 # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] 348 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 349 350 # test to-numpy utility 351 # CHECK: [4.+6.j] 352 npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 353 log(npout) 354 355 356run(testComplexMemrefAdd) 357 358 359# Test addition of two complex unranked memrefs 360# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd 361def testComplexUnrankedMemrefAdd(): 362 with Context(): 363 module = Module.parse(""" 364 module { 365 func.func @main(%arg0: memref<*xcomplex<f32>>, 366 %arg1: memref<*xcomplex<f32>>, 367 %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } { 368 %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 369 %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 370 %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 371 %0 = arith.constant 0 : index 372 %1 = memref.load %A[%0] : memref<1xcomplex<f32>> 373 %2 = memref.load %B[%0] : memref<1xcomplex<f32>> 374 %3 = complex.add %1, %2 : complex<f32> 375 memref.store %3, %C[%0] : memref<1xcomplex<f32>> 376 return 377 } 378 } """) 379 380 arg1 = np.array([5.+6.j]).astype(np.complex64) 381 arg2 = np.array([7.+8.j]).astype(np.complex64) 382 arg3 = np.array([0.+0.j]).astype(np.complex64) 383 384 arg1_memref_ptr = ctypes.pointer( 385 ctypes.pointer(get_unranked_memref_descriptor(arg1))) 386 arg2_memref_ptr = ctypes.pointer( 387 ctypes.pointer(get_unranked_memref_descriptor(arg2))) 388 arg3_memref_ptr = ctypes.pointer( 389 ctypes.pointer(get_unranked_memref_descriptor(arg3))) 390 391 execution_engine = ExecutionEngine(lowerToLLVM(module)) 392 execution_engine.invoke("main", 393 arg1_memref_ptr, 394 arg2_memref_ptr, 395 arg3_memref_ptr) 396 # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] 397 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 398 399 # test to-numpy utility 400 # CHECK: [12.+14.j] 401 npout = unranked_memref_to_numpy(arg3_memref_ptr[0], 402 np.dtype(np.complex64)) 403 log(npout) 404 405 406run(testComplexUnrankedMemrefAdd) 407 408 409# Test addition of two 2d_memref 410# CHECK-LABEL: TEST: testDynamicMemrefAdd2D 411def testDynamicMemrefAdd2D(): 412 with Context(): 413 module = Module.parse(""" 414 module { 415 func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { 416 %c0 = arith.constant 0 : index 417 %c2 = arith.constant 2 : index 418 %c1 = arith.constant 1 : index 419 cf.br ^bb1(%c0 : index) 420 ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 421 %1 = arith.cmpi slt, %0, %c2 : index 422 cf.cond_br %1, ^bb2, ^bb6 423 ^bb2: // pred: ^bb1 424 %c0_0 = arith.constant 0 : index 425 %c2_1 = arith.constant 2 : index 426 %c1_2 = arith.constant 1 : index 427 cf.br ^bb3(%c0_0 : index) 428 ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 429 %3 = arith.cmpi slt, %2, %c2_1 : index 430 cf.cond_br %3, ^bb4, ^bb5 431 ^bb4: // pred: ^bb3 432 %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> 433 %5 = memref.load %arg1[%0, %2] : memref<?x?xf32> 434 %6 = arith.addf %4, %5 : f32 435 memref.store %6, %arg2[%0, %2] : memref<2x2xf32> 436 %7 = arith.addi %2, %c1_2 : index 437 cf.br ^bb3(%7 : index) 438 ^bb5: // pred: ^bb3 439 %8 = arith.addi %0, %c1 : index 440 cf.br ^bb1(%8 : index) 441 ^bb6: // pred: ^bb1 442 return 443 } 444 } 445 """) 446 arg1 = np.random.randn(2, 2).astype(np.float32) 447 arg2 = np.random.randn(2, 2).astype(np.float32) 448 res = np.random.randn(2, 2).astype(np.float32) 449 450 arg1_memref_ptr = ctypes.pointer( 451 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 452 arg2_memref_ptr = ctypes.pointer( 453 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 454 res_memref_ptr = ctypes.pointer( 455 ctypes.pointer(get_ranked_memref_descriptor(res))) 456 457 execution_engine = ExecutionEngine(lowerToLLVM(module)) 458 execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, 459 res_memref_ptr) 460 # CHECK: True 461 log(np.allclose(arg1 + arg2, res)) 462 463 464run(testDynamicMemrefAdd2D) 465 466 467# Test loading of shared libraries. 468# CHECK-LABEL: TEST: testSharedLibLoad 469def testSharedLibLoad(): 470 with Context(): 471 module = Module.parse(""" 472 module { 473 func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { 474 %c0 = arith.constant 0 : index 475 %cst42 = arith.constant 42.0 : f32 476 memref.store %cst42, %arg0[%c0] : memref<1xf32> 477 %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32> 478 call @printMemrefF32(%u_memref) : (memref<*xf32>) -> () 479 return 480 } 481 func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } 482 } """) 483 arg0 = np.array([0.0]).astype(np.float32) 484 485 arg0_memref_ptr = ctypes.pointer( 486 ctypes.pointer(get_ranked_memref_descriptor(arg0))) 487 488 if sys.platform == 'win32': 489 shared_libs = [ 490 "../../../../bin/mlir_runner_utils.dll", 491 "../../../../bin/mlir_c_runner_utils.dll" 492 ] 493 else: 494 shared_libs = [ 495 "../../../../lib/libmlir_runner_utils.so", 496 "../../../../lib/libmlir_c_runner_utils.so" 497 ] 498 499 execution_engine = ExecutionEngine( 500 lowerToLLVM(module), 501 opt_level=3, 502 shared_libs=shared_libs) 503 execution_engine.invoke("main", arg0_memref_ptr) 504 # CHECK: Unranked Memref 505 # CHECK-NEXT: [42] 506 507 508run(testSharedLibLoad) 509 510 511# Test that nano time clock is available. 512# CHECK-LABEL: TEST: testNanoTime 513def testNanoTime(): 514 with Context(): 515 module = Module.parse(""" 516 module { 517 func.func @main() attributes { llvm.emit_c_interface } { 518 %now = call @nanoTime() : () -> i64 519 %memref = memref.alloca() : memref<1xi64> 520 %c0 = arith.constant 0 : index 521 memref.store %now, %memref[%c0] : memref<1xi64> 522 %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64> 523 call @printMemrefI64(%u_memref) : (memref<*xi64>) -> () 524 return 525 } 526 func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } 527 func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } 528 }""") 529 530 if sys.platform == 'win32': 531 shared_libs = [ 532 "../../../../bin/mlir_runner_utils.dll", 533 "../../../../bin/mlir_c_runner_utils.dll" 534 ] 535 else: 536 shared_libs = [ 537 "../../../../lib/libmlir_runner_utils.so", 538 "../../../../lib/libmlir_c_runner_utils.so" 539 ] 540 541 execution_engine = ExecutionEngine( 542 lowerToLLVM(module), 543 opt_level=3, 544 shared_libs=shared_libs) 545 execution_engine.invoke("main") 546 # CHECK: Unranked Memref 547 # CHECK: [{{.*}}] 548 549 550run(testNanoTime) 551