1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import ctypes 4import sys 5from mlir.ir import * 6from mlir.dialects import builtin 7from mlir.dialects import linalg 8from mlir.dialects import std 9from mlir.passmanager import * 10from mlir.execution_engine import * 11 12 13# Log everything to stderr and flush so that we have a unified stream to match 14# errors/info emitted by MLIR to stderr. 15def log(*args): 16 print(*args, file=sys.stderr) 17 sys.stderr.flush() 18 19 20matmul_boiler = """ 21func @main() -> f32 attributes {llvm.emit_c_interface} { 22 %v0 = arith.constant 0.0 : f32 23 %v1 = arith.constant 1.0 : f32 24 %v2 = arith.constant 2.0 : f32 25 26 %A = memref.alloc() : memref<4x16xf32> 27 %B = memref.alloc() : memref<16x8xf32> 28 %C = memref.alloc() : memref<4x8xf32> 29 linalg.fill(%v1, %A) : f32, memref<4x16xf32> 30 linalg.fill(%v2, %B) : f32, memref<16x8xf32> 31 linalg.fill(%v0, %C) : f32, memref<4x8xf32> 32 33 call @matmul_on_buffers(%A, %B, %C) : 34 (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () 35 36 %c0 = arith.constant 0 : index 37 %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> 38 39 // TODO: FFI-based solution to allow testing and printing with python code. 40 return %0 : f32 41} 42""" 43 44fill_boiler = """ 45func @main() -> i32 attributes {llvm.emit_c_interface} { 46 %O = memref.alloc() : memref<4x16xi32> 47 %min = arith.constant -1000.0 : f64 48 %max = arith.constant 1000.0 : f64 49 %seed = arith.constant 42 : i32 50 51 call @fill_on_buffers(%min, %max, %seed, %O) : 52 (f64, f64, i32, memref<4x16xi32>) -> () 53 54 %c0 = arith.constant 0 : index 55 %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> 56 57 // TODO: FFI-based solution to allow testing and printing with python code. 58 return %0 : i32 59} 60""" 61 62conv_boiler = """ 63func @main() -> i32 attributes {llvm.emit_c_interface} { 64 %v0 = arith.constant 0 : i32 65 %v1 = arith.constant 1.0 : f64 66 %v2 = arith.constant 2.0 : f64 67 68 %input = memref.alloc() : memref<1x4x16x1xf64> 69 %filter = memref.alloc() : memref<2x2x1xf64> 70 %output = memref.alloc() : memref<1x2x4x1xi32> 71 linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> 72 linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64> 73 linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> 74 75 call @conv_on_buffers(%input, %filter, %output) : 76 (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () 77 78 %c0 = arith.constant 0 : index 79 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 80 81 // TODO: FFI-based solution to allow testing and printing with python code. 82 return %0 : i32 83} 84""" 85 86pooling_boiler = """ 87func @main() -> i32 attributes {llvm.emit_c_interface} { 88 %v0 = arith.constant 0 : i32 89 %v42 = arith.constant 42.0 : f64 90 %v77 = arith.constant 77.0 : f64 91 %v-13 = arith.constant -13.0 : f64 92 %v1 = arith.constant 1.0 : f64 93 94 %input = memref.alloc() : memref<1x4x16x1xf64> 95 %shape = memref.alloc() : memref<2x2xf64> 96 %output = memref.alloc() : memref<1x2x4x1xi32> 97 linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> 98 linalg.fill(%v1, %shape) : f64, memref<2x2xf64> 99 linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> 100 101 %c0 = arith.constant 0 : index 102 %c1 = arith.constant 1 : index 103 %c2 = arith.constant 2 : index 104 memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> 105 memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> 106 memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64> 107 108 call @pooling_on_buffers(%input, %shape, %output) : 109 (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () 110 111 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 112 113 // TODO: FFI-based solution to allow testing and printing with python code. 114 return %0 : i32 115} 116""" 117 118 119def transform(module, boilerplate): 120 import mlir.conversions 121 import mlir.all_passes_registration 122 import mlir.transforms 123 124 # TODO: Allow cloning functions from one module to another. 125 # Atm we have to resort to string concatenation. 126 mod = Module.parse( 127 str(module.operation.regions[0].blocks[0].operations[0].operation) + 128 boilerplate) 129 pm = PassManager.parse( 130 "builtin.func(convert-linalg-to-loops, lower-affine, " + 131 "convert-scf-to-std, arith-expand, std-expand), convert-vector-to-llvm," + 132 "convert-memref-to-llvm, convert-std-to-llvm," + 133 "reconcile-unrealized-casts") 134 pm.run(mod) 135 return mod 136 137 138def test_matmul_builtin(): 139 with Context() as ctx, Location.unknown(): 140 module = Module.create() 141 f32 = F32Type.get() 142 with InsertionPoint(module.body): 143 144 @builtin.FuncOp.from_py_func( 145 MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), 146 MemRefType.get((4, 8), f32)) 147 def matmul_on_buffers(lhs, rhs, out): 148 linalg.matmul(lhs, rhs, outs=[out]) 149 150 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 151 152 # TODO: FFI-based solution to allow testing and printing with python code. 153 # Prepare arguments: one result f32. 154 # Arguments must be passed as pointers. 155 c_float_p = ctypes.c_float * 1 156 res = c_float_p(-1.) 157 execution_engine.invoke("main", res) 158 159 log("RESULT: ", res[0]) 160 # CHECK: RESULT: 32.0 161 162 163test_matmul_builtin() 164 165 166def test_matmul_generic(): 167 with Context() as ctx, Location.unknown(): 168 module = Module.create() 169 f32 = F32Type.get() 170 with InsertionPoint(module.body): 171 172 @builtin.FuncOp.from_py_func( 173 MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), 174 MemRefType.get((4, 8), f32)) 175 def matmul_on_buffers(lhs, rhs, out): 176 linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) 177 178 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 179 180 # TODO: FFI-based solution to allow testing and printing with python code. 181 # Prepare arguments: one result f32. 182 # Arguments must be passed as pointers. 183 c_float_p = ctypes.c_float * 1 184 res = c_float_p(-1.) 185 execution_engine.invoke("main", res) 186 187 log("RESULT: ", res[0]) 188 # CHECK: RESULT: 32.0 189 190 191test_matmul_generic() 192 193 194def test_fill_builtin(): 195 with Context() as ctx, Location.unknown(): 196 module = Module.create() 197 f64 = F64Type.get() 198 i32 = IntegerType.get_signless(32) 199 with InsertionPoint(module.body): 200 201 @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 202 def fill_on_buffers(min, max, seed, out): 203 linalg.fill_rng_2d(min, max, seed, outs=[out]) 204 205 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 206 207 # TODO: FFI-based solution to allow testing and printing with python code. 208 # Prepare arguments: one result i32. 209 # Arguments must be passed as pointers. 210 c_int_p = ctypes.c_int * 1 211 res = c_int_p(-1) 212 execution_engine.invoke("main", res) 213 214 log("RESULT: ", res[0]) 215 # CHECK: RESULT: -480 216 217 218test_fill_builtin() 219 220 221def test_fill_generic(): 222 with Context() as ctx, Location.unknown(): 223 module = Module.create() 224 f64 = F64Type.get() 225 i32 = IntegerType.get_signless(32) 226 with InsertionPoint(module.body): 227 228 @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 229 def fill_on_buffers(min, max, seed, out): 230 linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) 231 232 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 233 234 # TODO: FFI-based solution to allow testing and printing with python code. 235 # Prepare arguments: one result i32. 236 # Arguments must be passed as pointers. 237 c_int_p = ctypes.c_int * 1 238 res = c_int_p(-1) 239 execution_engine.invoke("main", res) 240 241 log("RESULT: ", res[0]) 242 # CHECK: RESULT: -480 243 244 245test_fill_generic() 246 247 248def test_max_pooling_builtin(): 249 with Context() as ctx, Location.unknown(): 250 module = Module.create() 251 f64 = F64Type.get() 252 i32 = IntegerType.get_signless(32) 253 with InsertionPoint(module.body): 254 255 @builtin.FuncOp.from_py_func( 256 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 257 MemRefType.get((1, 2, 4, 1), i32)) 258 def pooling_on_buffers(input, shape, output): 259 linalg.pooling_nhwc_max( 260 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) 261 262 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 263 264 # TODO: FFI-based solution to allow testing and printing with python code. 265 # Prepare arguments: one result i32. 266 # Arguments must be passed as pointers. 267 c_int_p = ctypes.c_int * 1 268 res = c_int_p(-1) 269 execution_engine.invoke("main", res) 270 271 log("RESULT: ", res[0]) 272 # 77 is not selected due to the dilation 2 in the second dimension. 273 # CHECK: RESULT: 42 274 275 276test_max_pooling_builtin() 277 278 279def test_max_pooling_generic(): 280 with Context() as ctx, Location.unknown(): 281 module = Module.create() 282 f64 = F64Type.get() 283 i32 = IntegerType.get_signless(32) 284 with InsertionPoint(module.body): 285 286 @builtin.FuncOp.from_py_func( 287 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 288 MemRefType.get((1, 2, 4, 1), i32)) 289 def pooling_on_buffers(input, shape, output): 290 linalg.pooling_nhwc_max( 291 input, 292 shape, 293 outs=[output], 294 strides=[2, 4], 295 dilations=[1, 2], 296 emit_generic=True) 297 298 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 299 300 # TODO: FFI-based solution to allow testing and printing with python code. 301 # Prepare arguments: one result i32. 302 # Arguments must be passed as pointers. 303 c_int_p = ctypes.c_int * 1 304 res = c_int_p(-1) 305 execution_engine.invoke("main", res) 306 307 log("RESULT: ", res[0]) 308 # 77 is not selected due to the dilation 2 in the second dimension. 309 # CHECK: RESULT: 42 310 311 312test_max_pooling_generic() 313 314 315def test_min_pooling_builtin(): 316 with Context() as ctx, Location.unknown(): 317 module = Module.create() 318 f64 = F64Type.get() 319 i32 = IntegerType.get_signless(32) 320 with InsertionPoint(module.body): 321 322 @builtin.FuncOp.from_py_func( 323 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 324 MemRefType.get((1, 2, 4, 1), i32)) 325 def pooling_on_buffers(input, shape, output): 326 linalg.pooling_nhwc_min( 327 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) 328 329 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 330 331 # TODO: FFI-based solution to allow testing and printing with python code. 332 # Prepare arguments: one result i32. 333 # Arguments must be passed as pointers. 334 c_int_p = ctypes.c_int * 1 335 res = c_int_p(-1) 336 execution_engine.invoke("main", res) 337 338 log("RESULT: ", res[0]) 339 # CHECK: RESULT: -13 340 341 342test_min_pooling_builtin() 343 344 345def test_min_pooling_generic(): 346 with Context() as ctx, Location.unknown(): 347 module = Module.create() 348 f64 = F64Type.get() 349 i32 = IntegerType.get_signless(32) 350 with InsertionPoint(module.body): 351 352 @builtin.FuncOp.from_py_func( 353 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 354 MemRefType.get((1, 2, 4, 1), i32)) 355 def pooling_on_buffers(input, shape, output): 356 linalg.pooling_nhwc_min( 357 input, 358 shape, 359 outs=[output], 360 strides=[2, 4], 361 dilations=[1, 2], 362 emit_generic=True) 363 364 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 365 366 # TODO: FFI-based solution to allow testing and printing with python code. 367 # Prepare arguments: one result i32. 368 # Arguments must be passed as pointers. 369 c_int_p = ctypes.c_int * 1 370 res = c_int_p(-1) 371 execution_engine.invoke("main", res) 372 373 log("RESULT: ", res[0]) 374 # CHECK: RESULT: -13 375 376 377test_min_pooling_generic() 378