1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import ctypes 4import sys 5from mlir.ir import * 6from mlir.dialects import builtin 7from mlir.dialects import func 8from mlir.dialects import linalg 9from mlir.passmanager import * 10from mlir.execution_engine import * 11 12from mlir.dialects.linalg.opdsl.lang import * 13 14 15# Log everything to stderr and flush so that we have a unified stream to match 16# errors/info emitted by MLIR to stderr. 17def log(*args): 18 print(*args, file=sys.stderr) 19 sys.stderr.flush() 20 21 22elemwise_boiler = """ 23func.func @main() -> f32 attributes {llvm.emit_c_interface} { 24 %v0 = arith.constant 0.0 : f32 25 %v1 = arith.constant 1.0 : f32 26 %v2 = arith.constant 2.0 : f32 27 28 %lhs = memref.alloc() : memref<f32> 29 %rhs = memref.alloc() : memref<4x8xf32> 30 %O0 = memref.alloc() : memref<4x8xf32> 31 %O1 = memref.alloc() : memref<4x8xf32> 32 linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>) 33 linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>) 34 linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>) 35 linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>) 36 37 call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : 38 (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> () 39 call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) : 40 (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> () 41 42 %c0 = arith.constant 0 : index 43 %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32> 44 %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32> 45 46 %0 = arith.addf %res0, %res1 : f32 47 48 // TODO: FFI-based solution to allow testing and printing with python code. 49 return %0 : f32 50} 51""" 52 53matmul_boiler = """ 54func.func @main() -> f32 attributes {llvm.emit_c_interface} { 55 %v0 = arith.constant 0.0 : f32 56 %v1 = arith.constant -1 : i8 57 %v2 = arith.constant 2.0 : f32 58 59 %A = memref.alloc() : memref<4x16xi8> 60 %B = memref.alloc() : memref<16x8xf32> 61 %C0 = memref.alloc() : memref<4x8xf32> 62 %C1 = memref.alloc() : memref<4x8xf32> 63 linalg.fill ins(%v1 : i8) outs(%A : memref<4x16xi8>) 64 linalg.fill ins(%v2 : f32) outs(%B : memref<16x8xf32>) 65 linalg.fill ins(%v0 : f32) outs(%C0 : memref<4x8xf32>) 66 linalg.fill ins(%v0 : f32) outs(%C1 : memref<4x8xf32>) 67 68 call @matmul_signed_on_buffers(%A, %B, %C0) : 69 (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () 70 call @matmul_unsigned_on_buffers(%A, %B, %C1) : 71 (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () 72 73 %c0 = arith.constant 0 : index 74 %res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32> 75 %res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32> 76 77 %0 = arith.addf %res0, %res1 : f32 78 79 // TODO: FFI-based solution to allow testing and printing with python code. 80 return %0 : f32 81} 82""" 83 84fill_boiler = """ 85func.func @main() -> i32 attributes {llvm.emit_c_interface} { 86 %O0 = memref.alloc() : memref<i32> 87 %O1 = memref.alloc() : memref<16xi32> 88 %O2 = memref.alloc() : memref<4x16xi32> 89 90 %val0 = arith.constant 1.0 : f32 91 %val1 = arith.constant 2.0 : f32 92 %val2 = arith.constant 3.0 : f32 93 94 call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> () 95 call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () 96 call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () 97 98 %c0 = arith.constant 0 : index 99 %res0 = memref.load %O0[] : memref<i32> 100 %c8 = arith.constant 8 : index 101 %res1 = memref.load %O1[%c8] : memref<16xi32> 102 %c2 = arith.constant 2 : index 103 %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32> 104 105 %0 = arith.addi %res0, %res1 : i32 106 %1 = arith.addi %0, %res2 : i32 107 108 // TODO: FFI-based solution to allow testing and printing with python code. 109 return %1 : i32 110} 111""" 112 113fill_rng_boiler = """ 114func.func @main() -> i32 attributes {llvm.emit_c_interface} { 115 %O = memref.alloc() : memref<4x16xi32> 116 %min = arith.constant -1000.0 : f64 117 %max = arith.constant 1000.0 : f64 118 %seed = arith.constant 42 : i32 119 120 call @fill_rng_on_buffers(%min, %max, %seed, %O) : 121 (f64, f64, i32, memref<4x16xi32>) -> () 122 123 %c0 = arith.constant 0 : index 124 %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> 125 126 // TODO: FFI-based solution to allow testing and printing with python code. 127 return %0 : i32 128} 129""" 130 131conv_boiler = """ 132func.func @main() -> i32 attributes {llvm.emit_c_interface} { 133 %v0 = arith.constant 0 : i32 134 %v1 = arith.constant 1.0 : f64 135 %v2 = arith.constant 2.0 : f64 136 137 %input = memref.alloc() : memref<1x4x16x1xf64> 138 %filter = memref.alloc() : memref<2x2x1xf64> 139 %output = memref.alloc() : memref<1x2x4x1xi32> 140 linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) 141 linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>) 142 linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) 143 144 call @conv_on_buffers(%input, %filter, %output) : 145 (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () 146 147 %c0 = arith.constant 0 : index 148 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 149 150 // TODO: FFI-based solution to allow testing and printing with python code. 151 return %0 : i32 152} 153""" 154 155pooling_boiler = """ 156func.func @main() -> i32 attributes {llvm.emit_c_interface} { 157 %v0 = arith.constant 0 : i32 158 %v42 = arith.constant 42.0 : f64 159 %v77 = arith.constant 77.0 : f64 160 %v-13 = arith.constant -13.0 : f64 161 %v1 = arith.constant 1.0 : f64 162 163 %input = memref.alloc() : memref<1x4x16x1xf64> 164 %shape = memref.alloc() : memref<2x2xf64> 165 %output = memref.alloc() : memref<1x2x4x1xi32> 166 linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) 167 linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>) 168 linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) 169 170 %c0 = arith.constant 0 : index 171 %c1 = arith.constant 1 : index 172 %c2 = arith.constant 2 : index 173 memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> 174 memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> 175 memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64> 176 177 call @pooling_on_buffers(%input, %shape, %output) : 178 (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () 179 180 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 181 182 // TODO: FFI-based solution to allow testing and printing with python code. 183 return %0 : i32 184} 185""" 186 187 188def transform(module, boilerplate): 189 import mlir.conversions 190 import mlir.all_passes_registration 191 import mlir.transforms 192 193 # TODO: Allow cloning functions from one module to another. 194 # Atm we have to resort to string concatenation. 195 ops = module.operation.regions[0].blocks[0].operations 196 mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) 197 198 pm = PassManager.parse( 199 "func.func(convert-linalg-to-loops, lower-affine, " + 200 "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), " 201 + "convert-vector-to-llvm, convert-memref-to-llvm, convert-func-to-llvm," + 202 "reconcile-unrealized-casts") 203 pm.run(mod) 204 return mod 205 206 207def test_elemwise_builtin(): 208 with Context() as ctx, Location.unknown(): 209 module = Module.create() 210 f32 = F32Type.get() 211 i8 = IntegerType.get_signless(8) 212 with InsertionPoint(module.body): 213 214 @func.FuncOp.from_py_func( 215 MemRefType.get((), f32), MemRefType.get((4, 8), f32), 216 MemRefType.get((4, 8), f32)) 217 def elemwise_exp_add_on_buffers(lhs, rhs, out): 218 linalg.elemwise_unary(lhs, outs=[out]) 219 linalg.elemwise_binary(out, rhs, outs=[out]) 220 221 @func.FuncOp.from_py_func( 222 MemRefType.get((), f32), MemRefType.get((4, 8), f32), 223 MemRefType.get((4, 8), f32)) 224 def elemwise_log_mul_on_buffers(lhs, rhs, out): 225 linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) 226 linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul) 227 228 execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) 229 230 # TODO: FFI-based solution to allow testing and printing with python code. 231 # Prepare arguments: one result f32. 232 # Arguments must be passed as pointers. 233 c_float_p = ctypes.c_float * 1 234 res = c_float_p(-1.) 235 execution_engine.invoke("main", res) 236 237 log("RESULT: ", res[0]) 238 # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 239 # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 240 # CHECK: RESULT: 4.71828 241 242 243test_elemwise_builtin() 244 245 246def test_elemwise_generic(): 247 with Context() as ctx, Location.unknown(): 248 module = Module.create() 249 f32 = F32Type.get() 250 i8 = IntegerType.get_signless(8) 251 with InsertionPoint(module.body): 252 253 @func.FuncOp.from_py_func( 254 MemRefType.get((), f32), MemRefType.get((4, 8), f32), 255 MemRefType.get((4, 8), f32)) 256 def elemwise_exp_add_on_buffers(lhs, rhs, out): 257 linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) 258 linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) 259 260 @func.FuncOp.from_py_func( 261 MemRefType.get((), f32), MemRefType.get((4, 8), f32), 262 MemRefType.get((4, 8), f32)) 263 def elemwise_log_mul_on_buffers(lhs, rhs, out): 264 linalg.elemwise_unary( 265 lhs, outs=[out], fun=UnaryFn.log, emit_generic=True) 266 linalg.elemwise_binary( 267 out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True) 268 269 execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) 270 271 # TODO: FFI-based solution to allow testing and printing with python code. 272 # Prepare arguments: one result f32. 273 # Arguments must be passed as pointers. 274 c_float_p = ctypes.c_float * 1 275 res = c_float_p(-1.) 276 execution_engine.invoke("main", res) 277 278 log("RESULT: ", res[0]) 279 # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 280 # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 281 # CHECK: RESULT: 4.71828 282 283 284test_elemwise_generic() 285 286 287def test_matmul_builtin(): 288 with Context() as ctx, Location.unknown(): 289 module = Module.create() 290 f32 = F32Type.get() 291 i8 = IntegerType.get_signless(8) 292 with InsertionPoint(module.body): 293 294 @func.FuncOp.from_py_func( 295 MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), 296 MemRefType.get((4, 8), f32)) 297 def matmul_signed_on_buffers(lhs, rhs, out): 298 linalg.matmul(lhs, rhs, outs=[out]) 299 300 @func.FuncOp.from_py_func( 301 MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), 302 MemRefType.get((4, 8), f32)) 303 def matmul_unsigned_on_buffers(lhs, rhs, out): 304 linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) 305 306 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 307 308 # TODO: FFI-based solution to allow testing and printing with python code. 309 # Prepare arguments: one result f32. 310 # Arguments must be passed as pointers. 311 c_float_p = ctypes.c_float * 1 312 res = c_float_p(-1.) 313 execution_engine.invoke("main", res) 314 315 log("RESULT: ", res[0]) 316 # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32 317 # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160 318 # CHECK: RESULT: 8128 319 320 321test_matmul_builtin() 322 323 324def test_matmul_generic(): 325 with Context() as ctx, Location.unknown(): 326 module = Module.create() 327 f32 = F32Type.get() 328 i8 = IntegerType.get_signless(8) 329 with InsertionPoint(module.body): 330 331 @func.FuncOp.from_py_func( 332 MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), 333 MemRefType.get((4, 8), f32)) 334 def matmul_signed_on_buffers(lhs, rhs, out): 335 linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) 336 337 @func.FuncOp.from_py_func( 338 MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), 339 MemRefType.get((4, 8), f32)) 340 def matmul_unsigned_on_buffers(lhs, rhs, out): 341 linalg.matmul( 342 lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True) 343 344 execution_engine = ExecutionEngine(transform(module, matmul_boiler)) 345 346 # TODO: FFI-based solution to allow testing and printing with python code. 347 # Prepare arguments: one result f32. 348 # Arguments must be passed as pointers. 349 c_float_p = ctypes.c_float * 1 350 res = c_float_p(-1.) 351 execution_engine.invoke("main", res) 352 353 log("RESULT: ", res[0]) 354 # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32 355 # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160 356 # CHECK: RESULT: 8128 357 358 359test_matmul_generic() 360 361 362def test_fill_builtin(): 363 with Context() as ctx, Location.unknown(): 364 module = Module.create() 365 f32 = F32Type.get() 366 i32 = IntegerType.get_signless(32) 367 with InsertionPoint(module.body): 368 369 @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 370 def fill_0d_on_buffers(value, out): 371 linalg.fill(value, outs=[out]) 372 373 @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 374 def fill_1d_on_buffers(value, out): 375 linalg.fill(value, outs=[out]) 376 377 @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 378 def fill_2d_on_buffers(value, out): 379 linalg.fill(value, outs=[out]) 380 381 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 382 383 # TODO: FFI-based solution to allow testing and printing with python code. 384 # Prepare arguments: one result i32. 385 # Arguments must be passed as pointers. 386 c_int_p = ctypes.c_int * 1 387 res = c_int_p(-1) 388 execution_engine.invoke("main", res) 389 390 log("RESULT: ", res[0]) 391 # CHECK: RESULT: 6 392 393 394test_fill_builtin() 395 396 397def test_fill_generic(): 398 with Context() as ctx, Location.unknown(): 399 module = Module.create() 400 f32 = F32Type.get() 401 i32 = IntegerType.get_signless(32) 402 with InsertionPoint(module.body): 403 404 @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 405 def fill_0d_on_buffers(value, out): 406 linalg.fill(value, outs=[out], emit_generic=True) 407 408 @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 409 def fill_1d_on_buffers(value, out): 410 linalg.fill(value, outs=[out], emit_generic=True) 411 412 @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 413 def fill_2d_on_buffers(value, out): 414 linalg.fill(value, outs=[out], emit_generic=True) 415 416 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 417 418 # TODO: FFI-based solution to allow testing and printing with python code. 419 # Prepare arguments: one result i32. 420 # Arguments must be passed as pointers. 421 c_int_p = ctypes.c_int * 1 422 res = c_int_p(-1) 423 execution_engine.invoke("main", res) 424 425 log("RESULT: ", res[0]) 426 # CHECK: RESULT: 6 427 428 429test_fill_generic() 430 431 432def test_fill_rng_builtin(): 433 with Context() as ctx, Location.unknown(): 434 module = Module.create() 435 f64 = F64Type.get() 436 i32 = IntegerType.get_signless(32) 437 with InsertionPoint(module.body): 438 439 @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 440 def fill_rng_on_buffers(min, max, seed, out): 441 linalg.fill_rng_2d(min, max, seed, outs=[out]) 442 443 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 444 445 # TODO: FFI-based solution to allow testing and printing with python code. 446 # Prepare arguments: one result i32. 447 # Arguments must be passed as pointers. 448 c_int_p = ctypes.c_int * 1 449 res = c_int_p(-1) 450 execution_engine.invoke("main", res) 451 452 log("RESULT: ", res[0]) 453 # CHECK: RESULT: -480 454 455 456test_fill_rng_builtin() 457 458 459def test_fill_rng_generic(): 460 with Context() as ctx, Location.unknown(): 461 module = Module.create() 462 f64 = F64Type.get() 463 i32 = IntegerType.get_signless(32) 464 with InsertionPoint(module.body): 465 466 @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 467 def fill_rng_on_buffers(min, max, seed, out): 468 linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) 469 470 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 471 472 # TODO: FFI-based solution to allow testing and printing with python code. 473 # Prepare arguments: one result i32. 474 # Arguments must be passed as pointers. 475 c_int_p = ctypes.c_int * 1 476 res = c_int_p(-1) 477 execution_engine.invoke("main", res) 478 479 log("RESULT: ", res[0]) 480 # CHECK: RESULT: -480 481 482 483test_fill_rng_generic() 484 485 486def test_max_pooling_builtin(): 487 with Context() as ctx, Location.unknown(): 488 module = Module.create() 489 f64 = F64Type.get() 490 i32 = IntegerType.get_signless(32) 491 with InsertionPoint(module.body): 492 493 @func.FuncOp.from_py_func( 494 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 495 MemRefType.get((1, 2, 4, 1), i32)) 496 def pooling_on_buffers(input, shape, output): 497 linalg.pooling_nhwc_max( 498 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) 499 500 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 501 502 # TODO: FFI-based solution to allow testing and printing with python code. 503 # Prepare arguments: one result i32. 504 # Arguments must be passed as pointers. 505 c_int_p = ctypes.c_int * 1 506 res = c_int_p(-1) 507 execution_engine.invoke("main", res) 508 509 log("RESULT: ", res[0]) 510 # 77 is not selected due to the dilation 2 in the second dimension. 511 # CHECK: RESULT: 42 512 513 514test_max_pooling_builtin() 515 516 517def test_max_pooling_generic(): 518 with Context() as ctx, Location.unknown(): 519 module = Module.create() 520 f64 = F64Type.get() 521 i32 = IntegerType.get_signless(32) 522 with InsertionPoint(module.body): 523 524 @func.FuncOp.from_py_func( 525 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 526 MemRefType.get((1, 2, 4, 1), i32)) 527 def pooling_on_buffers(input, shape, output): 528 linalg.pooling_nhwc_max( 529 input, 530 shape, 531 outs=[output], 532 strides=[2, 4], 533 dilations=[1, 2], 534 emit_generic=True) 535 536 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 537 538 # TODO: FFI-based solution to allow testing and printing with python code. 539 # Prepare arguments: one result i32. 540 # Arguments must be passed as pointers. 541 c_int_p = ctypes.c_int * 1 542 res = c_int_p(-1) 543 execution_engine.invoke("main", res) 544 545 log("RESULT: ", res[0]) 546 # 77 is not selected due to the dilation 2 in the second dimension. 547 # CHECK: RESULT: 42 548 549 550test_max_pooling_generic() 551 552 553def test_min_pooling_builtin(): 554 with Context() as ctx, Location.unknown(): 555 module = Module.create() 556 f64 = F64Type.get() 557 i32 = IntegerType.get_signless(32) 558 with InsertionPoint(module.body): 559 560 @func.FuncOp.from_py_func( 561 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 562 MemRefType.get((1, 2, 4, 1), i32)) 563 # Set the strides and use the default dilations. 564 def pooling_on_buffers(input, shape, output): 565 linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4]) 566 567 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 568 569 # TODO: FFI-based solution to allow testing and printing with python code. 570 # Prepare arguments: one result i32. 571 # Arguments must be passed as pointers. 572 c_int_p = ctypes.c_int * 1 573 res = c_int_p(-1) 574 execution_engine.invoke("main", res) 575 576 log("RESULT: ", res[0]) 577 # CHECK: RESULT: -13 578 579 580test_min_pooling_builtin() 581 582 583def test_min_pooling_generic(): 584 with Context() as ctx, Location.unknown(): 585 module = Module.create() 586 f64 = F64Type.get() 587 i32 = IntegerType.get_signless(32) 588 with InsertionPoint(module.body): 589 590 @func.FuncOp.from_py_func( 591 MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), 592 MemRefType.get((1, 2, 4, 1), i32)) 593 # Set the strides and use the default dilations. 594 def pooling_on_buffers(input, shape, output): 595 linalg.pooling_nhwc_min( 596 input, shape, outs=[output], strides=[2, 4], emit_generic=True) 597 598 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 599 600 # TODO: FFI-based solution to allow testing and printing with python code. 601 # Prepare arguments: one result i32. 602 # Arguments must be passed as pointers. 603 c_int_p = ctypes.c_int * 1 604 res = c_int_p(-1) 605 execution_engine.invoke("main", res) 606 607 log("RESULT: ", res[0]) 608 # CHECK: RESULT: -13 609 610 611test_min_pooling_generic() 612