1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import ctypes
4import sys
5from mlir.ir import *
6from mlir.dialects import builtin
7from mlir.dialects import linalg
8from mlir.dialects import std
9from mlir.passmanager import *
10from mlir.execution_engine import *
11
12
13# Log everything to stderr and flush so that we have a unified stream to match
14# errors/info emitted by MLIR to stderr.
15def log(*args):
16  print(*args, file=sys.stderr)
17  sys.stderr.flush()
18
19
20matmul_boiler = """
21func @main() -> f32 attributes {llvm.emit_c_interface} {
22  %v0 = arith.constant 0.0 : f32
23  %v1 = arith.constant 1.0 : f32
24  %v2 = arith.constant 2.0 : f32
25
26  %A = memref.alloc() : memref<4x16xf32>
27  %B = memref.alloc() : memref<16x8xf32>
28  %C = memref.alloc() : memref<4x8xf32>
29  linalg.fill(%v1, %A) : f32, memref<4x16xf32>
30  linalg.fill(%v2, %B) : f32, memref<16x8xf32>
31  linalg.fill(%v0, %C) : f32, memref<4x8xf32>
32
33  call @matmul_on_buffers(%A, %B, %C) :
34    (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
35
36  %c0 = arith.constant 0 : index
37  %0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
38
39  // TODO: FFI-based solution to allow testing and printing with python code.
40  return %0 : f32
41}
42"""
43
44fill_boiler = """
45func @main() -> i32 attributes {llvm.emit_c_interface} {
46  %O0 = memref.alloc() : memref<i32>
47  %O1 = memref.alloc() : memref<16xi32>
48  %O2 = memref.alloc() : memref<4x16xi32>
49
50  %val0 = arith.constant 1.0 : f32
51  %val1 = arith.constant 2.0 : f32
52  %val2 = arith.constant 3.0 : f32
53
54  call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
55  call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
56  call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
57
58  %c0 = arith.constant 0 : index
59  %res0 = memref.load %O0[] : memref<i32>
60  %c8 = arith.constant 8 : index
61  %res1 = memref.load %O1[%c8] : memref<16xi32>
62  %c2 = arith.constant 2 : index
63  %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32>
64
65  %0 = arith.addi %res0, %res1 : i32
66  %1 = arith.addi %0, %res2 : i32
67
68  // TODO: FFI-based solution to allow testing and printing with python code.
69  return %1 : i32
70}
71"""
72
73fill_rng_boiler = """
74func @main() -> i32 attributes {llvm.emit_c_interface} {
75  %O = memref.alloc() : memref<4x16xi32>
76  %min = arith.constant -1000.0 : f64
77  %max = arith.constant 1000.0 : f64
78  %seed = arith.constant 42 : i32
79
80  call @fill_rng_on_buffers(%min, %max, %seed, %O) :
81    (f64, f64, i32, memref<4x16xi32>) -> ()
82
83  %c0 = arith.constant 0 : index
84  %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
85
86  // TODO: FFI-based solution to allow testing and printing with python code.
87  return %0 : i32
88}
89"""
90
91conv_boiler = """
92func @main() -> i32 attributes {llvm.emit_c_interface} {
93  %v0 = arith.constant 0 : i32
94  %v1 = arith.constant 1.0 : f64
95  %v2 = arith.constant 2.0 : f64
96
97  %input = memref.alloc() : memref<1x4x16x1xf64>
98  %filter = memref.alloc() : memref<2x2x1xf64>
99  %output = memref.alloc() : memref<1x2x4x1xi32>
100  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
101  linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64>
102  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
103
104  call @conv_on_buffers(%input, %filter, %output) :
105    (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
106
107  %c0 = arith.constant 0 : index
108  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
109
110  // TODO: FFI-based solution to allow testing and printing with python code.
111  return %0 : i32
112}
113"""
114
115pooling_boiler = """
116func @main() -> i32 attributes {llvm.emit_c_interface} {
117  %v0 = arith.constant 0 : i32
118  %v42 = arith.constant 42.0 : f64
119  %v77 = arith.constant 77.0 : f64
120  %v-13 = arith.constant -13.0 : f64
121  %v1 = arith.constant 1.0 : f64
122
123  %input = memref.alloc() : memref<1x4x16x1xf64>
124  %shape = memref.alloc() : memref<2x2xf64>
125  %output = memref.alloc() : memref<1x2x4x1xi32>
126  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
127  linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
128  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
129
130  %c0 = arith.constant 0 : index
131  %c1 = arith.constant 1 : index
132  %c2 = arith.constant 2 : index
133  memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
134  memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
135  memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
136
137  call @pooling_on_buffers(%input, %shape, %output) :
138    (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
139
140  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
141
142  // TODO: FFI-based solution to allow testing and printing with python code.
143  return %0 : i32
144}
145"""
146
147
148def transform(module, boilerplate):
149  import mlir.conversions
150  import mlir.all_passes_registration
151  import mlir.transforms
152
153  # TODO: Allow cloning functions from one module to another.
154  # Atm we have to resort to string concatenation.
155  ops = module.operation.regions[0].blocks[0].operations
156  mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
157
158  pm = PassManager.parse(
159      "builtin.func(convert-linalg-to-loops, lower-affine, " +
160      "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," +
161      "convert-memref-to-llvm, convert-std-to-llvm," +
162      "reconcile-unrealized-casts")
163  pm.run(mod)
164  return mod
165
166
167def test_matmul_builtin():
168  with Context() as ctx, Location.unknown():
169    module = Module.create()
170    f32 = F32Type.get()
171    with InsertionPoint(module.body):
172
173      @builtin.FuncOp.from_py_func(
174          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
175          MemRefType.get((4, 8), f32))
176      def matmul_on_buffers(lhs, rhs, out):
177        linalg.matmul(lhs, rhs, outs=[out])
178
179    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
180
181    # TODO: FFI-based solution to allow testing and printing with python code.
182    # Prepare arguments: one result f32.
183    # Arguments must be passed as pointers.
184    c_float_p = ctypes.c_float * 1
185    res = c_float_p(-1.)
186    execution_engine.invoke("main", res)
187
188    log("RESULT: ", res[0])
189    # CHECK: RESULT: 32.0
190
191
192test_matmul_builtin()
193
194
195def test_matmul_generic():
196  with Context() as ctx, Location.unknown():
197    module = Module.create()
198    f32 = F32Type.get()
199    with InsertionPoint(module.body):
200
201      @builtin.FuncOp.from_py_func(
202          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
203          MemRefType.get((4, 8), f32))
204      def matmul_on_buffers(lhs, rhs, out):
205        linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
206
207    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
208
209    # TODO: FFI-based solution to allow testing and printing with python code.
210    # Prepare arguments: one result f32.
211    # Arguments must be passed as pointers.
212    c_float_p = ctypes.c_float * 1
213    res = c_float_p(-1.)
214    execution_engine.invoke("main", res)
215
216    log("RESULT: ", res[0])
217    # CHECK: RESULT: 32.0
218
219
220test_matmul_generic()
221
222
223def test_fill_builtin():
224  with Context() as ctx, Location.unknown():
225    module = Module.create()
226    f32 = F32Type.get()
227    i32 = IntegerType.get_signless(32)
228    with InsertionPoint(module.body):
229
230      @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32))
231      def fill_0d_on_buffers(value, out):
232        linalg.fill_tensor(value, outs=[out])
233
234      @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
235      def fill_1d_on_buffers(value, out):
236        linalg.fill_tensor(value, outs=[out])
237
238      @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
239      def fill_2d_on_buffers(value, out):
240        linalg.fill_tensor(value, outs=[out])
241
242    execution_engine = ExecutionEngine(transform(module, fill_boiler))
243
244    # TODO: FFI-based solution to allow testing and printing with python code.
245    # Prepare arguments: one result i32.
246    # Arguments must be passed as pointers.
247    c_int_p = ctypes.c_int * 1
248    res = c_int_p(-1)
249    execution_engine.invoke("main", res)
250
251    log("RESULT: ", res[0])
252    # CHECK: RESULT: 6
253
254
255test_fill_builtin()
256
257
258def test_fill_generic():
259  with Context() as ctx, Location.unknown():
260    module = Module.create()
261    f32 = F32Type.get()
262    i32 = IntegerType.get_signless(32)
263    with InsertionPoint(module.body):
264
265      @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32))
266      def fill_0d_on_buffers(value, out):
267        linalg.fill_tensor(value, outs=[out], emit_generic=True)
268
269      @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
270      def fill_1d_on_buffers(value, out):
271        linalg.fill_tensor(value, outs=[out], emit_generic=True)
272
273      @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
274      def fill_2d_on_buffers(value, out):
275        linalg.fill_tensor(value, outs=[out], emit_generic=True)
276
277    execution_engine = ExecutionEngine(transform(module, fill_boiler))
278
279    # TODO: FFI-based solution to allow testing and printing with python code.
280    # Prepare arguments: one result i32.
281    # Arguments must be passed as pointers.
282    c_int_p = ctypes.c_int * 1
283    res = c_int_p(-1)
284    execution_engine.invoke("main", res)
285
286    log("RESULT: ", res[0])
287    # CHECK: RESULT: 6
288
289
290test_fill_generic()
291
292
293def test_fill_rng_builtin():
294  with Context() as ctx, Location.unknown():
295    module = Module.create()
296    f64 = F64Type.get()
297    i32 = IntegerType.get_signless(32)
298    with InsertionPoint(module.body):
299
300      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
301      def fill_rng_on_buffers(min, max, seed, out):
302        linalg.fill_rng_2d(min, max, seed, outs=[out])
303
304    execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
305
306    # TODO: FFI-based solution to allow testing and printing with python code.
307    # Prepare arguments: one result i32.
308    # Arguments must be passed as pointers.
309    c_int_p = ctypes.c_int * 1
310    res = c_int_p(-1)
311    execution_engine.invoke("main", res)
312
313    log("RESULT: ", res[0])
314    # CHECK: RESULT: -480
315
316
317test_fill_rng_builtin()
318
319
320def test_fill_rng_generic():
321  with Context() as ctx, Location.unknown():
322    module = Module.create()
323    f64 = F64Type.get()
324    i32 = IntegerType.get_signless(32)
325    with InsertionPoint(module.body):
326
327      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
328      def fill_rng_on_buffers(min, max, seed, out):
329        linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
330
331    execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
332
333    # TODO: FFI-based solution to allow testing and printing with python code.
334    # Prepare arguments: one result i32.
335    # Arguments must be passed as pointers.
336    c_int_p = ctypes.c_int * 1
337    res = c_int_p(-1)
338    execution_engine.invoke("main", res)
339
340    log("RESULT: ", res[0])
341    # CHECK: RESULT: -480
342
343
344test_fill_rng_generic()
345
346
347def test_max_pooling_builtin():
348  with Context() as ctx, Location.unknown():
349    module = Module.create()
350    f64 = F64Type.get()
351    i32 = IntegerType.get_signless(32)
352    with InsertionPoint(module.body):
353
354      @builtin.FuncOp.from_py_func(
355          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
356          MemRefType.get((1, 2, 4, 1), i32))
357      def pooling_on_buffers(input, shape, output):
358        linalg.pooling_nhwc_max(
359            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
360
361    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
362
363    # TODO: FFI-based solution to allow testing and printing with python code.
364    # Prepare arguments: one result i32.
365    # Arguments must be passed as pointers.
366    c_int_p = ctypes.c_int * 1
367    res = c_int_p(-1)
368    execution_engine.invoke("main", res)
369
370    log("RESULT: ", res[0])
371    # 77 is not selected due to the dilation 2 in the second dimension.
372    # CHECK: RESULT: 42
373
374
375test_max_pooling_builtin()
376
377
378def test_max_pooling_generic():
379  with Context() as ctx, Location.unknown():
380    module = Module.create()
381    f64 = F64Type.get()
382    i32 = IntegerType.get_signless(32)
383    with InsertionPoint(module.body):
384
385      @builtin.FuncOp.from_py_func(
386          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
387          MemRefType.get((1, 2, 4, 1), i32))
388      def pooling_on_buffers(input, shape, output):
389        linalg.pooling_nhwc_max(
390            input,
391            shape,
392            outs=[output],
393            strides=[2, 4],
394            dilations=[1, 2],
395            emit_generic=True)
396
397    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
398
399    # TODO: FFI-based solution to allow testing and printing with python code.
400    # Prepare arguments: one result i32.
401    # Arguments must be passed as pointers.
402    c_int_p = ctypes.c_int * 1
403    res = c_int_p(-1)
404    execution_engine.invoke("main", res)
405
406    log("RESULT: ", res[0])
407    # 77 is not selected due to the dilation 2 in the second dimension.
408    # CHECK: RESULT: 42
409
410
411test_max_pooling_generic()
412
413
414def test_min_pooling_builtin():
415  with Context() as ctx, Location.unknown():
416    module = Module.create()
417    f64 = F64Type.get()
418    i32 = IntegerType.get_signless(32)
419    with InsertionPoint(module.body):
420
421      @builtin.FuncOp.from_py_func(
422          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
423          MemRefType.get((1, 2, 4, 1), i32))
424      # Set the strides and use the default dilations.
425      def pooling_on_buffers(input, shape, output):
426        linalg.pooling_nhwc_min(
427            input,
428            shape,
429            outs=[output],
430            strides=[2, 4])
431
432    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
433
434    # TODO: FFI-based solution to allow testing and printing with python code.
435    # Prepare arguments: one result i32.
436    # Arguments must be passed as pointers.
437    c_int_p = ctypes.c_int * 1
438    res = c_int_p(-1)
439    execution_engine.invoke("main", res)
440
441    log("RESULT: ", res[0])
442    # CHECK: RESULT: -13
443
444
445test_min_pooling_builtin()
446
447
448def test_min_pooling_generic():
449  with Context() as ctx, Location.unknown():
450    module = Module.create()
451    f64 = F64Type.get()
452    i32 = IntegerType.get_signless(32)
453    with InsertionPoint(module.body):
454
455      @builtin.FuncOp.from_py_func(
456          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
457          MemRefType.get((1, 2, 4, 1), i32))
458      # Set the strides and use the default dilations.
459      def pooling_on_buffers(input, shape, output):
460        linalg.pooling_nhwc_min(
461            input,
462            shape,
463            outs=[output],
464            strides=[2, 4],
465            emit_generic=True)
466
467    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
468
469    # TODO: FFI-based solution to allow testing and printing with python code.
470    # Prepare arguments: one result i32.
471    # Arguments must be passed as pointers.
472    c_int_p = ctypes.c_int * 1
473    res = c_int_p(-1)
474    execution_engine.invoke("main", res)
475
476    log("RESULT: ", res[0])
477    # CHECK: RESULT: -13
478
479
480test_min_pooling_generic()
481