1// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s
2
3// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
4
5// CHECK-LABEL: func @add4x2
6//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
7// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
8// CHECK-NEXT: %[[A1:.*]] = arith.addf %[[S1]], %[[S2]] : vector<2x2xf32>
9// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[A1]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
10// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
11// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
12// CHECK-NEXT: %[[A2:.*]] = arith.addf %[[S3]], %[[S4]] : vector<2x2xf32>
13// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[A2]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
14// CHECK-NEXT: return %[[VEC1:.*]] : vector<4x2xf32>
15
16func.func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
17  %1 = arith.addf %0, %0: vector<4x2xf32>
18  return %1: vector<4x2xf32>
19}
20
21// CHECK-LABEL: func @add4x4
22//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
23// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
24
25// CHECK-NEXT: %[[A1:.*]] = arith.addf %[[S1]], %[[S2]] : vector<2x2xf32>
26
27// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
28// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
29
30// CHECK-NEXT: %[[A2:.*]] = arith.addf %[[S3]], %[[S4]] : vector<2x2xf32>
31
32// CHECK-NEXT: %[[S5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
33// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
34// CHECK-NEXT: %[[A3:.*]] = arith.addf %[[S5]], %[[S6]] : vector<2x2xf32>
35
36// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
37// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
38// CHECK-NEXT: %[[A4:.*]] = arith.addf %[[S7]], %[[S8]] : vector<2x2xf32>
39
40// CHECK-NEXT: %[[S9:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
41// CHECK-NEXT: %[[A5:.*]] = arith.addf %[[S9]], %[[A1]] : vector<2x2xf32>
42// CHECK-NEXT: %[[R1:.*]] = vector.insert_strided_slice %[[A5]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
43
44
45// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
46// CHECK-NEXT: %[[A6:.*]] = arith.addf %[[S11]], %[[A2]] : vector<2x2xf32>
47// CHECK-NEXT: %[[R2:.*]] = vector.insert_strided_slice %[[A6]], %[[R1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
48
49// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
50// CHECK-NEXT: %[[A7:.*]] = arith.addf %[[S13]], %[[A3]] : vector<2x2xf32>
51// CHECK-NEXT: %[[R3:.*]] = vector.insert_strided_slice %[[A7]], %[[R2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
52
53// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
54// CHECK-NEXT: %[[A8:.*]] = arith.addf %[[S15]], %[[A4]] : vector<2x2xf32>
55// CHECK-NEXT: %[[R4:.*]] = vector.insert_strided_slice %[[A8]], %[[R3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
56
57// CHECK-NEXT: return %[[R4]] : vector<4x4xf32>
58
59func.func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
60  %2 = arith.addf %0, %1: vector<4x4xf32>
61  %3 = arith.addf %1, %2: vector<4x4xf32>
62  return %3: vector<4x4xf32>
63}
64
65#contraction_accesses0 = [
66  affine_map<(i, j, k) -> (i, k)>,
67  affine_map<(i, j, k) -> (k, j)>,
68  affine_map<(i, j, k) -> (i, j)>
69]
70#contraction_trait0 = {
71  indexing_maps = #contraction_accesses0,
72  iterator_types = ["parallel", "parallel", "reduction"]
73}
74
75// CHECK-LABEL: func @contraction4x4_ijk
76
77// Reducing output vector [0, 0]
78
79//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
80// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
81// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
82// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
83// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
84
85// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
86
87// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
88// CHECK-NEXT: %[[S8:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
89// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
90// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
91
92// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S6]], %[[S7]], %[[R1S00]], %[[S8]], %[[S9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
93
94// CHECK-NEXT: %[[S10:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
95// CHECK-NEXT: %[[S12:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
96// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
97// CHECK-NEXT: %[[S13:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
98// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S10]], %[[S11]], %[[R2S00]], %[[S12]], %[[S13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
99
100// Reducing output vector [0, 2]
101// CHECK-NEXT: %[[S14:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
102// CHECK-NEXT: %[[S17:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
103// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
104// CHECK-NEXT: %[[S18:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
105// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
106// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S14]], %[[S15]], %[[S16]], %[[S17]], %[[S18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
107
108// CHECK-NEXT: %[[S19:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
109// CHECK-NEXT: %[[S21:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
110// CHECK-NEXT: %[[S20:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
111// CHECK-NEXT: %[[S22:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
112// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S19]], %[[S20]], %[[R1S02]], %[[S21]], %[[S22]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
113
114// CHECK-NEXT: %[[S23:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
115// CHECK-NEXT: %[[S25:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
116// CHECK-NEXT: %[[S24:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
117// CHECK-NEXT: %[[S26:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
118// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S23]], %[[S24]], %[[R2S02]], %[[S25]], %[[S26]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
119
120// Reducing output vector [2, 0]
121
122// CHECK-NEXT: %[[S27:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
123// CHECK-NEXT: %[[S30:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
124// CHECK-NEXT: %[[S28:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
125// CHECK-NEXT: %[[S31:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
126// CHECK-NEXT: %[[S29:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
127// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S27]], %[[S28]], %[[S29]], %[[S30]], %[[S31]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
128
129// CHECK-NEXT: %[[S32:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
130// CHECK-NEXT: %[[S34:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
131// CHECK-NEXT: %[[S33:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
132// CHECK-NEXT: %[[S35:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
133// CHECK-NEXT:  %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S32]], %[[S33]], %[[R1S20]], %[[S34]], %[[S35]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
134
135// CHECK-NEXT: %[[S36:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
136// CHECK-NEXT: %[[S38:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
137// CHECK-NEXT: %[[S37:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
138// CHECK-NEXT: %[[S39:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
139// CHECK-NEXT:  %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S36]], %[[S37]], %[[R2S20]], %[[S38]], %[[S39]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
140
141// Reducing output vector [2, 2]
142
143// CHECK-NEXT: %[[S40:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
144// CHECK-NEXT: %[[S43:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
145// CHECK-NEXT: %[[S41:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
146// CHECK-NEXT: %[[S44:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
147// CHECK-NEXT: %[[S42:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
148// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S40]], %[[S41]], %[[S42]], %[[S43]], %[[S44]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
149
150// CHECK-NEXT: %[[S45:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
151// CHECK-NEXT: %[[S47:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
152// CHECK-NEXT: %[[S46:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
153// CHECK-NEXT: %[[S48:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
154// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S45]], %[[S46]], %[[R1S22]], %[[S47]], %[[S48]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
155
156// CHECK-NEXT: %[[S49:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
157// CHECK-NEXT: %[[S51:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
158// CHECK-NEXT: %[[S50:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
159// CHECK-NEXT: %[[S52:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
160// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[S49]], %[[S50]], %[[R2S22]], %[[S51]], %[[S52]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
161
162// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
163// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R3S02]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
164// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R3S20]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
165// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[R3S22]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
166
167// CHECK-NEXT:  return %[[VEC4]] : vector<4x4xf32>
168
169func.func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
170                         %arg2 : vector<4x4xf32>, %arg3 : index)
171                         -> (vector<4x4xf32>) {
172  %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
173  %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
174  %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
175      : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
176
177  return %0 : vector<4x4xf32>
178}
179
180#contraction_accesses1 = [
181  affine_map<(i, k, j) -> (i, k)>,
182  affine_map<(i, k, j) -> (k, j)>,
183  affine_map<(i, k, j) -> (i, j)>
184]
185#contraction_trait1 = {
186  indexing_maps = #contraction_accesses1,
187  iterator_types = ["parallel", "reduction", "parallel"]
188}
189
190// CHECK-LABEL: func @contraction4x4_ikj
191
192// Reducing output vector [0, 0]
193
194//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
195// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
196// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
197// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
198// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
199// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
200
201// Reducing output vector [0, 2]
202
203// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
204// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
205// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
206// CHECK-NEXT: %[[S10:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
207// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
208// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S6]], %[[S7]], %[[S8]], %[[S9]], %[[S10]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
209
210// Reducing output vector [2, 0]
211
212// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
213// CHECK-NEXT: %[[S14:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
214// CHECK-NEXT: %[[S12:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
215// CHECK-NEXT: %[[S15:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
216// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
217// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S11]], %[[S12]], %[[S13]], %[[S14]], %[[S15]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
218
219// Reducing output vector [2, 2]
220
221// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
222// CHECK-NEXT: %[[S19:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
223// CHECK-NEXT: %[[S17:.*]] = vector.extract_strided_slice %arg1 {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
224// CHECK-NEXT: %[[S20:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
225// CHECK-NEXT: %[[S18:.*]] = vector.extract_strided_slice %arg2 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
226// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[S16]], %[[S17]], %[[S18]], %[[S19]], %[[S20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
227
228// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
229// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
230// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
231// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
232// CHECK-NEXT:  return %[[VEC3]] : vector<4x4xf32>
233
234func.func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
235                         %arg2 : vector<4x4xf32>, %arg3 : index)
236                         -> (vector<4x4xf32>) {
237  %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
238  %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
239  %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
240      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
241
242  return %0 : vector<4x4xf32>
243}
244
245// CHECK-LABEL: func @contraction4x4_ikj_xfer_read
246
247// CHECK-DAG:      %[[C2:.*]] = arith.constant 2 : index
248// CHECK-DAG:      %[[C0:.*]] = arith.constant 0 : index
249
250// Check LHS vector.transfer read is split for each user.
251
252//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
253// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
254
255// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
256// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
257
258// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
259// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
260// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
261// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
262
263// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
264// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
265// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
266// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
267
268// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
269// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
270// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
271// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
272// CHECK-NEXT: return
273
274func.func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
275                                   %arg1 : memref<2x4xf32>,
276                                   %arg2 : memref<4x4xf32>) {
277  %c0 = arith.constant 0 : index
278  %cf0 = arith.constant 0.0 : f32
279
280  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0
281    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
282      : memref<4x2xf32>, vector<4x2xf32>
283
284  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0
285    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
286    : memref<2x4xf32>, vector<2x4xf32>
287
288  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0
289    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
290      : memref<4x4xf32>, vector<4x4xf32>
291
292  %3 = vector.contract #contraction_trait1 %0, %1, %2
293      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
294
295  vector.transfer_write %3, %arg2[%c0, %c0]
296    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
297      : vector<4x4xf32>, memref<4x4xf32>
298  return
299}
300
301// TODO: Update test with VTR split transform.
302// CHECK-LABEL: func @vector_transfers
303// CHECK-COUNT-8: vector.transfer_read
304// CHECK-COUNT-4: arith.addf
305// CHECK-COUNT-4: vector.transfer_write
306
307func.func @vector_transfers(%arg0: index, %arg1: index) {
308  %cst = arith.constant 0.000000e+00 : f32
309  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
310  %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
311  %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
312  %cst_0 = arith.constant 1.000000e+00 : f32
313  %cst_1 = arith.constant 2.000000e+00 : f32
314  affine.for %arg2 = 0 to %arg0 step 4 {
315    affine.for %arg3 = 0 to %arg1 step 4 {
316      %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
317      %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
318      %6 = arith.addf %4, %5 : vector<4x4xf32>
319      vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32>
320    }
321  }
322  return
323}
324
325// CHECK-LABEL: func @cancelling_shape_cast_ops
326//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
327//       CHECK: return %[[A0]] : vector<2x4xf32>
328func.func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
329  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
330  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
331  return %1 : vector<2x4xf32>
332}
333
334// CHECK-LABEL: func @elementwise_unroll
335//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
336//       CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
337//       CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
338//       CHECK:   %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
339//       CHECK:   %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
340//       CHECK:   %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
341//       CHECK:   %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
342//       CHECK:   %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
343//       CHECK:   %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
344//       CHECK:   %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
345//       CHECK:   %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
346//       CHECK:   %[[CMP0:.*]] = arith.cmpf ult, %[[VT0]], %[[VT4]] : vector<2x2xf32>
347//       CHECK:   %[[CMP1:.*]] = arith.cmpf ult, %[[VT1]], %[[VT5]] : vector<2x2xf32>
348//       CHECK:   %[[CMP2:.*]] = arith.cmpf ult, %[[VT2]], %[[VT6]] : vector<2x2xf32>
349//       CHECK:   %[[CMP3:.*]] = arith.cmpf ult, %[[VT3]], %[[VT7]] : vector<2x2xf32>
350//       CHECK:   %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
351//       CHECK:   %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
352//       CHECK:   %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
353//       CHECK:   %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
354//       CHECK:   %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
355//       CHECK:   %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
356//       CHECK:   %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
357//       CHECK:   %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
358//       CHECK:   %[[SEL0:.*]] = arith.select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
359//       CHECK:   %[[SEL1:.*]] = arith.select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
360//       CHECK:   %[[SEL2:.*]] = arith.select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
361//       CHECK:   %[[SEL3:.*]] = arith.select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32>
362//       CHECK:   vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
363//       CHECK:   vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
364//       CHECK:   vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
365//       CHECK:   vector.transfer_write %[[SEL3]], %[[ARG0]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
366func.func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
367  %c0 = arith.constant 0 : index
368  %cf0 = arith.constant 0.0 : f32
369  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
370  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
371  %cond = arith.cmpf ult, %0, %1 : vector<4x4xf32>
372  // Vector transfer split pattern only support single user right now.
373  %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
374  %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
375  %4 = arith.select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32>
376  vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
377  return
378}
379
380// Check that vector.transfer read/write are split based on contract unrolling.
381//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
382// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
383
384// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
385// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
386
387// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
388// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
389// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
390// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
391
392// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
393// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
394// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
395// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
396
397// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
398// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
399// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
400// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
401// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
402
403func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
404                                          %arg1 : tensor<2x4xf32>,
405                                          %arg2 : tensor<4x4xf32>) ->
406  tensor<4x4xf32> {
407  %c0 = arith.constant 0 : index
408  %cf0 = arith.constant 0.0 : f32
409  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 :
410    tensor<4x2xf32>, vector<4x2xf32>
411  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 :
412    tensor<2x4xf32>, vector<2x4xf32>
413  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 :
414    tensor<4x4xf32>, vector<4x4xf32>
415  %3 = vector.contract #contraction_trait1 %0, %1, %2
416      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
417  %r = vector.transfer_write %3, %arg2[%c0, %c0]
418      : vector<4x4xf32>, tensor<4x4xf32>
419  return %r : tensor<4x4xf32>
420}
421
422// CHECK-LABEL: func @bubble_down_bitcast_in_extract
423//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
424func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
425  %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
426  // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32>
427  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16>
428  // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16>
429  %1 = vector.extract %0[3] : vector<8xf16>
430  // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32>
431  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16>
432  // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16>
433  %2 = vector.extract %0[4] : vector<8xf16>
434  // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
435  return %1, %2: f16, f16
436}
437
438// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract
439//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
440func.func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> {
441  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
442  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16>
443  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
444  %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
445  // CHECK: return %[[CAST]]
446  return %0: vector<4xf16>
447}
448
449// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim
450//  CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32>
451func.func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> {
452  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32>
453  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16>
454  %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16>
455  %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16>
456  // CHECK: return %[[CAST]]
457  return %0: vector<2x4xf16>
458}
459
460// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset
461func.func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> {
462  // CHECK: vector.bitcast
463  // CHECK-NEXT: vector.extract_strided_slice
464  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
465  %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
466  return %0: vector<4xf16>
467}
468
469// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size
470func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> {
471  // CHECK: vector.bitcast
472  // CHECK-NEXT: vector.extract_strided_slice
473  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
474  %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16>
475  return %0: vector<3xf16>
476}
477
478// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
479//  CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
480func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
481  // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32>
482  // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32>
483  // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32>
484  // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
485  // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
486  %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
487  %1 = vector.insert_strided_slice %src2, %0   {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
488  %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32>
489  // CHECK: return %[[INSERT2]]
490  return %cast: vector<4xf32>
491}
492
493// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset
494func.func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> {
495  // CHECK: vector.insert_strided_slice
496  // CHECK-NEXT: vector.bitcast
497  %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16>
498  %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
499  return %cast: vector<4xf32>
500}
501
502// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank
503func.func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> {
504  // CHECK: vector.insert_strided_slice
505  // CHECK-NEXT: vector.bitcast
506  %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16>
507  %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
508  return %cast: vector<16x4x4xf32>
509}
510