1// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
2// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
3// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
4// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
5// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL
6
7#dotp_accesses = [
8  affine_map<(i) -> (i)>,
9  affine_map<(i) -> (i)>,
10  affine_map<(i) -> ()>
11]
12#dotp_trait = {
13  indexing_maps = #dotp_accesses,
14  iterator_types = ["reduction"]
15}
16
17// CHECK-LABEL: func @extract_contract1
18// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
19// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
20// CHECK-SAME: %[[C:.*2]]: f32
21// CHECK:      %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
22// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32
23// CHECK:      return %[[R]] : f32
24
25func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
26  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
27    : vector<4xf32>, vector<4xf32> into f32
28  return %0 : f32
29}
30
31// CHECK-LABEL: func @extract_contract1_int
32// CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
33// CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
34// CHECK-SAME: %[[C:.*2]]: i32
35// CHECK:      %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32>
36// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32
37// CHECK:      return %[[R]] : i32
38
39func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
40  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
41    : vector<4xi32>, vector<4xi32> into i32
42  return %0 : i32
43}
44
45#matvec_accesses = [
46  affine_map<(i, j) -> (i, j)>,
47  affine_map<(i, j) -> (j)>,
48  affine_map<(i, j) -> (i)>
49]
50#matvec_trait = {
51  indexing_maps = #matvec_accesses,
52  iterator_types = ["parallel", "reduction"]
53}
54
55// CHECK-LABEL: func @extract_contract2
56// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
57// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
58// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
59// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
60// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
61// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32>
62// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
63// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
64// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
65// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
66// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
67// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
68// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
69// CHECK:      return %[[T10]] : vector<2xf32>
70
71func.func @extract_contract2(%arg0: vector<2x3xf32>,
72                        %arg1: vector<3xf32>,
73                        %arg2: vector<2xf32>) -> vector<2xf32> {
74  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
75    : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
76  return %0 : vector<2xf32>
77}
78
79// CHECK-LABEL: func @extract_contract2_int
80// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
81// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
82// CHECK-SAME: %[[C:.*2]]: vector<2xi32>
83// CHECK:      %[[R:.*]] = arith.constant dense<0> : vector<2xi32>
84// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32>
85// CHECK:      %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32>
86// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xi32> into i32
87// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32>
88// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32>
89// CHECK:      %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
90// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xi32> into i32
91// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32>
92// CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32>
93// CHECK:      return %[[T10]] : vector<2xi32>
94func.func @extract_contract2_int(%arg0: vector<2x3xi32>,
95                        %arg1: vector<3xi32>,
96                        %arg2: vector<2xi32>) -> vector<2xi32> {
97  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
98    : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
99  return %0 : vector<2xi32>
100}
101
102#vecmat_accesses = [
103  affine_map<(i, j) -> (j)>,
104  affine_map<(i, j) -> (i, j)>,
105  affine_map<(i, j) -> (i)>
106]
107#vecmat_trait = {
108  indexing_maps = #vecmat_accesses,
109  iterator_types = ["parallel", "reduction"]
110}
111
112// CHECK-LABEL: func @extract_contract3
113// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
114// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
115// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
116// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
117// CHECK:      %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
118// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32>
119// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
120// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
121// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
122// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32>
123// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
124// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
125// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
126// CHECK:      return %[[T10]] : vector<2xf32>
127
128func.func @extract_contract3(%arg0: vector<3xf32>,
129                        %arg1: vector<2x3xf32>,
130                        %arg2: vector<2xf32>) -> vector<2xf32> {
131  %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
132    : vector<3xf32>, vector<2x3xf32> into vector<2xf32>
133  return %0 : vector<2xf32>
134}
135
136#matmat_accesses = [
137  affine_map<(i, j, k) -> (i, k)>,
138  affine_map<(i, j, k) -> (k, j)>,
139  affine_map<(i, j, k) -> (i, j)>
140]
141#matmat_trait = {
142  indexing_maps = #matmat_accesses,
143  iterator_types = ["parallel", "parallel", "reduction"]
144}
145
146// CHECK-LABEL: func @extract_contract4
147// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
148// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
149// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
150// CHECK:    %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
151// CHECK:    %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
152// CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
153// CHECK:    %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
154// CHECK:    %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
155// CHECK:    %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
156// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
157//
158// CHECK:    %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
159// CHECK:    %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
160// CHECK:    %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
161// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
162//
163// CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
164// CHECK:    %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
165// CHECK:    %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
166// CHECK:    %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
167// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
168//
169// CHECK:    %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
170// CHECK:    %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
171// CHECK:    %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
172// CHECK:    %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
173//
174// CHECK:    %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
175// CHECK:    return %[[T52]] : vector<2x2xf32>
176
177func.func @extract_contract4(%arg0: vector<2x2xf32>,
178                        %arg1: vector<2x2xf32>,
179                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
180  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
181    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
182  return %0 : vector<2x2xf32>
183}
184
185#contraction2d_accesses = [
186  affine_map<(i, j) -> (i, j)>,
187  affine_map<(i, j) -> (i, j)>,
188  affine_map<(i, j) -> ()>
189]
190#contraction2d_trait = {
191  indexing_maps = #contraction2d_accesses,
192  iterator_types = ["reduction", "reduction"]
193}
194
195// CHECK-LABEL: func @full_contract1
196// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
197// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
198// CHECK-SAME: %[[C:.*2]]: f32
199// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
200// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
201// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32>
202// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32
203// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
204// CHECK:      %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
205// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32>
206// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32
207// CHECK:      return %[[T8]] : f32
208
209func.func @full_contract1(%arg0: vector<2x3xf32>,
210                     %arg1: vector<2x3xf32>,
211                     %arg2: f32) -> f32 {
212  %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
213    : vector<2x3xf32>, vector<2x3xf32> into f32
214  return %0 : f32
215}
216
217#contraction2d_trans_accesses = [
218  affine_map<(i, j) -> (i, j)>,
219  affine_map<(i, j) -> (j, i)>,
220  affine_map<(i, j) -> ()>
221]
222#contraction2d_trans_trait = {
223  indexing_maps = #contraction2d_trans_accesses,
224  iterator_types = ["reduction", "reduction"]
225}
226
227// CHECK-LABEL: func @full_contract2
228// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
229// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>,
230// CHECK-SAME: %[[C:.*2]]: f32
231// CHECK:      %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32>
232// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
233// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32>
234// CHECK:      %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32>
235// CHECK:      %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32>
236// CHECK:      %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32>
237// CHECK:      %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
238// CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
239// CHECK:      %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32>
240// CHECK:      %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32
241//
242// CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
243// CHECK:      %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
244// CHECK:      %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32>
245// CHECK:      %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32>
246// CHECK:      %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32>
247// CHECK:      %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
248// CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
249// CHECK:      %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32>
250// CHECK:      %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32
251// CHECK:      return %[[T23]] : f32
252
253func.func @full_contract2(%arg0: vector<2x3xf32>,
254                     %arg1: vector<3x2xf32>,
255                     %arg2: f32) -> f32 {
256  %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
257    : vector<2x3xf32>, vector<3x2xf32> into f32
258  return %0 : f32
259}
260
261// CHECK-LABEL: func @outerproduct_noacc
262// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
263// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
264// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
265// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
266// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
267// CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
268// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
269// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
270// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
271// CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
272// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
273// CHECK:      return %[[T7]] : vector<2x3xf32>
274
275func.func @outerproduct_noacc(%arg0: vector<2xf32>,
276                         %arg1: vector<3xf32>) -> vector<2x3xf32> {
277  %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
278  return %0: vector<2x3xf32>
279}
280
281// CHECK-LABEL: func @outerproduct_acc
282// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
283// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
284// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
285// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
286// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
287// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
288// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
289// CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
290// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
291// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
292// CHECK:      %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
293// CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
294// CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
295// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
296// CHECK:      return %[[T9]] : vector<2x3xf32>
297
298func.func @outerproduct_acc(%arg0: vector<2xf32>,
299                       %arg1: vector<3xf32>,
300                       %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
301  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
302  return %0: vector<2x3xf32>
303}
304
305// CHECK-LABEL: func @outerproduct_noacc_int
306// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
307// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
308// CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
309// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
310// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
311// CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
312// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
313// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32>
314// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
315// CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
316// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
317// CHECK:      return %[[T7]] : vector<2x3xi32>
318func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
319                             %arg1: vector<3xi32>) -> vector<2x3xi32> {
320  %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32>
321  return %0: vector<2x3xi32>
322}
323
324// CHECK-LABEL: func @outerproduct_acc_int
325// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
326// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
327// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
328// CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
329// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
330// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
331// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32>
332// CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
333// CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
334// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
335// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32>
336// CHECK:      %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
337// CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32>
338// CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
339// CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
340// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32>
341// CHECK:      return %[[T11]] : vector<2x3xi32>
342func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
343                           %arg1: vector<3xi32>,
344                           %arg2: vector<2x3xi32>) -> vector<2x3xi32> {
345  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32>
346  return %0: vector<2x3xi32>
347}
348
349// CHECK-LABEL: func @axpy_fp(
350// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
351// CHECK-SAME: %[[B:.*1]]: f32)
352// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
353// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
354// CHECK: return %[[T1]] : vector<16xf32>
355func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
356   %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32
357   return %0: vector<16xf32>
358}
359
360// CHECK-LABEL: func @axpy_fp_add(
361// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
362// CHECK-SAME: %[[B:.*1]]: f32,
363// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
364// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
365// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
366// CHECK: return %[[T1]] : vector<16xf32>
367func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
368   %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32
369   return %0: vector<16xf32>
370}
371
372// CHECK-LABEL: func @axpy_int(
373// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
374// CHECK-SAME: %[[B:.*1]]: i32)
375// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
376// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
377// CHECK: return %[[T1]] : vector<16xi32>
378func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
379   %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32
380   return %0: vector<16xi32>
381}
382
383// CHECK-LABEL: func @axpy_int_add(
384// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
385// CHECK-SAME: %[[B:.*1]]: i32,
386// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
387// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
388// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
389// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
390// CHECK: return %[[T2]] : vector<16xi32>
391func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> {
392   %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32
393   return %0: vector<16xi32>
394}
395
396// CHECK-LABEL: func @nop_shape_cast
397// CHECK-SAME: %[[A:.*]]: vector<16xf32>
398// CHECK:      return %[[A]] : vector<16xf32>
399
400func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
401  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
402  return %0 : vector<16xf32>
403}
404
405// CHECK-LABEL: func @cancel_shape_cast
406// FIXME: PR49590
407// HECK-SAME: %[[A:.*]]: vector<16xf32>
408// HECK:      return %[[A]] : vector<16xf32>
409
410func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
411  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
412  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
413  return %1 : vector<16xf32>
414}
415
416// Shape up and downcasts for 2-D vectors, for supporting conversion to
417// llvm.matrix operations
418// CHECK-LABEL: func @shape_casts
419func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
420  // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
421  // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
422  // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32>
423  //
424  // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
425  // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
426  //
427  // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32>
428  //
429  // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
430  // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
431  //
432  %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
433  // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
434  %r0 = arith.addf %0, %0: vector<4xf32>
435  //
436  // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
437  // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
438  // CHECK-SAME: vector<4xf32> to vector<2xf32>
439  //
440  // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
441  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
442  //
443  // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
444  // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
445  // CHECK-SAME: vector<4xf32> to vector<2xf32>
446  //
447  // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
448  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
449  //
450  %1 = vector.shape_cast %r0  : vector<4xf32> to vector<2x2xf32>
451  // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
452  return %r0, %1 : vector<4xf32>, vector<2x2xf32>
453}
454
455// CHECK-LABEL: func @shape_cast_2d2d
456// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
457// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
458// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32>
459// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
460// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32>
461// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
462// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32>
463// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
464// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32>
465// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
466// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32>
467// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
468// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32>
469// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
470// CHECK: return %[[T11]] : vector<2x3xf32>
471
472func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
473  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
474  return %s : vector<2x3xf32>
475}
476
477// CHECK-LABEL: func @shape_cast_3d1d
478// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
479// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
480// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32>
481// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
482// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32>
483// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
484// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32>
485// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
486// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32>
487// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
488// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32>
489// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
490// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32>
491// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
492// CHECK: return %[[T11]] : vector<6xf32>
493
494func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
495  %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
496  return %s : vector<6xf32>
497}
498
499// CHECK-LABEL: func @shape_cast_1d3d
500// CHECK-SAME: %[[A:.*]]: vector<6xf32>
501// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
502// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32>
503// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
504// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32>
505// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
506// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32>
507// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
508// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32>
509// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
510// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32>
511// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
512// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32>
513// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
514// CHECK: return %[[T11]] : vector<2x1x3xf32>
515
516func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
517  %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
518  return %s : vector<2x1x3xf32>
519}
520
521// MATRIX-LABEL: func @matmul
522// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
523// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
524// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
525//      MATRIX:  %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
526//      MATRIX:  %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32>
527//      MATRIX:  %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
528//      MATRIX:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
529//      MATRIX:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
530//      MATRIX:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
531//      MATRIX:  %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
532//      MATRIX:  %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
533//      MATRIX:  %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
534//      MATRIX:  %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
535//      MATRIX:  %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
536//      MATRIX:  %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
537//      MATRIX:  %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
538//      MATRIX:  %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
539//      MATRIX:  %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
540//      MATRIX:  %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
541//      MATRIX:  %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
542//      MATRIX:  %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32>
543//      MATRIX:  %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
544//      MATRIX:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
545//      MATRIX:  %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
546
547// OUTERPRODUCT-LABEL: func @matmul
548// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
549// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
550// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
551//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
552// OUTERPRODUCT-SAME:  : vector<2x4xf32> to vector<4x2xf32>
553//
554//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32>
555//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
556//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
557// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
558//
559//      OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32>
560//      OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
561//      OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
562// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
563//
564//      OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32>
565//      OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
566//      OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
567// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
568//
569//      OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32>
570//      OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
571//      OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
572// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
573//
574//      OUTERPRODUCT: return %[[c3]] : vector<2x3xf32>
575
576// REDUCE-LABEL: func @matmul
577// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
578// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
579// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
580//
581//      REDUCE: %[[RES:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
582//      REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
583// REDUCE-SAME:  : vector<4x3f32> to vector<3x4xf32>
584//
585//      REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
586// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32>
587// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32>
588// REDUCE-NEXT: %[[s00:.*]] = vector.reduction <add>, %[[ab00]] : vector<4xf32> into f32
589// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32>
590//
591//      ...
592//
593//      REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
594// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32>
595// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32>
596// REDUCE-NEXT: %[[s12:.*]] = vector.reduction <add>, %[[ab12]] : vector<4xf32> into f32
597// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32>
598//
599//      REDUCE: return %[[c3]] : vector<2x3xf32>
600func.func @matmul(%arg0: vector<2x4xf32>,
601                          %arg1: vector<4x3xf32>,
602                          %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
603  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
604    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
605  return %0 : vector<2x3xf32>
606}
607
608// CHECK-LABEL: func @broadcast_vec1d_from_scalar
609// CHECK-SAME: %[[A:.*0]]: f32
610// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
611// CHECK:      return %[[T0]] : vector<2xf32>
612
613func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
614  %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
615  return %0 : vector<2xf32>
616}
617
618// CHECK-LABEL: func @broadcast_vec2d_from_scalar
619// CHECK-SAME: %[[A:.*0]]: f32
620// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
621// CHECK:      return %[[T0]] : vector<2x3xf32>
622
623func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
624  %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
625  return %0 : vector<2x3xf32>
626}
627
628// CHECK-LABEL: func @broadcast_vec3d_from_scalar
629// CHECK-SAME: %[[A:.*0]]: f32
630// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
631// CHECK:      return %[[T0]] : vector<2x3x4xf32>
632
633func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
634  %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32>
635  return %0 : vector<2x3x4xf32>
636}
637
638// CHECK-LABEL: func @broadcast_vec1d_from_vec1d
639// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
640// CHECK:      return %[[A]] : vector<2xf32>
641
642func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
643  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32>
644  return %0 : vector<2xf32>
645}
646
647// CHECK-LABEL: func @broadcast_vec2d_from_vec1d
648// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
649// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
650// CHECK:      %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32>
651// CHECK:      %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32>
652// CHECK:      %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32>
653// CHECK:      return %[[T2]] : vector<3x2xf32>
654
655func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
656  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
657  return %0 : vector<3x2xf32>
658}
659
660// CHECK-LABEL: func @broadcast_vec3d_from_vec1d
661// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
662// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
663// CHECK:      %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
664// CHECK:      %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32>
665// CHECK:      %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32>
666// CHECK:      %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32>
667// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
668// CHECK:      %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
669// CHECK:      %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
670// CHECK:      %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
671// CHECK:       return %[[T6]] : vector<4x3x2xf32>
672
673func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
674  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
675  return %0 : vector<4x3x2xf32>
676}
677
678// CHECK-LABEL: func @broadcast_vec3d_from_vec2d
679// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32>
680// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
681// CHECK:      %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
682// CHECK:      %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
683// CHECK:      %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
684// CHECK:      %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
685// CHECK:      return %[[T3]] : vector<4x3x2xf32>
686
687func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
688  %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32>
689  return %0 : vector<4x3x2xf32>
690}
691
692// CHECK-LABEL: func @broadcast_stretch
693// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
694// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32>
695// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
696// CHECK:      return %[[T1]] : vector<4xf32>
697
698func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
699  %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
700  return %0 : vector<4xf32>
701}
702
703// CHECK-LABEL: func @broadcast_stretch_at_start
704// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32>
705// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
706// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32>
707// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32>
708// CHECK:      %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32>
709// CHECK:      %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32>
710// CHECK:      return %[[T3]] : vector<3x4xf32>
711
712func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
713  %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
714  return %0 : vector<3x4xf32>
715}
716
717// CHECK-LABEL: func @broadcast_stretch_at_end
718// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
719// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
720// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32>
721// CHECK:      %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
722// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
723// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32>
724// CHECK:      %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
725// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
726// CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32>
727// CHECK:      %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
728// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
729// CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32>
730// CHECK:      %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
731// CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
732// CHECK:      return %[[T15]] : vector<4x3xf32>
733
734func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
735  %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
736  return %0 : vector<4x3xf32>
737}
738
739// CHECK-LABEL: func @broadcast_stretch_in_middle
740// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32>
741// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
742// CHECK:      %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
743// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32>
744// CHECK:      %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
745// CHECK:      %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
746// CHECK:      %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
747// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
748// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32>
749// CHECK:      %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
750// CHECK:      %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
751// CHECK:      %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
752// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
753// CHECK:      %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32>
754// CHECK:      %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
755// CHECK:      %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
756// CHECK:      %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
757// CHECK:      %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
758// CHECK:      %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32>
759// CHECK:      %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
760// CHECK:      %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
761// CHECK:      %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
762// CHECK:      %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
763// CHECK:      return %[[T23]] : vector<4x3x2xf32>
764
765func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
766  %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
767  return %0 : vector<4x3x2xf32>
768}
769
770// CHECK-LABEL: func @genbool_1d
771// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1>
772// CHECK: return %[[T0]] : vector<8xi1>
773
774func.func @genbool_1d() -> vector<8xi1> {
775  %0 = vector.constant_mask [4] : vector<8xi1>
776  return %0 : vector<8xi1>
777}
778
779// CHECK-LABEL: func @genbool_2d
780// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
781// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
782// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
783// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
784// CHECK: return %[[T1]] : vector<4x4xi1>
785
786func.func @genbool_2d() -> vector<4x4xi1> {
787  %v = vector.constant_mask [2, 2] : vector<4x4xi1>
788  return %v: vector<4x4xi1>
789}
790
791// CHECK-LABEL: func @genbool_3d
792// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
793// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
794// CHECK: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
795// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
796// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
797// CHECK: return %[[T1]] : vector<2x3x4xi1>
798
799func.func @genbool_3d() -> vector<2x3x4xi1> {
800  %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
801  return %v: vector<2x3x4xi1>
802}
803
804// CHECK-LABEL: func @genbool_var_1d(
805// CHECK-SAME: %[[A:.*]]: index)
806// CHECK:      %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1>
807// CHECK:      return %[[T0]] : vector<3xi1>
808
809func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
810  %0 = vector.create_mask %arg0 : vector<3xi1>
811  return %0 : vector<3xi1>
812}
813
814// CHECK-LABEL: func @genbool_var_2d(
815// CHECK-SAME: %[[A:.*0]]: index,
816// CHECK-SAME: %[[B:.*1]]: index)
817// CHECK:      %[[C1:.*]] = arith.constant dense<false> : vector<3xi1>
818// CHECK:      %[[C2:.*]] = arith.constant dense<false> : vector<2x3xi1>
819// CHECK:      %[[c0:.*]] = arith.constant 0 : index
820// CHECK:      %[[c1:.*]] = arith.constant 1 : index
821// CHECK:      %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
822// CHECK:      %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
823// CHECK:      %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
824// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
825// CHECK:      %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
826// CHECK:      %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
827// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
828// CHECK:      return %[[T6]] : vector<2x3xi1>
829
830func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
831  %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
832  return %0 : vector<2x3xi1>
833}
834
835// CHECK-LABEL: func @genbool_var_3d(
836// CHECK-SAME: %[[A:.*0]]: index,
837// CHECK-SAME: %[[B:.*1]]: index,
838// CHECK-SAME: %[[C:.*2]]: index)
839// CHECK-DAG:  %[[C1:.*]] = arith.constant dense<false> : vector<7xi1>
840// CHECK-DAG:  %[[C2:.*]] = arith.constant dense<false> : vector<1x7xi1>
841// CHECK-DAG:  %[[C3:.*]] = arith.constant dense<false> : vector<2x1x7xi1>
842// CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
843// CHECK-DAG:  %[[c1:.*]] = arith.constant 1 : index
844// CHECK:      %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
845// CHECK:      %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index
846// CHECK:      %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
847// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
848// CHECK:      %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
849// CHECK:      %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
850// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
851// CHECK:      %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
852// CHECK:      %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
853// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
854// CHECK:      return %[[T9]] : vector<2x1x7xi1>
855
856func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> {
857  %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
858  return %0 : vector<2x1x7xi1>
859}
860
861// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
862// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
863// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
864// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32>
865// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32>
866// CHECK:     %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
867// CHECK:     %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
868// CHECK:     %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
869// CHECK:     %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32>
870// CHECK:     %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
871// CHECK:     %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
872// CHECK:     %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
873// CHECK:     %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
874// CHECK:     return %[[S]] : vector<2xi32>
875
876func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
877  %res = vector.contract {
878    indexing_maps = [
879      affine_map<(d0, d1, d2) -> (d0, d2)>,
880      affine_map<(d0, d1, d2) -> (d1, d2)>,
881      affine_map<(d0, d1, d2) -> (d1)>
882    ],
883    iterator_types = ["reduction", "parallel", "reduction"],
884    kind = #vector.kind<add>
885  } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
886  return %res : vector<2xi32>
887}
888
889#matmat_accesses_0 = [
890  affine_map<(m, n, k) -> (m, k)>,
891  affine_map<(m, n, k) -> (k, n)>,
892  affine_map<(m, n, k) -> (m, n)>
893]
894#matmat_trait_0 = {
895  indexing_maps = #matmat_accesses_0,
896  iterator_types = ["parallel", "parallel", "reduction"]
897}
898
899// OUTERPRODUCT-LABEL: func @matmul_0
900// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
901// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
902// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
903//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
904//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
905//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
906//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
907//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
908func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
909-> vector<2x3xf32>
910{
911  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
912    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
913  return %0 : vector<2x3xf32>
914}
915
916// OUTERPRODUCT-LABEL: func @matmul_0_mixed
917// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
918// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
919// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
920//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
921//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16>
922//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16>
923//      OUTERPRODUCT: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
924//      OUTERPRODUCT: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
925//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
926//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
927func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>)
928-> vector<2x3xf32>
929{
930  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
931    : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
932  return %0 : vector<2x3xf32>
933}
934
935#matmat_accesses_1 = [
936  affine_map<(m, n, k) -> (m, k)>,
937  affine_map<(m, n, k) -> (n, k)>,
938  affine_map<(m, n, k) -> (m, n)>
939]
940#matmat_trait_1 = {
941  indexing_maps = #matmat_accesses_1,
942  iterator_types = ["parallel", "parallel", "reduction"]
943}
944
945// OUTERPRODUCT-LABEL: func @matmul_1
946// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
947// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
948// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
949//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
950//      OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
951//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
952//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
953//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
954//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
955func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
956-> vector<2x3xf32>
957{
958  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
959    : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
960  return %0 : vector<2x3xf32>
961}
962
963#matmat_accesses_2 = [
964  affine_map<(m, n, k) -> (k, m)>,
965  affine_map<(m, n, k) -> (k, n)>,
966  affine_map<(m, n, k) -> (m, n)>
967]
968#matmat_trait_2 = {
969  indexing_maps = #matmat_accesses_2,
970  iterator_types = ["parallel", "parallel", "reduction"]
971}
972
973// OUTERPRODUCT-LABEL: func @matmul_2
974// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
975// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
976// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
977//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
978//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
979//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
980//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
981func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
982-> vector<2x3xf32>
983{
984  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
985    : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
986  return %0 : vector<2x3xf32>
987}
988
989#matmat_accesses_3 = [
990  affine_map<(m, n, k) -> (k, m)>,
991  affine_map<(m, n, k) -> (n, k)>,
992  affine_map<(m, n, k) -> (m, n)>
993]
994#matmat_trait_3 = {
995  indexing_maps = #matmat_accesses_3,
996  iterator_types = ["parallel", "parallel", "reduction"]
997}
998
999// OUTERPRODUCT-LABEL: func @matmul_3
1000// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
1001// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
1002// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
1003//      OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
1004//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
1005//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
1006//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
1007//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
1008func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
1009-> vector<2x3xf32>
1010{
1011  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
1012    : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
1013  return %0 : vector<2x3xf32>
1014}
1015
1016#matmat_accesses_4 = [
1017  affine_map<(m, n, k) -> (m, k)>,
1018  affine_map<(m, n, k) -> (k, n)>,
1019  affine_map<(m, n, k) -> (n, m)>
1020]
1021#matmat_trait_4 = {
1022  indexing_maps = #matmat_accesses_4,
1023  iterator_types = ["parallel", "parallel", "reduction"]
1024}
1025
1026// OUTERPRODUCT-LABEL: func @matmul_4
1027// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1028// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1029// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1030//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1031//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1032//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1033//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1034//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1035func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1036-> vector<3x2xf32>
1037{
1038  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
1039    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1040  return %0 : vector<3x2xf32>
1041}
1042
1043#matmat_accesses_5 = [
1044  affine_map<(m, n, k) -> (m, k)>,
1045  affine_map<(m, n, k) -> (k, n)>,
1046  affine_map<(m, n, k) -> (n, m)>
1047]
1048#matmat_trait_5 = {
1049  indexing_maps = #matmat_accesses_5,
1050  iterator_types = ["parallel", "parallel", "reduction"]
1051}
1052
1053// OUTERPRODUCT-LABEL: func @matmul_5
1054// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1055// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1056// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1057//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1058//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1059//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1060//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1061//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1062func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1063-> vector<3x2xf32>
1064{
1065  %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2
1066    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1067  return %0 : vector<3x2xf32>
1068}
1069
1070#matmat_accesses_6 = [
1071  affine_map<(m, n, k) -> (m, k)>,
1072  affine_map<(m, n, k) -> (k, n)>,
1073  affine_map<(m, n, k) -> (n, m)>
1074]
1075#matmat_trait_6 = {
1076  indexing_maps = #matmat_accesses_6,
1077  iterator_types = ["parallel", "parallel", "reduction"]
1078}
1079
1080// OUTERPRODUCT-LABEL: func @matmul_6
1081// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1082// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1083// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1084//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1085//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1086//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1087//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1088//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1089func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1090-> vector<3x2xf32>
1091{
1092  %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2
1093    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1094  return %0 : vector<3x2xf32>
1095}
1096
1097#matmat_accesses_7 = [
1098  affine_map<(m, n, k) -> (m, k)>,
1099  affine_map<(m, n, k) -> (k, n)>,
1100  affine_map<(m, n, k) -> (n, m)>
1101]
1102#matmat_trait_7 = {
1103  indexing_maps = #matmat_accesses_7,
1104  iterator_types = ["parallel", "parallel", "reduction"]
1105}
1106
1107// OUTERPRODUCT-LABEL: func @matmul_7
1108// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1109// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1110// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1111//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1112//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1113//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1114//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1115//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1116func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1117-> vector<3x2xf32>
1118{
1119  %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2
1120    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1121  return %0 : vector<3x2xf32>
1122}
1123
1124// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
1125// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
1126// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
1127// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
1128//      FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
1129func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
1130-> vector<4x4xf32>
1131{
1132  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
1133    : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
1134  return %0 : vector<4x4xf32>
1135}
1136
1137// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
1138// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
1139// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
1140// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
1141//      FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
1142func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
1143-> vector<3x4xf32>
1144{
1145  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
1146    : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
1147  return %0 : vector<3x4xf32>
1148}
1149
1150// PARALLEL-LABEL: func @parrallel_contract_lowering
1151//       PARALLEL:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
1152//       PARALLEL:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
1153//       PARALLEL:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
1154//       PARALLEL:   return %[[F]] : vector<4xf32>
1155func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
1156  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
1157  return %0 : vector<4xf32>
1158}
1159
1160// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast
1161//       PARALLEL:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
1162//       PARALLEL:   %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
1163//       PARALLEL:   %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32>
1164//       PARALLEL:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
1165//       PARALLEL:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
1166//       PARALLEL:   return %[[F]] : vector<4xf32>
1167func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
1168  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32>
1169  return %0 : vector<4xf32>
1170}
1171
1172// PARALLEL-LABEL: func @parrallel_contract_lowering
1173//       PARALLEL:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
1174//       PARALLEL:   %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
1175//       PARALLEL:   %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32>
1176//       PARALLEL:   %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32>
1177//       PARALLEL:   %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32>
1178//       PARALLEL:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32>
1179//       PARALLEL:   return %[[F]] : vector<4xf32>
1180func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
1181  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32>
1182  return %0 : vector<4xf32>
1183}
1184
1185// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar
1186//       PARALLEL:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32>
1187//       PARALLEL:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32>
1188//       PARALLEL:   %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32
1189//       PARALLEL:   %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
1190//       PARALLEL:   return %[[A]] : f32
1191func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 {
1192  %0 = vector.contract {
1193    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1194                     affine_map<(d0, d1) -> (d0, d1)>,
1195                     affine_map<(d0, d1) -> ()>],
1196    iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
1197  %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
1198  return %0 : f32
1199}
1200