1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2# REQUIRES: host-supports-jit 3import gc, sys 4from mlir.ir import * 5from mlir.passmanager import * 6from mlir.execution_engine import * 7from mlir.runtime import * 8 9 10# Log everything to stderr and flush so that we have a unified stream to match 11# errors/info emitted by MLIR to stderr. 12def log(*args): 13 print(*args, file=sys.stderr) 14 sys.stderr.flush() 15 16 17def run(f): 18 log("\nTEST:", f.__name__) 19 f() 20 gc.collect() 21 assert Context._get_live_count() == 0 22 23 24# Verify capsule interop. 25# CHECK-LABEL: TEST: testCapsule 26def testCapsule(): 27 with Context(): 28 module = Module.parse(r""" 29llvm.func @none() { 30 llvm.return 31} 32 """) 33 execution_engine = ExecutionEngine(module) 34 execution_engine_capsule = execution_engine._CAPIPtr 35 # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr 36 log(repr(execution_engine_capsule)) 37 execution_engine._testing_release() 38 execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) 39 # CHECK: _mlirExecutionEngine.ExecutionEngine 40 log(repr(execution_engine1)) 41 42 43run(testCapsule) 44 45 46# Test invalid ExecutionEngine creation 47# CHECK-LABEL: TEST: testInvalidModule 48def testInvalidModule(): 49 with Context(): 50 # Builtin function 51 module = Module.parse(r""" 52 func.func @foo() { return } 53 """) 54 # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. 55 try: 56 execution_engine = ExecutionEngine(module) 57 except RuntimeError as e: 58 log("Got RuntimeError: ", e) 59 60 61run(testInvalidModule) 62 63 64def lowerToLLVM(module): 65 pm = PassManager.parse( 66 "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") 67 pm.run(module) 68 return module 69 70 71# Test simple ExecutionEngine execution 72# CHECK-LABEL: TEST: testInvokeVoid 73def testInvokeVoid(): 74 with Context(): 75 module = Module.parse(r""" 76func.func @void() attributes { llvm.emit_c_interface } { 77 return 78} 79 """) 80 execution_engine = ExecutionEngine(lowerToLLVM(module)) 81 # Nothing to check other than no exception thrown here. 82 execution_engine.invoke("void") 83 84 85run(testInvokeVoid) 86 87 88# Test argument passing and result with a simple float addition. 89# CHECK-LABEL: TEST: testInvokeFloatAdd 90def testInvokeFloatAdd(): 91 with Context(): 92 module = Module.parse(r""" 93func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { 94 %add = arith.addf %arg0, %arg1 : f32 95 return %add : f32 96} 97 """) 98 execution_engine = ExecutionEngine(lowerToLLVM(module)) 99 # Prepare arguments: two input floats and one result. 100 # Arguments must be passed as pointers. 101 c_float_p = ctypes.c_float * 1 102 arg0 = c_float_p(42.) 103 arg1 = c_float_p(2.) 104 res = c_float_p(-1.) 105 execution_engine.invoke("add", arg0, arg1, res) 106 # CHECK: 42.0 + 2.0 = 44.0 107 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) 108 109 110run(testInvokeFloatAdd) 111 112 113# Test callback 114# CHECK-LABEL: TEST: testBasicCallback 115def testBasicCallback(): 116 # Define a callback function that takes a float and an integer and returns a float. 117 @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) 118 def callback(a, b): 119 return a / 2 + b / 2 120 121 with Context(): 122 # The module just forwards to a runtime function known as "some_callback_into_python". 123 module = Module.parse(r""" 124func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { 125 %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) 126 return %resf : f32 127} 128func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } 129 """) 130 execution_engine = ExecutionEngine(lowerToLLVM(module)) 131 execution_engine.register_runtime("some_callback_into_python", callback) 132 133 # Prepare arguments: two input floats and one result. 134 # Arguments must be passed as pointers. 135 c_float_p = ctypes.c_float * 1 136 c_int_p = ctypes.c_int * 1 137 arg0 = c_float_p(42.) 138 arg1 = c_int_p(2) 139 res = c_float_p(-1.) 140 execution_engine.invoke("add", arg0, arg1, res) 141 # CHECK: 42.0 + 2 = 44.0 142 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) 143 144 145run(testBasicCallback) 146 147 148# Test callback with an unranked memref 149# CHECK-LABEL: TEST: testUnrankedMemRefCallback 150def testUnrankedMemRefCallback(): 151 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 152 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 153 def callback(a): 154 arr = unranked_memref_to_numpy(a, np.float32) 155 log("Inside callback: ") 156 log(arr) 157 158 with Context(): 159 # The module just forwards to a runtime function known as "some_callback_into_python". 160 module = Module.parse(r""" 161func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { 162 call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () 163 return 164} 165func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } 166""") 167 execution_engine = ExecutionEngine(lowerToLLVM(module)) 168 execution_engine.register_runtime("some_callback_into_python", callback) 169 inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) 170 # CHECK: Inside callback: 171 # CHECK{LITERAL}: [[1. 2.] 172 # CHECK{LITERAL}: [3. 4.]] 173 execution_engine.invoke( 174 "callback_memref", 175 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), 176 ) 177 inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) 178 strided_arr = np.lib.stride_tricks.as_strided( 179 inp_arr_1, strides=(4, 0), shape=(3, 4)) 180 # CHECK: Inside callback: 181 # CHECK{LITERAL}: [[5. 5. 5. 5.] 182 # CHECK{LITERAL}: [6. 6. 6. 6.] 183 # CHECK{LITERAL}: [7. 7. 7. 7.]] 184 execution_engine.invoke( 185 "callback_memref", 186 ctypes.pointer( 187 ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), 188 ) 189 190 191run(testUnrankedMemRefCallback) 192 193 194# Test callback with a ranked memref. 195# CHECK-LABEL: TEST: testRankedMemRefCallback 196def testRankedMemRefCallback(): 197 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 198 @ctypes.CFUNCTYPE( 199 None, 200 ctypes.POINTER( 201 make_nd_memref_descriptor(2, 202 np.ctypeslib.as_ctypes_type(np.float32))), 203 ) 204 def callback(a): 205 arr = ranked_memref_to_numpy(a) 206 log("Inside Callback: ") 207 log(arr) 208 209 with Context(): 210 # The module just forwards to a runtime function known as "some_callback_into_python". 211 module = Module.parse(r""" 212func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { 213 call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () 214 return 215} 216func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } 217""") 218 execution_engine = ExecutionEngine(lowerToLLVM(module)) 219 execution_engine.register_runtime("some_callback_into_python", callback) 220 inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) 221 # CHECK: Inside Callback: 222 # CHECK{LITERAL}: [[1. 5.] 223 # CHECK{LITERAL}: [6. 7.]] 224 execution_engine.invoke( 225 "callback_memref", 226 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))) 227 228 229run(testRankedMemRefCallback) 230 231 232# Test addition of two memrefs. 233# CHECK-LABEL: TEST: testMemrefAdd 234def testMemrefAdd(): 235 with Context(): 236 module = Module.parse(""" 237 module { 238 func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { 239 %0 = arith.constant 0 : index 240 %1 = memref.load %arg0[%0] : memref<1xf32> 241 %2 = memref.load %arg1[] : memref<f32> 242 %3 = arith.addf %1, %2 : f32 243 memref.store %3, %arg2[%0] : memref<1xf32> 244 return 245 } 246 } """) 247 arg1 = np.array([32.5]).astype(np.float32) 248 arg2 = np.array(6).astype(np.float32) 249 res = np.array([0]).astype(np.float32) 250 251 arg1_memref_ptr = ctypes.pointer( 252 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 253 arg2_memref_ptr = ctypes.pointer( 254 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 255 res_memref_ptr = ctypes.pointer( 256 ctypes.pointer(get_ranked_memref_descriptor(res))) 257 258 execution_engine = ExecutionEngine(lowerToLLVM(module)) 259 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, 260 res_memref_ptr) 261 # CHECK: [32.5] + 6.0 = [38.5] 262 log("{0} + {1} = {2}".format(arg1, arg2, res)) 263 264 265run(testMemrefAdd) 266 267 268# Test addition of two f16 memrefs 269# CHECK-LABEL: TEST: testF16MemrefAdd 270def testF16MemrefAdd(): 271 with Context(): 272 module = Module.parse(""" 273 module { 274 func.func @main(%arg0: memref<1xf16>, 275 %arg1: memref<1xf16>, 276 %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { 277 %0 = arith.constant 0 : index 278 %1 = memref.load %arg0[%0] : memref<1xf16> 279 %2 = memref.load %arg1[%0] : memref<1xf16> 280 %3 = arith.addf %1, %2 : f16 281 memref.store %3, %arg2[%0] : memref<1xf16> 282 return 283 } 284 } """) 285 286 arg1 = np.array([11.]).astype(np.float16) 287 arg2 = np.array([22.]).astype(np.float16) 288 arg3 = np.array([0.]).astype(np.float16) 289 290 arg1_memref_ptr = ctypes.pointer( 291 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 292 arg2_memref_ptr = ctypes.pointer( 293 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 294 arg3_memref_ptr = ctypes.pointer( 295 ctypes.pointer(get_ranked_memref_descriptor(arg3))) 296 297 execution_engine = ExecutionEngine(lowerToLLVM(module)) 298 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, 299 arg3_memref_ptr) 300 # CHECK: [11.] + [22.] = [33.] 301 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 302 303 # test to-numpy utility 304 # CHECK: [33.] 305 npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 306 log(npout) 307 308 309run(testF16MemrefAdd) 310 311 312# Test addition of two complex memrefs 313# CHECK-LABEL: TEST: testComplexMemrefAdd 314def testComplexMemrefAdd(): 315 with Context(): 316 module = Module.parse(""" 317 module { 318 func.func @main(%arg0: memref<1xcomplex<f64>>, 319 %arg1: memref<1xcomplex<f64>>, 320 %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } { 321 %0 = arith.constant 0 : index 322 %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>> 323 %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>> 324 %3 = complex.add %1, %2 : complex<f64> 325 memref.store %3, %arg2[%0] : memref<1xcomplex<f64>> 326 return 327 } 328 } """) 329 330 arg1 = np.array([1.+2.j]).astype(np.complex128) 331 arg2 = np.array([3.+4.j]).astype(np.complex128) 332 arg3 = np.array([0.+0.j]).astype(np.complex128) 333 334 arg1_memref_ptr = ctypes.pointer( 335 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 336 arg2_memref_ptr = ctypes.pointer( 337 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 338 arg3_memref_ptr = ctypes.pointer( 339 ctypes.pointer(get_ranked_memref_descriptor(arg3))) 340 341 execution_engine = ExecutionEngine(lowerToLLVM(module)) 342 execution_engine.invoke("main", 343 arg1_memref_ptr, 344 arg2_memref_ptr, 345 arg3_memref_ptr) 346 # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] 347 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 348 349 # test to-numpy utility 350 # CHECK: [4.+6.j] 351 npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 352 log(npout) 353 354 355run(testComplexMemrefAdd) 356 357 358# Test addition of two complex unranked memrefs 359# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd 360def testComplexUnrankedMemrefAdd(): 361 with Context(): 362 module = Module.parse(""" 363 module { 364 func.func @main(%arg0: memref<*xcomplex<f32>>, 365 %arg1: memref<*xcomplex<f32>>, 366 %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } { 367 %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 368 %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 369 %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 370 %0 = arith.constant 0 : index 371 %1 = memref.load %A[%0] : memref<1xcomplex<f32>> 372 %2 = memref.load %B[%0] : memref<1xcomplex<f32>> 373 %3 = complex.add %1, %2 : complex<f32> 374 memref.store %3, %C[%0] : memref<1xcomplex<f32>> 375 return 376 } 377 } """) 378 379 arg1 = np.array([5.+6.j]).astype(np.complex64) 380 arg2 = np.array([7.+8.j]).astype(np.complex64) 381 arg3 = np.array([0.+0.j]).astype(np.complex64) 382 383 arg1_memref_ptr = ctypes.pointer( 384 ctypes.pointer(get_unranked_memref_descriptor(arg1))) 385 arg2_memref_ptr = ctypes.pointer( 386 ctypes.pointer(get_unranked_memref_descriptor(arg2))) 387 arg3_memref_ptr = ctypes.pointer( 388 ctypes.pointer(get_unranked_memref_descriptor(arg3))) 389 390 execution_engine = ExecutionEngine(lowerToLLVM(module)) 391 execution_engine.invoke("main", 392 arg1_memref_ptr, 393 arg2_memref_ptr, 394 arg3_memref_ptr) 395 # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] 396 log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 397 398 # test to-numpy utility 399 # CHECK: [12.+14.j] 400 npout = unranked_memref_to_numpy(arg3_memref_ptr[0], 401 np.dtype(np.complex64)) 402 log(npout) 403 404 405run(testComplexUnrankedMemrefAdd) 406 407 408# Test addition of two 2d_memref 409# CHECK-LABEL: TEST: testDynamicMemrefAdd2D 410def testDynamicMemrefAdd2D(): 411 with Context(): 412 module = Module.parse(""" 413 module { 414 func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { 415 %c0 = arith.constant 0 : index 416 %c2 = arith.constant 2 : index 417 %c1 = arith.constant 1 : index 418 cf.br ^bb1(%c0 : index) 419 ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 420 %1 = arith.cmpi slt, %0, %c2 : index 421 cf.cond_br %1, ^bb2, ^bb6 422 ^bb2: // pred: ^bb1 423 %c0_0 = arith.constant 0 : index 424 %c2_1 = arith.constant 2 : index 425 %c1_2 = arith.constant 1 : index 426 cf.br ^bb3(%c0_0 : index) 427 ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 428 %3 = arith.cmpi slt, %2, %c2_1 : index 429 cf.cond_br %3, ^bb4, ^bb5 430 ^bb4: // pred: ^bb3 431 %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> 432 %5 = memref.load %arg1[%0, %2] : memref<?x?xf32> 433 %6 = arith.addf %4, %5 : f32 434 memref.store %6, %arg2[%0, %2] : memref<2x2xf32> 435 %7 = arith.addi %2, %c1_2 : index 436 cf.br ^bb3(%7 : index) 437 ^bb5: // pred: ^bb3 438 %8 = arith.addi %0, %c1 : index 439 cf.br ^bb1(%8 : index) 440 ^bb6: // pred: ^bb1 441 return 442 } 443 } 444 """) 445 arg1 = np.random.randn(2, 2).astype(np.float32) 446 arg2 = np.random.randn(2, 2).astype(np.float32) 447 res = np.random.randn(2, 2).astype(np.float32) 448 449 arg1_memref_ptr = ctypes.pointer( 450 ctypes.pointer(get_ranked_memref_descriptor(arg1))) 451 arg2_memref_ptr = ctypes.pointer( 452 ctypes.pointer(get_ranked_memref_descriptor(arg2))) 453 res_memref_ptr = ctypes.pointer( 454 ctypes.pointer(get_ranked_memref_descriptor(res))) 455 456 execution_engine = ExecutionEngine(lowerToLLVM(module)) 457 execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, 458 res_memref_ptr) 459 # CHECK: True 460 log(np.allclose(arg1 + arg2, res)) 461 462 463run(testDynamicMemrefAdd2D) 464 465 466# Test loading of shared libraries. 467# CHECK-LABEL: TEST: testSharedLibLoad 468def testSharedLibLoad(): 469 with Context(): 470 module = Module.parse(""" 471 module { 472 func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { 473 %c0 = arith.constant 0 : index 474 %cst42 = arith.constant 42.0 : f32 475 memref.store %cst42, %arg0[%c0] : memref<1xf32> 476 %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32> 477 call @printMemrefF32(%u_memref) : (memref<*xf32>) -> () 478 return 479 } 480 func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } 481 } """) 482 arg0 = np.array([0.0]).astype(np.float32) 483 484 arg0_memref_ptr = ctypes.pointer( 485 ctypes.pointer(get_ranked_memref_descriptor(arg0))) 486 487 if sys.platform == 'win32': 488 shared_libs = [ 489 "../../../../bin/mlir_runner_utils.dll", 490 "../../../../bin/mlir_c_runner_utils.dll" 491 ] 492 elif sys.platform == 'darwin': 493 shared_libs = [ 494 "../../../../lib/libmlir_runner_utils.dylib", 495 "../../../../lib/libmlir_c_runner_utils.dylib" 496 ] 497 else: 498 shared_libs = [ 499 "../../../../lib/libmlir_runner_utils.so", 500 "../../../../lib/libmlir_c_runner_utils.so" 501 ] 502 503 execution_engine = ExecutionEngine( 504 lowerToLLVM(module), 505 opt_level=3, 506 shared_libs=shared_libs) 507 execution_engine.invoke("main", arg0_memref_ptr) 508 # CHECK: Unranked Memref 509 # CHECK-NEXT: [42] 510 511 512run(testSharedLibLoad) 513 514 515# Test that nano time clock is available. 516# CHECK-LABEL: TEST: testNanoTime 517def testNanoTime(): 518 with Context(): 519 module = Module.parse(""" 520 module { 521 func.func @main() attributes { llvm.emit_c_interface } { 522 %now = call @nanoTime() : () -> i64 523 %memref = memref.alloca() : memref<1xi64> 524 %c0 = arith.constant 0 : index 525 memref.store %now, %memref[%c0] : memref<1xi64> 526 %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64> 527 call @printMemrefI64(%u_memref) : (memref<*xi64>) -> () 528 return 529 } 530 func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } 531 func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } 532 }""") 533 534 if sys.platform == 'win32': 535 shared_libs = [ 536 "../../../../bin/mlir_runner_utils.dll", 537 "../../../../bin/mlir_c_runner_utils.dll" 538 ] 539 else: 540 shared_libs = [ 541 "../../../../lib/libmlir_runner_utils.so", 542 "../../../../lib/libmlir_c_runner_utils.so" 543 ] 544 545 execution_engine = ExecutionEngine( 546 lowerToLLVM(module), 547 opt_level=3, 548 shared_libs=shared_libs) 549 execution_engine.invoke("main") 550 # CHECK: Unranked Memref 551 # CHECK: [{{.*}}] 552 553 554run(testNanoTime) 555