1// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s 2 3#matvec_accesses = [ 4 affine_map<(i, j) -> (i, j)>, 5 affine_map<(i, j) -> (j)>, 6 affine_map<(i, j) -> (i)> 7] 8#matvec_trait = { 9 indexing_maps = #matvec_accesses, 10 iterator_types = ["parallel", "reduction"] 11} 12#matvecmax_trait = { 13 indexing_maps = #matvec_accesses, 14 iterator_types = ["parallel", "reduction"], 15 kind = #vector.kind<maxf> 16} 17 18#mattransvec_accesses = [ 19 affine_map<(i, j) -> (j, i)>, 20 affine_map<(i, j) -> (j)>, 21 affine_map<(i, j) -> (i)> 22] 23#mattransvec_trait = { 24 indexing_maps = #mattransvec_accesses, 25 iterator_types = ["parallel", "reduction"] 26} 27 28#vecmat_accesses = [ 29 affine_map<(i, j) -> (j)>, 30 affine_map<(i, j) -> (i, j)>, 31 affine_map<(i, j) -> (i)> 32] 33#vecmat_trait = { 34 indexing_maps = #vecmat_accesses, 35 iterator_types = ["parallel", "reduction"] 36} 37 38#vecmattrans_accesses = [ 39 affine_map<(i, j) -> (j)>, 40 affine_map<(i, j) -> (j, i)>, 41 affine_map<(i, j) -> (i)> 42] 43#vecmattrans_trait = { 44 indexing_maps = #vecmattrans_accesses, 45 iterator_types = ["parallel", "reduction"] 46} 47 48#redpar_vecmattrans_accesses = [ 49 affine_map<(i, j) -> (i)>, 50 affine_map<(i, j) -> (i, j)>, 51 affine_map<(i, j) -> (j)> 52] 53#redpar_vecmattrans_trait = { 54 indexing_maps = #redpar_vecmattrans_accesses, 55 iterator_types = ["reduction", "parallel"] 56} 57 58// CHECK-LABEL: func @matvec2x2 59// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> 60// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> 61// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> 62// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> 63// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> 64// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> 65// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 66// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> 67// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> 68// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 69// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> 70// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> 71// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 72// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>> 73// CHECK: return 74func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, 75 %arg2: memref<vector<2xf32>>) { 76 %A = memref.load %arg0[] : memref<vector<2x2xf32>> 77 %x = memref.load %arg1[] : memref<vector<2xf32>> 78 %b = memref.load %arg2[] : memref<vector<2xf32>> 79 %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 80 memref.store %0, %arg2[] : memref<vector<2xf32>> 81 return 82} 83 84// CHECK-LABEL: func @matvecmax2x2 85// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> 86// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> 87// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> 88// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> 89// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> 90// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> 91// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 92// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> 93// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> 94// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 95// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> 96// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> 97// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 98// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>> 99// CHECK: return 100func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, 101 %arg2: memref<vector<2xf32>>) { 102 %A = memref.load %arg0[] : memref<vector<2x2xf32>> 103 %x = memref.load %arg1[] : memref<vector<2xf32>> 104 %b = memref.load %arg2[] : memref<vector<2xf32>> 105 %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 106 memref.store %0, %arg2[] : memref<vector<2xf32>> 107 return 108} 109 110// CHECK-LABEL: func @mattransvec2x2 111// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> 112// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> 113// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> 114// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> 115// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> 116// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> 117// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> 118// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> 119// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 120// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> 121// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> 122// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 123// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>> 124// CHECK: return 125func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, 126 %arg2: memref<vector<2xf32>>) { 127 %A = memref.load %arg0[] : memref<vector<2x2xf32>> 128 %x = memref.load %arg1[] : memref<vector<2xf32>> 129 %b = memref.load %arg2[] : memref<vector<2xf32>> 130 %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 131 memref.store %0, %arg2[] : memref<vector<2xf32>> 132 return 133} 134 135// CHECK-LABEL: func @vecmat2x2 136// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> 137// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> 138// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> 139// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> 140// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> 141// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> 142// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 143// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> 144// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> 145// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 146// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> 147// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> 148// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 149// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>> 150// CHECK: return 151func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, 152 %arg2: memref<vector<2xf32>>) { 153 %A = memref.load %arg0[] : memref<vector<2x2xf32>> 154 %x = memref.load %arg1[] : memref<vector<2xf32>> 155 %b = memref.load %arg2[] : memref<vector<2xf32>> 156 %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 157 memref.store %0, %arg2[] : memref<vector<2xf32>> 158 return 159} 160 161// CHECK-LABEL: func @vecmattrans2x2 162// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> 163// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> 164// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> 165// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> 166// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> 167// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> 168// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> 169// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> 170// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 171// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> 172// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> 173// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 174// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>> 175// CHECK: return 176func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, 177 %arg2: memref<vector<2xf32>>) { 178 %A = memref.load %arg0[] : memref<vector<2x2xf32>> 179 %x = memref.load %arg1[] : memref<vector<2xf32>> 180 %b = memref.load %arg2[] : memref<vector<2xf32>> 181 %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 182 memref.store %0, %arg2[] : memref<vector<2xf32>> 183 return 184} 185 186// CHECK-LABEL: func @redpar_vecmattrans2x2 187// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>> 188// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>> 189// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>> 190// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>> 191// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>> 192// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>> 193// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> 194// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> 195// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32 196// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> 197// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> 198// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 199// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>> 200// CHECK: return 201func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>, 202 %arg2: memref<vector<2xf32>>) { 203 %A = memref.load %arg0[] : memref<vector<2x2xf32>> 204 %x = memref.load %arg1[] : memref<vector<2xf32>> 205 %b = memref.load %arg2[] : memref<vector<2xf32>> 206 %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 207 memref.store %0, %arg2[] : memref<vector<2xf32>> 208 return 209} 210