1// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
3
4// CHECK-LABEL: func @distribute_vector_add
5//  CHECK-SAME: (%[[ID:.*]]: index
6//  CHECK-NEXT:    %[[ADDV:.*]] = arith.addf %{{.*}}, %{{.*}} : vector<32xf32>
7//  CHECK-NEXT:    %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
8//  CHECK-NEXT:    %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
9//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<1xf32>
10//  CHECK-NEXT:    %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
11//  CHECK-NEXT:    return %[[INS]] : vector<32xf32>
12func.func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
13  %0 = arith.addf %A, %B : vector<32xf32>
14  return %0: vector<32xf32>
15}
16
17// -----
18
19// CHECK-LABEL: func @distribute_vector_add_exp
20//  CHECK-SAME: (%[[ID:.*]]: index
21//  CHECK-NEXT:    %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
22//  CHECK-NEXT:    %[[ADDV:.*]] = arith.addf %[[EXPV]], %{{.*}} : vector<32xf32>
23//  CHECK-NEXT:    %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
24//  CHECK-NEXT:    %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
25//  CHECK-NEXT:    %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
26//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXC]], %[[EXB]] : vector<1xf32>
27//  CHECK-NEXT:    %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
28//  CHECK-NEXT:    return %[[INS]] : vector<32xf32>
29func.func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
30  %C = math.exp %A : vector<32xf32>
31  %0 = arith.addf %C, %B : vector<32xf32>
32  return %0: vector<32xf32>
33}
34
35// -----
36
37// CHECK-LABEL: func @vector_add_read_write
38//  CHECK-SAME: (%[[ID:.*]]: index
39//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
40//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
41//  CHECK-NEXT:    %[[ADD1:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<1xf32>
42//  CHECK-NEXT:    %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
43//  CHECK-NEXT:    %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[EXC]] : vector<1xf32>
44//  CHECK-NEXT:    vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] {{.*}} : vector<1xf32>, memref<32xf32>
45//  CHECK-NEXT:    return
46func.func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
47  %c0 = arith.constant 0 : index
48  %cf0 = arith.constant 0.0 : f32
49  %a = vector.transfer_read %A[%c0], %cf0: memref<32xf32>, vector<32xf32>
50  %b = vector.transfer_read %B[%c0], %cf0: memref<32xf32>, vector<32xf32>
51  %acc = arith.addf %a, %b: vector<32xf32>
52  %c = vector.transfer_read %C[%c0], %cf0: memref<32xf32>, vector<32xf32>
53  %d = arith.addf %acc, %c: vector<32xf32>
54  vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
55  return
56}
57
58// -----
59
60// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
61
62//       CHECK: func @vector_add_cycle
63//  CHECK-SAME: (%[[ID:.*]]: index
64//       CHECK:    %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
65//  CHECK-NEXT:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]]], %{{.*}} : memref<64xf32>, vector<2xf32>
66//  CHECK-NEXT:    %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
67//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]]], %{{.*}} : memref<64xf32>, vector<2xf32>
68//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2xf32>
69//  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
70//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] {{.*}} : vector<2xf32>, memref<64xf32>
71//  CHECK-NEXT:    return
72func.func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
73  %c0 = arith.constant 0 : index
74  %cf0 = arith.constant 0.0 : f32
75  %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<64xf32>
76  %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<64xf32>
77  %acc = arith.addf %a, %b: vector<64xf32>
78  vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<64xf32>
79  return
80}
81
82// -----
83
84// Negative test to make sure nothing is done in case the vector size is not a
85// multiple of multiplicity.
86// CHECK-LABEL: func @vector_negative_test
87//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
88//       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
89//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
90//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<16xf32>
91//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]]] {{.*}} : vector<16xf32>, memref<64xf32>
92//  CHECK-NEXT:    return
93func.func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
94  %c0 = arith.constant 0 : index
95  %cf0 = arith.constant 0.0 : f32
96  %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<16xf32>
97  %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<16xf32>
98  %acc = arith.addf %a, %b: vector<16xf32>
99  vector.transfer_write %acc, %C[%c0]: vector<16xf32>, memref<64xf32>
100  return
101}
102
103// -----
104
105// CHECK-LABEL: func @distribute_vector_add_3d
106//  CHECK-SAME: (%[[ID0:.*]]: index, %[[ID1:.*]]: index
107//  CHECK-NEXT:    %[[ADDV:.*]] = arith.addf %{{.*}}, %{{.*}} : vector<64x4x32xf32>
108//  CHECK-NEXT:    %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32>
109//  CHECK-NEXT:    %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32>
110//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
111//  CHECK-NEXT:    %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID0]], %[[ID1]]] : vector<2x4x1xf32> into vector<64x4x32xf32>
112//  CHECK-NEXT:    return %[[INS]] : vector<64x4x32xf32>
113func.func @distribute_vector_add_3d(%id0 : index, %id1 : index,
114  %A: vector<64x4x32xf32>, %B: vector<64x4x32xf32>) -> vector<64x4x32xf32> {
115  %0 = arith.addf %A, %B : vector<64x4x32xf32>
116  return %0: vector<64x4x32xf32>
117}
118
119// -----
120
121// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
122
123//       CHECK: func @vector_add_transfer_3d
124//  CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
125//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
126//       CHECK:    %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
127//  CHECK-NEXT:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32>
128//  CHECK-NEXT:    %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
129//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32>
130//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
131//  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
132//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]], %[[C0]], %[[ID_1]]] {{.*}} : vector<2x4x1xf32>, memref<64x64x64xf32>
133//  CHECK-NEXT:    return
134func.func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32>,
135  %B: memref<64x64x64xf32>, %C: memref<64x64x64xf32>) {
136  %c0 = arith.constant 0 : index
137  %cf0 = arith.constant 0.0 : f32
138  %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32>
139  %b = vector.transfer_read %B[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32>
140  %acc = arith.addf %a, %b: vector<64x4x32xf32>
141  vector.transfer_write %acc, %C[%c0, %c0, %c0]: vector<64x4x32xf32>, memref<64x64x64xf32>
142  return
143}
144
145// -----
146
147#map0 = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
148#map1 = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
149#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
150
151// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
152// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
153// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
154// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
155
156//       CHECK: func @vector_add_transfer_permutation
157//  CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
158//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
159//       CHECK:    %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
160//  CHECK-NEXT:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[ID2]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
161//  CHECK-NEXT:    %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID_0]], %[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP2]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
162//  CHECK-NEXT:    %[[ADD:.*]] = arith.addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
163//  CHECK-NEXT:    %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
164//  CHECK-NEXT:    vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]], %[[ID_1]], %[[C0]], %[[ID3]]] {permutation_map = #[[MAP3]]} : vector<2x4x1xf32>, memref<?x?x?x?xf32>
165//  CHECK-NEXT:    return
166func.func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?x?x?xf32>,
167  %B: memref<?x?x?x?xf32>, %C: memref<?x?x?x?xf32>) {
168  %c0 = arith.constant 0 : index
169  %cf0 = arith.constant 0.0 : f32
170  %a = vector.transfer_read %A[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<64x4x32xf32>
171  %b = vector.transfer_read %B[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map1}: memref<?x?x?x?xf32>, vector<64x4x32xf32>
172  %acc = arith.addf %a, %b: vector<64x4x32xf32>
173  vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
174  return
175}
176
177// -----
178
179// CHECK2D-LABEL: vector_add_contract
180//       CHECK2D:   %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
181//       CHECK2D:   %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
182//       CHECK2D:   %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
183//       CHECK2D:   %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
184//       CHECK2D:   %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
185//       CHECK2D:   %[[R:.+]] = arith.addf %[[D]], %[[E]] : vector<2x16xf32>
186//       CHECK2D:   vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
187func.func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
188  %B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
189  %c0 = arith.constant 0 : index
190  %cf0 = arith.constant 0.0 : f32
191  %a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
192  %b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
193  %c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
194  %d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
195                                         affine_map<(d0, d1, d2) -> (d1, d2)>,
196                                         affine_map<(d0, d1, d2) -> (d0, d1)>],
197                        iterator_types = ["parallel", "parallel", "reduction"],
198                        kind = #vector.kind<add>}
199    %a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
200  %e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
201  %r = arith.addf %d, %e : vector<64x64xf32>
202  vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
203  return
204}
205