1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import ctypes
4import sys
5from mlir.ir import *
6from mlir.dialects import builtin
7from mlir.dialects import linalg
8from mlir.dialects import std
9from mlir.passmanager import *
10from mlir.execution_engine import *
11
12
13# Log everything to stderr and flush so that we have a unified stream to match
14# errors/info emitted by MLIR to stderr.
15def log(*args):
16  print(*args, file=sys.stderr)
17  sys.stderr.flush()
18
19
20matmul_boiler = """
21func @main() -> f32 attributes {llvm.emit_c_interface} {
22  %v0 = arith.constant 0.0 : f32
23  %v1 = arith.constant 1.0 : f32
24  %v2 = arith.constant 2.0 : f32
25
26  %A = memref.alloc() : memref<4x16xf32>
27  %B = memref.alloc() : memref<16x8xf32>
28  %C = memref.alloc() : memref<4x8xf32>
29  linalg.fill(%v1, %A) : f32, memref<4x16xf32>
30  linalg.fill(%v2, %B) : f32, memref<16x8xf32>
31  linalg.fill(%v0, %C) : f32, memref<4x8xf32>
32
33  call @matmul_on_buffers(%A, %B, %C) :
34    (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
35
36  %c0 = arith.constant 0 : index
37  %0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
38
39  // TODO: FFI-based solution to allow testing and printing with python code.
40  return %0 : f32
41}
42"""
43
44fill_boiler = """
45func @main() -> i32 attributes {llvm.emit_c_interface} {
46  %O = memref.alloc() : memref<4x16xi32>
47  %min = arith.constant -1000.0 : f64
48  %max = arith.constant 1000.0 : f64
49  %seed = arith.constant 42 : i32
50
51  call @fill_on_buffers(%min, %max, %seed, %O) :
52    (f64, f64, i32, memref<4x16xi32>) -> ()
53
54  %c0 = arith.constant 0 : index
55  %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
56
57  // TODO: FFI-based solution to allow testing and printing with python code.
58  return %0 : i32
59}
60"""
61
62conv_boiler = """
63func @main() -> i32 attributes {llvm.emit_c_interface} {
64  %v0 = arith.constant 0 : i32
65  %v1 = arith.constant 1.0 : f64
66  %v2 = arith.constant 2.0 : f64
67
68  %input = memref.alloc() : memref<1x4x16x1xf64>
69  %filter = memref.alloc() : memref<2x2x1xf64>
70  %output = memref.alloc() : memref<1x2x4x1xi32>
71  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
72  linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64>
73  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
74
75  call @conv_on_buffers(%input, %filter, %output) :
76    (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
77
78  %c0 = arith.constant 0 : index
79  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
80
81  // TODO: FFI-based solution to allow testing and printing with python code.
82  return %0 : i32
83}
84"""
85
86pooling_boiler = """
87func @main() -> i32 attributes {llvm.emit_c_interface} {
88  %v0 = arith.constant 0 : i32
89  %v42 = arith.constant 42.0 : f64
90  %v77 = arith.constant 77.0 : f64
91  %v-13 = arith.constant -13.0 : f64
92  %v1 = arith.constant 1.0 : f64
93
94  %input = memref.alloc() : memref<1x4x16x1xf64>
95  %shape = memref.alloc() : memref<2x2xf64>
96  %output = memref.alloc() : memref<1x2x4x1xi32>
97  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
98  linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
99  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
100
101  %c0 = arith.constant 0 : index
102  %c1 = arith.constant 1 : index
103  %c2 = arith.constant 2 : index
104  memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
105  memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
106  memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
107
108  call @pooling_on_buffers(%input, %shape, %output) :
109    (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
110
111  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
112
113  // TODO: FFI-based solution to allow testing and printing with python code.
114  return %0 : i32
115}
116"""
117
118
119def transform(module, boilerplate):
120  import mlir.conversions
121  import mlir.all_passes_registration
122  import mlir.transforms
123
124  # TODO: Allow cloning functions from one module to another.
125  # Atm we have to resort to string concatenation.
126  mod = Module.parse(
127      str(module.operation.regions[0].blocks[0].operations[0].operation) +
128      boilerplate)
129  pm = PassManager.parse(
130      "builtin.func(convert-linalg-to-loops, lower-affine, " +
131      "convert-scf-to-std, arith-expand, std-expand), convert-vector-to-llvm," +
132      "convert-memref-to-llvm, convert-std-to-llvm," +
133      "reconcile-unrealized-casts")
134  pm.run(mod)
135  return mod
136
137
138def test_matmul_builtin():
139  with Context() as ctx, Location.unknown():
140    module = Module.create()
141    f32 = F32Type.get()
142    with InsertionPoint(module.body):
143
144      @builtin.FuncOp.from_py_func(
145          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
146          MemRefType.get((4, 8), f32))
147      def matmul_on_buffers(lhs, rhs, out):
148        linalg.matmul(lhs, rhs, outs=[out])
149
150    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
151
152    # TODO: FFI-based solution to allow testing and printing with python code.
153    # Prepare arguments: one result f32.
154    # Arguments must be passed as pointers.
155    c_float_p = ctypes.c_float * 1
156    res = c_float_p(-1.)
157    execution_engine.invoke("main", res)
158
159    log("RESULT: ", res[0])
160    # CHECK: RESULT: 32.0
161
162
163test_matmul_builtin()
164
165
166def test_matmul_generic():
167  with Context() as ctx, Location.unknown():
168    module = Module.create()
169    f32 = F32Type.get()
170    with InsertionPoint(module.body):
171
172      @builtin.FuncOp.from_py_func(
173          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
174          MemRefType.get((4, 8), f32))
175      def matmul_on_buffers(lhs, rhs, out):
176        linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
177
178    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
179
180    # TODO: FFI-based solution to allow testing and printing with python code.
181    # Prepare arguments: one result f32.
182    # Arguments must be passed as pointers.
183    c_float_p = ctypes.c_float * 1
184    res = c_float_p(-1.)
185    execution_engine.invoke("main", res)
186
187    log("RESULT: ", res[0])
188    # CHECK: RESULT: 32.0
189
190
191test_matmul_generic()
192
193
194def test_fill_builtin():
195  with Context() as ctx, Location.unknown():
196    module = Module.create()
197    f64 = F64Type.get()
198    i32 = IntegerType.get_signless(32)
199    with InsertionPoint(module.body):
200
201      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
202      def fill_on_buffers(min, max, seed, out):
203        linalg.fill_rng_2d(min, max, seed, outs=[out])
204
205    execution_engine = ExecutionEngine(transform(module, fill_boiler))
206
207    # TODO: FFI-based solution to allow testing and printing with python code.
208    # Prepare arguments: one result i32.
209    # Arguments must be passed as pointers.
210    c_int_p = ctypes.c_int * 1
211    res = c_int_p(-1)
212    execution_engine.invoke("main", res)
213
214    log("RESULT: ", res[0])
215    # CHECK: RESULT: -480
216
217
218test_fill_builtin()
219
220
221def test_fill_generic():
222  with Context() as ctx, Location.unknown():
223    module = Module.create()
224    f64 = F64Type.get()
225    i32 = IntegerType.get_signless(32)
226    with InsertionPoint(module.body):
227
228      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
229      def fill_on_buffers(min, max, seed, out):
230        linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
231
232    execution_engine = ExecutionEngine(transform(module, fill_boiler))
233
234    # TODO: FFI-based solution to allow testing and printing with python code.
235    # Prepare arguments: one result i32.
236    # Arguments must be passed as pointers.
237    c_int_p = ctypes.c_int * 1
238    res = c_int_p(-1)
239    execution_engine.invoke("main", res)
240
241    log("RESULT: ", res[0])
242    # CHECK: RESULT: -480
243
244
245test_fill_generic()
246
247
248def test_max_pooling_builtin():
249  with Context() as ctx, Location.unknown():
250    module = Module.create()
251    f64 = F64Type.get()
252    i32 = IntegerType.get_signless(32)
253    with InsertionPoint(module.body):
254
255      @builtin.FuncOp.from_py_func(
256          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
257          MemRefType.get((1, 2, 4, 1), i32))
258      def pooling_on_buffers(input, shape, output):
259        linalg.pooling_nhwc_max(
260            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
261
262    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
263
264    # TODO: FFI-based solution to allow testing and printing with python code.
265    # Prepare arguments: one result i32.
266    # Arguments must be passed as pointers.
267    c_int_p = ctypes.c_int * 1
268    res = c_int_p(-1)
269    execution_engine.invoke("main", res)
270
271    log("RESULT: ", res[0])
272    # 77 is not selected due to the dilation 2 in the second dimension.
273    # CHECK: RESULT: 42
274
275
276test_max_pooling_builtin()
277
278
279def test_max_pooling_generic():
280  with Context() as ctx, Location.unknown():
281    module = Module.create()
282    f64 = F64Type.get()
283    i32 = IntegerType.get_signless(32)
284    with InsertionPoint(module.body):
285
286      @builtin.FuncOp.from_py_func(
287          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
288          MemRefType.get((1, 2, 4, 1), i32))
289      def pooling_on_buffers(input, shape, output):
290        linalg.pooling_nhwc_max(
291            input,
292            shape,
293            outs=[output],
294            strides=[2, 4],
295            dilations=[1, 2],
296            emit_generic=True)
297
298    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
299
300    # TODO: FFI-based solution to allow testing and printing with python code.
301    # Prepare arguments: one result i32.
302    # Arguments must be passed as pointers.
303    c_int_p = ctypes.c_int * 1
304    res = c_int_p(-1)
305    execution_engine.invoke("main", res)
306
307    log("RESULT: ", res[0])
308    # 77 is not selected due to the dilation 2 in the second dimension.
309    # CHECK: RESULT: 42
310
311
312test_max_pooling_generic()
313
314
315def test_min_pooling_builtin():
316  with Context() as ctx, Location.unknown():
317    module = Module.create()
318    f64 = F64Type.get()
319    i32 = IntegerType.get_signless(32)
320    with InsertionPoint(module.body):
321
322      @builtin.FuncOp.from_py_func(
323          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
324          MemRefType.get((1, 2, 4, 1), i32))
325      def pooling_on_buffers(input, shape, output):
326        linalg.pooling_nhwc_min(
327            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
328
329    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
330
331    # TODO: FFI-based solution to allow testing and printing with python code.
332    # Prepare arguments: one result i32.
333    # Arguments must be passed as pointers.
334    c_int_p = ctypes.c_int * 1
335    res = c_int_p(-1)
336    execution_engine.invoke("main", res)
337
338    log("RESULT: ", res[0])
339    # CHECK: RESULT: -13
340
341
342test_min_pooling_builtin()
343
344
345def test_min_pooling_generic():
346  with Context() as ctx, Location.unknown():
347    module = Module.create()
348    f64 = F64Type.get()
349    i32 = IntegerType.get_signless(32)
350    with InsertionPoint(module.body):
351
352      @builtin.FuncOp.from_py_func(
353          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
354          MemRefType.get((1, 2, 4, 1), i32))
355      def pooling_on_buffers(input, shape, output):
356        linalg.pooling_nhwc_min(
357            input,
358            shape,
359            outs=[output],
360            strides=[2, 4],
361            dilations=[1, 2],
362            emit_generic=True)
363
364    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
365
366    # TODO: FFI-based solution to allow testing and printing with python code.
367    # Prepare arguments: one result i32.
368    # Arguments must be passed as pointers.
369    c_int_p = ctypes.c_int * 1
370    res = c_int_p(-1)
371    execution_engine.invoke("main", res)
372
373    log("RESULT: ", res[0])
374    # CHECK: RESULT: -13
375
376
377test_min_pooling_generic()
378