1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import builtin
5from mlir.dialects import func
6from mlir.dialects import linalg
7
8from mlir.dialects.linalg.opdsl.lang import *
9
10T1 = TV.T1
11T2 = TV.T2
12
13
14@linalg_structured_op
15def pooling_poly(
16    I=TensorDef(T1, S.N, S.H, S.W, S.C),
17    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
18    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
19    reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
20    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
21    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
22    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
23  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
24  O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
25      cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
26                D.c]))
27
28
29with Context() as ctx, Location.unknown():
30  module = Module.create()
31  f32 = F32Type.get()
32  i32 = IntegerType.get_signless(32)
33  with InsertionPoint(module.body):
34
35    # Pooling indexing maps.
36    # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
37    # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
38    # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
39
40    # CHECK-LABEL: @test_f32i32_max_pooling
41    # CHECK: linalg.generic
42    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
43    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
44    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
45    # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
46    # CHECK-NEXT:   %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
47    # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
48    # CHECK-NEXT: -> tensor<1x2x4x1xi32>
49    @func.FuncOp.from_py_func(
50        RankedTensorType.get((1, 4, 16, 1), f32),
51        RankedTensorType.get((2, 2), f32),
52        RankedTensorType.get((1, 2, 4, 1), i32))
53    def test_f32i32_max_pooling(input, shape, init_result):
54      return pooling_poly(
55          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
56
57    # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
58    # CHECK:   = arith.fptoui
59    # CHECK:   = arith.maxui
60    @func.FuncOp.from_py_func(
61        RankedTensorType.get((1, 4, 16, 1), f32),
62        RankedTensorType.get((2, 2), f32),
63        RankedTensorType.get((1, 2, 4, 1), i32))
64    def test_f32i32_max_unsigned_pooling(input, shape, init_result):
65      return pooling_poly(
66          input,
67          shape,
68          outs=[init_result],
69          reduce=BinaryFn.max_unsigned,
70          cast=TypeFn.cast_unsigned,
71          strides=[2, 4],
72          dilations=[1, 2])
73
74    # CHECK-LABEL: @test_f32f32_max_pooling
75    # CHECK: linalg.generic
76    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
77    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
78    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
79    # CHECK-NEXT:   %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
80    # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
81    # CHECK-NEXT: -> tensor<1x2x4x1xf32>
82    @func.FuncOp.from_py_func(
83        RankedTensorType.get((1, 4, 16, 1), f32),
84        RankedTensorType.get((2, 2), f32),
85        RankedTensorType.get((1, 2, 4, 1), f32))
86    def test_f32f32_max_pooling(input, shape, init_result):
87      return pooling_poly(
88          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
89
90    # CHECK-LABEL: @test_f32i32_min_pooling
91    # CHECK:   = arith.fptosi
92    # CHECK:   = arith.minsi
93    @func.FuncOp.from_py_func(
94        RankedTensorType.get((1, 4, 16, 1), f32),
95        RankedTensorType.get((2, 2), f32),
96        RankedTensorType.get((1, 2, 4, 1), i32))
97    def test_f32i32_min_pooling(input, shape, init_result):
98      return pooling_poly(
99          input,
100          shape,
101          outs=[init_result],
102          reduce=BinaryFn.min_signed,
103          strides=[2, 4],
104          dilations=[1, 2])
105
106    # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
107    # CHECK:   = arith.fptoui
108    # CHECK:   = arith.minui
109    @func.FuncOp.from_py_func(
110        RankedTensorType.get((1, 4, 16, 1), f32),
111        RankedTensorType.get((2, 2), f32),
112        RankedTensorType.get((1, 2, 4, 1), i32))
113    def test_f32i32_min_unsigned_pooling(input, shape, init_result):
114      return pooling_poly(
115          input,
116          shape,
117          outs=[init_result],
118          reduce=BinaryFn.min_unsigned,
119          cast=TypeFn.cast_unsigned,
120          strides=[2, 4],
121          dilations=[1, 2])
122
123    # CHECK-LABEL: @test_f32f32_min_pooling
124    # CHECK:   = arith.minf
125    @func.FuncOp.from_py_func(
126        RankedTensorType.get((1, 4, 16, 1), f32),
127        RankedTensorType.get((2, 2), f32),
128        RankedTensorType.get((1, 2, 4, 1), f32))
129    def test_f32f32_min_pooling(input, shape, init_result):
130      return pooling_poly(
131          input,
132          shape,
133          outs=[init_result],
134          reduce=BinaryFn.min_signed,
135          strides=[2, 4],
136          dilations=[1, 2])
137
138
139print(module)
140