1// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s
2
3#matvec_accesses = [
4  affine_map<(i, j) -> (i, j)>,
5  affine_map<(i, j) -> (j)>,
6  affine_map<(i, j) -> (i)>
7]
8#matvec_trait = {
9  indexing_maps = #matvec_accesses,
10  iterator_types = ["parallel", "reduction"]
11}
12#matvecmax_trait = {
13  indexing_maps = #matvec_accesses,
14  iterator_types = ["parallel", "reduction"],
15  kind = #vector.kind<maxf>
16}
17
18#mattransvec_accesses = [
19  affine_map<(i, j) -> (j, i)>,
20  affine_map<(i, j) -> (j)>,
21  affine_map<(i, j) -> (i)>
22]
23#mattransvec_trait = {
24  indexing_maps = #mattransvec_accesses,
25  iterator_types = ["parallel", "reduction"]
26}
27
28#vecmat_accesses = [
29  affine_map<(i, j) -> (j)>,
30  affine_map<(i, j) -> (i, j)>,
31  affine_map<(i, j) -> (i)>
32]
33#vecmat_trait = {
34  indexing_maps = #vecmat_accesses,
35  iterator_types = ["parallel", "reduction"]
36}
37
38#vecmattrans_accesses = [
39  affine_map<(i, j) -> (j)>,
40  affine_map<(i, j) -> (j, i)>,
41  affine_map<(i, j) -> (i)>
42]
43#vecmattrans_trait = {
44  indexing_maps = #vecmattrans_accesses,
45  iterator_types = ["parallel", "reduction"]
46}
47
48#redpar_vecmattrans_accesses = [
49  affine_map<(i, j) -> (i)>,
50  affine_map<(i, j) -> (i, j)>,
51  affine_map<(i, j) -> (j)>
52]
53#redpar_vecmattrans_trait = {
54  indexing_maps = #redpar_vecmattrans_accesses,
55  iterator_types = ["reduction", "parallel"]
56}
57
58// CHECK-LABEL: func @matvec2x2
59// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
60// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
61// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
62// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
63// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
64// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
65// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
66// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
67// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
68// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
69// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
70// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
71// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
72// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
73// CHECK: return
74func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
75                                                %arg2: memref<vector<2xf32>>) {
76  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
77  %x = memref.load %arg1[] : memref<vector<2xf32>>
78  %b = memref.load %arg2[] : memref<vector<2xf32>>
79  %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
80  memref.store %0, %arg2[] : memref<vector<2xf32>>
81  return
82}
83
84// CHECK-LABEL: func @matvecmax2x2
85// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
86// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
87// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
88// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
89// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
90// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
91// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
92// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
93// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
94// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
95// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
96// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
97// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
98// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
99// CHECK: return
100func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
101                                                   %arg2: memref<vector<2xf32>>) {
102  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
103  %x = memref.load %arg1[] : memref<vector<2xf32>>
104  %b = memref.load %arg2[] : memref<vector<2xf32>>
105  %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
106  memref.store %0, %arg2[] : memref<vector<2xf32>>
107  return
108}
109
110// CHECK-LABEL: func @mattransvec2x2
111// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
112// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
113// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
114// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
115// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
116// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
117// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
118// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
119// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
120// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
121// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
122// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
123// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
124// CHECK: return
125func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
126                                                     %arg2: memref<vector<2xf32>>) {
127  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
128  %x = memref.load %arg1[] : memref<vector<2xf32>>
129  %b = memref.load %arg2[] : memref<vector<2xf32>>
130  %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
131  memref.store %0, %arg2[] : memref<vector<2xf32>>
132  return
133}
134
135// CHECK-LABEL: func @vecmat2x2
136// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
137// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
138// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
139// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
140// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
141// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
142// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
143// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
144// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
145// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
146// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
147// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
148// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
149// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
150// CHECK: return
151func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
152                                                %arg2: memref<vector<2xf32>>) {
153  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
154  %x = memref.load %arg1[] : memref<vector<2xf32>>
155  %b = memref.load %arg2[] : memref<vector<2xf32>>
156  %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
157  memref.store %0, %arg2[] : memref<vector<2xf32>>
158  return
159}
160
161// CHECK-LABEL: func @vecmattrans2x2
162// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
163// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
164// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
165// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
166// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
167// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
168// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
169// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
170// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
171// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
172// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
173// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
174// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
175// CHECK: return
176func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
177                                                     %arg2: memref<vector<2xf32>>) {
178  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
179  %x = memref.load %arg1[] : memref<vector<2xf32>>
180  %b = memref.load %arg2[] : memref<vector<2xf32>>
181  %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
182  memref.store %0, %arg2[] : memref<vector<2xf32>>
183  return
184}
185
186// CHECK-LABEL: func @redpar_vecmattrans2x2
187// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
188// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
189// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
190// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
191// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
192// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
193// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
194// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
195// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
196// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
197// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
198// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
199// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
200// CHECK: return
201func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
202                                                     %arg2: memref<vector<2xf32>>) {
203  %A = memref.load %arg0[] : memref<vector<2x2xf32>>
204  %x = memref.load %arg1[] : memref<vector<2xf32>>
205  %b = memref.load %arg2[] : memref<vector<2xf32>>
206  %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
207  memref.store %0, %arg2[] : memref<vector<2xf32>>
208  return
209}
210