19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s 2>&1 | FileCheck %s 2*b630bafbSRainer Orth# REQUIRES: host-supports-jit 39f3f6d7bSStella Laurenzoimport gc, sys 49f3f6d7bSStella Laurenzofrom mlir.ir import * 59f3f6d7bSStella Laurenzofrom mlir.passmanager import * 69f3f6d7bSStella Laurenzofrom mlir.execution_engine import * 79f3f6d7bSStella Laurenzofrom mlir.runtime import * 89f3f6d7bSStella Laurenzo 9a54f4eaeSMogball 109f3f6d7bSStella Laurenzo# Log everything to stderr and flush so that we have a unified stream to match 119f3f6d7bSStella Laurenzo# errors/info emitted by MLIR to stderr. 129f3f6d7bSStella Laurenzodef log(*args): 139f3f6d7bSStella Laurenzo print(*args, file=sys.stderr) 149f3f6d7bSStella Laurenzo sys.stderr.flush() 159f3f6d7bSStella Laurenzo 16a54f4eaeSMogball 179f3f6d7bSStella Laurenzodef run(f): 189f3f6d7bSStella Laurenzo log("\nTEST:", f.__name__) 199f3f6d7bSStella Laurenzo f() 209f3f6d7bSStella Laurenzo gc.collect() 219f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 229f3f6d7bSStella Laurenzo 23a54f4eaeSMogball 249f3f6d7bSStella Laurenzo# Verify capsule interop. 259f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCapsule 269f3f6d7bSStella Laurenzodef testCapsule(): 279f3f6d7bSStella Laurenzo with Context(): 289f3f6d7bSStella Laurenzo module = Module.parse(r""" 299f3f6d7bSStella Laurenzollvm.func @none() { 309f3f6d7bSStella Laurenzo llvm.return 319f3f6d7bSStella Laurenzo} 329f3f6d7bSStella Laurenzo """) 339f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(module) 349f3f6d7bSStella Laurenzo execution_engine_capsule = execution_engine._CAPIPtr 359f3f6d7bSStella Laurenzo # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr 369f3f6d7bSStella Laurenzo log(repr(execution_engine_capsule)) 379f3f6d7bSStella Laurenzo execution_engine._testing_release() 389f3f6d7bSStella Laurenzo execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) 390cdf4915SStella Laurenzo # CHECK: _mlirExecutionEngine.ExecutionEngine 409f3f6d7bSStella Laurenzo log(repr(execution_engine1)) 419f3f6d7bSStella Laurenzo 42a54f4eaeSMogball 439f3f6d7bSStella Laurenzorun(testCapsule) 449f3f6d7bSStella Laurenzo 45a54f4eaeSMogball 469f3f6d7bSStella Laurenzo# Test invalid ExecutionEngine creation 479f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModule 489f3f6d7bSStella Laurenzodef testInvalidModule(): 499f3f6d7bSStella Laurenzo with Context(): 509f3f6d7bSStella Laurenzo # Builtin function 519f3f6d7bSStella Laurenzo module = Module.parse(r""" 522310ced8SRiver Riddle func.func @foo() { return } 539f3f6d7bSStella Laurenzo """) 549f3f6d7bSStella Laurenzo # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. 559f3f6d7bSStella Laurenzo try: 569f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(module) 579f3f6d7bSStella Laurenzo except RuntimeError as e: 589f3f6d7bSStella Laurenzo log("Got RuntimeError: ", e) 599f3f6d7bSStella Laurenzo 60a54f4eaeSMogball 619f3f6d7bSStella Laurenzorun(testInvalidModule) 629f3f6d7bSStella Laurenzo 63a54f4eaeSMogball 649f3f6d7bSStella Laurenzodef lowerToLLVM(module): 65a54f4eaeSMogball pm = PassManager.parse( 66d6682189SAart Bik "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") 679f3f6d7bSStella Laurenzo pm.run(module) 689f3f6d7bSStella Laurenzo return module 699f3f6d7bSStella Laurenzo 70a54f4eaeSMogball 719f3f6d7bSStella Laurenzo# Test simple ExecutionEngine execution 729f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeVoid 739f3f6d7bSStella Laurenzodef testInvokeVoid(): 749f3f6d7bSStella Laurenzo with Context(): 759f3f6d7bSStella Laurenzo module = Module.parse(r""" 762310ced8SRiver Riddlefunc.func @void() attributes { llvm.emit_c_interface } { 779f3f6d7bSStella Laurenzo return 789f3f6d7bSStella Laurenzo} 799f3f6d7bSStella Laurenzo """) 809f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 819f3f6d7bSStella Laurenzo # Nothing to check other than no exception thrown here. 829f3f6d7bSStella Laurenzo execution_engine.invoke("void") 839f3f6d7bSStella Laurenzo 84a54f4eaeSMogball 859f3f6d7bSStella Laurenzorun(testInvokeVoid) 869f3f6d7bSStella Laurenzo 879f3f6d7bSStella Laurenzo 889f3f6d7bSStella Laurenzo# Test argument passing and result with a simple float addition. 899f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testInvokeFloatAdd 909f3f6d7bSStella Laurenzodef testInvokeFloatAdd(): 919f3f6d7bSStella Laurenzo with Context(): 929f3f6d7bSStella Laurenzo module = Module.parse(r""" 932310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { 94a54f4eaeSMogball %add = arith.addf %arg0, %arg1 : f32 959f3f6d7bSStella Laurenzo return %add : f32 969f3f6d7bSStella Laurenzo} 979f3f6d7bSStella Laurenzo """) 989f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 999f3f6d7bSStella Laurenzo # Prepare arguments: two input floats and one result. 1009f3f6d7bSStella Laurenzo # Arguments must be passed as pointers. 1019f3f6d7bSStella Laurenzo c_float_p = ctypes.c_float * 1 1029f3f6d7bSStella Laurenzo arg0 = c_float_p(42.) 1039f3f6d7bSStella Laurenzo arg1 = c_float_p(2.) 1049f3f6d7bSStella Laurenzo res = c_float_p(-1.) 1059f3f6d7bSStella Laurenzo execution_engine.invoke("add", arg0, arg1, res) 1069f3f6d7bSStella Laurenzo # CHECK: 42.0 + 2.0 = 44.0 1079f3f6d7bSStella Laurenzo log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) 1089f3f6d7bSStella Laurenzo 109a54f4eaeSMogball 1109f3f6d7bSStella Laurenzorun(testInvokeFloatAdd) 1119f3f6d7bSStella Laurenzo 1129f3f6d7bSStella Laurenzo 1139f3f6d7bSStella Laurenzo# Test callback 1149f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBasicCallback 1159f3f6d7bSStella Laurenzodef testBasicCallback(): 1169f3f6d7bSStella Laurenzo # Define a callback function that takes a float and an integer and returns a float. 1179f3f6d7bSStella Laurenzo @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) 1189f3f6d7bSStella Laurenzo def callback(a, b): 1199f3f6d7bSStella Laurenzo return a / 2 + b / 2 1209f3f6d7bSStella Laurenzo 1219f3f6d7bSStella Laurenzo with Context(): 1229f3f6d7bSStella Laurenzo # The module just forwards to a runtime function known as "some_callback_into_python". 1239f3f6d7bSStella Laurenzo module = Module.parse(r""" 1242310ced8SRiver Riddlefunc.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { 1259f3f6d7bSStella Laurenzo %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) 1269f3f6d7bSStella Laurenzo return %resf : f32 1279f3f6d7bSStella Laurenzo} 1282310ced8SRiver Riddlefunc.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } 1299f3f6d7bSStella Laurenzo """) 1309f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 1319f3f6d7bSStella Laurenzo execution_engine.register_runtime("some_callback_into_python", callback) 1329f3f6d7bSStella Laurenzo 1339f3f6d7bSStella Laurenzo # Prepare arguments: two input floats and one result. 1349f3f6d7bSStella Laurenzo # Arguments must be passed as pointers. 1359f3f6d7bSStella Laurenzo c_float_p = ctypes.c_float * 1 1369f3f6d7bSStella Laurenzo c_int_p = ctypes.c_int * 1 1379f3f6d7bSStella Laurenzo arg0 = c_float_p(42.) 1389f3f6d7bSStella Laurenzo arg1 = c_int_p(2) 1399f3f6d7bSStella Laurenzo res = c_float_p(-1.) 1409f3f6d7bSStella Laurenzo execution_engine.invoke("add", arg0, arg1, res) 1419f3f6d7bSStella Laurenzo # CHECK: 42.0 + 2 = 44.0 1429f3f6d7bSStella Laurenzo log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) 1439f3f6d7bSStella Laurenzo 144a54f4eaeSMogball 1459f3f6d7bSStella Laurenzorun(testBasicCallback) 1469f3f6d7bSStella Laurenzo 147a54f4eaeSMogball 1489f3f6d7bSStella Laurenzo# Test callback with an unranked memref 1499f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testUnrankedMemRefCallback 1509f3f6d7bSStella Laurenzodef testUnrankedMemRefCallback(): 1519f3f6d7bSStella Laurenzo # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. 1529f3f6d7bSStella Laurenzo @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) 1539f3f6d7bSStella Laurenzo def callback(a): 1549f3f6d7bSStella Laurenzo arr = unranked_memref_to_numpy(a, np.float32) 1559f3f6d7bSStella Laurenzo log("Inside callback: ") 1569f3f6d7bSStella Laurenzo log(arr) 1579f3f6d7bSStella Laurenzo 1589f3f6d7bSStella Laurenzo with Context(): 1599f3f6d7bSStella Laurenzo # The module just forwards to a runtime function known as "some_callback_into_python". 160a54f4eaeSMogball module = Module.parse(r""" 1612310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { 1629f3f6d7bSStella Laurenzo call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () 1639f3f6d7bSStella Laurenzo return 1649f3f6d7bSStella Laurenzo} 1652310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } 166a54f4eaeSMogball""") 1679f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 1689f3f6d7bSStella Laurenzo execution_engine.register_runtime("some_callback_into_python", callback) 1699f3f6d7bSStella Laurenzo inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) 1709f3f6d7bSStella Laurenzo # CHECK: Inside callback: 1719f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [[1. 2.] 1729f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [3. 4.]] 1739f3f6d7bSStella Laurenzo execution_engine.invoke( 1749f3f6d7bSStella Laurenzo "callback_memref", 1759f3f6d7bSStella Laurenzo ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), 1769f3f6d7bSStella Laurenzo ) 1779f3f6d7bSStella Laurenzo inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) 1789f3f6d7bSStella Laurenzo strided_arr = np.lib.stride_tricks.as_strided( 179a54f4eaeSMogball inp_arr_1, strides=(4, 0), shape=(3, 4)) 1809f3f6d7bSStella Laurenzo # CHECK: Inside callback: 1819f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [[5. 5. 5. 5.] 1829f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [6. 6. 6. 6.] 1839f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [7. 7. 7. 7.]] 1849f3f6d7bSStella Laurenzo execution_engine.invoke( 1859f3f6d7bSStella Laurenzo "callback_memref", 1869f3f6d7bSStella Laurenzo ctypes.pointer( 187a54f4eaeSMogball ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), 1889f3f6d7bSStella Laurenzo ) 1899f3f6d7bSStella Laurenzo 190a54f4eaeSMogball 1919f3f6d7bSStella Laurenzorun(testUnrankedMemRefCallback) 1929f3f6d7bSStella Laurenzo 193a54f4eaeSMogball 1949f3f6d7bSStella Laurenzo# Test callback with a ranked memref. 1959f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testRankedMemRefCallback 1969f3f6d7bSStella Laurenzodef testRankedMemRefCallback(): 1979f3f6d7bSStella Laurenzo # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. 1989f3f6d7bSStella Laurenzo @ctypes.CFUNCTYPE( 1999f3f6d7bSStella Laurenzo None, 2009f3f6d7bSStella Laurenzo ctypes.POINTER( 201a54f4eaeSMogball make_nd_memref_descriptor(2, 202a54f4eaeSMogball np.ctypeslib.as_ctypes_type(np.float32))), 2039f3f6d7bSStella Laurenzo ) 2049f3f6d7bSStella Laurenzo def callback(a): 2059f3f6d7bSStella Laurenzo arr = ranked_memref_to_numpy(a) 2069f3f6d7bSStella Laurenzo log("Inside Callback: ") 2079f3f6d7bSStella Laurenzo log(arr) 2089f3f6d7bSStella Laurenzo 2099f3f6d7bSStella Laurenzo with Context(): 2109f3f6d7bSStella Laurenzo # The module just forwards to a runtime function known as "some_callback_into_python". 211a54f4eaeSMogball module = Module.parse(r""" 2122310ced8SRiver Riddlefunc.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { 2139f3f6d7bSStella Laurenzo call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () 2149f3f6d7bSStella Laurenzo return 2159f3f6d7bSStella Laurenzo} 2162310ced8SRiver Riddlefunc.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } 217a54f4eaeSMogball""") 2189f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 2199f3f6d7bSStella Laurenzo execution_engine.register_runtime("some_callback_into_python", callback) 2209f3f6d7bSStella Laurenzo inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) 2219f3f6d7bSStella Laurenzo # CHECK: Inside Callback: 2229f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [[1. 5.] 2239f3f6d7bSStella Laurenzo # CHECK{LITERAL}: [6. 7.]] 2249f3f6d7bSStella Laurenzo execution_engine.invoke( 225a54f4eaeSMogball "callback_memref", 226a54f4eaeSMogball ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))) 227a54f4eaeSMogball 2289f3f6d7bSStella Laurenzo 2299f3f6d7bSStella Laurenzorun(testRankedMemRefCallback) 2309f3f6d7bSStella Laurenzo 231a54f4eaeSMogball 232c8b8e8e0SUday Bondhugula# Test addition of two memrefs. 2339f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testMemrefAdd 2349f3f6d7bSStella Laurenzodef testMemrefAdd(): 2359f3f6d7bSStella Laurenzo with Context(): 236a54f4eaeSMogball module = Module.parse(""" 2379f3f6d7bSStella Laurenzo module { 2382310ced8SRiver Riddle func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { 239a54f4eaeSMogball %0 = arith.constant 0 : index 2409f3f6d7bSStella Laurenzo %1 = memref.load %arg0[%0] : memref<1xf32> 2419f3f6d7bSStella Laurenzo %2 = memref.load %arg1[] : memref<f32> 242a54f4eaeSMogball %3 = arith.addf %1, %2 : f32 2439f3f6d7bSStella Laurenzo memref.store %3, %arg2[%0] : memref<1xf32> 2449f3f6d7bSStella Laurenzo return 2459f3f6d7bSStella Laurenzo } 246a54f4eaeSMogball } """) 2479f3f6d7bSStella Laurenzo arg1 = np.array([32.5]).astype(np.float32) 2489f3f6d7bSStella Laurenzo arg2 = np.array(6).astype(np.float32) 2499f3f6d7bSStella Laurenzo res = np.array([0]).astype(np.float32) 2509f3f6d7bSStella Laurenzo 251a54f4eaeSMogball arg1_memref_ptr = ctypes.pointer( 252a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(arg1))) 253a54f4eaeSMogball arg2_memref_ptr = ctypes.pointer( 254a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(arg2))) 255a54f4eaeSMogball res_memref_ptr = ctypes.pointer( 256a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(res))) 2579f3f6d7bSStella Laurenzo 2589f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 259a54f4eaeSMogball execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, 260a54f4eaeSMogball res_memref_ptr) 2619f3f6d7bSStella Laurenzo # CHECK: [32.5] + 6.0 = [38.5] 2629f3f6d7bSStella Laurenzo log("{0} + {1} = {2}".format(arg1, arg2, res)) 2639f3f6d7bSStella Laurenzo 264a54f4eaeSMogball 2659f3f6d7bSStella Laurenzorun(testMemrefAdd) 2669f3f6d7bSStella Laurenzo 267a54f4eaeSMogball 268f8b692ddSAart Bik# Test addition of two f16 memrefs 269f8b692ddSAart Bik# CHECK-LABEL: TEST: testF16MemrefAdd 270f8b692ddSAart Bikdef testF16MemrefAdd(): 271f8b692ddSAart Bik with Context(): 272f8b692ddSAart Bik module = Module.parse(""" 273f8b692ddSAart Bik module { 274f8b692ddSAart Bik func.func @main(%arg0: memref<1xf16>, 275f8b692ddSAart Bik %arg1: memref<1xf16>, 276f8b692ddSAart Bik %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { 277f8b692ddSAart Bik %0 = arith.constant 0 : index 278f8b692ddSAart Bik %1 = memref.load %arg0[%0] : memref<1xf16> 279f8b692ddSAart Bik %2 = memref.load %arg1[%0] : memref<1xf16> 280f8b692ddSAart Bik %3 = arith.addf %1, %2 : f16 281f8b692ddSAart Bik memref.store %3, %arg2[%0] : memref<1xf16> 282f8b692ddSAart Bik return 283f8b692ddSAart Bik } 284f8b692ddSAart Bik } """) 285f8b692ddSAart Bik 286f8b692ddSAart Bik arg1 = np.array([11.]).astype(np.float16) 287f8b692ddSAart Bik arg2 = np.array([22.]).astype(np.float16) 288f8b692ddSAart Bik arg3 = np.array([0.]).astype(np.float16) 289f8b692ddSAart Bik 290f8b692ddSAart Bik arg1_memref_ptr = ctypes.pointer( 291f8b692ddSAart Bik ctypes.pointer(get_ranked_memref_descriptor(arg1))) 292f8b692ddSAart Bik arg2_memref_ptr = ctypes.pointer( 293f8b692ddSAart Bik ctypes.pointer(get_ranked_memref_descriptor(arg2))) 294f8b692ddSAart Bik arg3_memref_ptr = ctypes.pointer( 295f8b692ddSAart Bik ctypes.pointer(get_ranked_memref_descriptor(arg3))) 296f8b692ddSAart Bik 297f8b692ddSAart Bik execution_engine = ExecutionEngine(lowerToLLVM(module)) 298f8b692ddSAart Bik execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, 299f8b692ddSAart Bik arg3_memref_ptr) 300f8b692ddSAart Bik # CHECK: [11.] + [22.] = [33.] 301f8b692ddSAart Bik log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 302f8b692ddSAart Bik 303f8b692ddSAart Bik # test to-numpy utility 304f8b692ddSAart Bik # CHECK: [33.] 305f8b692ddSAart Bik npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 306f8b692ddSAart Bik log(npout) 307f8b692ddSAart Bik 308f8b692ddSAart Bik 309f8b692ddSAart Bikrun(testF16MemrefAdd) 310f8b692ddSAart Bik 311f8b692ddSAart Bik 312d6682189SAart Bik# Test addition of two complex memrefs 313d6682189SAart Bik# CHECK-LABEL: TEST: testComplexMemrefAdd 314d6682189SAart Bikdef testComplexMemrefAdd(): 315d6682189SAart Bik with Context(): 316d6682189SAart Bik module = Module.parse(""" 317d6682189SAart Bik module { 318d6682189SAart Bik func.func @main(%arg0: memref<1xcomplex<f64>>, 319d6682189SAart Bik %arg1: memref<1xcomplex<f64>>, 320d6682189SAart Bik %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } { 321d6682189SAart Bik %0 = arith.constant 0 : index 322d6682189SAart Bik %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>> 323d6682189SAart Bik %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>> 324d6682189SAart Bik %3 = complex.add %1, %2 : complex<f64> 325d6682189SAart Bik memref.store %3, %arg2[%0] : memref<1xcomplex<f64>> 326d6682189SAart Bik return 327d6682189SAart Bik } 328d6682189SAart Bik } """) 329d6682189SAart Bik 330d6682189SAart Bik arg1 = np.array([1.+2.j]).astype(np.complex128) 331d6682189SAart Bik arg2 = np.array([3.+4.j]).astype(np.complex128) 332d6682189SAart Bik arg3 = np.array([0.+0.j]).astype(np.complex128) 333d6682189SAart Bik 334d6682189SAart Bik arg1_memref_ptr = ctypes.pointer( 335d6682189SAart Bik ctypes.pointer(get_ranked_memref_descriptor(arg1))) 336d6682189SAart Bik arg2_memref_ptr = ctypes.pointer( 337d6682189SAart Bik ctypes.pointer(get_ranked_memref_descriptor(arg2))) 338d6682189SAart Bik arg3_memref_ptr = ctypes.pointer( 339d6682189SAart Bik ctypes.pointer(get_ranked_memref_descriptor(arg3))) 340d6682189SAart Bik 341d6682189SAart Bik execution_engine = ExecutionEngine(lowerToLLVM(module)) 342d6682189SAart Bik execution_engine.invoke("main", 343d6682189SAart Bik arg1_memref_ptr, 344d6682189SAart Bik arg2_memref_ptr, 345d6682189SAart Bik arg3_memref_ptr) 346d6682189SAart Bik # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] 347d6682189SAart Bik log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 348d6682189SAart Bik 349d6682189SAart Bik # test to-numpy utility 350d6682189SAart Bik # CHECK: [4.+6.j] 351d6682189SAart Bik npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) 352d6682189SAart Bik log(npout) 353d6682189SAart Bik 354d6682189SAart Bik 355d6682189SAart Bikrun(testComplexMemrefAdd) 356d6682189SAart Bik 357d6682189SAart Bik 358d6682189SAart Bik# Test addition of two complex unranked memrefs 359d6682189SAart Bik# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd 360d6682189SAart Bikdef testComplexUnrankedMemrefAdd(): 361d6682189SAart Bik with Context(): 362d6682189SAart Bik module = Module.parse(""" 363d6682189SAart Bik module { 364d6682189SAart Bik func.func @main(%arg0: memref<*xcomplex<f32>>, 365d6682189SAart Bik %arg1: memref<*xcomplex<f32>>, 366d6682189SAart Bik %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } { 367d6682189SAart Bik %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 368d6682189SAart Bik %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 369d6682189SAart Bik %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>> 370d6682189SAart Bik %0 = arith.constant 0 : index 371d6682189SAart Bik %1 = memref.load %A[%0] : memref<1xcomplex<f32>> 372d6682189SAart Bik %2 = memref.load %B[%0] : memref<1xcomplex<f32>> 373d6682189SAart Bik %3 = complex.add %1, %2 : complex<f32> 374d6682189SAart Bik memref.store %3, %C[%0] : memref<1xcomplex<f32>> 375d6682189SAart Bik return 376d6682189SAart Bik } 377d6682189SAart Bik } """) 378d6682189SAart Bik 379d6682189SAart Bik arg1 = np.array([5.+6.j]).astype(np.complex64) 380d6682189SAart Bik arg2 = np.array([7.+8.j]).astype(np.complex64) 381d6682189SAart Bik arg3 = np.array([0.+0.j]).astype(np.complex64) 382d6682189SAart Bik 383d6682189SAart Bik arg1_memref_ptr = ctypes.pointer( 384d6682189SAart Bik ctypes.pointer(get_unranked_memref_descriptor(arg1))) 385d6682189SAart Bik arg2_memref_ptr = ctypes.pointer( 386d6682189SAart Bik ctypes.pointer(get_unranked_memref_descriptor(arg2))) 387d6682189SAart Bik arg3_memref_ptr = ctypes.pointer( 388d6682189SAart Bik ctypes.pointer(get_unranked_memref_descriptor(arg3))) 389d6682189SAart Bik 390d6682189SAart Bik execution_engine = ExecutionEngine(lowerToLLVM(module)) 391d6682189SAart Bik execution_engine.invoke("main", 392d6682189SAart Bik arg1_memref_ptr, 393d6682189SAart Bik arg2_memref_ptr, 394d6682189SAart Bik arg3_memref_ptr) 395d6682189SAart Bik # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] 396d6682189SAart Bik log("{0} + {1} = {2}".format(arg1, arg2, arg3)) 397d6682189SAart Bik 398d6682189SAart Bik # test to-numpy utility 399d6682189SAart Bik # CHECK: [12.+14.j] 400d6682189SAart Bik npout = unranked_memref_to_numpy(arg3_memref_ptr[0], 401d6682189SAart Bik np.dtype(np.complex64)) 402d6682189SAart Bik log(npout) 403d6682189SAart Bik 404d6682189SAart Bik 405d6682189SAart Bikrun(testComplexUnrankedMemrefAdd) 406d6682189SAart Bik 407d6682189SAart Bik 4089f3f6d7bSStella Laurenzo# Test addition of two 2d_memref 4099f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDynamicMemrefAdd2D 4109f3f6d7bSStella Laurenzodef testDynamicMemrefAdd2D(): 4119f3f6d7bSStella Laurenzo with Context(): 412a54f4eaeSMogball module = Module.parse(""" 4139f3f6d7bSStella Laurenzo module { 4142310ced8SRiver Riddle func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { 415a54f4eaeSMogball %c0 = arith.constant 0 : index 416a54f4eaeSMogball %c2 = arith.constant 2 : index 417a54f4eaeSMogball %c1 = arith.constant 1 : index 418ace01605SRiver Riddle cf.br ^bb1(%c0 : index) 4199f3f6d7bSStella Laurenzo ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 420a54f4eaeSMogball %1 = arith.cmpi slt, %0, %c2 : index 421ace01605SRiver Riddle cf.cond_br %1, ^bb2, ^bb6 4229f3f6d7bSStella Laurenzo ^bb2: // pred: ^bb1 423a54f4eaeSMogball %c0_0 = arith.constant 0 : index 424a54f4eaeSMogball %c2_1 = arith.constant 2 : index 425a54f4eaeSMogball %c1_2 = arith.constant 1 : index 426ace01605SRiver Riddle cf.br ^bb3(%c0_0 : index) 4279f3f6d7bSStella Laurenzo ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 428a54f4eaeSMogball %3 = arith.cmpi slt, %2, %c2_1 : index 429ace01605SRiver Riddle cf.cond_br %3, ^bb4, ^bb5 4309f3f6d7bSStella Laurenzo ^bb4: // pred: ^bb3 4319f3f6d7bSStella Laurenzo %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> 4329f3f6d7bSStella Laurenzo %5 = memref.load %arg1[%0, %2] : memref<?x?xf32> 433a54f4eaeSMogball %6 = arith.addf %4, %5 : f32 4349f3f6d7bSStella Laurenzo memref.store %6, %arg2[%0, %2] : memref<2x2xf32> 435a54f4eaeSMogball %7 = arith.addi %2, %c1_2 : index 436ace01605SRiver Riddle cf.br ^bb3(%7 : index) 4379f3f6d7bSStella Laurenzo ^bb5: // pred: ^bb3 438a54f4eaeSMogball %8 = arith.addi %0, %c1 : index 439ace01605SRiver Riddle cf.br ^bb1(%8 : index) 4409f3f6d7bSStella Laurenzo ^bb6: // pred: ^bb1 4419f3f6d7bSStella Laurenzo return 4429f3f6d7bSStella Laurenzo } 4439f3f6d7bSStella Laurenzo } 444a54f4eaeSMogball """) 4459f3f6d7bSStella Laurenzo arg1 = np.random.randn(2, 2).astype(np.float32) 4469f3f6d7bSStella Laurenzo arg2 = np.random.randn(2, 2).astype(np.float32) 4479f3f6d7bSStella Laurenzo res = np.random.randn(2, 2).astype(np.float32) 4489f3f6d7bSStella Laurenzo 449a54f4eaeSMogball arg1_memref_ptr = ctypes.pointer( 450a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(arg1))) 451a54f4eaeSMogball arg2_memref_ptr = ctypes.pointer( 452a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(arg2))) 453a54f4eaeSMogball res_memref_ptr = ctypes.pointer( 454a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(res))) 4559f3f6d7bSStella Laurenzo 4569f3f6d7bSStella Laurenzo execution_engine = ExecutionEngine(lowerToLLVM(module)) 457a54f4eaeSMogball execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, 458a54f4eaeSMogball res_memref_ptr) 4599f3f6d7bSStella Laurenzo # CHECK: True 4609f3f6d7bSStella Laurenzo log(np.allclose(arg1 + arg2, res)) 4619f3f6d7bSStella Laurenzo 462a54f4eaeSMogball 4639f3f6d7bSStella Laurenzorun(testDynamicMemrefAdd2D) 464c8b8e8e0SUday Bondhugula 465a54f4eaeSMogball 466c8b8e8e0SUday Bondhugula# Test loading of shared libraries. 467c8b8e8e0SUday Bondhugula# CHECK-LABEL: TEST: testSharedLibLoad 468c8b8e8e0SUday Bondhuguladef testSharedLibLoad(): 469c8b8e8e0SUday Bondhugula with Context(): 470a54f4eaeSMogball module = Module.parse(""" 471c8b8e8e0SUday Bondhugula module { 4722310ced8SRiver Riddle func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { 473a54f4eaeSMogball %c0 = arith.constant 0 : index 474a54f4eaeSMogball %cst42 = arith.constant 42.0 : f32 475c8b8e8e0SUday Bondhugula memref.store %cst42, %arg0[%c0] : memref<1xf32> 476c8b8e8e0SUday Bondhugula %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32> 477d4555698SStella Stamenova call @printMemrefF32(%u_memref) : (memref<*xf32>) -> () 478c8b8e8e0SUday Bondhugula return 479c8b8e8e0SUday Bondhugula } 480d4555698SStella Stamenova func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } 481a54f4eaeSMogball } """) 482c8b8e8e0SUday Bondhugula arg0 = np.array([0.0]).astype(np.float32) 483c8b8e8e0SUday Bondhugula 484a54f4eaeSMogball arg0_memref_ptr = ctypes.pointer( 485a54f4eaeSMogball ctypes.pointer(get_ranked_memref_descriptor(arg0))) 486c8b8e8e0SUday Bondhugula 487057863a9SStella Stamenova if sys.platform == 'win32': 488057863a9SStella Stamenova shared_libs = [ 489057863a9SStella Stamenova "../../../../bin/mlir_runner_utils.dll", 490057863a9SStella Stamenova "../../../../bin/mlir_c_runner_utils.dll" 491057863a9SStella Stamenova ] 492f9676d2dSAnush Elangovan elif sys.platform == 'darwin': 493f9676d2dSAnush Elangovan shared_libs = [ 494f9676d2dSAnush Elangovan "../../../../lib/libmlir_runner_utils.dylib", 495f9676d2dSAnush Elangovan "../../../../lib/libmlir_c_runner_utils.dylib" 496f9676d2dSAnush Elangovan ] 497057863a9SStella Stamenova else: 498a54f4eaeSMogball shared_libs = [ 499a54f4eaeSMogball "../../../../lib/libmlir_runner_utils.so", 500a54f4eaeSMogball "../../../../lib/libmlir_c_runner_utils.so" 501057863a9SStella Stamenova ] 502057863a9SStella Stamenova 503057863a9SStella Stamenova execution_engine = ExecutionEngine( 504057863a9SStella Stamenova lowerToLLVM(module), 505057863a9SStella Stamenova opt_level=3, 506057863a9SStella Stamenova shared_libs=shared_libs) 507c8b8e8e0SUday Bondhugula execution_engine.invoke("main", arg0_memref_ptr) 508c8b8e8e0SUday Bondhugula # CHECK: Unranked Memref 509c8b8e8e0SUday Bondhugula # CHECK-NEXT: [42] 510c8b8e8e0SUday Bondhugula 511a54f4eaeSMogball 512c8b8e8e0SUday Bondhugularun(testSharedLibLoad) 513aaea92e1SDenys Shabalin 514aaea92e1SDenys Shabalin 515aaea92e1SDenys Shabalin# Test that nano time clock is available. 516aaea92e1SDenys Shabalin# CHECK-LABEL: TEST: testNanoTime 517aaea92e1SDenys Shabalindef testNanoTime(): 518aaea92e1SDenys Shabalin with Context(): 519aaea92e1SDenys Shabalin module = Module.parse(""" 520aaea92e1SDenys Shabalin module { 5212310ced8SRiver Riddle func.func @main() attributes { llvm.emit_c_interface } { 522d4555698SStella Stamenova %now = call @nanoTime() : () -> i64 523aaea92e1SDenys Shabalin %memref = memref.alloca() : memref<1xi64> 524aaea92e1SDenys Shabalin %c0 = arith.constant 0 : index 525aaea92e1SDenys Shabalin memref.store %now, %memref[%c0] : memref<1xi64> 526aaea92e1SDenys Shabalin %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64> 527d4555698SStella Stamenova call @printMemrefI64(%u_memref) : (memref<*xi64>) -> () 528aaea92e1SDenys Shabalin return 529aaea92e1SDenys Shabalin } 530d4555698SStella Stamenova func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } 531d4555698SStella Stamenova func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } 532aaea92e1SDenys Shabalin }""") 533aaea92e1SDenys Shabalin 534057863a9SStella Stamenova if sys.platform == 'win32': 535057863a9SStella Stamenova shared_libs = [ 536057863a9SStella Stamenova "../../../../bin/mlir_runner_utils.dll", 537057863a9SStella Stamenova "../../../../bin/mlir_c_runner_utils.dll" 538057863a9SStella Stamenova ] 539057863a9SStella Stamenova else: 540aaea92e1SDenys Shabalin shared_libs = [ 541aaea92e1SDenys Shabalin "../../../../lib/libmlir_runner_utils.so", 542aaea92e1SDenys Shabalin "../../../../lib/libmlir_c_runner_utils.so" 543057863a9SStella Stamenova ] 544057863a9SStella Stamenova 545057863a9SStella Stamenova execution_engine = ExecutionEngine( 546057863a9SStella Stamenova lowerToLLVM(module), 547057863a9SStella Stamenova opt_level=3, 548057863a9SStella Stamenova shared_libs=shared_libs) 549aaea92e1SDenys Shabalin execution_engine.invoke("main") 550aaea92e1SDenys Shabalin # CHECK: Unranked Memref 551aaea92e1SDenys Shabalin # CHECK: [{{.*}}] 552aaea92e1SDenys Shabalin 553aaea92e1SDenys Shabalin 554aaea92e1SDenys Shabalinrun(testNanoTime) 555