1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import ctypes 4import sys 5from mlir.ir import * 6from mlir.dialects import builtin 7from mlir.dialects import linalg 8from mlir.dialects import std 9from mlir.passmanager import * 10from mlir.execution_engine import * 11 12 13# Log everything to stderr and flush so that we have a unified stream to match 14# errors/info emitted by MLIR to stderr. 15def log(*args): 16 print(*args, file=sys.stderr) 17 sys.stderr.flush() 18 19 20matmul_boiler = """ 21func @main() -> f32 attributes {llvm.emit_c_interface} { 22 %v0 = arith.constant 0.0 : f32 23 %v1 = arith.constant 1.0 : f32 24 %v2 = arith.constant 2.0 : f32 25 26 %A = memref.alloc() : memref<4x16xf32> 27 %B = memref.alloc() : memref<16x8xf32> 28 %C = memref.alloc() : memref<4x8xf32> 29 linalg.fill(%v1, %A) : f32, memref<4x16xf32> 30 linalg.fill(%v2, %B) : f32, memref<16x8xf32> 31 linalg.fill(%v0, %C) : f32, memref<4x8xf32> 32 33 call @matmul_on_buffers(%A, %B, %C) : 34 (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () 35 36 %c0 = arith.constant 0 : index 37 %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> 38 39 // TODO: FFI-based solution to allow testing and printing with python code. 40 return %0 : f32 41} 42""" 43 44fill_boiler = """ 45func @main() -> i32 attributes {llvm.emit_c_interface} { 46 %O0 = memref.alloc() : memref<i32> 47 %O1 = memref.alloc() : memref<16xi32> 48 %O2 = memref.alloc() : memref<4x16xi32> 49 50 %val0 = arith.constant 1.0 : f32 51 %val1 = arith.constant 2.0 : f32 52 %val2 = arith.constant 3.0 : f32 53 54 call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> () 55 call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () 56 call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () 57 58 %c0 = arith.constant 0 : index 59 %res0 = memref.load %O0[] : memref<i32> 60 %c8 = arith.constant 8 : index 61 %res1 = memref.load %O1[%c8] : memref<16xi32> 62 %c2 = arith.constant 2 : index 63 %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32> 64 65 %0 = arith.addi %res0, %res1 : i32 66 %1 = arith.addi %0, %res2 : i32 67 68 // TODO: FFI-based solution to allow testing and printing with python code. 69 return %1 : i32 70} 71""" 72 73fill_rng_boiler = """ 74func @main() -> i32 attributes {llvm.emit_c_interface} { 75 %O = memref.alloc() : memref<4x16xi32> 76 %min = arith.constant -1000.0 : f64 77 %max = arith.constant 1000.0 : f64 78 %seed = arith.constant 42 : i32 79 80 call @fill_rng_on_buffers(%min, %max, %seed, %O) : 81 (f64, f64, i32, memref<4x16xi32>) -> () 82 83 %c0 = arith.constant 0 : index 84 %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> 85 86 // TODO: FFI-based solution to allow testing and printing with python code. 87 return %0 : i32 88} 89""" 90 91conv_boiler = """ 92func @main() -> i32 attributes {llvm.emit_c_interface} { 93 %v0 = arith.constant 0 : i32 94 %v1 = arith.constant 1.0 : f64 95 %v2 = arith.constant 2.0 : f64 96 97 %input = memref.alloc() : memref<1x4x16x1xf64> 98 %filter = memref.alloc() : memref<2x2x1xf64> 99 %output = memref.alloc() : memref<1x2x4x1xi32> 100 linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> 101 linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64> 102 linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> 103 104 call @conv_on_buffers(%input, %filter, %output) : 105 (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () 106 107 %c0 = arith.constant 0 : index 108 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 109 110 // TODO: FFI-based solution to allow testing and printing with python code. 111 return %0 : i32 112} 113""" 114 115pooling_boiler = """ 116func @main() -> i32 attributes {llvm.emit_c_interface} { 117 %v0 = arith.constant 0 : i32 118 %v42 = arith.constant 42.0 : f64 119 %v77 = arith.constant 77.0 : f64 120 %v-13 = arith.constant -13.0 : f64 121 %v1 = arith.constant 1.0 : f64 122 123 %input = memref.alloc() : memref<1x4x16x1xf64> 124 %shape = memref.alloc() : memref<2x2xf64> 125 %output = memref.alloc() : memref<1x2x4x1xi32> 126 linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> 127 linalg.fill(%v1, %shape) : f64, memref<2x2xf64> 128 linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> 129 130 %c0 = arith.constant 0 : index 131 %c1 = arith.constant 1 : index 132 %c2 = arith.constant 2 : index 133 memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> 134 memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> 135 memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64> 136 137 call @pooling_on_buffers(%input, %shape, %output) : 138 (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () 139 140 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 141 142 // TODO: FFI-based solution to allow testing and printing with python code. 143 return %0 : i32 144} 145""" 146 147 148def transform(module, boilerplate): 149 import mlir.conversions 150 import mlir.all_passes_registration 151 import mlir.transforms 152 153 # TODO: Allow cloning functions from one module to another. 154 # Atm we have to resort to string concatenation. 155 ops = module.operation.regions[0].blocks[0].operations 156 mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) 157 158 pm = PassManager.parse( 159 "builtin.func(convert-linalg-to-loops, lower-affine, " + 160 "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," + 161 "convert-memref-to-llvm, convert-std-to-llvm," + 162 "reconcile-unrealized-casts") 163 pm.run(mod) 164 return mod 165 166 167def test_matmul_builtin(): 168 with Context() as ctx, Location.unknown(): 169 module = Module.create() 170 f32 = F32Type.get() 171 with InsertionPoint(module.body): 172 173 @builtin.FuncOp.from_py_func( 174 MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), 175 MemRefType.get((4, 8), f32)) 176 def matmul_on_buffers(lhs, rhs, out): 177 linalg.matmul(lhs, rhs, outs=[out]) 178 179 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 180 181 # TODO: FFI-based solution to allow testing and printing with python code. 182 # Prepare arguments: one result f32. 183 # Arguments must be passed as pointers. 184 c_float_p = ctypes.c_float * 1 185 res = c_float_p(-1.) 186 execution_engine.invoke("main", res) 187 188 log("RESULT: ", res[0]) 189 # CHECK: RESULT: 32.0 190 191 192test_matmul_builtin() 193 194 195def test_matmul_generic(): 196 with Context() as ctx, Location.unknown(): 197 module = Module.create() 198 f32 = F32Type.get() 199 with InsertionPoint(module.body): 200 201 @builtin.FuncOp.from_py_func( 202 MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), 203 MemRefType.get((4, 8), f32)) 204 def matmul_on_buffers(lhs, rhs, out): 205 linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) 206 207 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 208 209 # TODO: FFI-based solution to allow testing and printing with python code. 210 # Prepare arguments: one result f32. 211 # Arguments must be passed as pointers. 212 c_float_p = ctypes.c_float * 1 213 res = c_float_p(-1.) 214 execution_engine.invoke("main", res) 215 216 log("RESULT: ", res[0]) 217 # CHECK: RESULT: 32.0 218 219 220test_matmul_generic() 221 222 223def test_fill_builtin(): 224 with Context() as ctx, Location.unknown(): 225 module = Module.create() 226 f32 = F32Type.get() 227 i32 = IntegerType.get_signless(32) 228 with InsertionPoint(module.body): 229 230 @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 231 def fill_0d_on_buffers(value, out): 232 linalg.fill_tensor(value, outs=[out]) 233 234 @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 235 def fill_1d_on_buffers(value, out): 236 linalg.fill_tensor(value, outs=[out]) 237 238 @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 239 def fill_2d_on_buffers(value, out): 240 linalg.fill_tensor(value, outs=[out]) 241 242 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 243 244 # TODO: FFI-based solution to allow testing and printing with python code. 245 # Prepare arguments: one result i32. 246 # Arguments must be passed as pointers. 247 c_int_p = ctypes.c_int * 1 248 res = c_int_p(-1) 249 execution_engine.invoke("main", res) 250 251 log("RESULT: ", res[0]) 252 # CHECK: RESULT: 6 253 254 255test_fill_builtin() 256 257 258def test_fill_generic(): 259 with Context() as ctx, Location.unknown(): 260 module = Module.create() 261 f32 = F32Type.get() 262 i32 = IntegerType.get_signless(32) 263 with InsertionPoint(module.body): 264 265 @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 266 def fill_0d_on_buffers(value, out): 267 linalg.fill_tensor(value, outs=[out], emit_generic=True) 268 269 @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 270 def fill_1d_on_buffers(value, out): 271 linalg.fill_tensor(value, outs=[out], emit_generic=True) 272 273 @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 274 def fill_2d_on_buffers(value, out): 275 linalg.fill_tensor(value, outs=[out], emit_generic=True) 276 277 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 278 279 # TODO: FFI-based solution to allow testing and printing with python code. 280 # Prepare arguments: one result i32. 281 # Arguments must be passed as pointers. 282 c_int_p = ctypes.c_int * 1 283 res = c_int_p(-1) 284 execution_engine.invoke("main", res) 285 286 log("RESULT: ", res[0]) 287 # CHECK: RESULT: 6 288 289 290test_fill_generic() 291 292 293def test_fill_rng_builtin(): 294 with Context() as ctx, Location.unknown(): 295 module = Module.create() 296 f64 = F64Type.get() 297 i32 = IntegerType.get_signless(32) 298 with InsertionPoint(module.body): 299 300 @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 301 def fill_rng_on_buffers(min, max, seed, out): 302 linalg.fill_rng_2d(min, max, seed, outs=[out]) 303 304 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 305 306 # TODO: FFI-based solution to allow testing and printing with python code. 307 # Prepare arguments: one result i32. 308 # Arguments must be passed as pointers. 309 c_int_p = ctypes.c_int * 1 310 res = c_int_p(-1) 311 execution_engine.invoke("main", res) 312 313 log("RESULT: ", res[0]) 314 # CHECK: RESULT: -480 315 316 317test_fill_rng_builtin() 318 319 320def test_fill_rng_generic(): 321 with Context() as ctx, Location.unknown(): 322 module = Module.create() 323 f64 = F64Type.get() 324 i32 = IntegerType.get_signless(32) 325 with InsertionPoint(module.body): 326 327 @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 328 def fill_rng_on_buffers(min, max, seed, out): 329 linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) 330 331 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 332 333 # TODO: FFI-based solution to allow testing and printing with python code. 334 # Prepare arguments: one result i32. 335 # Arguments must be passed as pointers. 336 c_int_p = ctypes.c_int * 1 337 res = c_int_p(-1) 338 execution_engine.invoke("main", res) 339 340 log("RESULT: ", res[0]) 341 # CHECK: RESULT: -480 342 343 344test_fill_rng_generic() 345 346 347def test_max_pooling_builtin(): 348 with Context() as ctx, Location.unknown(): 349 module = Module.create() 350 f64 = F64Type.get() 351 i32 = IntegerType.get_signless(32) 352 with InsertionPoint(module.body): 353 354 @builtin.FuncOp.from_py_func( 355 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 356 MemRefType.get((1, 2, 4, 1), i32)) 357 def pooling_on_buffers(input, shape, output): 358 linalg.pooling_nhwc_max( 359 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) 360 361 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 362 363 # TODO: FFI-based solution to allow testing and printing with python code. 364 # Prepare arguments: one result i32. 365 # Arguments must be passed as pointers. 366 c_int_p = ctypes.c_int * 1 367 res = c_int_p(-1) 368 execution_engine.invoke("main", res) 369 370 log("RESULT: ", res[0]) 371 # 77 is not selected due to the dilation 2 in the second dimension. 372 # CHECK: RESULT: 42 373 374 375test_max_pooling_builtin() 376 377 378def test_max_pooling_generic(): 379 with Context() as ctx, Location.unknown(): 380 module = Module.create() 381 f64 = F64Type.get() 382 i32 = IntegerType.get_signless(32) 383 with InsertionPoint(module.body): 384 385 @builtin.FuncOp.from_py_func( 386 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 387 MemRefType.get((1, 2, 4, 1), i32)) 388 def pooling_on_buffers(input, shape, output): 389 linalg.pooling_nhwc_max( 390 input, 391 shape, 392 outs=[output], 393 strides=[2, 4], 394 dilations=[1, 2], 395 emit_generic=True) 396 397 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 398 399 # TODO: FFI-based solution to allow testing and printing with python code. 400 # Prepare arguments: one result i32. 401 # Arguments must be passed as pointers. 402 c_int_p = ctypes.c_int * 1 403 res = c_int_p(-1) 404 execution_engine.invoke("main", res) 405 406 log("RESULT: ", res[0]) 407 # 77 is not selected due to the dilation 2 in the second dimension. 408 # CHECK: RESULT: 42 409 410 411test_max_pooling_generic() 412 413 414def test_min_pooling_builtin(): 415 with Context() as ctx, Location.unknown(): 416 module = Module.create() 417 f64 = F64Type.get() 418 i32 = IntegerType.get_signless(32) 419 with InsertionPoint(module.body): 420 421 @builtin.FuncOp.from_py_func( 422 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 423 MemRefType.get((1, 2, 4, 1), i32)) 424 # Set the strides and use the default dilations. 425 def pooling_on_buffers(input, shape, output): 426 linalg.pooling_nhwc_min( 427 input, 428 shape, 429 outs=[output], 430 strides=[2, 4]) 431 432 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 433 434 # TODO: FFI-based solution to allow testing and printing with python code. 435 # Prepare arguments: one result i32. 436 # Arguments must be passed as pointers. 437 c_int_p = ctypes.c_int * 1 438 res = c_int_p(-1) 439 execution_engine.invoke("main", res) 440 441 log("RESULT: ", res[0]) 442 # CHECK: RESULT: -13 443 444 445test_min_pooling_builtin() 446 447 448def test_min_pooling_generic(): 449 with Context() as ctx, Location.unknown(): 450 module = Module.create() 451 f64 = F64Type.get() 452 i32 = IntegerType.get_signless(32) 453 with InsertionPoint(module.body): 454 455 @builtin.FuncOp.from_py_func( 456 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 457 MemRefType.get((1, 2, 4, 1), i32)) 458 # Set the strides and use the default dilations. 459 def pooling_on_buffers(input, shape, output): 460 linalg.pooling_nhwc_min( 461 input, 462 shape, 463 outs=[output], 464 strides=[2, 4], 465 emit_generic=True) 466 467 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 468 469 # TODO: FFI-based solution to allow testing and printing with python code. 470 # Prepare arguments: one result i32. 471 # Arguments must be passed as pointers. 472 c_int_p = ctypes.c_int * 1 473 res = c_int_p(-1) 474 execution_engine.invoke("main", res) 475 476 log("RESULT: ", res[0]) 477 # CHECK: RESULT: -13 478 479 480test_min_pooling_generic() 481