1// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
2
3// -----
4
5// CHECK-LABEL: contraction_dot
6func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
7
8// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
9// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [0] : vector<1584xf32> to f32
10  linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
11            outs(%C: memref<f32>)
12  return
13}
14
15// -----
16
17// CHECK-LABEL: contraction_matvec
18func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
19
20// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
21// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
22  linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
23            outs(%C: memref<1584xf32>)
24  return
25}
26
27// -----
28
29// CHECK-LABEL: contraction_matmul
30func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
31// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
32// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
33  linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
34            outs(%C: memref<1584x1584xf32>)
35  return
36}
37
38// -----
39
40// CHECK-LABEL: contraction_batch_matmul
41func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
42// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
43// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
44  linalg.batch_matmul
45    ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
46   outs(%C: memref<1584x1584x1584xf32>)
47  return
48}
49
50// -----
51
52#matmul_trait = {
53  args_in = 2,
54  args_out = 1,
55  indexing_maps = [
56    affine_map<(m, n, k) -> (m, k)>,
57    affine_map<(m, n, k) -> (k, n)>,
58    affine_map<(m, n, k) -> (m, n)>
59  ],
60  iterator_types = ["parallel", "parallel", "reduction"]
61}
62
63// CHECK-LABEL: func @vectorization_test
64func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
65                         %C: memref<8x32xf32>) {
66  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
67  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
68  //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
69  //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
70  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
71  //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
72  linalg.generic #matmul_trait
73    ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
74   outs(%C : memref<8x32xf32>) {
75    ^bb(%a: f32, %b: f32, %c: f32) :
76      %d = arith.mulf %a, %b: f32
77      %e = arith.addf %c, %d: f32
78      linalg.yield %e : f32
79  }
80  return
81}
82
83// -----
84
85#matmul_transpose_out_trait = {
86  args_in = 2,
87  args_out = 1,
88  indexing_maps = [
89    affine_map<(m, n, k) -> (m, k)>,
90    affine_map<(m, n, k) -> (k, n)>,
91    affine_map<(m, n, k) -> (n, m)>
92  ],
93  iterator_types = ["parallel", "parallel", "reduction"]
94}
95
96// CHECK-LABEL: func @generic_output_transpose
97func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
98                         %C: memref<32x8xf32>) {
99  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
100  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
101  //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
102  //       CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
103  //       CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
104  //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
105  linalg.generic #matmul_transpose_out_trait
106    ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
107   outs(%C : memref<32x8xf32>) {
108    ^bb(%a: f32, %b: f32, %c: f32) :
109      %d = arith.mulf %a, %b: f32
110      %e = arith.addf %c, %d: f32
111      linalg.yield %e : f32
112  }
113  return
114}
115
116// -----
117
118#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
119#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
120// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
121// CHECK: func @generic_interchanged_transpose
122func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> {
123  // CHECK: %[[IN:.+]] = vector.transfer_read
124  // CHECK: vector.transfer_write %[[IN]], {{.+}} permutation_map = #[[MAP]]
125  %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32>
126  %1 = linalg.generic {indexing_maps = [#map0, #map1],
127                       iterator_types = ["parallel", "parallel", "parallel"]}
128    ins(%arg0 : tensor<12x128x32xf32>)
129    outs(%0 : tensor<128x12x32xf32>) {
130  ^bb0(%arg1: f32, %arg2: f32):
131    linalg.yield %arg1 : f32
132  } -> tensor<128x12x32xf32>
133  return %1 : tensor<128x12x32xf32>
134}
135
136// -----
137
138#matmul_trait = {
139  args_in = 2,
140  args_out = 1,
141  indexing_maps = [
142    affine_map<(m, n, k) -> (m, k)>,
143    affine_map<(m, n, k) -> (k, n)>,
144    affine_map<(m, n, k) -> (m, n)>
145  ],
146  iterator_types = ["parallel", "parallel", "reduction"]
147}
148
149// CHECK-LABEL: func @vectorization_test_integer
150func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
151                                 %C: memref<8x32xi32>) {
152  //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32>
153  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
154  //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
155  //       CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
156  //       CHECK: vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
157  //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
158  linalg.generic #matmul_trait
159    ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
160   outs(%C : memref<8x32xi32>) {
161    ^bb(%a: i32, %b: i32, %c: i32) :
162      %d = arith.muli %a, %b: i32
163      %e = arith.addi %c, %d: i32
164      linalg.yield %e : i32
165  }
166  return
167}
168
169// -----
170
171// CHECK-LABEL: func @vectorization_test_2
172func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
173                         %C: memref<8x32xf32>) {
174  //       CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
175  //       CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
176  linalg.matmul
177    ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
178   outs(%C: memref<8x32xf32>)
179  return
180}
181
182// -----
183
184// CHECK-LABEL: func @test_vectorize_scalar_input
185func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
186  //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
187  //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
188  linalg.generic {
189    indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
190    iterator_types = ["parallel", "parallel"]}
191   ins(%arg0 : f32)
192  outs(%A: memref<8x16xf32>) {
193    ^bb(%0: f32, %1: f32) :
194      linalg.yield %0 : f32
195  }
196  return
197}
198
199// -----
200
201// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
202func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex<f32>>, %arg0 : complex<f32>) {
203  // CHECK-NOT: vector.broadcast
204  // CHECK-NOT: vector.transfer_write
205  linalg.generic {
206    indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
207    iterator_types = ["parallel", "parallel"]}
208   ins(%arg0 : complex<f32>)
209  outs(%A: memref<8x16xcomplex<f32>>) {
210    ^bb(%0: complex<f32>, %1: complex<f32>) :
211      linalg.yield %0 : complex<f32>
212  }
213  return
214}
215
216// -----
217
218// CHECK-LABEL: func @test_vectorize_fill
219func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
220  //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
221  //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
222  linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>)
223  return
224}
225
226// -----
227
228// CHECK-LABEL: func @test_vectorize_fill
229func.func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
230  // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
231  //      CHECK:   %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
232  //      CHECK:   vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
233  linalg.fill ins(%arg0 : f32) outs(%A : memref<f32>)
234  return
235}
236
237// -----
238
239// CHECK-LABEL: func @test_vectorize_copy
240func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
241  //       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
242  //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
243  memref.copy %A, %B :  memref<8x16xf32> to memref<8x16xf32>
244  return
245}
246
247// -----
248
249// CHECK-LABEL: func @test_vectorize_copy_scalar
250func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
251  //  CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
252  //       CHECK:   %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
253  //       CHECK:   %[[val:.*]] = vector.extractelement %[[V]][] : vector<f32>
254  //       CHECK:   %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
255  //       CHECK:   vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
256  memref.copy %A, %B :  memref<f32> to memref<f32>
257  return
258}
259
260// -----
261
262// CHECK-LABEL: func @test_vectorize_trailing_index
263  //  CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
264func.func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
265  //   CHECK-DAG:   %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
266  //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
267  linalg.generic {
268    indexing_maps = [
269      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
270    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
271  outs(%arg0: memref<1x2x4x8xindex>) {
272  ^bb0(%arg1: index):
273  //       CHECK:   %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<8xindex> to vector<1x2x4x8xindex>
274  //       CHECK:   vector.transfer_write %[[BCST]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
275    %0 = linalg.index 3 : index
276    linalg.yield %0 : index
277  }
278  return
279}
280
281// -----
282
283// CHECK-LABEL: func @test_vectorize_inner_index
284  //  CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
285func.func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) {
286  //   CHECK-DAG:   %[[CST0:.*]] = arith.constant dense<[0, 1]> : vector<2xindex>
287  //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
288  linalg.generic {
289    indexing_maps = [
290      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
291    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
292  outs(%arg0: memref<1x2x4x8xindex>) {
293  ^bb0(%arg1: index):
294  //       CHECK:   %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<2xindex> to vector<1x8x4x2xindex>
295  //       CHECK:   %[[TRAN:.*]] = vector.transpose %[[BCST]], [0, 3, 2, 1] : vector<1x8x4x2xindex> to vector<1x2x4x8xindex>
296  //       CHECK:   vector.transfer_write %[[TRAN]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
297    %0 = linalg.index 1 : index
298    linalg.yield %0 : index
299  }
300  return
301}
302
303// -----
304
305// CHECK-LABEL: func @generic_vectorize
306  //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
307  //  CHECK-SAME:  %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
308func.func @generic_vectorize(%arg0: memref<4x256xf32>,
309                        %arg1: memref<4x256xf32>,
310                        %arg2: memref<256xf32>, %i: f32) {
311  //   CHECK-DAG:   %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<4x256xf32>
312  //   CHECK-DAG:   %[[CST1:.*]] = arith.constant dense<1.000000e+00> : vector<4x256xf32>
313  //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
314  %c1_f32 = arith.constant 1.0 : f32
315  linalg.generic {
316    args_in = 0 : i64,
317    args_out = 10 : i64,
318    indexing_maps = [
319      affine_map<(d0, d1) -> (d0, d1)>,
320      affine_map<(d0, d1) -> (d1)>,
321      affine_map<(d0, d1) -> (d0, d1)>,
322      affine_map<(d0, d1) -> (d0, d1)>,
323      affine_map<(d0, d1) -> (d0, d1)>,
324      affine_map<(d0, d1) -> (d0, d1)>,
325      affine_map<(d0, d1) -> (d0, d1)>,
326      affine_map<(d0, d1) -> (d0, d1)>,
327      affine_map<(d0, d1) -> (d0, d1)>,
328      affine_map<(d0, d1) -> (d0, d1)>,
329      affine_map<(d0, d1) -> (d0, d1)>,
330      affine_map<(d0, d1) -> (d0, d1)>],
331    iterator_types = ["parallel", "parallel"]}
332  ins(%arg1, %arg2: memref<4x256xf32>, memref<256xf32>)
333  outs(
334    %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
335    memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
336    memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
337    memref<4x256xf32>, memref<4x256xf32>) {
338  ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
339  //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
340  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<4x256xf32>
341  //       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
342  //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
343    %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
344    %arg14 : f32):
345  //       CHECK:   %[[ADD:.*]] = arith.addf %[[V0]], %[[V1]] : vector<4x256xf32>
346    %6 = arith.addf %arg4, %arg6 : f32
347  //       CHECK:   %[[CMP:.*]] = arith.cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
348    %7 = arith.cmpf ogt, %arg3, %arg6 : f32
349  //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
350    %8 = arith.constant 2.0 : f32
351  //       CHECK:   %[[DIV:.*]] = arith.divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
352    %9 = arith.divf %arg5, %i : f32
353  //       CHECK:   %[[EXP:.*]] = math.exp2 %[[V3]] : vector<4x256xf32>
354    %10 = math.exp2 %arg5 : f32
355  //       CHECK:   %[[MUL:.*]] = arith.mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
356    %11 = arith.mulf %arg5, %8 : f32
357  //       CHECK:   %[[RSQRT:.*]] = math.rsqrt %[[V3]] : vector<4x256xf32>
358    %12 = math.rsqrt %arg5 : f32
359  //       CHECK:   %[[SEL:.*]] = arith.select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
360    %13 = arith.select %7, %arg5, %arg6 : f32
361  //       CHECK:   %[[SUB:.*]] = arith.subf %[[V3]], %[[V0]] : vector<4x256xf32>
362    %14 = arith.subf %arg5, %arg4 : f32
363  //       CHECK:   %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32>
364    %15 = math.tanh %arg5 : f32
365  //       CHECK:   vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
366  //       CHECK:   vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
367  //       CHECK:   vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
368  //       CHECK:   vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
369  //       CHECK:   vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
370  //       CHECK:   vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
371  //       CHECK:   vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
372  //       CHECK:   vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
373  //       CHECK:   vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
374  //       CHECK:   vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
375    linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
376      f32, f32, f32, f32, f32, f32, f32, f32
377  }
378  return
379}
380
381// -----
382
383// CHECK-LABEL: func @generic_vectorize_tensor
384//  CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
385//  CHECK-SAME:  %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
386func.func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
387  %arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
388  %i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
389    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
390    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) {
391  %c1_f32 = arith.constant 1.0 : f32
392  %r:10 = linalg.generic {
393    indexing_maps = [
394      affine_map<(d0, d1) -> (d0, d1)>,
395      affine_map<(d0, d1) -> (d1)>,
396      affine_map<(d0, d1) -> (d0, d1)>,
397      affine_map<(d0, d1) -> (d0, d1)>,
398      affine_map<(d0, d1) -> (d0, d1)>,
399      affine_map<(d0, d1) -> (d0, d1)>,
400      affine_map<(d0, d1) -> (d0, d1)>,
401      affine_map<(d0, d1) -> (d0, d1)>,
402      affine_map<(d0, d1) -> (d0, d1)>,
403      affine_map<(d0, d1) -> (d0, d1)>,
404      affine_map<(d0, d1) -> (d0, d1)>,
405      affine_map<(d0, d1) -> (d0, d1)>],
406    iterator_types = ["parallel", "parallel"]}
407  ins(%arg1, %arg2: tensor<4x256xf32>, tensor<256xf32>)
408  outs(
409    %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
410    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
411    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
412    tensor<4x256xf32>, tensor<4x256xf32>) {
413  ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
414    %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
415    %arg14 : f32):
416  //   CHECK-DAG:   %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<4x256xf32>
417  //   CHECK-DAG:   %[[CST1:.*]] = arith.constant dense<1.000000e+00> : vector<4x256xf32>
418  //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
419  //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
420  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<4x256xf32>
421  //       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
422  //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
423  //       CHECK:   %[[ADD:.*]] = arith.addf %[[V0]], %[[V1]] : vector<4x256xf32>
424    %6 = arith.addf %arg4, %arg6 : f32
425  //       CHECK:   %[[CMP:.*]] = arith.cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
426    %7 = arith.cmpf ogt, %arg3, %arg6 : f32
427  //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
428    %8 = arith.constant 2.0 : f32
429  //       CHECK:   %[[DIV:.*]] = arith.divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
430    %9 = arith.divf %arg5, %i : f32
431  //       CHECK:   %[[EXP:.*]] = math.exp2 %[[V3]] : vector<4x256xf32>
432    %10 = math.exp2 %arg5 : f32
433  //       CHECK:   %[[MUL:.*]] = arith.mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
434    %11 = arith.mulf %arg5, %8 : f32
435  //       CHECK:   %[[RSQRT:.*]] = math.rsqrt %[[V3]] : vector<4x256xf32>
436    %12 = math.rsqrt %arg5 : f32
437  //       CHECK:   %[[SEL:.*]] = arith.select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
438    %13 = arith.select %7, %arg5, %arg6 : f32
439  //       CHECK:   %[[SUB:.*]] = arith.subf %[[V3]], %[[V0]] : vector<4x256xf32>
440    %14 = arith.subf %arg5, %arg4 : f32
441  //       CHECK:   %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32>
442    %15 = math.tanh %arg5 : f32
443  //       CHECK:   %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
444  //       CHECK:   %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
445  //       CHECK:   %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
446  //       CHECK:   %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
447  //       CHECK:   %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
448  //       CHECK:   %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
449  //       CHECK:   %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
450  //       CHECK:   %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
451  //       CHECK:   %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
452  //       CHECK:   %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
453    linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
454      f32, f32, f32, f32, f32, f32, f32, f32
455  } -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
456    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
457    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>)
458  //       CHECK:   return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
459  return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
460    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
461    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
462    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
463}
464
465// -----
466
467// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)>
468// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0, 0, 0, 0)>
469// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, 0, d0, 0)>
470// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1, 0, d0, 0)>
471//     CHECK: func @generic_vectorize_broadcast_transpose
472// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
473// CHECK-DAG:   %[[CF:.*]] = arith.constant 0.000000e+00 : f32
474//     CHECK:   %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP0]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
475//     CHECK:   %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP1]]} : memref<4xf32>, vector<4x4x4x4xf32>
476//     CHECK:   %[[V2:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP2]]} : memref<4xf32>, vector<4x4x4x4xf32>
477//     CHECK:   %[[V3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP3]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
478//     CHECK:   %[[SUB:.*]] = arith.subf %[[V0]], %[[V1]] : vector<4x4x4x4xf32>
479//     CHECK:   %[[ADD0:.*]] = arith.addf %[[V2]], %[[SUB]] : vector<4x4x4x4xf32>
480//     CHECK:   %[[ADD1:.*]] = arith.addf %[[V3]], %[[ADD0]] : vector<4x4x4x4xf32>
481//     CHECK: vector.transfer_write %[[ADD1]], {{.*}} : vector<4x4x4x4xf32>, memref<4x4x4x4xf32>
482func.func @generic_vectorize_broadcast_transpose(
483  %A: memref<4xf32>, %B: memref<4x4xf32>, %C: memref<4x4x4x4xf32>) {
484  linalg.generic {
485  indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
486                   affine_map<(d0, d1, d2, d3) -> (d0)>,
487                   affine_map<(d0, d1, d2, d3) -> (d2)>,
488                   affine_map<(d0, d1, d2, d3) -> (d2, d0)>,
489                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
490  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
491  ins(%B, %A, %A, %B: memref<4x4xf32>, memref<4xf32>, memref<4xf32>, memref<4x4xf32>)
492  outs(%C : memref<4x4x4x4xf32>) {
493  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
494    %s = arith.subf %arg0, %arg1 : f32
495    %a = arith.addf %arg2, %s : f32
496    %b = arith.addf %arg3, %a : f32
497    linalg.yield %b : f32
498  }
499  return
500}
501
502// -----
503
504// Test different input maps.
505#matmul_trait = {
506  indexing_maps = [
507    affine_map<(d0, d1, d2, d3) -> (d1, d0)>,
508    affine_map<(d0, d1, d2, d3) -> (d3, d1)>,
509    affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>,
510    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
511  ],
512  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
513}
514
515// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
516// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (0, d1, 0, d0)>
517// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
518//       CHECK: func @vectorization_transpose
519//       CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14x8x16xf32>
520//       CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP1]]} : memref<16x14xf32>, vector<7x14x8x16xf32>
521//       CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP2]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32>
522//       CHECK: arith.addf {{.*}} : vector<7x14x8x16xf32>
523//       CHECK: arith.addf {{.*}} : vector<7x14x8x16xf32>
524//       CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32>
525func.func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>,
526                         %C: memref<16x14x7x8xf32>, %D: memref<7x14x8x16xf32>) {
527  linalg.generic #matmul_trait
528    ins(%A, %B, %C : memref<14x7xf32>, memref<16x14xf32>, memref<16x14x7x8xf32>)
529   outs(%D : memref<7x14x8x16xf32>) {
530    ^bb(%a: f32, %b: f32, %c: f32, %d: f32) :
531      %e = arith.addf %a, %b: f32
532      %f = arith.addf %e, %c: f32
533      linalg.yield %f : f32
534  }
535  return
536}
537
538// -----
539
540// CHECK-LABEL: func @matmul_tensors
541//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
542//  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
543func.func @matmul_tensors(
544  %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
545    -> tensor<8x12xf32> {
546  //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
547  //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32>
548  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32>
549  //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
550  //
551  // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
552  // convert it to a 2D contract.
553  //       CHECK:   %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
554  //       CHECK:   %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
555  //       CHECK:   %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
556  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
557                     outs(%arg2: tensor<8x12xf32>)
558    -> tensor<8x12xf32>
559  //       CHECK:   return %[[W]] : tensor<8x12xf32>
560  return %0 : tensor<8x12xf32>
561}
562
563// -----
564
565// CHECK-LABEL: func @pad_static(
566//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<2x?x2xf32>, %[[PAD:.*]]: f32
567//   CHECK-NOT:   tensor.pad
568//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
569//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
570//   CHECK-DAG:   %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32>
571//   CHECK-DAG:   %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x3x4xf32>
572//       CHECK:   %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]]{{.*}} : vector<2x3x4xf32>, tensor<2x3x4xf32>
573//       CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, false, true]} : tensor<2x?x2xf32>, vector<2x3x2xf32>
574//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x3x2xf32>, tensor<2x3x4xf32>
575//       CHECK:   return %[[RESULT]]
576func.func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
577  %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] {
578    ^bb0(%arg1: index, %arg2: index, %arg3: index):
579      tensor.yield %pad_value : f32
580    } : tensor<2x?x2xf32> to tensor<2x3x4xf32>
581  return %0 : tensor<2x3x4xf32>
582}
583
584// -----
585
586// CHECK-LABEL: func @pad_static_source(
587//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<2x5x2xf32>, %[[PAD:.*]]: f32
588//   CHECK-NOT:   tensor.pad
589//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
590//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
591//       CHECK:   %[[INIT:.*]] = linalg.init_tensor [2, 6, 4] : tensor<2x6x4xf32>
592//       CHECK:   %[[VEC:.*]] =  vector.broadcast %[[PAD]] : f32 to vector<2x6x4xf32>
593//       CHECK:   %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x6x4xf32>, tensor<2x6x4xf32>
594//       CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : tensor<2x5x2xf32>, vector<2x5x2xf32>
595//       CHECK:   %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x5x2xf32>, tensor<2x6x4xf32>
596//       CHECK:   return %[[WRITE]]
597func.func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tensor<2x6x4xf32> {
598  %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] {
599    ^bb0(%arg1: index, %arg2: index, %arg3: index):
600      tensor.yield %pad_value : f32
601    } : tensor<2x5x2xf32> to tensor<2x6x4xf32>
602  return %0 : tensor<2x6x4xf32>
603}
604
605// -----
606
607// CHECK-LABEL: func @pad_static_dynamic(
608//  CHECK-SAME:                          %[[SRC:.*]]: tensor<1x2x2x?xf32>, %[[LOW:.*]]: index, %[[HIGH:.*]]: index
609//   CHECK-NOT:   tensor.pad
610//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
611//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
612//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
613//       CHECK:   %[[V0:.*]] = arith.addi %[[LOW]], %[[C2]] : index
614//       CHECK:   %[[V1:.*]] = arith.addi %[[V0]], %[[C3]] : index
615//       CHECK:   %[[V2:.*]] = arith.addi %[[HIGH]], %[[C5]] : index
616//       CHECK:   %[[DIM3:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32>
617//       CHECK:   %[[V4:.*]] = arith.addi %[[DIM3]], %[[C3]] : index
618//       CHECK:   %[[V5:.*]] = arith.addi %[[V4]], %[[C2]] : index
619//       CHECK:   %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32>
620//       CHECK:   %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[INIT]] : tensor<6x?x?x?xf32>) -> tensor<6x?x?x?xf32>
621//       CHECK:   %[[SRCDIM:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32>
622//       CHECK:   %[[RESULT:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32>
623//       CHECK:   return %[[RESULT]]
624func.func @pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
625                  %pad_value: f32) -> tensor<6x?x?x?xf32> {
626  %0 = tensor.pad %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
627    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
628      tensor.yield %pad_value : f32
629    } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
630  return %0 : tensor<6x?x?x?xf32>
631}
632
633// -----
634
635// CHECK-LABEL: func @pad_and_transfer_read
636//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
637//   CHECK-NOT:   tensor.pad
638//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
639//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5.0
640//       CHECK:   %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
641//       CHECK:   return %[[RESULT]]
642func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
643  %c0 = arith.constant 0 : index
644  %c5 = arith.constant 5.0 : f32
645  %c6 = arith.constant 6.0 : f32
646  %0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
647    ^bb0(%arg1: index, %arg2: index):
648      tensor.yield %c5 : f32
649  } : tensor<5x6xf32> to tensor<10x13xf32>
650  %1 = vector.transfer_read %0[%c0, %c0], %c6
651      : tensor<10x13xf32>, vector<7x9xf32>
652  return %1 : vector<7x9xf32>
653}
654
655// -----
656
657func.func private @make_vector() -> vector<7x9xf32>
658
659// CHECK-LABEL: func @pad_and_transfer_write_static
660//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
661//   CHECK-NOT:   tensor.pad
662//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
663//       CHECK:   %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
664//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
665//       CHECK:   return %[[RESULT]]
666func.func @pad_and_transfer_write_static(
667    %arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
668  %c0 = arith.constant 0 : index
669  %c5 = arith.constant 5.0 : f32
670  %0 = tensor.pad %arg0 low[0, 0] high[5, 7] {
671    ^bb0(%arg2: index, %arg3: index):
672      tensor.yield %c5 : f32
673  } : tensor<5x6xf32> to tensor<10x13xf32>
674  %1 = call @make_vector() : () -> vector<7x9xf32>
675  %2 = vector.transfer_write %1, %0[%c0, %c0]
676      : vector<7x9xf32>, tensor<10x13xf32>
677  %3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32>
678  return %3 : tensor<5x6xf32>
679}
680
681// -----
682
683func.func private @make_vector() -> vector<7x9xf32>
684
685// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static
686//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x?xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
687//   CHECK-NOT:   tensor.pad
688//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
689//       CHECK:   %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
690//       CHECK:   %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
691//       CHECK:   %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
692//       CHECK:   return %[[RESULT]]
693func.func @pad_and_transfer_write_dynamic_static(
694    %arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
695  %c0 = arith.constant 0 : index
696  %c5 = arith.constant 5.0 : f32
697  %s = tensor.extract_slice %arg0[0, 0] [%size, 6] [1, 1]
698      : tensor<?x?xf32> to tensor<?x6xf32>
699  %0 = tensor.pad %s low[0, 0] high[%padding, 7] {
700    ^bb0(%arg2: index, %arg3: index):
701      tensor.yield %c5 : f32
702  } : tensor<?x6xf32> to tensor<?x13xf32>
703  %1 = call @make_vector() : () -> vector<7x9xf32>
704  %2 = vector.transfer_write %1, %0[%c0, %c0]
705      : vector<7x9xf32>, tensor<?x13xf32>
706  %3 = tensor.extract_slice %2[0, 0] [%size, 6] [1, 1] : tensor<?x13xf32> to tensor<?x6xf32>
707  return %3 : tensor<?x6xf32>
708}
709
710// -----
711
712func.func private @make_vector() -> tensor<12x13xf32>
713
714// CHECK-LABEL: func @pad_and_insert_slice_source
715//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
716//   CHECK-NOT:   tensor.pad
717//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
718//   CHECK-DAG:   %[[C5:.*]] = arith.constant 5.0
719//       CHECK:   %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32>
720//       CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
721//       CHECK:   %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
722//       CHECK:   return %[[WRITE]]
723func.func @pad_and_insert_slice_source(
724    %arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
725  %c0 = arith.constant 0 : index
726  %c5 = arith.constant 5.0 : f32
727  %0 = tensor.pad %arg0 low[0, 0] high[2, 3] {
728    ^bb0(%arg2: index, %arg3: index):
729      tensor.yield %c5 : f32
730  } : tensor<5x6xf32> to tensor<7x9xf32>
731  %1 = call @make_vector() : () -> tensor<12x13xf32>
732  %r = tensor.insert_slice %0 into %1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
733  return %r : tensor<12x13xf32>
734}
735
736// -----
737
738func.func private @make_vector() -> tensor<12x13xf32>
739
740// CHECK-LABEL: func @pad_and_insert_slice_dest
741// Check the insert slice is not rewritten if the padded result is used by the destination operand.
742//       CHECK:   %[[T1:.*]] = call @make_vector() : () -> tensor<12x13xf32>
743//       CHECK:   = tensor.insert_slice %[[T1]] into
744func.func @pad_and_insert_slice_dest(
745    %arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
746  %c5 = arith.constant 5.0 : f32
747  %0 = tensor.pad %arg0 low[0, 0, 0] high[0, 7, 7] {
748    ^bb0(%arg2: index, %arg3: index, %arg4: index):
749      tensor.yield %c5 : f32
750  } : tensor<1x5x6xf32> to tensor<1x12x13xf32>
751  %1 = call @make_vector() : () -> tensor<12x13xf32>
752  %r = tensor.insert_slice %1 into %0[0, 0, 0][1, 12, 13][1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
753  return %r : tensor<1x12x13xf32>
754}
755
756// -----
757
758// CHECK-LABEL: func @pad_tensor_non_const_pad_value
759//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
760//   CHECK-NOT:   tensor.pad
761//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
762//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
763//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
764//       CHECK:   %[[FILL:.*]] = tensor.generate
765//       CHECK:     %[[RES:.*]] = arith.mulf
766//       CHECK:     tensor.yield %[[RES]] : f32
767//       CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : tensor<5x6xf32>, vector<5x6xf32>
768//       CHECK:   %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C3]], %[[C4]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<12x13xf32>
769//       CHECK:   return %[[WRITE]]
770func.func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
771  %c0 = arith.constant 0 : index
772  %c5 = arith.constant 5.0 : f32
773  %0 = tensor.pad %arg0 low[3, 4] high[4, 3] {
774    ^bb0(%arg1: index, %arg2: index):
775      %i1 = arith.index_cast %arg1 : index to i32
776      %i2 = arith.index_cast %arg2 : index to i32
777      %f1 = arith.sitofp %i1 : i32 to f32
778      %f2 = arith.sitofp %i2 : i32 to f32
779      %m = arith.mulf %f1, %f2 : f32
780      tensor.yield %m : f32
781  } : tensor<5x6xf32> to tensor<12x13xf32>
782  return %0 : tensor<12x13xf32>
783}
784
785// -----
786
787// CHECK-LABEL: func @sum_exp
788func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
789  -> tensor<4x16xf32>
790{
791  // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
792  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
793  // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
794  // CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
795  // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
796  // CHECK: return {{.*}} : tensor<4x16xf32>
797  %0 = linalg.generic {
798      indexing_maps = [
799        affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
800        affine_map<(d0, d1, d2) -> (d0, d1)>
801      ],
802      iterator_types = ["parallel", "parallel", "reduction"]
803    } ins(%input : tensor<4x16x8xf32>) outs(%output : tensor<4x16xf32>) {
804    ^bb0(%arg0: f32, %arg1: f32):
805      %1 = math.exp %arg0 : f32
806      %2 = arith.addf %1, %arg1 : f32
807      linalg.yield %2 : f32
808    } -> tensor<4x16xf32>
809  return %0 : tensor<4x16xf32>
810}
811
812// -----
813
814// CHECK-DAG: #[[$M1:.*]] =  affine_map<(d0, d1) -> (d1, d0, 0, 0)>
815// CHECK-DAG: #[[$M2:.*]] =  affine_map<(d0, d1) -> (0, 0, d1, d0)>
816// CHECK-DAG: #[[$M3:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
817
818// CHECK-LABEL: func @sum_exp_2
819func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: tensor<5x2xf32>)
820  -> tensor<5x2xf32>
821{
822  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
823  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
824  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
825  // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
826  // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
827  // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
828  // CHECK: vector.multi_reduction <add>, {{.*}}, %{{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
829  // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
830  // CHECK: return {{.*}} : tensor<5x2xf32>
831  %0 = linalg.generic {
832      indexing_maps = [
833        affine_map<(d0, d1, d2, d3) -> (d1, d0)>,
834        affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
835        affine_map<(d0, d1, d2, d3) -> (d3, d0)>
836      ],
837      iterator_types = ["parallel", "reduction", "reduction", "parallel"]
838    } ins(%input, %input_2 : tensor<3x2xf32>, tensor<5x4xf32>) outs(%output : tensor<5x2xf32>) {
839    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
840      %1 = math.exp %arg0 : f32
841      %2 = math.exp %arg1 : f32
842      %3 = arith.addf %1, %2 : f32
843      %4 = arith.addf %3, %arg2 : f32
844      linalg.yield %4 : f32
845    } -> tensor<5x2xf32>
846  return %0 : tensor<5x2xf32>
847}
848
849// -----
850
851// CHECK-LABEL:   func @red_max_2d(
852func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
853  // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
854  // CHECK: linalg.init_tensor [4] : tensor<4xf32>
855  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
856  // CHECK: vector.multi_reduction <maxf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
857  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
858  %ident = arith.constant -3.40282e+38 : f32
859  %init = linalg.init_tensor [4] : tensor<4xf32>
860  %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
861  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
862                                          affine_map<(d0, d1) -> (d0)>],
863                         iterator_types = ["parallel", "reduction"]}
864                         ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
865  ^bb0(%in0: f32, %out0: f32):
866    %max = arith.maxf %in0, %out0 : f32
867    linalg.yield %max : f32
868  } -> tensor<4xf32>
869  return %red : tensor<4xf32>
870}
871
872// -----
873
874// CHECK-LABEL:   func @red_min_2d(
875func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
876  // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
877  // CHECK: linalg.init_tensor [4] : tensor<4xf32>
878  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
879  // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
880  // CHECK: vector.multi_reduction <minf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
881  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
882  %maxf32 = arith.constant 3.40282e+38 : f32
883  %init = linalg.init_tensor [4] : tensor<4xf32>
884  %fill = linalg.fill ins(%maxf32 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
885  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
886                                          affine_map<(d0, d1) -> (d0)>],
887                         iterator_types = ["parallel", "reduction"]}
888                         ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
889  ^bb0(%in0: f32, %out0: f32):
890    %min = arith.minf %out0, %in0 : f32
891    linalg.yield %min : f32
892  } -> tensor<4xf32>
893  return %red : tensor<4xf32>
894}
895
896// -----
897
898// CHECK-LABEL:   func @red_mul_2d(
899func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
900  // CHECK: linalg.init_tensor [4] : tensor<4xf32>
901  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
902  // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
903  // CHECK: vector.multi_reduction <mul>, {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
904  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
905  %ident = arith.constant 1.0 : f32
906  %init = linalg.init_tensor [4] : tensor<4xf32>
907  %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
908  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
909                                          affine_map<(d0, d1) -> (d0)>],
910                         iterator_types = ["parallel", "reduction"]}
911                         ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
912  ^bb0(%in0: f32, %out0: f32):
913    %mul = arith.mulf %in0, %out0 : f32
914    linalg.yield %mul : f32
915  } -> tensor<4xf32>
916  return %red : tensor<4xf32>
917}
918
919// -----
920
921// CHECK-LABEL:   func @red_or_2d(
922func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
923  // CHECK: linalg.init_tensor [4] : tensor<4xi1>
924  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
925  // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
926  // CHECK: vector.multi_reduction <or>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
927  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
928  %ident = arith.constant false
929  %init = linalg.init_tensor [4] : tensor<4xi1>
930  %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1>
931  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
932                                          affine_map<(d0, d1) -> (d0)>],
933                         iterator_types = ["parallel", "reduction"]}
934                         ins(%arg0 : tensor<4x4xi1>) outs(%fill : tensor<4xi1>) {
935  ^bb0(%in0: i1, %out0: i1):
936    %or = arith.ori %in0, %out0 : i1
937    linalg.yield %or : i1
938  } -> tensor<4xi1>
939  return %red : tensor<4xi1>
940}
941
942// -----
943
944// CHECK-LABEL:   func @red_and_2d(
945func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
946  // CHECK: linalg.init_tensor [4] : tensor<4xi1>
947  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
948  // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
949  // CHECK: vector.multi_reduction <and>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
950  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
951  %ident = arith.constant true
952  %init = linalg.init_tensor [4] : tensor<4xi1>
953  %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1>
954  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
955                                          affine_map<(d0, d1) -> (d0)>],
956                         iterator_types = ["parallel", "reduction"]}
957                         ins(%arg0 : tensor<4x4xi1>) outs(%fill : tensor<4xi1>) {
958  ^bb0(%in0: i1, %out0: i1):
959    %and = arith.andi %in0, %out0 : i1
960    linalg.yield %and : i1
961  } -> tensor<4xi1>
962  return %red : tensor<4xi1>
963}
964
965// -----
966
967// CHECK-LABEL:   func @red_xor_2d(
968func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
969  // CHECK: linalg.init_tensor [4] : tensor<4xi1>
970  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
971  // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
972  // CHECK: vector.multi_reduction <xor>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
973  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
974  %ident = arith.constant false
975  %init = linalg.init_tensor [4] : tensor<4xi1>
976  %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1>
977  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
978                                          affine_map<(d0, d1) -> (d0)>],
979                         iterator_types = ["parallel", "reduction"]}
980                         ins(%arg0 : tensor<4x4xi1>) outs(%fill : tensor<4xi1>) {
981  ^bb0(%in0: i1, %out0: i1):
982    %xor = arith.xori %in0, %out0 : i1
983    linalg.yield %xor : i1
984  } -> tensor<4xi1>
985  return %red : tensor<4xi1>
986}
987
988// -----
989
990// CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)>
991
992// CHECK-LABEL:   func @explicit_broadcast(
993func.func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
994  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
995  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M5]]} : tensor<4x1xf32>, vector<4x4xf32>
996  // CHECK: subf {{.*}} : vector<4x4xf32>
997  // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
998  %c0 = arith.constant 0.0 : f32
999  %init = linalg.init_tensor [4, 4] : tensor<4x4xf32>
1000  %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4x4xf32>) -> tensor<4x4xf32>
1001  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1002                                          affine_map<(d0, d1) -> (d0, 0)>,
1003                                          affine_map<(d0, d1) -> (d0, d1)>],
1004   iterator_types = ["parallel", "parallel"]}
1005   ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x1xf32>)
1006   outs(%fill : tensor<4x4xf32>) {
1007    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
1008      %40 = arith.subf %arg7, %arg8 : f32
1009      linalg.yield %40 : f32
1010    } -> tensor<4x4xf32>
1011  return %red : tensor<4x4xf32>
1012}
1013
1014// -----
1015
1016// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
1017
1018// CHECK-LABEL:   func @fused_broadcast_red_2d
1019func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4xf32> {
1020  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
1021  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
1022  // CHECK: subf {{.*}} : vector<4x4xf32>
1023  // CHECK: math.exp {{.*}} : vector<4x4xf32>
1024  // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32>
1025  // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
1026  %c0 = arith.constant 0.0 : f32
1027  %init = linalg.init_tensor [4] : tensor<4xf32>
1028  %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
1029  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1030                                          affine_map<(d0, d1) -> (d0, 0)>,
1031                                          affine_map<(d0, d1) -> (d0)>],
1032   iterator_types = ["parallel", "reduction"]}
1033   ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x1xf32>)
1034   outs(%fill : tensor<4xf32>) {
1035    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
1036      %40 = arith.subf %arg7, %arg8 : f32
1037      %41 = math.exp %40 : f32
1038      %42 = arith.addf %41, %arg9 : f32
1039      linalg.yield %42 : f32
1040    } -> tensor<4xf32>
1041  return %red : tensor<4xf32>
1042}
1043
1044// -----
1045
1046//  CHECK-LABEL: func @reduce_1d(
1047//   CHECK-SAME:   %[[A:.*]]: tensor<32xf32>
1048func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
1049  //  CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
1050  //  CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
1051  //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1052  %f0 = arith.constant 0.000000e+00 : f32
1053
1054  //      CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
1055  %0 = linalg.init_tensor [] : tensor<f32>
1056
1057  //      CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][]
1058  // CHECK-SAME:   : vector<f32>, tensor<f32>
1059  %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
1060  //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
1061  // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
1062  //      CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
1063  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
1064  // CHECK-SAME:   : vector<32xf32> to f32
1065  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
1066  //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
1067  // CHECK-SAME:   : vector<f32>, tensor<f32>
1068  %2 = linalg.generic {
1069         indexing_maps = [affine_map<(d0) -> (d0)>,
1070                          affine_map<(d0) -> ()>],
1071         iterator_types = ["reduction"]}
1072         ins(%arg0 : tensor<32xf32>)
1073         outs(%1 : tensor<f32>) {
1074    ^bb0(%a: f32, %b: f32):
1075      %3 = arith.addf %a, %b : f32
1076      linalg.yield %3 : f32
1077    } -> tensor<f32>
1078
1079  return %2 : tensor<f32>
1080}
1081
1082
1083// -----
1084
1085// This test checks that vectorization does not occur when an input indexing map
1086// is not a projected permutation. In the future, this can be converted to a
1087// positive test when support is added.
1088
1089// CHECK-LABEL:   func @not_projected_permutation
1090func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf32> {
1091  %c0 = arith.constant 0.0 : f32
1092  %init = linalg.init_tensor [6, 6, 3, 3] : tensor<6x6x3x3xf32>
1093  %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<6x6x3x3xf32>) -> tensor<6x6x3x3xf32>
1094  // CHECK: linalg.generic
1095  %result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>,
1096                                             affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
1097   iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1098   ins(%arg0 : tensor<8x8xf32>)
1099   outs(%fill : tensor<6x6x3x3xf32>) {
1100    ^bb0(%arg7: f32, %arg9: f32):
1101      linalg.yield %arg7 : f32
1102    } -> tensor<6x6x3x3xf32>
1103  return %result : tensor<6x6x3x3xf32>
1104}
1105