1// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s 2// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX 3// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT 4// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT 5 6#dotp_accesses = [ 7 affine_map<(i) -> (i)>, 8 affine_map<(i) -> (i)>, 9 affine_map<(i) -> ()> 10] 11#dotp_trait = { 12 indexing_maps = #dotp_accesses, 13 iterator_types = ["reduction"] 14} 15 16// CHECK-LABEL: func @extract_contract1 17// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, 18// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, 19// CHECK-SAME: %[[C:.*2]]: f32 20// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> 21// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32 22// CHECK: %[[ACC:.*]] = arith.addf %[[R]], %[[C]] : f32 23// CHECK: return %[[ACC]] : f32 24 25func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { 26 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 27 : vector<4xf32>, vector<4xf32> into f32 28 return %0 : f32 29} 30 31// CHECK-LABEL: func @extract_contract1_int 32// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, 33// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, 34// CHECK-SAME: %[[C:.*2]]: i32 35// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> 36// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32 37// CHECK: %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32 38// CHECK: return %[[ACC]] : i32 39 40func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { 41 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 42 : vector<4xi32>, vector<4xi32> into i32 43 return %0 : i32 44} 45 46#matvec_accesses = [ 47 affine_map<(i, j) -> (i, j)>, 48 affine_map<(i, j) -> (j)>, 49 affine_map<(i, j) -> (i)> 50] 51#matvec_trait = { 52 indexing_maps = #matvec_accesses, 53 iterator_types = ["parallel", "reduction"] 54} 55 56// CHECK-LABEL: func @extract_contract2 57// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 58// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, 59// CHECK-SAME: %[[C:.*2]]: vector<2xf32> 60// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 61// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> 62// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> 63// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 64// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> 65// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> 66// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> 67// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 68// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> 69// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> 70// CHECK: return %[[T10]] : vector<2xf32> 71 72func @extract_contract2(%arg0: vector<2x3xf32>, 73 %arg1: vector<3xf32>, 74 %arg2: vector<2xf32>) -> vector<2xf32> { 75 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 76 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> 77 return %0 : vector<2xf32> 78} 79 80// CHECK-LABEL: func @extract_contract2_int 81// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, 82// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, 83// CHECK-SAME: %[[C:.*2]]: vector<2xi32> 84// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> 85// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> 86// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> 87// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32 88// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> 89// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> 90// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> 91// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32 92// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> 93// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> 94// CHECK: return %[[T10]] : vector<2xi32> 95func @extract_contract2_int(%arg0: vector<2x3xi32>, 96 %arg1: vector<3xi32>, 97 %arg2: vector<2xi32>) -> vector<2xi32> { 98 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 99 : vector<2x3xi32>, vector<3xi32> into vector<2xi32> 100 return %0 : vector<2xi32> 101} 102 103#vecmat_accesses = [ 104 affine_map<(i, j) -> (j)>, 105 affine_map<(i, j) -> (i, j)>, 106 affine_map<(i, j) -> (i)> 107] 108#vecmat_trait = { 109 indexing_maps = #vecmat_accesses, 110 iterator_types = ["parallel", "reduction"] 111} 112 113// CHECK-LABEL: func @extract_contract3 114// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, 115// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, 116// CHECK-SAME: %[[C:.*2]]: vector<2xf32> 117// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 118// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> 119// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> 120// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 121// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> 122// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> 123// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> 124// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 125// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> 126// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> 127// CHECK: return %[[T10]] : vector<2xf32> 128 129func @extract_contract3(%arg0: vector<3xf32>, 130 %arg1: vector<2x3xf32>, 131 %arg2: vector<2xf32>) -> vector<2xf32> { 132 %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 133 : vector<3xf32>, vector<2x3xf32> into vector<2xf32> 134 return %0 : vector<2xf32> 135} 136 137#matmat_accesses = [ 138 affine_map<(i, j, k) -> (i, k)>, 139 affine_map<(i, j, k) -> (k, j)>, 140 affine_map<(i, j, k) -> (i, j)> 141] 142#matmat_trait = { 143 indexing_maps = #matmat_accesses, 144 iterator_types = ["parallel", "parallel", "reduction"] 145} 146 147// CHECK-LABEL: func @extract_contract4 148// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, 149// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, 150// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> 151// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> 152// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> 153// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> 154// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> 155// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> 156// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]] : vector<2xf32> into f32 157// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> 158// 159// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> 160// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> 161// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]] : vector<2xf32> into f32 162// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> 163// 164// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> 165// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> 166// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> 167// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]] : vector<2xf32> into f32 168// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> 169// 170// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> 171// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> 172// CHECK: %[[T42:.*]] = vector.reduction "add", %[[T41]] : vector<2xf32> into f32 173// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> 174// 175// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> 176// CHECK: return %[[T52]] : vector<2x2xf32> 177 178func @extract_contract4(%arg0: vector<2x2xf32>, 179 %arg1: vector<2x2xf32>, 180 %arg2: vector<2x2xf32>) -> vector<2x2xf32> { 181 %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 182 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 183 return %0 : vector<2x2xf32> 184} 185 186#contraction2d_accesses = [ 187 affine_map<(i, j) -> (i, j)>, 188 affine_map<(i, j) -> (i, j)>, 189 affine_map<(i, j) -> ()> 190] 191#contraction2d_trait = { 192 indexing_maps = #contraction2d_accesses, 193 iterator_types = ["reduction", "reduction"] 194} 195 196// CHECK-LABEL: func @full_contract1 197// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 198// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, 199// CHECK-SAME: %[[C:.*2]]: f32 200// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> 201// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> 202// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> 203// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 204// CHECK: %[[T4:.*]] = arith.addf %[[T3]], %[[C]] : f32 205// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> 206// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> 207// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> 208// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 209// CHECK: %[[T9:.*]] = arith.addf %[[T8]], %[[T4]] : f32 210// CHECK: return %[[T9]] : f32 211 212func @full_contract1(%arg0: vector<2x3xf32>, 213 %arg1: vector<2x3xf32>, 214 %arg2: f32) -> f32 { 215 %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 216 : vector<2x3xf32>, vector<2x3xf32> into f32 217 return %0 : f32 218} 219 220#contraction2d_trans_accesses = [ 221 affine_map<(i, j) -> (i, j)>, 222 affine_map<(i, j) -> (j, i)>, 223 affine_map<(i, j) -> ()> 224] 225#contraction2d_trans_trait = { 226 indexing_maps = #contraction2d_trans_accesses, 227 iterator_types = ["reduction", "reduction"] 228} 229 230// CHECK-LABEL: func @full_contract2 231// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 232// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, 233// CHECK-SAME: %[[C:.*2]]: f32 234// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> 235// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> 236// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> 237// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> 238// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> 239// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> 240// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> 241// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> 242// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> 243// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32 244// CHECK: %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32 245// 246// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> 247// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf 248// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> 249// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> 250// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> 251// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> 252// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> 253// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> 254// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32 255// CHECK: %[[ACC1:.*]] = arith.addf %[[T23]], %[[ACC0]] : f32 256// CHECK: return %[[ACC1]] : f32 257 258func @full_contract2(%arg0: vector<2x3xf32>, 259 %arg1: vector<3x2xf32>, 260 %arg2: f32) -> f32 { 261 %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 262 : vector<2x3xf32>, vector<3x2xf32> into f32 263 return %0 : f32 264} 265 266// CHECK-LABEL: func @outerproduct_noacc 267// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, 268// CHECK-SAME: %[[B:.*1]]: vector<3xf32> 269// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 270// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> 271// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32> 272// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> 273// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> 274// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> 275// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32> 276// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> 277// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> 278// CHECK: return %[[T7]] : vector<2x3xf32> 279 280func @outerproduct_noacc(%arg0: vector<2xf32>, 281 %arg1: vector<3xf32>) -> vector<2x3xf32> { 282 %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> 283 return %0: vector<2x3xf32> 284} 285 286// CHECK-LABEL: func @outerproduct_acc 287// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, 288// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, 289// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> 290// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 291// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> 292// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32> 293// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> 294// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> 295// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> 296// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> 297// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> 298// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> 299// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> 300// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> 301// CHECK: return %[[T9]] : vector<2x3xf32> 302 303func @outerproduct_acc(%arg0: vector<2xf32>, 304 %arg1: vector<3xf32>, 305 %arg2: vector<2x3xf32>) -> vector<2x3xf32> { 306 %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> 307 return %0: vector<2x3xf32> 308} 309 310// CHECK-LABEL: func @outerproduct_noacc_int 311// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, 312// CHECK-SAME: %[[B:.*1]]: vector<3xi32> 313// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> 314// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> 315// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32> 316// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> 317// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> 318// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> 319// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xi32> 320// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> 321// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> 322// CHECK: return %[[T7]] : vector<2x3xi32> 323func @outerproduct_noacc_int(%arg0: vector<2xi32>, 324 %arg1: vector<3xi32>) -> vector<2x3xi32> { 325 %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> 326 return %0: vector<2x3xi32> 327} 328 329// CHECK-LABEL: func @outerproduct_acc_int 330// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, 331// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, 332// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> 333// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> 334// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> 335// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32> 336// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> 337// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> 338// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> 339// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> 340// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> 341// CHECK: %[[T7:.*]] = splat %[[T6]] : vector<3xi32> 342// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> 343// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> 344// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> 345// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> 346// CHECK: return %[[T11]] : vector<2x3xi32> 347func @outerproduct_acc_int(%arg0: vector<2xi32>, 348 %arg1: vector<3xi32>, 349 %arg2: vector<2x3xi32>) -> vector<2x3xi32> { 350 %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> 351 return %0: vector<2x3xi32> 352} 353 354// CHECK-LABEL: func @axpy_fp( 355// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, 356// CHECK-SAME: %[[B:.*1]]: f32) 357// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32> 358// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> 359// CHECK: return %[[T1]] : vector<16xf32> 360func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { 361 %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 362 return %0: vector<16xf32> 363} 364 365// CHECK-LABEL: func @axpy_fp_add( 366// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, 367// CHECK-SAME: %[[B:.*1]]: f32, 368// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) 369// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32> 370// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> 371// CHECK: return %[[T1]] : vector<16xf32> 372func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { 373 %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 374 return %0: vector<16xf32> 375} 376 377// CHECK-LABEL: func @axpy_int( 378// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, 379// CHECK-SAME: %[[B:.*1]]: i32) 380// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32> 381// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> 382// CHECK: return %[[T1]] : vector<16xi32> 383func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { 384 %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 385 return %0: vector<16xi32> 386} 387 388// CHECK-LABEL: func @axpy_int_add( 389// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, 390// CHECK-SAME: %[[B:.*1]]: i32, 391// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) 392// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32> 393// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> 394// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> 395// CHECK: return %[[T2]] : vector<16xi32> 396func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { 397 %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 398 return %0: vector<16xi32> 399} 400 401// CHECK-LABEL: func @nop_shape_cast 402// CHECK-SAME: %[[A:.*]]: vector<16xf32> 403// CHECK: return %[[A]] : vector<16xf32> 404 405func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { 406 %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> 407 return %0 : vector<16xf32> 408} 409 410// CHECK-LABEL: func @cancel_shape_cast 411// FIXME: PR49590 412// HECK-SAME: %[[A:.*]]: vector<16xf32> 413// HECK: return %[[A]] : vector<16xf32> 414 415func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { 416 %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> 417 %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> 418 return %1 : vector<16xf32> 419} 420 421// Shape up and downcasts for 2-D vectors, for supporting conversion to 422// llvm.matrix operations 423// CHECK-LABEL: func @shape_casts 424func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { 425 // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> 426 // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> 427 // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> 428 // 429 // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] 430 // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 431 // 432 // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> 433 // 434 // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] 435 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 436 // 437 %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> 438 // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> 439 %r0 = arith.addf %0, %0: vector<4xf32> 440 // 441 // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] 442 // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : 443 // CHECK-SAME: vector<4xf32> to vector<2xf32> 444 // 445 // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : 446 // CHECK-SAME: vector<2xf32> into vector<2x2xf32> 447 // 448 // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] 449 // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : 450 // CHECK-SAME: vector<4xf32> to vector<2xf32> 451 // 452 // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : 453 // CHECK-SAME: vector<2xf32> into vector<2x2xf32> 454 // 455 %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> 456 // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> 457 return %r0, %1 : vector<4xf32>, vector<2x2xf32> 458} 459 460// CHECK-LABEL: func @shape_cast_2d2d 461// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> 462// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 463// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> 464// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> 465// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> 466// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> 467// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> 468// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> 469// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> 470// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> 471// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> 472// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> 473// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> 474// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> 475// CHECK: return %[[T11]] : vector<2x3xf32> 476 477func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { 478 %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> 479 return %s : vector<2x3xf32> 480} 481 482// CHECK-LABEL: func @shape_cast_3d1d 483// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> 484// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> 485// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> 486// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> 487// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> 488// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> 489// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> 490// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> 491// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> 492// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> 493// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> 494// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> 495// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> 496// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> 497// CHECK: return %[[T11]] : vector<6xf32> 498 499func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { 500 %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> 501 return %s : vector<6xf32> 502} 503 504// CHECK-LABEL: func @shape_cast_1d3d 505// CHECK-SAME: %[[A:.*]]: vector<6xf32> 506// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> 507// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> 508// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> 509// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> 510// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> 511// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> 512// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> 513// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> 514// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> 515// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> 516// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> 517// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> 518// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> 519// CHECK: return %[[T11]] : vector<2x1x3xf32> 520 521func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { 522 %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> 523 return %s : vector<2x1x3xf32> 524} 525 526// MATRIX-LABEL: func @matmul 527// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 528// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 529// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 530// MATRIX: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> 531// MATRIX: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32> 532// MATRIX: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 533// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> 534// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> 535// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> 536// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> 537// MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> 538// MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> 539// MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> 540// MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> 541// MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> 542// MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> 543// MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> 544// MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> 545// MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> 546// MATRIX: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> 547// MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> 548// MATRIX: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> 549// MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> 550// MATRIX: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> 551 552// OUTERPRODUCT-LABEL: func @matmul 553// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 554// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 555// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 556// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 557// OUTERPRODUCT-SAME: : vector<2x4xf32> to vector<4x2xf32> 558// 559// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> 560// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> 561// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 562// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 563// 564// OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> 565// OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> 566// OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] 567// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 568// 569// OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> 570// OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> 571// OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] 572// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 573// 574// OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> 575// OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> 576// OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] 577// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 578// 579// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> 580 581// REDUCE-LABEL: func @matmul 582// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 583// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 584// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 585// 586// REDUCE: %[[RES:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 587// REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 588// REDUCE-SAME: : vector<4x3f32> to vector<3x4xf32> 589// 590// REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> 591// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32> 592// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32> 593// REDUCE-NEXT: %[[s00:.*]] = vector.reduction "add", %[[ab00]] : vector<4xf32> into f32 594// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32> 595// 596// ... 597// 598// REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> 599// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32> 600// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32> 601// REDUCE-NEXT: %[[s12:.*]] = vector.reduction "add", %[[ab12]] : vector<4xf32> into f32 602// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32> 603// 604// REDUCE: return %[[c3]] : vector<2x3xf32> 605func @matmul(%arg0: vector<2x4xf32>, 606 %arg1: vector<4x3xf32>, 607 %arg2: vector<2x3xf32>) -> vector<2x3xf32> { 608 %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 609 : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> 610 return %0 : vector<2x3xf32> 611} 612 613// CHECK-LABEL: func @broadcast_vec1d_from_scalar 614// CHECK-SAME: %[[A:.*0]]: f32 615// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32> 616// CHECK: return %[[T0]] : vector<2xf32> 617 618func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { 619 %0 = vector.broadcast %arg0 : f32 to vector<2xf32> 620 return %0 : vector<2xf32> 621} 622 623// CHECK-LABEL: func @broadcast_vec2d_from_scalar 624// CHECK-SAME: %[[A:.*0]]: f32 625// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32> 626// CHECK: return %[[T0]] : vector<2x3xf32> 627 628func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { 629 %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> 630 return %0 : vector<2x3xf32> 631} 632 633// CHECK-LABEL: func @broadcast_vec3d_from_scalar 634// CHECK-SAME: %[[A:.*0]]: f32 635// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32> 636// CHECK: return %[[T0]] : vector<2x3x4xf32> 637 638func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { 639 %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> 640 return %0 : vector<2x3x4xf32> 641} 642 643// CHECK-LABEL: func @broadcast_vec1d_from_vec1d 644// CHECK-SAME: %[[A:.*0]]: vector<2xf32> 645// CHECK: return %[[A]] : vector<2xf32> 646 647func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { 648 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> 649 return %0 : vector<2xf32> 650} 651 652// CHECK-LABEL: func @broadcast_vec2d_from_vec1d 653// CHECK-SAME: %[[A:.*0]]: vector<2xf32> 654// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> 655// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> 656// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> 657// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> 658// CHECK: return %[[T2]] : vector<3x2xf32> 659 660func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { 661 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> 662 return %0 : vector<3x2xf32> 663} 664 665// CHECK-LABEL: func @broadcast_vec3d_from_vec1d 666// CHECK-SAME: %[[A:.*0]]: vector<2xf32> 667// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> 668// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> 669// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> 670// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> 671// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> 672// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> 673// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> 674// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> 675// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> 676// CHECK: return %[[T6]] : vector<4x3x2xf32> 677 678func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { 679 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> 680 return %0 : vector<4x3x2xf32> 681} 682 683// CHECK-LABEL: func @broadcast_vec3d_from_vec2d 684// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> 685// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> 686// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> 687// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> 688// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> 689// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> 690// CHECK: return %[[T3]] : vector<4x3x2xf32> 691 692func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { 693 %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> 694 return %0 : vector<4x3x2xf32> 695} 696 697// CHECK-LABEL: func @broadcast_stretch 698// CHECK-SAME: %[[A:.*0]]: vector<1xf32> 699// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> 700// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<4xf32> 701// CHECK: return %[[T1]] : vector<4xf32> 702 703func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { 704 %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> 705 return %0 : vector<4xf32> 706} 707 708// CHECK-LABEL: func @broadcast_stretch_at_start 709// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> 710// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> 711// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> 712// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> 713// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> 714// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> 715// CHECK: return %[[T3]] : vector<3x4xf32> 716 717func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { 718 %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> 719 return %0 : vector<3x4xf32> 720} 721 722// CHECK-LABEL: func @broadcast_stretch_at_end 723// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> 724// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> 725// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> 726// CHECK: %[[T2:.*]] = splat %[[T0]] : vector<3xf32> 727// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> 728// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> 729// CHECK: %[[T6:.*]] = splat %[[T4]] : vector<3xf32> 730// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> 731// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> 732// CHECK: %[[T10:.*]] = splat %[[T8]] : vector<3xf32> 733// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> 734// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> 735// CHECK: %[[T14:.*]] = splat %[[T12]] : vector<3xf32> 736// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> 737// CHECK: return %[[T15]] : vector<4x3xf32> 738 739func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { 740 %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> 741 return %0 : vector<4x3xf32> 742} 743 744// CHECK-LABEL: func @broadcast_stretch_in_middle 745// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> 746// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> 747// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> 748// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> 749// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 750// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> 751// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> 752// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> 753// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> 754// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 755// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> 756// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> 757// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> 758// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> 759// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 760// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> 761// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> 762// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> 763// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> 764// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 765// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> 766// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> 767// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> 768// CHECK: return %[[T23]] : vector<4x3x2xf32> 769 770func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { 771 %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> 772 return %0 : vector<4x3x2xf32> 773} 774 775// CHECK-LABEL: func @genbool_1d 776// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> 777// CHECK: return %[[T0]] : vector<8xi1> 778 779func @genbool_1d() -> vector<8xi1> { 780 %0 = vector.constant_mask [4] : vector<8xi1> 781 return %0 : vector<8xi1> 782} 783 784// CHECK-LABEL: func @genbool_2d 785// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> 786// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1> 787// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> 788// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> 789// CHECK: return %[[T1]] : vector<4x4xi1> 790 791func @genbool_2d() -> vector<4x4xi1> { 792 %v = vector.constant_mask [2, 2] : vector<4x4xi1> 793 return %v: vector<4x4xi1> 794} 795 796// CHECK-LABEL: func @genbool_3d 797// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> 798// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1> 799// CHECK: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1> 800// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> 801// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> 802// CHECK: return %[[T1]] : vector<2x3x4xi1> 803 804func @genbool_3d() -> vector<2x3x4xi1> { 805 %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> 806 return %v: vector<2x3x4xi1> 807} 808 809// CHECK-LABEL: func @genbool_var_1d( 810// CHECK-SAME: %[[A:.*]]: index) 811// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> 812// CHECK: return %[[T0]] : vector<3xi1> 813 814func @genbool_var_1d(%arg0: index) -> vector<3xi1> { 815 %0 = vector.create_mask %arg0 : vector<3xi1> 816 return %0 : vector<3xi1> 817} 818 819// CHECK-LABEL: func @genbool_var_2d( 820// CHECK-SAME: %[[A:.*0]]: index, 821// CHECK-SAME: %[[B:.*1]]: index) 822// CHECK: %[[C1:.*]] = arith.constant dense<false> : vector<3xi1> 823// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<2x3xi1> 824// CHECK: %[[c0:.*]] = arith.constant 0 : index 825// CHECK: %[[c1:.*]] = arith.constant 1 : index 826// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> 827// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index 828// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> 829// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> 830// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index 831// CHECK: %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> 832// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> 833// CHECK: return %[[T6]] : vector<2x3xi1> 834 835func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { 836 %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> 837 return %0 : vector<2x3xi1> 838} 839 840// CHECK-LABEL: func @genbool_var_3d( 841// CHECK-SAME: %[[A:.*0]]: index, 842// CHECK-SAME: %[[B:.*1]]: index, 843// CHECK-SAME: %[[C:.*2]]: index) 844// CHECK-DAG: %[[C1:.*]] = arith.constant dense<false> : vector<7xi1> 845// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<1x7xi1> 846// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x1x7xi1> 847// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 848// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 849// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> 850// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index 851// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> 852// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> 853// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index 854// CHECK: %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> 855// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> 856// CHECK: %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index 857// CHECK: %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> 858// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> 859// CHECK: return %[[T9]] : vector<2x1x7xi1> 860 861func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { 862 %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> 863 return %0 : vector<2x1x7xi1> 864} 865 866#matmat_accesses_0 = [ 867 affine_map<(m, n, k) -> (m, k)>, 868 affine_map<(m, n, k) -> (k, n)>, 869 affine_map<(m, n, k) -> (m, n)> 870] 871#matmat_trait_0 = { 872 indexing_maps = #matmat_accesses_0, 873 iterator_types = ["parallel", "parallel", "reduction"] 874} 875 876// OUTERPRODUCT-LABEL: func @matmul_0 877// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 878// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 879// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 880// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 881// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 882// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 883// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 884// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 885func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) 886-> vector<2x3xf32> 887{ 888 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 889 : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> 890 return %0 : vector<2x3xf32> 891} 892 893#matmat_accesses_1 = [ 894 affine_map<(m, n, k) -> (m, k)>, 895 affine_map<(m, n, k) -> (n, k)>, 896 affine_map<(m, n, k) -> (m, n)> 897] 898#matmat_trait_1 = { 899 indexing_maps = #matmat_accesses_1, 900 iterator_types = ["parallel", "parallel", "reduction"] 901} 902 903// OUTERPRODUCT-LABEL: func @matmul_1 904// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 905// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, 906// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 907// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 908// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 909// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 910// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> 911// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 912// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 913func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) 914-> vector<2x3xf32> 915{ 916 %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 917 : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> 918 return %0 : vector<2x3xf32> 919} 920 921#matmat_accesses_2 = [ 922 affine_map<(m, n, k) -> (k, m)>, 923 affine_map<(m, n, k) -> (k, n)>, 924 affine_map<(m, n, k) -> (m, n)> 925] 926#matmat_trait_2 = { 927 indexing_maps = #matmat_accesses_2, 928 iterator_types = ["parallel", "parallel", "reduction"] 929} 930 931// OUTERPRODUCT-LABEL: func @matmul_2 932// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 933// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 934// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 935// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> 936// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 937// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 938// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 939func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) 940-> vector<2x3xf32> 941{ 942 %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 943 : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> 944 return %0 : vector<2x3xf32> 945} 946 947#matmat_accesses_3 = [ 948 affine_map<(m, n, k) -> (k, m)>, 949 affine_map<(m, n, k) -> (n, k)>, 950 affine_map<(m, n, k) -> (m, n)> 951] 952#matmat_trait_3 = { 953 indexing_maps = #matmat_accesses_3, 954 iterator_types = ["parallel", "parallel", "reduction"] 955} 956 957// OUTERPRODUCT-LABEL: func @matmul_3 958// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 959// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, 960// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 961// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 962// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> 963// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> 964// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 965// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 966func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) 967-> vector<2x3xf32> 968{ 969 %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 970 : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> 971 return %0 : vector<2x3xf32> 972} 973 974#matmat_accesses_4 = [ 975 affine_map<(m, n, k) -> (m, k)>, 976 affine_map<(m, n, k) -> (k, n)>, 977 affine_map<(m, n, k) -> (n, m)> 978] 979#matmat_trait_4 = { 980 indexing_maps = #matmat_accesses_4, 981 iterator_types = ["parallel", "parallel", "reduction"] 982} 983 984// OUTERPRODUCT-LABEL: func @matmul_4 985// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 986// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 987// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 988// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 989// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 990// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 991// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 992// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 993func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 994-> vector<3x2xf32> 995{ 996 %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 997 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 998 return %0 : vector<3x2xf32> 999} 1000 1001#matmat_accesses_5 = [ 1002 affine_map<(m, n, k) -> (m, k)>, 1003 affine_map<(m, n, k) -> (k, n)>, 1004 affine_map<(m, n, k) -> (n, m)> 1005] 1006#matmat_trait_5 = { 1007 indexing_maps = #matmat_accesses_5, 1008 iterator_types = ["parallel", "parallel", "reduction"] 1009} 1010 1011// OUTERPRODUCT-LABEL: func @matmul_5 1012// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1013// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1014// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1015// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1016// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1017// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1018// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1019// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1020func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1021-> vector<3x2xf32> 1022{ 1023 %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 1024 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1025 return %0 : vector<3x2xf32> 1026} 1027 1028#matmat_accesses_6 = [ 1029 affine_map<(m, n, k) -> (m, k)>, 1030 affine_map<(m, n, k) -> (k, n)>, 1031 affine_map<(m, n, k) -> (n, m)> 1032] 1033#matmat_trait_6 = { 1034 indexing_maps = #matmat_accesses_6, 1035 iterator_types = ["parallel", "parallel", "reduction"] 1036} 1037 1038// OUTERPRODUCT-LABEL: func @matmul_6 1039// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1040// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1041// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1042// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1043// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1044// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1045// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1046// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1047func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1048-> vector<3x2xf32> 1049{ 1050 %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 1051 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1052 return %0 : vector<3x2xf32> 1053} 1054 1055#matmat_accesses_7 = [ 1056 affine_map<(m, n, k) -> (m, k)>, 1057 affine_map<(m, n, k) -> (k, n)>, 1058 affine_map<(m, n, k) -> (n, m)> 1059] 1060#matmat_trait_7 = { 1061 indexing_maps = #matmat_accesses_7, 1062 iterator_types = ["parallel", "parallel", "reduction"] 1063} 1064 1065// OUTERPRODUCT-LABEL: func @matmul_7 1066// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1067// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1068// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1069// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1070// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1071// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1072// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1073// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1074func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1075-> vector<3x2xf32> 1076{ 1077 %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 1078 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1079 return %0 : vector<3x2xf32> 1080} 1081 1082// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered 1083// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>, 1084// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, 1085// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32> 1086// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] 1087func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>) 1088-> vector<4x4xf32> 1089{ 1090 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 1091 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> 1092 return %0 : vector<4x4xf32> 1093} 1094 1095// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered 1096// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>, 1097// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, 1098// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32> 1099// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] 1100func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>) 1101-> vector<3x4xf32> 1102{ 1103 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 1104 : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> 1105 return %0 : vector<3x4xf32> 1106} 1107