1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import gc, sys 4from mlir.ir import * 5from mlir.passmanager import * 6from mlir.execution_engine import * 7from mlir.runtime import * 8 9# Log everything to stderr and flush so that we have a unified stream to match 10# errors/info emitted by MLIR to stderr. 11def log(*args): 12 print(*args, file=sys.stderr) 13 sys.stderr.flush() 14 15def run(f): 16 log("\nTEST:", f.__name__) 17 f() 18 gc.collect() 19 assert Context._get_live_count() == 0 20 21# Verify capsule interop. 22# CHECK-LABEL: TEST: testCapsule 23def testCapsule(): 24 with Context(): 25 module = Module.parse(r""" 26llvm.func @none() { 27 llvm.return 28} 29 """) 30 execution_engine = ExecutionEngine(module) 31 execution_engine_capsule = execution_engine._CAPIPtr 32 # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr 33 log(repr(execution_engine_capsule)) 34 execution_engine._testing_release() 35 execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) 36 # CHECK: _mlir.execution_engine.ExecutionEngine 37 log(repr(execution_engine1)) 38 39run(testCapsule) 40 41# Test invalid ExecutionEngine creation 42# CHECK-LABEL: TEST: testInvalidModule 43def testInvalidModule(): 44 with Context(): 45 # Builtin function 46 module = Module.parse(r""" 47 func @foo() { return } 48 """) 49 # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. 50 try: 51 execution_engine = ExecutionEngine(module) 52 except RuntimeError as e: 53 log("Got RuntimeError: ", e) 54 55run(testInvalidModule) 56 57def lowerToLLVM(module): 58 import mlir.conversions 59 pm = PassManager.parse("convert-std-to-llvm") 60 pm.run(module) 61 return module 62 63# Test simple ExecutionEngine execution 64# CHECK-LABEL: TEST: testInvokeVoid 65def testInvokeVoid(): 66 with Context(): 67 module = Module.parse(r""" 68func @void() attributes { llvm.emit_c_interface } { 69 return 70} 71 """) 72 execution_engine = ExecutionEngine(lowerToLLVM(module)) 73 # Nothing to check other than no exception thrown here. 74 execution_engine.invoke("void") 75 76run(testInvokeVoid) 77 78 79# Test argument passing and result with a simple float addition. 80# CHECK-LABEL: TEST: testInvokeFloatAdd 81def testInvokeFloatAdd(): 82 with Context(): 83 module = Module.parse(r""" 84func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { 85 %add = std.addf %arg0, %arg1 : f32 86 return %add : f32 87} 88 """) 89 execution_engine = ExecutionEngine(lowerToLLVM(module)) 90 # Prepare arguments: two input floats and one result. 91 # Arguments must be passed as pointers. 92 c_float_p = ctypes.c_float * 1 93 arg0 = c_float_p(42.) 94 arg1 = c_float_p(2.) 95 res = c_float_p(-1.) 96 execution_engine.invoke("add", arg0, arg1, res) 97 # CHECK: 42.0 + 2.0 = 44.0 98 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) 99 100run(testInvokeFloatAdd) 101 102 103# Test callback 104# CHECK-LABEL: TEST: testBasicCallback 105def testBasicCallback(): 106 # Define a callback function that takes a float and an integer and returns a float. 107 @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) 108 def callback(a, b): 109 return a/2 + b/2 110 111 with Context(): 112 # The module just forwards to a runtime function known as "some_callback_into_python". 113 module = Module.parse(r""" 114func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { 115 %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) 116 return %resf : f32 117} 118func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } 119 """) 120 execution_engine = ExecutionEngine(lowerToLLVM(module)) 121 execution_engine.register_runtime("some_callback_into_python", callback) 122 123 # Prepare arguments: two input floats and one result. 124 # Arguments must be passed as pointers. 125 c_float_p = ctypes.c_float * 1 126 c_int_p = ctypes.c_int * 1 127 arg0 = c_float_p(42.) 128 arg1 = c_int_p(2) 129 res = c_float_p(-1.) 130 execution_engine.invoke("add", arg0, arg1, res) 131 # CHECK: 42.0 + 2 = 44.0 132 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2)) 133 134run(testBasicCallback) 135 136# Test callback with an unranked memref 137# CHECK-LABEL: TEST: testUnrankedMemRefCallback 138def testUnrankedMemRefCallback(): 139 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 140 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 141 def callback(a): 142 arr = unranked_memref_to_numpy(a, np.float32) 143 log("Inside callback: ") 144 log(arr) 145 146 with Context(): 147 # The module just forwards to a runtime function known as "some_callback_into_python". 148 module = Module.parse( 149 r""" 150func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { 151 call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () 152 return 153} 154func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } 155""" 156 ) 157 execution_engine = ExecutionEngine(lowerToLLVM(module)) 158 execution_engine.register_runtime("some_callback_into_python", callback) 159 inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) 160 # CHECK: Inside callback: 161 # CHECK{LITERAL}: [[1. 2.] 162 # CHECK{LITERAL}: [3. 4.]] 163 execution_engine.invoke( 164 "callback_memref", 165 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), 166 ) 167 inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) 168 strided_arr = np.lib.stride_tricks.as_strided( 169 inp_arr_1, strides=(4, 0), shape=(3, 4) 170 ) 171 # CHECK: Inside callback: 172 # CHECK{LITERAL}: [[5. 5. 5. 5.] 173 # CHECK{LITERAL}: [6. 6. 6. 6.] 174 # CHECK{LITERAL}: [7. 7. 7. 7.]] 175 execution_engine.invoke( 176 "callback_memref", 177 ctypes.pointer( 178 ctypes.pointer(get_unranked_memref_descriptor(strided_arr)) 179 ), 180 ) 181 182run(testUnrankedMemRefCallback) 183 184# Test callback with a ranked memref. 185# CHECK-LABEL: TEST: testRankedMemRefCallback 186def testRankedMemRefCallback(): 187 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 188 @ctypes.CFUNCTYPE( 189 None, 190 ctypes.POINTER( 191 make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32)) 192 ), 193 ) 194 def callback(a): 195 arr = ranked_memref_to_numpy(a) 196 log("Inside Callback: ") 197 log(arr) 198 199 with Context(): 200 # The module just forwards to a runtime function known as "some_callback_into_python". 201 module = Module.parse( 202 r""" 203func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { 204 call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () 205 return 206} 207func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } 208""" 209 ) 210 execution_engine = ExecutionEngine(lowerToLLVM(module)) 211 execution_engine.register_runtime("some_callback_into_python", callback) 212 inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) 213 # CHECK: Inside Callback: 214 # CHECK{LITERAL}: [[1. 5.] 215 # CHECK{LITERAL}: [6. 7.]] 216 execution_engine.invoke( 217 "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))) 218 ) 219 220run(testRankedMemRefCallback) 221 222# Test addition of two memref 223# CHECK-LABEL: TEST: testMemrefAdd 224def testMemrefAdd(): 225 with Context(): 226 module = Module.parse( 227 """ 228 module { 229 func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { 230 %0 = constant 0 : index 231 %1 = memref.load %arg0[%0] : memref<1xf32> 232 %2 = memref.load %arg1[] : memref<f32> 233 %3 = addf %1, %2 : f32 234 memref.store %3, %arg2[%0] : memref<1xf32> 235 return 236 } 237 } """ 238 ) 239 arg1 = np.array([32.5]).astype(np.float32) 240 arg2 = np.array(6).astype(np.float32) 241 res = np.array([0]).astype(np.float32) 242 243 arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1))) 244 arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2))) 245 res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res))) 246 247 execution_engine = ExecutionEngine(lowerToLLVM(module)) 248 execution_engine.invoke( 249 "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr 250 ) 251 # CHECK: [32.5] + 6.0 = [38.5] 252 log("{0} + {1} = {2}".format(arg1, arg2, res)) 253 254run(testMemrefAdd) 255 256# Test addition of two 2d_memref 257# CHECK-LABEL: TEST: testDynamicMemrefAdd2D 258def testDynamicMemrefAdd2D(): 259 with Context(): 260 module = Module.parse( 261 """ 262 module { 263 func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { 264 %c0 = constant 0 : index 265 %c2 = constant 2 : index 266 %c1 = constant 1 : index 267 br ^bb1(%c0 : index) 268 ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 269 %1 = cmpi slt, %0, %c2 : index 270 cond_br %1, ^bb2, ^bb6 271 ^bb2: // pred: ^bb1 272 %c0_0 = constant 0 : index 273 %c2_1 = constant 2 : index 274 %c1_2 = constant 1 : index 275 br ^bb3(%c0_0 : index) 276 ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 277 %3 = cmpi slt, %2, %c2_1 : index 278 cond_br %3, ^bb4, ^bb5 279 ^bb4: // pred: ^bb3 280 %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> 281 %5 = memref.load %arg1[%0, %2] : memref<?x?xf32> 282 %6 = addf %4, %5 : f32 283 memref.store %6, %arg2[%0, %2] : memref<2x2xf32> 284 %7 = addi %2, %c1_2 : index 285 br ^bb3(%7 : index) 286 ^bb5: // pred: ^bb3 287 %8 = addi %0, %c1 : index 288 br ^bb1(%8 : index) 289 ^bb6: // pred: ^bb1 290 return 291 } 292 } 293 """ 294 ) 295 arg1 = np.random.randn(2,2).astype(np.float32) 296 arg2 = np.random.randn(2,2).astype(np.float32) 297 res = np.random.randn(2,2).astype(np.float32) 298 299 arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1))) 300 arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2))) 301 res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res))) 302 303 execution_engine = ExecutionEngine(lowerToLLVM(module)) 304 execution_engine.invoke( 305 "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr 306 ) 307 # CHECK: True 308 log(np.allclose(arg1+arg2, res)) 309 310run(testDynamicMemrefAdd2D) 311