1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import ctypes 4import sys 5from mlir.ir import * 6from mlir.dialects import builtin 7from mlir.dialects import linalg 8from mlir.dialects import std 9from mlir.passmanager import * 10from mlir.execution_engine import * 11 12 13# Log everything to stderr and flush so that we have a unified stream to match 14# errors/info emitted by MLIR to stderr. 15def log(*args): 16 print(*args, file=sys.stderr) 17 sys.stderr.flush() 18 19 20matmul_boiler = """ 21func @main() -> f32 attributes {llvm.emit_c_interface} { 22 %v0 = constant 0.0 : f32 23 %v1 = constant 1.0 : f32 24 %v2 = constant 2.0 : f32 25 26 %A = memref.alloc() : memref<4x16xf32> 27 %B = memref.alloc() : memref<16x8xf32> 28 %C = memref.alloc() : memref<4x8xf32> 29 linalg.fill(%v1, %A) : f32, memref<4x16xf32> 30 linalg.fill(%v2, %B) : f32, memref<16x8xf32> 31 linalg.fill(%v0, %C) : f32, memref<4x8xf32> 32 33 call @matmul_on_buffers(%A, %B, %C) : 34 (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () 35 36 %c0 = constant 0 : index 37 %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> 38 39 // TODO: FFI-based solution to allow testing and printing with python code. 40 return %0 : f32 41} 42""" 43 44fill_boiler = """ 45func @main() -> i32 attributes {llvm.emit_c_interface} { 46 %O = memref.alloc() : memref<4x16xi32> 47 %min = constant -1000.0 : f64 48 %max = constant 1000.0 : f64 49 %seed = constant 42 : i32 50 51 call @fill_on_buffers(%min, %max, %seed, %O) : 52 (f64, f64, i32, memref<4x16xi32>) -> () 53 54 %c0 = constant 0 : index 55 %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> 56 57 // TODO: FFI-based solution to allow testing and printing with python code. 58 return %0 : i32 59} 60""" 61 62conv_boiler = """ 63func @main() -> i32 attributes {llvm.emit_c_interface} { 64 %v0 = constant 0 : i32 65 %v1 = constant 1.0 : f64 66 %v2 = constant 2.0 : f64 67 68 %input = memref.alloc() : memref<1x4x16x1xf64> 69 %filter = memref.alloc() : memref<2x2x1xf64> 70 %output = memref.alloc() : memref<1x2x4x1xi32> 71 linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> 72 linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64> 73 linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> 74 75 call @conv_on_buffers(%input, %filter, %output) : 76 (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () 77 78 %c0 = constant 0 : index 79 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 80 81 // TODO: FFI-based solution to allow testing and printing with python code. 82 return %0 : i32 83} 84""" 85 86pooling_boiler = """ 87func @main() -> i32 attributes {llvm.emit_c_interface} { 88 %v0 = constant 0 : i32 89 %v42 = constant 42.0 : f64 90 %v77 = constant 77.0 : f64 91 %v-13 = constant -13.0 : f64 92 %v1 = constant 1.0 : f64 93 94 %input = memref.alloc() : memref<1x4x16x1xf64> 95 %shape = memref.alloc() : memref<2x2xf64> 96 %output = memref.alloc() : memref<1x2x4x1xi32> 97 linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> 98 linalg.fill(%v1, %shape) : f64, memref<2x2xf64> 99 linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> 100 101 %c0 = constant 0 : index 102 %c1 = constant 1 : index 103 %c2 = constant 2 : index 104 memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> 105 memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> 106 memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64> 107 108 call @pooling_on_buffers(%input, %shape, %output) : 109 (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () 110 111 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 112 113 // TODO: FFI-based solution to allow testing and printing with python code. 114 return %0 : i32 115} 116""" 117 118 119def transform(module, boilerplate): 120 import mlir.conversions 121 import mlir.dialects.linalg.passes 122 import mlir.transforms 123 124 # TODO: Allow cloning functions from one module to another. 125 # Atm we have to resort to string concatenation. 126 mod = Module.parse( 127 str(module.operation.regions[0].blocks[0].operations[0].operation) + 128 boilerplate) 129 pm = PassManager.parse( 130 "builtin.func(convert-linalg-to-loops, lower-affine, " + 131 "convert-scf-to-std), convert-vector-to-llvm," + 132 "convert-memref-to-llvm,convert-std-to-llvm") 133 pm.run(mod) 134 return mod 135 136 137def test_matmul_builtin(): 138 with Context() as ctx, Location.unknown(): 139 module = Module.create() 140 f32 = F32Type.get() 141 with InsertionPoint(module.body): 142 143 @builtin.FuncOp.from_py_func( 144 MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), 145 MemRefType.get((4, 8), f32)) 146 def matmul_on_buffers(lhs, rhs, out): 147 linalg.matmul(lhs, rhs, outs=[out]) 148 149 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 150 151 # TODO: FFI-based solution to allow testing and printing with python code. 152 # Prepare arguments: one result f32. 153 # Arguments must be passed as pointers. 154 c_float_p = ctypes.c_float * 1 155 res = c_float_p(-1.) 156 execution_engine.invoke("main", res) 157 158 log("RESULT: ", res[0]) 159 # CHECK: RESULT: 32.0 160 161 162test_matmul_builtin() 163 164 165def test_matmul_generic(): 166 with Context() as ctx, Location.unknown(): 167 module = Module.create() 168 f32 = F32Type.get() 169 with InsertionPoint(module.body): 170 171 @builtin.FuncOp.from_py_func( 172 MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32), 173 MemRefType.get((4, 8), f32)) 174 def matmul_on_buffers(lhs, rhs, out): 175 linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) 176 177 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 178 179 # TODO: FFI-based solution to allow testing and printing with python code. 180 # Prepare arguments: one result f32. 181 # Arguments must be passed as pointers. 182 c_float_p = ctypes.c_float * 1 183 res = c_float_p(-1.) 184 execution_engine.invoke("main", res) 185 186 log("RESULT: ", res[0]) 187 # CHECK: RESULT: 32.0 188 189 190test_matmul_generic() 191 192 193def test_fill_builtin(): 194 with Context() as ctx, Location.unknown(): 195 module = Module.create() 196 f64 = F64Type.get() 197 i32 = IntegerType.get_signless(32) 198 with InsertionPoint(module.body): 199 200 @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 201 def fill_on_buffers(min, max, seed, out): 202 linalg.fill_rng_2d(min, max, seed, outs=[out]) 203 204 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 205 206 # TODO: FFI-based solution to allow testing and printing with python code. 207 # Prepare arguments: one result i32. 208 # Arguments must be passed as pointers. 209 c_int_p = ctypes.c_int * 1 210 res = c_int_p(-1) 211 execution_engine.invoke("main", res) 212 213 log("RESULT: ", res[0]) 214 # CHECK: RESULT: -480 215 216 217test_fill_builtin() 218 219 220def test_fill_generic(): 221 with Context() as ctx, Location.unknown(): 222 module = Module.create() 223 f64 = F64Type.get() 224 i32 = IntegerType.get_signless(32) 225 with InsertionPoint(module.body): 226 227 @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 228 def fill_on_buffers(min, max, seed, out): 229 linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) 230 231 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 232 233 # TODO: FFI-based solution to allow testing and printing with python code. 234 # Prepare arguments: one result i32. 235 # Arguments must be passed as pointers. 236 c_int_p = ctypes.c_int * 1 237 res = c_int_p(-1) 238 execution_engine.invoke("main", res) 239 240 log("RESULT: ", res[0]) 241 # CHECK: RESULT: -480 242 243 244test_fill_generic() 245 246 247def test_max_pooling_builtin(): 248 with Context() as ctx, Location.unknown(): 249 module = Module.create() 250 f64 = F64Type.get() 251 i32 = IntegerType.get_signless(32) 252 with InsertionPoint(module.body): 253 254 @builtin.FuncOp.from_py_func( 255 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 256 MemRefType.get((1, 2, 4, 1), i32)) 257 def pooling_on_buffers(input, shape, output): 258 linalg.pooling_nhwc_max( 259 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) 260 261 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 262 263 # TODO: FFI-based solution to allow testing and printing with python code. 264 # Prepare arguments: one result i32. 265 # Arguments must be passed as pointers. 266 c_int_p = ctypes.c_int * 1 267 res = c_int_p(-1) 268 execution_engine.invoke("main", res) 269 270 log("RESULT: ", res[0]) 271 # 77 is not selected due to the dilation 2 in the second dimension. 272 # CHECK: RESULT: 42 273 274 275test_max_pooling_builtin() 276 277 278def test_max_pooling_generic(): 279 with Context() as ctx, Location.unknown(): 280 module = Module.create() 281 f64 = F64Type.get() 282 i32 = IntegerType.get_signless(32) 283 with InsertionPoint(module.body): 284 285 @builtin.FuncOp.from_py_func( 286 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 287 MemRefType.get((1, 2, 4, 1), i32)) 288 def pooling_on_buffers(input, shape, output): 289 linalg.pooling_nhwc_max( 290 input, 291 shape, 292 outs=[output], 293 strides=[2, 4], 294 dilations=[1, 2], 295 emit_generic=True) 296 297 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 298 299 # TODO: FFI-based solution to allow testing and printing with python code. 300 # Prepare arguments: one result i32. 301 # Arguments must be passed as pointers. 302 c_int_p = ctypes.c_int * 1 303 res = c_int_p(-1) 304 execution_engine.invoke("main", res) 305 306 log("RESULT: ", res[0]) 307 # 77 is not selected due to the dilation 2 in the second dimension. 308 # CHECK: RESULT: 42 309 310 311test_max_pooling_generic() 312 313 314def test_min_pooling_builtin(): 315 with Context() as ctx, Location.unknown(): 316 module = Module.create() 317 f64 = F64Type.get() 318 i32 = IntegerType.get_signless(32) 319 with InsertionPoint(module.body): 320 321 @builtin.FuncOp.from_py_func( 322 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 323 MemRefType.get((1, 2, 4, 1), i32)) 324 def pooling_on_buffers(input, shape, output): 325 linalg.pooling_nhwc_min( 326 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) 327 328 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 329 330 # TODO: FFI-based solution to allow testing and printing with python code. 331 # Prepare arguments: one result i32. 332 # Arguments must be passed as pointers. 333 c_int_p = ctypes.c_int * 1 334 res = c_int_p(-1) 335 execution_engine.invoke("main", res) 336 337 log("RESULT: ", res[0]) 338 # CHECK: RESULT: -13 339 340 341test_min_pooling_builtin() 342 343 344def test_min_pooling_generic(): 345 with Context() as ctx, Location.unknown(): 346 module = Module.create() 347 f64 = F64Type.get() 348 i32 = IntegerType.get_signless(32) 349 with InsertionPoint(module.body): 350 351 @builtin.FuncOp.from_py_func( 352 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 353 MemRefType.get((1, 2, 4, 1), i32)) 354 def pooling_on_buffers(input, shape, output): 355 linalg.pooling_nhwc_min( 356 input, 357 shape, 358 outs=[output], 359 strides=[2, 4], 360 dilations=[1, 2], 361 emit_generic=True) 362 363 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 364 365 # TODO: FFI-based solution to allow testing and printing with python code. 366 # Prepare arguments: one result i32. 367 # Arguments must be passed as pointers. 368 c_int_p = ctypes.c_int * 1 369 res = c_int_p(-1) 370 execution_engine.invoke("main", res) 371 372 log("RESULT: ", res[0]) 373 # CHECK: RESULT: -13 374 375 376test_min_pooling_generic() 377