1// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
2// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
3// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
4// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
5
6#dotp_accesses = [
7  affine_map<(i) -> (i)>,
8  affine_map<(i) -> (i)>,
9  affine_map<(i) -> ()>
10]
11#dotp_trait = {
12  indexing_maps = #dotp_accesses,
13  iterator_types = ["reduction"]
14}
15
16// CHECK-LABEL: func @extract_contract1
17// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
18// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
19// CHECK-SAME: %[[C:.*2]]: f32
20// CHECK:      %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
21// CHECK:      %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32
22// CHECK:      %[[ACC:.*]] = arith.addf %[[R]], %[[C]] : f32
23// CHECK:      return %[[ACC]] : f32
24
25func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
26  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
27    : vector<4xf32>, vector<4xf32> into f32
28  return %0 : f32
29}
30
31// CHECK-LABEL: func @extract_contract1_int
32// CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
33// CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
34// CHECK-SAME: %[[C:.*2]]: i32
35// CHECK:      %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32>
36// CHECK:      %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32
37// CHECK:      %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32
38// CHECK:      return %[[ACC]] : i32
39
40func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
41  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
42    : vector<4xi32>, vector<4xi32> into i32
43  return %0 : i32
44}
45
46#matvec_accesses = [
47  affine_map<(i, j) -> (i, j)>,
48  affine_map<(i, j) -> (j)>,
49  affine_map<(i, j) -> (i)>
50]
51#matvec_trait = {
52  indexing_maps = #matvec_accesses,
53  iterator_types = ["parallel", "reduction"]
54}
55
56// CHECK-LABEL: func @extract_contract2
57// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
58// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
59// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
60// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
61// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
62// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32>
63// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
64// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
65// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
66// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
67// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
68// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
69// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
70// CHECK:      return %[[T10]] : vector<2xf32>
71
72func @extract_contract2(%arg0: vector<2x3xf32>,
73                        %arg1: vector<3xf32>,
74			%arg2: vector<2xf32>) -> vector<2xf32> {
75  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
76    : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
77  return %0 : vector<2xf32>
78}
79
80// CHECK-LABEL: func @extract_contract2_int
81// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
82// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
83// CHECK-SAME: %[[C:.*2]]: vector<2xi32>
84// CHECK:      %[[R:.*]] = arith.constant dense<0> : vector<2xi32>
85// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32>
86// CHECK:      %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32>
87// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32
88// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32>
89// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32>
90// CHECK:      %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
91// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32
92// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32>
93// CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32>
94// CHECK:      return %[[T10]] : vector<2xi32>
95func @extract_contract2_int(%arg0: vector<2x3xi32>,
96                        %arg1: vector<3xi32>,
97			%arg2: vector<2xi32>) -> vector<2xi32> {
98  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
99    : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
100  return %0 : vector<2xi32>
101}
102
103#vecmat_accesses = [
104  affine_map<(i, j) -> (j)>,
105  affine_map<(i, j) -> (i, j)>,
106  affine_map<(i, j) -> (i)>
107]
108#vecmat_trait = {
109  indexing_maps = #vecmat_accesses,
110  iterator_types = ["parallel", "reduction"]
111}
112
113// CHECK-LABEL: func @extract_contract3
114// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
115// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
116// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
117// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
118// CHECK:      %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
119// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32>
120// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
121// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
122// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
123// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32>
124// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
125// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
126// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
127// CHECK:      return %[[T10]] : vector<2xf32>
128
129func @extract_contract3(%arg0: vector<3xf32>,
130                        %arg1: vector<2x3xf32>,
131                        %arg2: vector<2xf32>) -> vector<2xf32> {
132  %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
133    : vector<3xf32>, vector<2x3xf32> into vector<2xf32>
134  return %0 : vector<2xf32>
135}
136
137#matmat_accesses = [
138  affine_map<(i, j, k) -> (i, k)>,
139  affine_map<(i, j, k) -> (k, j)>,
140  affine_map<(i, j, k) -> (i, j)>
141]
142#matmat_trait = {
143  indexing_maps = #matmat_accesses,
144  iterator_types = ["parallel", "parallel", "reduction"]
145}
146
147// CHECK-LABEL: func @extract_contract4
148// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
149// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
150// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
151// CHECK:    %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
152// CHECK:    %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
153// CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
154// CHECK:    %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
155// CHECK:    %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
156// CHECK:    %[[T10:.*]] = vector.reduction "add", %[[T9]] : vector<2xf32> into f32
157// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
158//
159// CHECK:    %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
160// CHECK:    %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
161// CHECK:    %[[T20:.*]] = vector.reduction "add", %[[T19]] : vector<2xf32> into f32
162// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
163//
164// CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
165// CHECK:    %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
166// CHECK:    %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
167// CHECK:    %[[T33:.*]] = vector.reduction "add", %[[T32]] : vector<2xf32> into f32
168// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
169//
170// CHECK:    %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32>
171// CHECK:    %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
172// CHECK:    %[[T42:.*]] = vector.reduction "add", %[[T41]] : vector<2xf32> into f32
173// CHECK:    %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
174//
175// CHECK:    %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
176// CHECK:    return %[[T52]] : vector<2x2xf32>
177
178func @extract_contract4(%arg0: vector<2x2xf32>,
179                        %arg1: vector<2x2xf32>,
180                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
181  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
182    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
183  return %0 : vector<2x2xf32>
184}
185
186#contraction2d_accesses = [
187  affine_map<(i, j) -> (i, j)>,
188  affine_map<(i, j) -> (i, j)>,
189  affine_map<(i, j) -> ()>
190]
191#contraction2d_trait = {
192  indexing_maps = #contraction2d_accesses,
193  iterator_types = ["reduction", "reduction"]
194}
195
196// CHECK-LABEL: func @full_contract1
197// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
198// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
199// CHECK-SAME: %[[C:.*2]]: f32
200// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
201// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
202// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32>
203// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
204// CHECK:      %[[T4:.*]] = arith.addf %[[T3]], %[[C]] : f32
205// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
206// CHECK:      %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
207// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32>
208// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
209// CHECK:      %[[T9:.*]] = arith.addf %[[T8]], %[[T4]] : f32
210// CHECK:      return %[[T9]] : f32
211
212func @full_contract1(%arg0: vector<2x3xf32>,
213                     %arg1: vector<2x3xf32>,
214		     %arg2: f32) -> f32 {
215  %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
216    : vector<2x3xf32>, vector<2x3xf32> into f32
217  return %0 : f32
218}
219
220#contraction2d_trans_accesses = [
221  affine_map<(i, j) -> (i, j)>,
222  affine_map<(i, j) -> (j, i)>,
223  affine_map<(i, j) -> ()>
224]
225#contraction2d_trans_trait = {
226  indexing_maps = #contraction2d_trans_accesses,
227  iterator_types = ["reduction", "reduction"]
228}
229
230// CHECK-LABEL: func @full_contract2
231// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
232// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>,
233// CHECK-SAME: %[[C:.*2]]: f32
234// CHECK:      %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32>
235// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
236// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32>
237// CHECK:      %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32>
238// CHECK:      %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32>
239// CHECK:      %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32>
240// CHECK:      %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
241// CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
242// CHECK:      %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32>
243// CHECK:      %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32
244// CHECK:      %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32
245//
246// CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
247// CHECK:      %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
248// CHECK:      %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32>
249// CHECK:      %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32>
250// CHECK:      %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32>
251// CHECK:      %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
252// CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
253// CHECK:      %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32>
254// CHECK:      %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32
255// CHECK:      %[[ACC1:.*]] = arith.addf %[[T23]], %[[ACC0]] : f32
256// CHECK:      return %[[ACC1]] : f32
257
258func @full_contract2(%arg0: vector<2x3xf32>,
259                     %arg1: vector<3x2xf32>,
260		     %arg2: f32) -> f32 {
261  %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
262    : vector<2x3xf32>, vector<3x2xf32> into f32
263  return %0 : f32
264}
265
266// CHECK-LABEL: func @outerproduct_noacc
267// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
268// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
269// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
270// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
271// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
272// CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
273// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
274// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
275// CHECK:      %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
276// CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
277// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
278// CHECK:      return %[[T7]] : vector<2x3xf32>
279
280func @outerproduct_noacc(%arg0: vector<2xf32>,
281                         %arg1: vector<3xf32>) -> vector<2x3xf32> {
282  %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
283  return %0: vector<2x3xf32>
284}
285
286// CHECK-LABEL: func @outerproduct_acc
287// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
288// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
289// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
290// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
291// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
292// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
293// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
294// CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
295// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
296// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
297// CHECK:      %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
298// CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
299// CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
300// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
301// CHECK:      return %[[T9]] : vector<2x3xf32>
302
303func @outerproduct_acc(%arg0: vector<2xf32>,
304                       %arg1: vector<3xf32>,
305                       %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
306  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
307  return %0: vector<2x3xf32>
308}
309
310// CHECK-LABEL: func @outerproduct_noacc_int
311// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
312// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
313// CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
314// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
315// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
316// CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
317// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
318// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32>
319// CHECK:      %[[T5:.*]] = splat %[[T4]] : vector<3xi32>
320// CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
321// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
322// CHECK:      return %[[T7]] : vector<2x3xi32>
323func @outerproduct_noacc_int(%arg0: vector<2xi32>,
324                             %arg1: vector<3xi32>) -> vector<2x3xi32> {
325  %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32>
326  return %0: vector<2x3xi32>
327}
328
329// CHECK-LABEL: func @outerproduct_acc_int
330// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
331// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
332// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
333// CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
334// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
335// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
336// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32>
337// CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
338// CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
339// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
340// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32>
341// CHECK:      %[[T7:.*]] = splat %[[T6]] : vector<3xi32>
342// CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32>
343// CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
344// CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
345// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32>
346// CHECK:      return %[[T11]] : vector<2x3xi32>
347func @outerproduct_acc_int(%arg0: vector<2xi32>,
348                           %arg1: vector<3xi32>,
349                           %arg2: vector<2x3xi32>) -> vector<2x3xi32> {
350  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32>
351  return %0: vector<2x3xi32>
352}
353
354// CHECK-LABEL: func @axpy_fp(
355// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
356// CHECK-SAME: %[[B:.*1]]: f32)
357// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
358// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
359// CHECK: return %[[T1]] : vector<16xf32>
360func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
361   %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32
362   return %0: vector<16xf32>
363}
364
365// CHECK-LABEL: func @axpy_fp_add(
366// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
367// CHECK-SAME: %[[B:.*1]]: f32,
368// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
369// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
370// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
371// CHECK: return %[[T1]] : vector<16xf32>
372func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
373   %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32
374   return %0: vector<16xf32>
375}
376
377// CHECK-LABEL: func @axpy_int(
378// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
379// CHECK-SAME: %[[B:.*1]]: i32)
380// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
381// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
382// CHECK: return %[[T1]] : vector<16xi32>
383func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
384   %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32
385   return %0: vector<16xi32>
386}
387
388// CHECK-LABEL: func @axpy_int_add(
389// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
390// CHECK-SAME: %[[B:.*1]]: i32,
391// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
392// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
393// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
394// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
395// CHECK: return %[[T2]] : vector<16xi32>
396func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> {
397   %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32
398   return %0: vector<16xi32>
399}
400
401// CHECK-LABEL: func @nop_shape_cast
402// CHECK-SAME: %[[A:.*]]: vector<16xf32>
403// CHECK:      return %[[A]] : vector<16xf32>
404
405func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
406  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
407  return %0 : vector<16xf32>
408}
409
410// CHECK-LABEL: func @cancel_shape_cast
411// FIXME: PR49590
412// HECK-SAME: %[[A:.*]]: vector<16xf32>
413// HECK:      return %[[A]] : vector<16xf32>
414
415func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
416  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
417  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
418  return %1 : vector<16xf32>
419}
420
421// Shape up and downcasts for 2-D vectors, for supporting conversion to
422// llvm.matrix operations
423// CHECK-LABEL: func @shape_casts
424func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
425  // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
426  // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
427  // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32>
428  //
429  // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
430  // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
431  //
432  // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32>
433  //
434  // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
435  // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
436  //
437  %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
438  // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
439  %r0 = arith.addf %0, %0: vector<4xf32>
440  //
441  // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
442  // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
443  // CHECK-SAME: vector<4xf32> to vector<2xf32>
444  //
445  // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
446  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
447  //
448  // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
449  // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
450  // CHECK-SAME: vector<4xf32> to vector<2xf32>
451  //
452  // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
453  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
454  //
455  %1 = vector.shape_cast %r0  : vector<4xf32> to vector<2x2xf32>
456  // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
457  return %r0, %1 : vector<4xf32>, vector<2x2xf32>
458}
459
460// CHECK-LABEL: func @shape_cast_2d2d
461// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
462// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
463// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32>
464// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
465// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32>
466// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
467// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32>
468// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
469// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32>
470// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
471// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32>
472// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
473// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32>
474// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
475// CHECK: return %[[T11]] : vector<2x3xf32>
476
477func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
478  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
479  return %s : vector<2x3xf32>
480}
481
482// CHECK-LABEL: func @shape_cast_3d1d
483// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
484// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
485// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32>
486// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
487// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32>
488// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
489// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32>
490// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
491// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32>
492// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
493// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32>
494// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
495// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32>
496// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
497// CHECK: return %[[T11]] : vector<6xf32>
498
499func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
500  %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
501  return %s : vector<6xf32>
502}
503
504// CHECK-LABEL: func @shape_cast_1d3d
505// CHECK-SAME: %[[A:.*]]: vector<6xf32>
506// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
507// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32>
508// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
509// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32>
510// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
511// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32>
512// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
513// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32>
514// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
515// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32>
516// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
517// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32>
518// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
519// CHECK: return %[[T11]] : vector<2x1x3xf32>
520
521func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
522  %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
523  return %s : vector<2x1x3xf32>
524}
525
526// MATRIX-LABEL: func @matmul
527// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
528// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
529// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
530//      MATRIX:  %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
531//      MATRIX:  %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32>
532//      MATRIX:  %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
533//      MATRIX:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
534//      MATRIX:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
535//      MATRIX:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
536//      MATRIX:  %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
537//      MATRIX:  %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
538//      MATRIX:  %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
539//      MATRIX:  %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
540//      MATRIX:  %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
541//      MATRIX:  %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
542//      MATRIX:  %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
543//      MATRIX:  %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
544//      MATRIX:  %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
545//      MATRIX:  %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
546//      MATRIX:  %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
547//      MATRIX:  %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32>
548//      MATRIX:  %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
549//      MATRIX:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
550//      MATRIX:  %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
551
552// OUTERPRODUCT-LABEL: func @matmul
553// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
554// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
555// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
556//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
557// OUTERPRODUCT-SAME:  : vector<2x4xf32> to vector<4x2xf32>
558//
559//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32>
560//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
561//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
562// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
563//
564//      OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32>
565//      OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
566//      OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
567// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
568//
569//      OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32>
570//      OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
571//      OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
572// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
573//
574//      OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32>
575//      OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
576//      OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
577// OUTERPRODUCT-SAME:  : vector<2xf32>, vector<3xf32>
578//
579//      OUTERPRODUCT: return %[[c3]] : vector<2x3xf32>
580
581// REDUCE-LABEL: func @matmul
582// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
583// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
584// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
585//
586//      REDUCE: %[[RES:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
587//      REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
588// REDUCE-SAME:  : vector<4x3f32> to vector<3x4xf32>
589//
590//      REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
591// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32>
592// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32>
593// REDUCE-NEXT: %[[s00:.*]] = vector.reduction "add", %[[ab00]] : vector<4xf32> into f32
594// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32>
595//
596//      ...
597//
598//      REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
599// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32>
600// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32>
601// REDUCE-NEXT: %[[s12:.*]] = vector.reduction "add", %[[ab12]] : vector<4xf32> into f32
602// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32>
603//
604//      REDUCE: return %[[c3]] : vector<2x3xf32>
605func @matmul(%arg0: vector<2x4xf32>,
606                          %arg1: vector<4x3xf32>,
607                          %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
608  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
609    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
610  return %0 : vector<2x3xf32>
611}
612
613// CHECK-LABEL: func @broadcast_vec1d_from_scalar
614// CHECK-SAME: %[[A:.*0]]: f32
615// CHECK:      %[[T0:.*]] = splat %[[A]] : vector<2xf32>
616// CHECK:      return %[[T0]] : vector<2xf32>
617
618func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
619  %0 = vector.broadcast %arg0 : f32 to vector<2xf32>
620  return %0 : vector<2xf32>
621}
622
623// CHECK-LABEL: func @broadcast_vec2d_from_scalar
624// CHECK-SAME: %[[A:.*0]]: f32
625// CHECK:      %[[T0:.*]] = splat %[[A]] : vector<2x3xf32>
626// CHECK:      return %[[T0]] : vector<2x3xf32>
627
628func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
629  %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
630  return %0 : vector<2x3xf32>
631}
632
633// CHECK-LABEL: func @broadcast_vec3d_from_scalar
634// CHECK-SAME: %[[A:.*0]]: f32
635// CHECK:      %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32>
636// CHECK:      return %[[T0]] : vector<2x3x4xf32>
637
638func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
639  %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32>
640  return %0 : vector<2x3x4xf32>
641}
642
643// CHECK-LABEL: func @broadcast_vec1d_from_vec1d
644// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
645// CHECK:      return %[[A]] : vector<2xf32>
646
647func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
648  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32>
649  return %0 : vector<2xf32>
650}
651
652// CHECK-LABEL: func @broadcast_vec2d_from_vec1d
653// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
654// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
655// CHECK:      %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32>
656// CHECK:      %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32>
657// CHECK:      %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32>
658// CHECK:      return %[[T2]] : vector<3x2xf32>
659
660func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
661  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32>
662  return %0 : vector<3x2xf32>
663}
664
665// CHECK-LABEL: func @broadcast_vec3d_from_vec1d
666// CHECK-SAME: %[[A:.*0]]: vector<2xf32>
667// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
668// CHECK:      %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
669// CHECK:      %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32>
670// CHECK:      %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32>
671// CHECK:      %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32>
672// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
673// CHECK:      %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
674// CHECK:      %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
675// CHECK:      %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
676// CHECK:       return %[[T6]] : vector<4x3x2xf32>
677
678func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
679  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
680  return %0 : vector<4x3x2xf32>
681}
682
683// CHECK-LABEL: func @broadcast_vec3d_from_vec2d
684// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32>
685// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
686// CHECK:      %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
687// CHECK:      %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
688// CHECK:      %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
689// CHECK:      %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
690// CHECK:      return %[[T3]] : vector<4x3x2xf32>
691
692func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
693  %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32>
694  return %0 : vector<4x3x2xf32>
695}
696
697// CHECK-LABEL: func @broadcast_stretch
698// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
699// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32>
700// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<4xf32>
701// CHECK:      return %[[T1]] : vector<4xf32>
702
703func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
704  %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
705  return %0 : vector<4xf32>
706}
707
708// CHECK-LABEL: func @broadcast_stretch_at_start
709// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32>
710// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
711// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32>
712// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32>
713// CHECK:      %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32>
714// CHECK:      %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32>
715// CHECK:      return %[[T3]] : vector<3x4xf32>
716
717func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
718  %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
719  return %0 : vector<3x4xf32>
720}
721
722// CHECK-LABEL: func @broadcast_stretch_at_end
723// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
724// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
725// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32>
726// CHECK:      %[[T2:.*]] = splat %[[T0]] : vector<3xf32>
727// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
728// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32>
729// CHECK:      %[[T6:.*]] = splat %[[T4]] : vector<3xf32>
730// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
731// CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32>
732// CHECK:      %[[T10:.*]] = splat %[[T8]] : vector<3xf32>
733// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
734// CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32>
735// CHECK:      %[[T14:.*]] = splat %[[T12]] : vector<3xf32>
736// CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
737// CHECK:      return %[[T15]] : vector<4x3xf32>
738
739func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
740  %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
741  return %0 : vector<4x3xf32>
742}
743
744// CHECK-LABEL: func @broadcast_stretch_in_middle
745// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32>
746// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
747// CHECK:      %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
748// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32>
749// CHECK:      %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
750// CHECK:      %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32>
751// CHECK:      %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32>
752// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32>
753// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32>
754// CHECK:      %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
755// CHECK:      %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32>
756// CHECK:      %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32>
757// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32>
758// CHECK:      %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32>
759// CHECK:      %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
760// CHECK:      %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32>
761// CHECK:      %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32>
762// CHECK:      %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32>
763// CHECK:      %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32>
764// CHECK:      %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32>
765// CHECK:      %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32>
766// CHECK:      %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32>
767// CHECK:      %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32>
768// CHECK:      return %[[T23]] : vector<4x3x2xf32>
769
770func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
771  %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
772  return %0 : vector<4x3x2xf32>
773}
774
775// CHECK-LABEL: func @genbool_1d
776// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1>
777// CHECK: return %[[T0]] : vector<8xi1>
778
779func @genbool_1d() -> vector<8xi1> {
780  %0 = vector.constant_mask [4] : vector<8xi1>
781  return %0 : vector<8xi1>
782}
783
784// CHECK-LABEL: func @genbool_2d
785// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
786// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
787// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
788// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
789// CHECK: return %[[T1]] : vector<4x4xi1>
790
791func @genbool_2d() -> vector<4x4xi1> {
792  %v = vector.constant_mask [2, 2] : vector<4x4xi1>
793  return %v: vector<4x4xi1>
794}
795
796// CHECK-LABEL: func @genbool_3d
797// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
798// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
799// CHECK: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
800// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
801// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
802// CHECK: return %[[T1]] : vector<2x3x4xi1>
803
804func @genbool_3d() -> vector<2x3x4xi1> {
805  %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
806  return %v: vector<2x3x4xi1>
807}
808
809// CHECK-LABEL: func @genbool_var_1d(
810// CHECK-SAME: %[[A:.*]]: index)
811// CHECK:      %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1>
812// CHECK:      return %[[T0]] : vector<3xi1>
813
814func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
815  %0 = vector.create_mask %arg0 : vector<3xi1>
816  return %0 : vector<3xi1>
817}
818
819// CHECK-LABEL: func @genbool_var_2d(
820// CHECK-SAME: %[[A:.*0]]: index,
821// CHECK-SAME: %[[B:.*1]]: index)
822// CHECK:      %[[C1:.*]] = arith.constant dense<false> : vector<3xi1>
823// CHECK:      %[[C2:.*]] = arith.constant dense<false> : vector<2x3xi1>
824// CHECK:      %[[c0:.*]] = arith.constant 0 : index
825// CHECK:      %[[c1:.*]] = arith.constant 1 : index
826// CHECK:      %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
827// CHECK:      %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
828// CHECK:      %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
829// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
830// CHECK:      %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
831// CHECK:      %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
832// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
833// CHECK:      return %[[T6]] : vector<2x3xi1>
834
835func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
836  %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
837  return %0 : vector<2x3xi1>
838}
839
840// CHECK-LABEL: func @genbool_var_3d(
841// CHECK-SAME: %[[A:.*0]]: index,
842// CHECK-SAME: %[[B:.*1]]: index,
843// CHECK-SAME: %[[C:.*2]]: index)
844// CHECK-DAG:  %[[C1:.*]] = arith.constant dense<false> : vector<7xi1>
845// CHECK-DAG:  %[[C2:.*]] = arith.constant dense<false> : vector<1x7xi1>
846// CHECK-DAG:  %[[C3:.*]] = arith.constant dense<false> : vector<2x1x7xi1>
847// CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
848// CHECK-DAG:  %[[c1:.*]] = arith.constant 1 : index
849// CHECK:      %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
850// CHECK:      %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index
851// CHECK:      %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
852// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
853// CHECK:      %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
854// CHECK:      %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
855// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
856// CHECK:      %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
857// CHECK:      %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
858// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
859// CHECK:      return %[[T9]] : vector<2x1x7xi1>
860
861func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> {
862  %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
863  return %0 : vector<2x1x7xi1>
864}
865
866#matmat_accesses_0 = [
867  affine_map<(m, n, k) -> (m, k)>,
868  affine_map<(m, n, k) -> (k, n)>,
869  affine_map<(m, n, k) -> (m, n)>
870]
871#matmat_trait_0 = {
872  indexing_maps = #matmat_accesses_0,
873  iterator_types = ["parallel", "parallel", "reduction"]
874}
875
876// OUTERPRODUCT-LABEL: func @matmul_0
877// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
878// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
879// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
880//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
881//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
882//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
883//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
884//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
885func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
886-> vector<2x3xf32>
887{
888  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
889    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
890  return %0 : vector<2x3xf32>
891}
892
893#matmat_accesses_1 = [
894  affine_map<(m, n, k) -> (m, k)>,
895  affine_map<(m, n, k) -> (n, k)>,
896  affine_map<(m, n, k) -> (m, n)>
897]
898#matmat_trait_1 = {
899  indexing_maps = #matmat_accesses_1,
900  iterator_types = ["parallel", "parallel", "reduction"]
901}
902
903// OUTERPRODUCT-LABEL: func @matmul_1
904// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
905// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
906// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
907//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
908//      OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
909//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
910//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
911//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
912//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
913func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
914-> vector<2x3xf32>
915{
916  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
917    : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
918  return %0 : vector<2x3xf32>
919}
920
921#matmat_accesses_2 = [
922  affine_map<(m, n, k) -> (k, m)>,
923  affine_map<(m, n, k) -> (k, n)>,
924  affine_map<(m, n, k) -> (m, n)>
925]
926#matmat_trait_2 = {
927  indexing_maps = #matmat_accesses_2,
928  iterator_types = ["parallel", "parallel", "reduction"]
929}
930
931// OUTERPRODUCT-LABEL: func @matmul_2
932// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
933// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
934// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
935//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
936//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
937//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
938//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
939func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
940-> vector<2x3xf32>
941{
942  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
943    : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
944  return %0 : vector<2x3xf32>
945}
946
947#matmat_accesses_3 = [
948  affine_map<(m, n, k) -> (k, m)>,
949  affine_map<(m, n, k) -> (n, k)>,
950  affine_map<(m, n, k) -> (m, n)>
951]
952#matmat_trait_3 = {
953  indexing_maps = #matmat_accesses_3,
954  iterator_types = ["parallel", "parallel", "reduction"]
955}
956
957// OUTERPRODUCT-LABEL: func @matmul_3
958// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
959// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
960// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
961//      OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
962//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32>
963//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32>
964//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
965//      OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
966func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
967-> vector<2x3xf32>
968{
969  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
970    : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
971  return %0 : vector<2x3xf32>
972}
973
974#matmat_accesses_4 = [
975  affine_map<(m, n, k) -> (m, k)>,
976  affine_map<(m, n, k) -> (k, n)>,
977  affine_map<(m, n, k) -> (n, m)>
978]
979#matmat_trait_4 = {
980  indexing_maps = #matmat_accesses_4,
981  iterator_types = ["parallel", "parallel", "reduction"]
982}
983
984// OUTERPRODUCT-LABEL: func @matmul_4
985// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
986// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
987// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
988//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
989//      OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
990//      OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
991//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
992//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
993func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
994-> vector<3x2xf32>
995{
996  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
997    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
998  return %0 : vector<3x2xf32>
999}
1000
1001#matmat_accesses_5 = [
1002  affine_map<(m, n, k) -> (m, k)>,
1003  affine_map<(m, n, k) -> (k, n)>,
1004  affine_map<(m, n, k) -> (n, m)>
1005]
1006#matmat_trait_5 = {
1007  indexing_maps = #matmat_accesses_5,
1008  iterator_types = ["parallel", "parallel", "reduction"]
1009}
1010
1011// OUTERPRODUCT-LABEL: func @matmul_5
1012// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1013// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1014// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1015//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1016//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1017//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1018//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1019//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1020func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1021-> vector<3x2xf32>
1022{
1023  %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2
1024    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1025  return %0 : vector<3x2xf32>
1026}
1027
1028#matmat_accesses_6 = [
1029  affine_map<(m, n, k) -> (m, k)>,
1030  affine_map<(m, n, k) -> (k, n)>,
1031  affine_map<(m, n, k) -> (n, m)>
1032]
1033#matmat_trait_6 = {
1034  indexing_maps = #matmat_accesses_6,
1035  iterator_types = ["parallel", "parallel", "reduction"]
1036}
1037
1038// OUTERPRODUCT-LABEL: func @matmul_6
1039// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1040// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1041// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1042//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1043//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1044//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1045//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1046//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1047func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1048-> vector<3x2xf32>
1049{
1050  %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2
1051    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1052  return %0 : vector<3x2xf32>
1053}
1054
1055#matmat_accesses_7 = [
1056  affine_map<(m, n, k) -> (m, k)>,
1057  affine_map<(m, n, k) -> (k, n)>,
1058  affine_map<(m, n, k) -> (n, m)>
1059]
1060#matmat_trait_7 = {
1061  indexing_maps = #matmat_accesses_7,
1062  iterator_types = ["parallel", "parallel", "reduction"]
1063}
1064
1065// OUTERPRODUCT-LABEL: func @matmul_7
1066// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
1067// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
1068// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
1069//      OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
1070//      OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32>
1071//      OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32>
1072//      OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
1073//      OUTERPRODUCT: return %[[c0]] : vector<3x2xf32>
1074func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
1075-> vector<3x2xf32>
1076{
1077  %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2
1078    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
1079  return %0 : vector<3x2xf32>
1080}
1081
1082// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
1083// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
1084// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
1085// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
1086//      FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
1087func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
1088-> vector<4x4xf32>
1089{
1090  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
1091    : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
1092  return %0 : vector<4x4xf32>
1093}
1094
1095// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
1096// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
1097// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
1098// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
1099//      FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
1100func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
1101-> vector<3x4xf32>
1102{
1103  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
1104    : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
1105  return %0 : vector<3x4xf32>
1106}
1107