1// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s
2
3// CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
4// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
5// CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
6// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
7// CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
8// CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
9
10func.func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
11          %y: memref<?xf32, offset: ?, strides: [1]>,
12          %v: memref<f32>) {
13  linalg.dot { __internal_linalg_transform__ = "MEM" }
14    ins(%x, %y: memref<?xf32, offset: ?, strides: [1]>,
15                memref<?xf32, offset: ?, strides: [1]>)
16    outs(%v: memref<f32>)
17
18  return
19}
20// CHECK-LABEL: func @dot
21// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
22// CHECK-DAG:     %[[c1:.*]] = arith.constant 1 : index
23// CHECK-DAG:     %[[c8000:.*]] = arith.constant 8000 : index
24// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
25// CHECK:             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] {
26// CHECK:               load
27// CHECK:               load
28// CHECK:               load
29// CHECK:               arith.mulf
30// CHECK:               arith.addf
31// CHECK:               store
32
33func.func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
34             %x: memref<?xf32, offset: ?, strides: [1]>,
35             %y: memref<?xf32, offset: ?, strides: [1]>) {
36  linalg.matvec
37    ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>,
38                memref<?xf32, offset: ?, strides: [1]>)
39    outs(%y: memref<?xf32, offset: ?, strides: [1]>)
40  return
41}
42// CHECK-LABEL: func @matvec
43// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
44// CHECK-DAG:     %[[c5:.*]] = arith.constant 5 : index
45// CHECK-DAG:     %[[c6:.*]] = arith.constant 6 : index
46// CHECK:         scf.parallel {{.*}} step (%[[c5]])
47// CHECK:           scf.for {{.*}} step %[[c6]]
48// CHECK:             linalg.matvec
49// CHECK:               ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?xf32, #[[$STRIDED_1D]]>)
50// CHECK:              outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>)
51
52func.func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
53             %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
54             %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
55  linalg.matmul { __internal_linalg_transform__ = "MEM" }
56    ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
57                memref<?x?xf32, offset: ?, strides: [?, 1]>)
58    outs(%C: memref<?x?xf32, offset: ?, strides: [?, 1]>)
59  return
60}
61// CHECK-LABEL: func @matmul
62// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
63// CHECK-DAG:     %[[c2:.*]] = arith.constant 2 : index
64// CHECK-DAG:     %[[c3:.*]] = arith.constant 3 : index
65// CHECK-DAG:     %[[c4:.*]] = arith.constant 4 : index
66// CHECK-DAG:     %[[c20:.*]] = arith.constant 20 : index
67// CHECK-DAG:     %[[c30:.*]] = arith.constant 30 : index
68// CHECK-DAG:     %[[c40:.*]] = arith.constant 40 : index
69// CHECK-DAG:     %[[c200:.*]] = arith.constant 200 : index
70// CHECK-DAG:     %[[c300:.*]] = arith.constant 300 : index
71// CHECK-DAG:     %[[c400:.*]] = arith.constant 400 : index
72// CHECK-DAG:     %[[c2000:.*]] = arith.constant 2000 : index
73// CHECK-DAG:     %[[c3000:.*]] = arith.constant 3000 : index
74// CHECK-DAG:     %[[c4000:.*]] = arith.constant 4000 : index
75// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
76// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
77// CHECK:             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
78// CHECK:               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
79// CHECK:                 scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
80// CHECK:                   scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
81// CHECK:                     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
82// CHECK:                       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
83// CHECK:                         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
84// CHECK:                           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
85// CHECK:                             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
86// CHECK:                               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
87// CHECK:                                 linalg.matmul
88// CHECK:                                   ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?x?xf32, #[[$STRIDED_2D_u_1]]>)
89// CHECK:                                  outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>)
90
91#matmul_accesses = [
92  affine_map<(m, n, k) -> (m, k)>,
93  affine_map<(m, n, k) -> (k, n)>,
94  affine_map<(m, n, k) -> (m, n)>
95]
96#generic_matmul_trait = {
97  args_in = 2,
98  args_out = 1,
99  indexing_maps = #matmul_accesses,
100  library_call = "linalg_matmul",
101  iterator_types = ["parallel", "parallel", "reduction"]
102}
103func.func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
104           %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
105           %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
106  linalg.generic #generic_matmul_trait
107    ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
108                 memref<?x?xf32, offset: ?, strides: [?, 1]>)
109   outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
110    ^bb(%a: f32, %b: f32, %c: f32):
111      %d = arith.mulf %a, %b: f32
112      %e = arith.addf %c, %d: f32
113      linalg.yield %e: f32
114  }
115  return
116}
117// CHECK-LABEL:  func @permute_generic
118// CHECK:        linalg.generic {
119// CHECK-SAME:   indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
120// CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"],
121// CHECK-SAME:   library_call = "linalg_matmul"}
122// CHECK:          memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
123// CHECK-SAME:     memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
124// CHECK-SAME:     memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
125
126func.func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
127             %x: memref<?xf32, offset: ?, strides: [1]>,
128             %y: memref<?xf32, offset: ?, strides: [1]>) {
129  linalg.matvec {__internal_linalg_transform__ = "__with_perm__"}
130    ins(%A, %x: memref<?x?xf32, offset: ?, strides: [?, 1]>,
131                memref<?xf32, offset: ?, strides: [1]>)
132   outs(%y: memref<?xf32, offset: ?, strides: [1]>)
133  return
134}
135// CHECK-LABEL: func @matvec_perm
136// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
137// CHECK-DAG:     %[[c5:.*]] = arith.constant 5 : index
138// CHECK-DAG:     %[[c6:.*]] = arith.constant 6 : index
139// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
140// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
141// CHECK:             linalg.matvec
142// CHECK:               ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?xf32, #[[$STRIDED_1D]]>)
143// CHECK:              outs({{.*}}: memref<?xf32, #[[$STRIDED_1D]]>)
144
145func.func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
146             %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
147             %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
148  linalg.matmul {__internal_linalg_transform__ = "__with_perm__"}
149    ins(%A, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
150                memref<?x?xf32, offset: ?, strides: [?, 1]>)
151   outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>)
152  return
153}
154// CHECK-LABEL: func @matmul_perm
155// CHECK-DAG:     %[[c0:.*]] = arith.constant 0 : index
156// CHECK-DAG:     %[[c20:.*]] = arith.constant 20 : index
157// CHECK-DAG:     %[[c30:.*]] = arith.constant 30 : index
158// CHECK-DAG:     %[[c40:.*]] = arith.constant 40 : index
159// CHECK-DAG:     %[[c200:.*]] = arith.constant 200 : index
160// CHECK-DAG:     %[[c300:.*]] = arith.constant 300 : index
161// CHECK-DAG:     %[[c400:.*]] = arith.constant 400 : index
162// CHECK-DAG:     %[[c2000:.*]] = arith.constant 2000 : index
163// CHECK-DAG:     %[[c3000:.*]] = arith.constant 3000 : index
164// CHECK-DAG:     %[[c4000:.*]] = arith.constant 4000 : index
165// CHECK:         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
166// CHECK:           scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
167// CHECK:             scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
168// CHECK:               scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
169// CHECK:                 scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
170// CHECK:                   scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
171// CHECK:                     scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
172// CHECK:                       scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
173// CHECK:                         scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
174// CHECK:                                 linalg.matmul
175// CHECK:                                  ins({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>, memref<?x?xf32, #[[$STRIDED_2D_u_1]]>)
176// CHECK:                                   outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>)
177
178func.func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
179                                 %arg1: memref<?x?xf32>,
180                                 %arg2: memref<?x?xf32>) {
181  linalg.matmul {__internal_linalg_transform__ = "par__with_perm__"}
182    ins(%arg0, %arg1: memref<?x?xf32>, memref<?x?xf32>)
183   outs(%arg2: memref<?x?xf32>)
184  return
185}
186// CHECK-LABEL: func @tile_permute_parallel_loop
187//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
188//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
189//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
190//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
191//   CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
192//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
193//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
194//   CHECK-DAG:   %[[D0:.*]] = memref.dim %[[ARG0]], %c0
195//   CHECK-DAG:   %[[D1:.*]] = memref.dim %[[ARG0]], %c1
196//   CHECK-DAG:   %[[D2:.*]] = memref.dim %[[ARG1]], %c1
197//       CHECK:   scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D2]]) step (%[[C8]])
198//       CHECK:     scf.for %{{.*}} = %[[C0]] to %[[D1]] step %[[C4]]
199//       CHECK:       scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D0]]) step (%[[C16]])
200