1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import ctypes
4import sys
5from mlir.ir import *
6from mlir.dialects import builtin
7from mlir.dialects import linalg
8from mlir.dialects import std
9from mlir.passmanager import *
10from mlir.execution_engine import *
11
12
13# Log everything to stderr and flush so that we have a unified stream to match
14# errors/info emitted by MLIR to stderr.
15def log(*args):
16  print(*args, file=sys.stderr)
17  sys.stderr.flush()
18
19
20matmul_boiler = """
21func @main() -> f32 attributes {llvm.emit_c_interface} {
22  %v0 = constant 0.0 : f32
23  %v1 = constant 1.0 : f32
24  %v2 = constant 2.0 : f32
25
26  %A = memref.alloc() : memref<4x16xf32>
27  %B = memref.alloc() : memref<16x8xf32>
28  %C = memref.alloc() : memref<4x8xf32>
29  linalg.fill(%v1, %A) : f32, memref<4x16xf32>
30  linalg.fill(%v2, %B) : f32, memref<16x8xf32>
31  linalg.fill(%v0, %C) : f32, memref<4x8xf32>
32
33  call @matmul_on_buffers(%A, %B, %C) :
34    (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
35
36  %c0 = constant 0 : index
37  %0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
38
39  // TODO: FFI-based solution to allow testing and printing with python code.
40  return %0 : f32
41}
42"""
43
44fill_boiler = """
45func @main() -> i32 attributes {llvm.emit_c_interface} {
46  %O = memref.alloc() : memref<4x16xi32>
47  %min = constant -1000.0 : f64
48  %max = constant 1000.0 : f64
49  %seed = constant 42 : i32
50
51  call @fill_on_buffers(%min, %max, %seed, %O) :
52    (f64, f64, i32, memref<4x16xi32>) -> ()
53
54  %c0 = constant 0 : index
55  %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
56
57  // TODO: FFI-based solution to allow testing and printing with python code.
58  return %0 : i32
59}
60"""
61
62conv_boiler = """
63func @main() -> i32 attributes {llvm.emit_c_interface} {
64  %v0 = constant 0 : i32
65  %v1 = constant 1.0 : f64
66  %v2 = constant 2.0 : f64
67
68  %input = memref.alloc() : memref<1x4x16x1xf64>
69  %filter = memref.alloc() : memref<2x2x1xf64>
70  %output = memref.alloc() : memref<1x2x4x1xi32>
71  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
72  linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64>
73  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
74
75  call @conv_on_buffers(%input, %filter, %output) :
76    (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
77
78  %c0 = constant 0 : index
79  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
80
81  // TODO: FFI-based solution to allow testing and printing with python code.
82  return %0 : i32
83}
84"""
85
86pooling_boiler = """
87func @main() -> i32 attributes {llvm.emit_c_interface} {
88  %v0 = constant 0 : i32
89  %v42 = constant 42.0 : f64
90  %v77 = constant 77.0 : f64
91  %v-13 = constant -13.0 : f64
92  %v1 = constant 1.0 : f64
93
94  %input = memref.alloc() : memref<1x4x16x1xf64>
95  %shape = memref.alloc() : memref<2x2xf64>
96  %output = memref.alloc() : memref<1x2x4x1xi32>
97  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
98  linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
99  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
100
101  %c0 = constant 0 : index
102  %c1 = constant 1 : index
103  %c2 = constant 2 : index
104  memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
105  memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
106  memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
107
108  call @pooling_on_buffers(%input, %shape, %output) :
109    (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
110
111  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
112
113  // TODO: FFI-based solution to allow testing and printing with python code.
114  return %0 : i32
115}
116"""
117
118
119def transform(module, boilerplate):
120  import mlir.conversions
121  import mlir.dialects.linalg.passes
122  import mlir.transforms
123
124  # TODO: Allow cloning functions from one module to another.
125  # Atm we have to resort to string concatenation.
126  mod = Module.parse(
127      str(module.operation.regions[0].blocks[0].operations[0].operation) +
128      boilerplate)
129  pm = PassManager.parse(
130      "builtin.func(convert-linalg-to-loops, lower-affine, " +
131      "convert-scf-to-std), convert-vector-to-llvm," +
132      "convert-memref-to-llvm,convert-std-to-llvm")
133  pm.run(mod)
134  return mod
135
136
137def test_matmul_builtin():
138  with Context() as ctx, Location.unknown():
139    module = Module.create()
140    f32 = F32Type.get()
141    with InsertionPoint(module.body):
142
143      @builtin.FuncOp.from_py_func(
144          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
145          MemRefType.get((4, 8), f32))
146      def matmul_on_buffers(lhs, rhs, out):
147        linalg.matmul(lhs, rhs, outs=[out])
148
149    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
150
151    # TODO: FFI-based solution to allow testing and printing with python code.
152    # Prepare arguments: one result f32.
153    # Arguments must be passed as pointers.
154    c_float_p = ctypes.c_float * 1
155    res = c_float_p(-1.)
156    execution_engine.invoke("main", res)
157
158    log("RESULT: ", res[0])
159    # CHECK: RESULT: 32.0
160
161
162test_matmul_builtin()
163
164
165def test_matmul_generic():
166  with Context() as ctx, Location.unknown():
167    module = Module.create()
168    f32 = F32Type.get()
169    with InsertionPoint(module.body):
170
171      @builtin.FuncOp.from_py_func(
172          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
173          MemRefType.get((4, 8), f32))
174      def matmul_on_buffers(lhs, rhs, out):
175        linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
176
177    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
178
179    # TODO: FFI-based solution to allow testing and printing with python code.
180    # Prepare arguments: one result f32.
181    # Arguments must be passed as pointers.
182    c_float_p = ctypes.c_float * 1
183    res = c_float_p(-1.)
184    execution_engine.invoke("main", res)
185
186    log("RESULT: ", res[0])
187    # CHECK: RESULT: 32.0
188
189
190test_matmul_generic()
191
192
193def test_fill_builtin():
194  with Context() as ctx, Location.unknown():
195    module = Module.create()
196    f64 = F64Type.get()
197    i32 = IntegerType.get_signless(32)
198    with InsertionPoint(module.body):
199
200      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
201      def fill_on_buffers(min, max, seed, out):
202        linalg.fill_rng_2d(min, max, seed, outs=[out])
203
204    execution_engine = ExecutionEngine(transform(module, fill_boiler))
205
206    # TODO: FFI-based solution to allow testing and printing with python code.
207    # Prepare arguments: one result i32.
208    # Arguments must be passed as pointers.
209    c_int_p = ctypes.c_int * 1
210    res = c_int_p(-1)
211    execution_engine.invoke("main", res)
212
213    log("RESULT: ", res[0])
214    # CHECK: RESULT: -480
215
216
217test_fill_builtin()
218
219
220def test_fill_generic():
221  with Context() as ctx, Location.unknown():
222    module = Module.create()
223    f64 = F64Type.get()
224    i32 = IntegerType.get_signless(32)
225    with InsertionPoint(module.body):
226
227      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
228      def fill_on_buffers(min, max, seed, out):
229        linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
230
231    execution_engine = ExecutionEngine(transform(module, fill_boiler))
232
233    # TODO: FFI-based solution to allow testing and printing with python code.
234    # Prepare arguments: one result i32.
235    # Arguments must be passed as pointers.
236    c_int_p = ctypes.c_int * 1
237    res = c_int_p(-1)
238    execution_engine.invoke("main", res)
239
240    log("RESULT: ", res[0])
241    # CHECK: RESULT: -480
242
243
244test_fill_generic()
245
246
247def test_max_pooling_builtin():
248  with Context() as ctx, Location.unknown():
249    module = Module.create()
250    f64 = F64Type.get()
251    i32 = IntegerType.get_signless(32)
252    with InsertionPoint(module.body):
253
254      @builtin.FuncOp.from_py_func(
255          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
256          MemRefType.get((1, 2, 4, 1), i32))
257      def pooling_on_buffers(input, shape, output):
258        linalg.pooling_nhwc_max(
259            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
260
261    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
262
263    # TODO: FFI-based solution to allow testing and printing with python code.
264    # Prepare arguments: one result i32.
265    # Arguments must be passed as pointers.
266    c_int_p = ctypes.c_int * 1
267    res = c_int_p(-1)
268    execution_engine.invoke("main", res)
269
270    log("RESULT: ", res[0])
271    # 77 is not selected due to the dilation 2 in the second dimension.
272    # CHECK: RESULT: 42
273
274
275test_max_pooling_builtin()
276
277
278def test_max_pooling_generic():
279  with Context() as ctx, Location.unknown():
280    module = Module.create()
281    f64 = F64Type.get()
282    i32 = IntegerType.get_signless(32)
283    with InsertionPoint(module.body):
284
285      @builtin.FuncOp.from_py_func(
286          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
287          MemRefType.get((1, 2, 4, 1), i32))
288      def pooling_on_buffers(input, shape, output):
289        linalg.pooling_nhwc_max(
290            input,
291            shape,
292            outs=[output],
293            strides=[2, 4],
294            dilations=[1, 2],
295            emit_generic=True)
296
297    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
298
299    # TODO: FFI-based solution to allow testing and printing with python code.
300    # Prepare arguments: one result i32.
301    # Arguments must be passed as pointers.
302    c_int_p = ctypes.c_int * 1
303    res = c_int_p(-1)
304    execution_engine.invoke("main", res)
305
306    log("RESULT: ", res[0])
307    # 77 is not selected due to the dilation 2 in the second dimension.
308    # CHECK: RESULT: 42
309
310
311test_max_pooling_generic()
312
313
314def test_min_pooling_builtin():
315  with Context() as ctx, Location.unknown():
316    module = Module.create()
317    f64 = F64Type.get()
318    i32 = IntegerType.get_signless(32)
319    with InsertionPoint(module.body):
320
321      @builtin.FuncOp.from_py_func(
322          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
323          MemRefType.get((1, 2, 4, 1), i32))
324      def pooling_on_buffers(input, shape, output):
325        linalg.pooling_nhwc_min(
326            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
327
328    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
329
330    # TODO: FFI-based solution to allow testing and printing with python code.
331    # Prepare arguments: one result i32.
332    # Arguments must be passed as pointers.
333    c_int_p = ctypes.c_int * 1
334    res = c_int_p(-1)
335    execution_engine.invoke("main", res)
336
337    log("RESULT: ", res[0])
338    # CHECK: RESULT: -13
339
340
341test_min_pooling_builtin()
342
343
344def test_min_pooling_generic():
345  with Context() as ctx, Location.unknown():
346    module = Module.create()
347    f64 = F64Type.get()
348    i32 = IntegerType.get_signless(32)
349    with InsertionPoint(module.body):
350
351      @builtin.FuncOp.from_py_func(
352          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
353          MemRefType.get((1, 2, 4, 1), i32))
354      def pooling_on_buffers(input, shape, output):
355        linalg.pooling_nhwc_min(
356            input,
357            shape,
358            outs=[output],
359            strides=[2, 4],
360            dilations=[1, 2],
361            emit_generic=True)
362
363    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
364
365    # TODO: FFI-based solution to allow testing and printing with python code.
366    # Prepare arguments: one result i32.
367    # Arguments must be passed as pointers.
368    c_int_p = ctypes.c_int * 1
369    res = c_int_p(-1)
370    execution_engine.invoke("main", res)
371
372    log("RESULT: ", res[0])
373    # CHECK: RESULT: -13
374
375
376test_min_pooling_generic()
377