1// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s
2// RUN: mlir-opt %s --test-transform-dialect-interpreter --canonicalize --split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CANON
3
4transform.with_pdl_patterns {
5^bb0(%arg0: !pdl.operation):
6  transform.sequence %arg0 {
7  ^bb1(%arg1: !pdl.operation):
8    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
9    %1:2 = transform.structured.split %0 after 42 { dimension = 0 }
10  }
11}
12
13func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
14
15// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)>
16// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)>
17
18// CHECK-LABEL: @one_d_static
19// CHECK-SAME:  %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32>
20func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
21  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
22  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
23  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
24  // CHECK:   ins(%[[IN_SLICE_LOW]]
25  // CHECK:   outs(%[[OUT_SLICE_LOW]]
26  // CHECK:   linalg.index 0
27  // CHECK:   func.call @elem
28  // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1]
29  //
30  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
31  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
32  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
33  // CHECK:   ins(%[[IN_SLICE_HIGH]]
34  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
35  // CHECK:   %[[IDX:.+]] = linalg.index 0
36  // CHECK:   affine.apply #[[$ADD_42_MAP]](%[[IDX]])
37  // CHECK:   func.call @elem
38  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1]
39  %0 = linalg.generic {
40    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
41    iterator_types = ["parallel"]
42  }
43  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
44  ^bb0(%0: f32, %1: f32):
45    %i = linalg.index 0 : index
46    %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
47    linalg.yield %call_res : f32
48  } -> tensor<100xf32>
49
50  // CHECK: return %[[RES]]
51  return %0 : tensor<100xf32>
52}
53
54// CHECK-LABEL: @one_d_static_overflow
55// CHECK-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
56// CANON-LABEL:  @one_d_static_overflow
57// CANON-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
58func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
59  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
60  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
61  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
62  // CHECK:   ins(%[[IN_SLICE_LOW]]
63  // CHECK:   outs(%[[OUT_SLICE_LOW]]
64  // CHECK:   linalg.index 0
65  // CHECK:   func.call @elem
66  // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1]
67  //
68  // Due to overflow, the first part of the split computes everything and the
69  // insert/extract slices are folded away by the canonicalizer.
70  // CANON: %[[RES_PARTIAL:.+]] = linalg.generic
71  // CANON:   ins(%[[IN]]
72  // CANON:   outs(%[[OUT]]
73  // CANON:   linalg.index 0
74  // CANON:   func.call @elem
75  // The second part operates on zero-sized slices that are not currently
76  // folded away.
77  //
78  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
79  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
80  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
81  // CHECK:   ins(%[[IN_SLICE_HIGH]]
82  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
83  // CHECK:   %[[IDX:.+]] = linalg.index 0
84  // CHECK:   affine.apply #[[$ADD_10_MAP]](%[[IDX]])
85  // CHECK:   func.call @elem
86  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1]
87  %0 = linalg.generic {
88    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
89    iterator_types = ["parallel"]
90  }
91  ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) {
92  ^bb0(%0: f32, %1: f32):
93    %i = linalg.index 0 : index
94    %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
95    linalg.yield %call_res : f32
96  } -> tensor<10xf32>
97  return %0 : tensor<10xf32>
98}
99
100// -----
101
102transform.with_pdl_patterns {
103^bb0(%arg0: !pdl.operation):
104  transform.sequence %arg0 {
105  ^bb1(%arg1: !pdl.operation):
106    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
107    %1 = transform.structured.match ops{["func.call"]} in %arg1
108    transform.structured.split %0 after %1 { dimension = 0 }
109  }
110}
111
112func.func private @get_size() -> index
113
114// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (s0, 100)>
115// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)>
116
117// CHECK-LABEL: @dynamic
118func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
119  // CHECK: %[[SPLIT:.+]] = call @get_size
120  // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]]
121  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
122  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
123  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
124  // CHECK:   ins(%[[IN_SLICE_LOW]]
125  // CHECK:   outs(%[[OUT_SLICE_LOW]]
126  // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1]
127  //
128  // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
129  // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
130  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32>
131  // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
132  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32>
133  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
134  // CHECK:   ins(%[[IN_SLICE_HIGH]]
135  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
136  // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1]
137  %0 = func.call @get_size() : () -> index
138  %1 = linalg.generic {
139    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
140    iterator_types = ["parallel"]
141  }
142  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
143  ^bb0(%3: f32, %4: f32):
144    %5 = arith.addf %3, %4 : f32
145    linalg.yield %5 : f32
146  } -> tensor<100xf32>
147  return %1 : tensor<100xf32>
148}
149
150// -----
151
152transform.with_pdl_patterns {
153^bb0(%arg0: !pdl.operation):
154  transform.sequence %arg0 {
155  ^bb1(%arg1: !pdl.operation):
156    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
157    %1:2 = transform.structured.split %0 after 4 { dimension = 0}
158    %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 }
159  }
160}
161
162func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
163
164// CHECK-LABEL: @two_d
165func.func @two_d(%arg0: tensor<10x34xf32>,
166                 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
167  // Check the overall structure: split along the dimension 0, and then split
168  // the second half only along the dimension 1.
169  // CHECK:      %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0]
170  // CHECK:      %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0]
171  // CHECK:      %[[RES_1:.+]] = linalg.generic
172  // CHECK-SAME:   ins(%[[IN_1]] : tensor<4x34xf32>)
173  // CHECK-SAME:   outs(%[[OUT_1]] : tensor<4x34xf32>)
174  // CHECK:      %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]]
175  //
176  // CHECK:      %[[IN_2:.+]] = tensor.extract_slice %[[IN]]
177  // CHECK:      %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]]
178  // CHECK:      %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]]
179  // CHECK:      %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]]
180  // CHECK:      %[[RES_21:.+]] = linalg.generic
181  // CHECK-SAME:   ins(%[[IN_21]] : tensor<6x16xf32>)
182  // CHECK-SAME:   outs(%[[OUT_21]] : tensor<6x16xf32>)
183  // CHECK:      %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]]
184  //
185  // CHECK:      %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]]
186  // CHECK:      %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]]
187  // CHECK:      %[[RES_22:.+]] = linalg.generic
188  // CHECK-SAME:   ins(%[[IN_22]] : tensor<6x18xf32>)
189  // CHECK-SAME:   outs(%[[OUT_22]] : tensor<6x18xf32>)
190  // CHECK:      %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]]
191  // CHECK:      %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]]
192  %0 = linalg.generic {
193    indexing_maps = [affine_map<(i, j) -> (i, j)>,
194                     affine_map<(i, j) -> (i, j)>],
195    iterator_types = ["parallel", "parallel"]
196  }
197  ins(%arg0: tensor<10x34xf32>)
198  outs(%arg1: tensor<10x34xf32>) {
199  ^bb0(%0: f32, %1: f32):
200    %i = linalg.index 0 : index
201    %j = linalg.index 1 : index
202    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
203    linalg.yield %call_res : f32
204  } -> tensor<10x34xf32>
205  return %0 : tensor<10x34xf32>
206}
207
208// -----
209
210transform.sequence {
211^bb1(%arg1: !pdl.operation):
212  // expected-error @below {{expects either a dynamic or a static split point to be provided}}
213  %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -1 } : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
214}
215
216// -----
217
218transform.with_pdl_patterns {
219^bb0(%arg0: !pdl.operation):
220  transform.sequence %arg0 {
221  ^bb1(%arg1: !pdl.operation):
222    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
223    %1 = transform.structured.match ops{["func.call"]} in %arg1
224    // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
225    transform.structured.split %0 after %1 { dimension = 0 }
226  }
227}
228
229func.func private @get_size() -> i64
230
231func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
232  // expected-note @below {{dynamic split point}}
233  %0 = func.call @get_size() : () -> i64
234  %1 = linalg.generic {
235    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
236    iterator_types = ["parallel"]
237  }
238  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
239  ^bb0(%3: f32, %4: f32):
240    linalg.yield %3 : f32
241  } -> tensor<100xf32>
242  return %1 : tensor<100xf32>
243}
244
245// -----
246
247transform.with_pdl_patterns {
248^bb0(%arg0: !pdl.operation):
249  transform.sequence %arg0 {
250  ^bb1(%arg1: !pdl.operation):
251    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
252    %1 = transform.structured.match ops{["func.call"]} in %arg1
253    // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
254    transform.structured.split %0 after %1 { dimension = 0 }
255  }
256}
257
258func.func private @get_size() -> i64
259
260func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
261  %1 = linalg.generic {
262    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
263    iterator_types = ["parallel"]
264  }
265  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
266  ^bb0(%3: f32, %4: f32):
267    linalg.yield %3 : f32
268  } -> tensor<100xf32>
269  return %1 : tensor<100xf32>
270}
271
272// -----
273
274transform.with_pdl_patterns {
275^bb0(%arg0: !pdl.operation):
276  pdl.pattern @func_return : benefit(1) {
277    %0 = pdl.operands
278    %1 = pdl.types
279    %2 = pdl.operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
280    pdl.rewrite %2 with "transform.dialect"
281  }
282
283  transform.sequence %arg0 {
284  ^bb1(%arg1: !pdl.operation):
285    %0 = transform.structured.match ops{["func.return"]} in %arg1
286    // expected-error @below {{only applies to structured ops}}
287    transform.structured.split %0 after 16 { dimension = 1 }
288  }
289}
290
291func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
292  // expected-note @below {{target op}}
293  return %arg0 : tensor<100xf32>
294}
295
296// -----
297
298transform.with_pdl_patterns {
299^bb0(%arg0: !pdl.operation):
300  pdl.pattern @linalg_generic : benefit(1) {
301    %0 = pdl.operands
302    %1 = pdl.types
303    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
304    pdl.rewrite %2 with "transform.dialect"
305  }
306
307  transform.sequence %arg0 {
308  ^bb1(%arg1: !pdl.operation):
309    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
310    // expected-error @below {{dimension 1 does not exist in target op}}
311    transform.structured.split %0 after 16 { dimension = 1 }
312  }
313}
314
315func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
316  // expected-note @below {{target op}}
317  %0 = linalg.generic {
318    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
319    iterator_types = ["parallel"]
320  }
321  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
322  ^bb0(%0: f32, %1: f32):
323    linalg.yield %0 : f32
324  } -> tensor<100xf32>
325  return %0 : tensor<100xf32>
326}
327
328