1// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s 2// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX 3// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT 4// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT 5// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL 6 7#dotp_accesses = [ 8 affine_map<(i) -> (i)>, 9 affine_map<(i) -> (i)>, 10 affine_map<(i) -> ()> 11] 12#dotp_trait = { 13 indexing_maps = #dotp_accesses, 14 iterator_types = ["reduction"] 15} 16 17// CHECK-LABEL: func @extract_contract1 18// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, 19// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, 20// CHECK-SAME: %[[C:.*2]]: f32 21// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> 22// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32 23// CHECK: return %[[R]] : f32 24 25func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { 26 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 27 : vector<4xf32>, vector<4xf32> into f32 28 return %0 : f32 29} 30 31// CHECK-LABEL: func @extract_contract1_int 32// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, 33// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, 34// CHECK-SAME: %[[C:.*2]]: i32 35// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> 36// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32 37// CHECK: return %[[R]] : i32 38 39func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { 40 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 41 : vector<4xi32>, vector<4xi32> into i32 42 return %0 : i32 43} 44 45#matvec_accesses = [ 46 affine_map<(i, j) -> (i, j)>, 47 affine_map<(i, j) -> (j)>, 48 affine_map<(i, j) -> (i)> 49] 50#matvec_trait = { 51 indexing_maps = #matvec_accesses, 52 iterator_types = ["parallel", "reduction"] 53} 54 55// CHECK-LABEL: func @extract_contract2 56// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 57// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, 58// CHECK-SAME: %[[C:.*2]]: vector<2xf32> 59// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 60// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> 61// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> 62// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32 63// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> 64// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> 65// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> 66// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32 67// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> 68// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> 69// CHECK: return %[[T10]] : vector<2xf32> 70 71func.func @extract_contract2(%arg0: vector<2x3xf32>, 72 %arg1: vector<3xf32>, 73 %arg2: vector<2xf32>) -> vector<2xf32> { 74 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 75 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> 76 return %0 : vector<2xf32> 77} 78 79// CHECK-LABEL: func @extract_contract2_int 80// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, 81// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, 82// CHECK-SAME: %[[C:.*2]]: vector<2xi32> 83// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> 84// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> 85// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> 86// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xi32> into i32 87// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> 88// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> 89// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> 90// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xi32> into i32 91// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> 92// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> 93// CHECK: return %[[T10]] : vector<2xi32> 94func.func @extract_contract2_int(%arg0: vector<2x3xi32>, 95 %arg1: vector<3xi32>, 96 %arg2: vector<2xi32>) -> vector<2xi32> { 97 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 98 : vector<2x3xi32>, vector<3xi32> into vector<2xi32> 99 return %0 : vector<2xi32> 100} 101 102#vecmat_accesses = [ 103 affine_map<(i, j) -> (j)>, 104 affine_map<(i, j) -> (i, j)>, 105 affine_map<(i, j) -> (i)> 106] 107#vecmat_trait = { 108 indexing_maps = #vecmat_accesses, 109 iterator_types = ["parallel", "reduction"] 110} 111 112// CHECK-LABEL: func @extract_contract3 113// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, 114// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, 115// CHECK-SAME: %[[C:.*2]]: vector<2xf32> 116// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 117// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> 118// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> 119// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32 120// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> 121// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> 122// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> 123// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32 124// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> 125// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> 126// CHECK: return %[[T10]] : vector<2xf32> 127 128func.func @extract_contract3(%arg0: vector<3xf32>, 129 %arg1: vector<2x3xf32>, 130 %arg2: vector<2xf32>) -> vector<2xf32> { 131 %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 132 : vector<3xf32>, vector<2x3xf32> into vector<2xf32> 133 return %0 : vector<2xf32> 134} 135 136#matmat_accesses = [ 137 affine_map<(i, j, k) -> (i, k)>, 138 affine_map<(i, j, k) -> (k, j)>, 139 affine_map<(i, j, k) -> (i, j)> 140] 141#matmat_trait = { 142 indexing_maps = #matmat_accesses, 143 iterator_types = ["parallel", "parallel", "reduction"] 144} 145 146// CHECK-LABEL: func @extract_contract4 147// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, 148// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, 149// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> 150// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> 151// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> 152// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> 153// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> 154// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> 155// CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32 156// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> 157// 158// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> 159// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> 160// CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32 161// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> 162// 163// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> 164// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> 165// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> 166// CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32 167// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> 168// 169// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> 170// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> 171// CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32 172// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> 173// 174// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> 175// CHECK: return %[[T52]] : vector<2x2xf32> 176 177func.func @extract_contract4(%arg0: vector<2x2xf32>, 178 %arg1: vector<2x2xf32>, 179 %arg2: vector<2x2xf32>) -> vector<2x2xf32> { 180 %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 181 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 182 return %0 : vector<2x2xf32> 183} 184 185#contraction2d_accesses = [ 186 affine_map<(i, j) -> (i, j)>, 187 affine_map<(i, j) -> (i, j)>, 188 affine_map<(i, j) -> ()> 189] 190#contraction2d_trait = { 191 indexing_maps = #contraction2d_accesses, 192 iterator_types = ["reduction", "reduction"] 193} 194 195// CHECK-LABEL: func @full_contract1 196// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 197// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, 198// CHECK-SAME: %[[C:.*2]]: f32 199// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> 200// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> 201// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> 202// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32 203// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> 204// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> 205// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> 206// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32 207// CHECK: return %[[T8]] : f32 208 209func.func @full_contract1(%arg0: vector<2x3xf32>, 210 %arg1: vector<2x3xf32>, 211 %arg2: f32) -> f32 { 212 %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 213 : vector<2x3xf32>, vector<2x3xf32> into f32 214 return %0 : f32 215} 216 217#contraction2d_trans_accesses = [ 218 affine_map<(i, j) -> (i, j)>, 219 affine_map<(i, j) -> (j, i)>, 220 affine_map<(i, j) -> ()> 221] 222#contraction2d_trans_trait = { 223 indexing_maps = #contraction2d_trans_accesses, 224 iterator_types = ["reduction", "reduction"] 225} 226 227// CHECK-LABEL: func @full_contract2 228// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 229// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, 230// CHECK-SAME: %[[C:.*2]]: f32 231// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> 232// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> 233// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> 234// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> 235// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> 236// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> 237// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> 238// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> 239// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> 240// CHECK: %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32 241// 242// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> 243// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf 244// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> 245// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> 246// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> 247// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> 248// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> 249// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> 250// CHECK: %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32 251// CHECK: return %[[T23]] : f32 252 253func.func @full_contract2(%arg0: vector<2x3xf32>, 254 %arg1: vector<3x2xf32>, 255 %arg2: f32) -> f32 { 256 %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 257 : vector<2x3xf32>, vector<3x2xf32> into f32 258 return %0 : f32 259} 260 261// CHECK-LABEL: func @outerproduct_noacc 262// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, 263// CHECK-SAME: %[[B:.*1]]: vector<3xf32> 264// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 265// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> 266// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> 267// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> 268// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> 269// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> 270// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> 271// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> 272// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> 273// CHECK: return %[[T7]] : vector<2x3xf32> 274 275func.func @outerproduct_noacc(%arg0: vector<2xf32>, 276 %arg1: vector<3xf32>) -> vector<2x3xf32> { 277 %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> 278 return %0: vector<2x3xf32> 279} 280 281// CHECK-LABEL: func @outerproduct_acc 282// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, 283// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, 284// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> 285// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 286// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> 287// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> 288// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> 289// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> 290// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> 291// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> 292// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> 293// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> 294// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> 295// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> 296// CHECK: return %[[T9]] : vector<2x3xf32> 297 298func.func @outerproduct_acc(%arg0: vector<2xf32>, 299 %arg1: vector<3xf32>, 300 %arg2: vector<2x3xf32>) -> vector<2x3xf32> { 301 %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> 302 return %0: vector<2x3xf32> 303} 304 305// CHECK-LABEL: func @outerproduct_noacc_int 306// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, 307// CHECK-SAME: %[[B:.*1]]: vector<3xi32> 308// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> 309// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> 310// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> 311// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> 312// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> 313// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> 314// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> 315// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> 316// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> 317// CHECK: return %[[T7]] : vector<2x3xi32> 318func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, 319 %arg1: vector<3xi32>) -> vector<2x3xi32> { 320 %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> 321 return %0: vector<2x3xi32> 322} 323 324// CHECK-LABEL: func @outerproduct_acc_int 325// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, 326// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, 327// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> 328// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> 329// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> 330// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> 331// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> 332// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> 333// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> 334// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> 335// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> 336// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> 337// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> 338// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> 339// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> 340// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> 341// CHECK: return %[[T11]] : vector<2x3xi32> 342func.func @outerproduct_acc_int(%arg0: vector<2xi32>, 343 %arg1: vector<3xi32>, 344 %arg2: vector<2x3xi32>) -> vector<2x3xi32> { 345 %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> 346 return %0: vector<2x3xi32> 347} 348 349// CHECK-LABEL: func @axpy_fp( 350// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, 351// CHECK-SAME: %[[B:.*1]]: f32) 352// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> 353// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> 354// CHECK: return %[[T1]] : vector<16xf32> 355func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { 356 %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 357 return %0: vector<16xf32> 358} 359 360// CHECK-LABEL: func @axpy_fp_add( 361// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, 362// CHECK-SAME: %[[B:.*1]]: f32, 363// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) 364// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> 365// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> 366// CHECK: return %[[T1]] : vector<16xf32> 367func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { 368 %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 369 return %0: vector<16xf32> 370} 371 372// CHECK-LABEL: func @axpy_int( 373// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, 374// CHECK-SAME: %[[B:.*1]]: i32) 375// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> 376// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> 377// CHECK: return %[[T1]] : vector<16xi32> 378func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { 379 %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 380 return %0: vector<16xi32> 381} 382 383// CHECK-LABEL: func @axpy_int_add( 384// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, 385// CHECK-SAME: %[[B:.*1]]: i32, 386// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) 387// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> 388// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> 389// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> 390// CHECK: return %[[T2]] : vector<16xi32> 391func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { 392 %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 393 return %0: vector<16xi32> 394} 395 396// CHECK-LABEL: func @nop_shape_cast 397// CHECK-SAME: %[[A:.*]]: vector<16xf32> 398// CHECK: return %[[A]] : vector<16xf32> 399 400func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { 401 %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> 402 return %0 : vector<16xf32> 403} 404 405// CHECK-LABEL: func @cancel_shape_cast 406// FIXME: PR49590 407// HECK-SAME: %[[A:.*]]: vector<16xf32> 408// HECK: return %[[A]] : vector<16xf32> 409 410func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { 411 %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> 412 %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> 413 return %1 : vector<16xf32> 414} 415 416// Shape up and downcasts for 2-D vectors, for supporting conversion to 417// llvm.matrix operations 418// CHECK-LABEL: func @shape_casts 419func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { 420 // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> 421 // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> 422 // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> 423 // 424 // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] 425 // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 426 // 427 // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> 428 // 429 // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] 430 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 431 // 432 %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> 433 // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> 434 %r0 = arith.addf %0, %0: vector<4xf32> 435 // 436 // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] 437 // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : 438 // CHECK-SAME: vector<4xf32> to vector<2xf32> 439 // 440 // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : 441 // CHECK-SAME: vector<2xf32> into vector<2x2xf32> 442 // 443 // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] 444 // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : 445 // CHECK-SAME: vector<4xf32> to vector<2xf32> 446 // 447 // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : 448 // CHECK-SAME: vector<2xf32> into vector<2x2xf32> 449 // 450 %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> 451 // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> 452 return %r0, %1 : vector<4xf32>, vector<2x2xf32> 453} 454 455// CHECK-LABEL: func @shape_cast_2d2d 456// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> 457// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 458// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> 459// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> 460// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> 461// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> 462// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> 463// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> 464// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> 465// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> 466// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> 467// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> 468// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> 469// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> 470// CHECK: return %[[T11]] : vector<2x3xf32> 471 472func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { 473 %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> 474 return %s : vector<2x3xf32> 475} 476 477// CHECK-LABEL: func @shape_cast_3d1d 478// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> 479// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> 480// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> 481// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> 482// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> 483// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> 484// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> 485// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> 486// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> 487// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> 488// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> 489// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> 490// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> 491// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> 492// CHECK: return %[[T11]] : vector<6xf32> 493 494func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { 495 %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> 496 return %s : vector<6xf32> 497} 498 499// CHECK-LABEL: func @shape_cast_1d3d 500// CHECK-SAME: %[[A:.*]]: vector<6xf32> 501// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> 502// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> 503// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> 504// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> 505// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> 506// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> 507// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> 508// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> 509// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> 510// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> 511// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> 512// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> 513// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> 514// CHECK: return %[[T11]] : vector<2x1x3xf32> 515 516func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { 517 %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> 518 return %s : vector<2x1x3xf32> 519} 520 521// MATRIX-LABEL: func @matmul 522// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 523// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 524// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 525// MATRIX: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> 526// MATRIX: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32> 527// MATRIX: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 528// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> 529// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> 530// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> 531// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> 532// MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> 533// MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> 534// MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> 535// MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> 536// MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> 537// MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> 538// MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> 539// MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> 540// MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> 541// MATRIX: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> 542// MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> 543// MATRIX: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> 544// MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> 545// MATRIX: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> 546 547// OUTERPRODUCT-LABEL: func @matmul 548// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 549// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 550// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 551// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 552// OUTERPRODUCT-SAME: : vector<2x4xf32> to vector<4x2xf32> 553// 554// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> 555// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> 556// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 557// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 558// 559// OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> 560// OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> 561// OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] 562// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 563// 564// OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> 565// OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> 566// OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] 567// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 568// 569// OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> 570// OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> 571// OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] 572// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> 573// 574// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> 575 576// REDUCE-LABEL: func @matmul 577// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 578// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 579// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 580// 581// REDUCE: %[[RES:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> 582// REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 583// REDUCE-SAME: : vector<4x3f32> to vector<3x4xf32> 584// 585// REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> 586// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32> 587// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32> 588// REDUCE-NEXT: %[[s00:.*]] = vector.reduction <add>, %[[ab00]] : vector<4xf32> into f32 589// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32> 590// 591// ... 592// 593// REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> 594// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32> 595// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32> 596// REDUCE-NEXT: %[[s12:.*]] = vector.reduction <add>, %[[ab12]] : vector<4xf32> into f32 597// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32> 598// 599// REDUCE: return %[[c3]] : vector<2x3xf32> 600func.func @matmul(%arg0: vector<2x4xf32>, 601 %arg1: vector<4x3xf32>, 602 %arg2: vector<2x3xf32>) -> vector<2x3xf32> { 603 %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 604 : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> 605 return %0 : vector<2x3xf32> 606} 607 608// CHECK-LABEL: func @broadcast_vec1d_from_scalar 609// CHECK-SAME: %[[A:.*0]]: f32 610// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> 611// CHECK: return %[[T0]] : vector<2xf32> 612 613func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { 614 %0 = vector.broadcast %arg0 : f32 to vector<2xf32> 615 return %0 : vector<2xf32> 616} 617 618// CHECK-LABEL: func @broadcast_vec2d_from_scalar 619// CHECK-SAME: %[[A:.*0]]: f32 620// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> 621// CHECK: return %[[T0]] : vector<2x3xf32> 622 623func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { 624 %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> 625 return %0 : vector<2x3xf32> 626} 627 628// CHECK-LABEL: func @broadcast_vec3d_from_scalar 629// CHECK-SAME: %[[A:.*0]]: f32 630// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> 631// CHECK: return %[[T0]] : vector<2x3x4xf32> 632 633func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { 634 %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> 635 return %0 : vector<2x3x4xf32> 636} 637 638// CHECK-LABEL: func @broadcast_vec1d_from_vec1d 639// CHECK-SAME: %[[A:.*0]]: vector<2xf32> 640// CHECK: return %[[A]] : vector<2xf32> 641 642func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { 643 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> 644 return %0 : vector<2xf32> 645} 646 647// CHECK-LABEL: func @broadcast_vec2d_from_vec1d 648// CHECK-SAME: %[[A:.*0]]: vector<2xf32> 649// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> 650// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> 651// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> 652// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> 653// CHECK: return %[[T2]] : vector<3x2xf32> 654 655func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { 656 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> 657 return %0 : vector<3x2xf32> 658} 659 660// CHECK-LABEL: func @broadcast_vec3d_from_vec1d 661// CHECK-SAME: %[[A:.*0]]: vector<2xf32> 662// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> 663// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> 664// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> 665// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> 666// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> 667// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> 668// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> 669// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> 670// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> 671// CHECK: return %[[T6]] : vector<4x3x2xf32> 672 673func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { 674 %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> 675 return %0 : vector<4x3x2xf32> 676} 677 678// CHECK-LABEL: func @broadcast_vec3d_from_vec2d 679// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> 680// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> 681// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> 682// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> 683// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> 684// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> 685// CHECK: return %[[T3]] : vector<4x3x2xf32> 686 687func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { 688 %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> 689 return %0 : vector<4x3x2xf32> 690} 691 692// CHECK-LABEL: func @broadcast_stretch 693// CHECK-SAME: %[[A:.*0]]: vector<1xf32> 694// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> 695// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> 696// CHECK: return %[[T1]] : vector<4xf32> 697 698func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { 699 %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> 700 return %0 : vector<4xf32> 701} 702 703// CHECK-LABEL: func @broadcast_stretch_at_start 704// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> 705// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> 706// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> 707// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> 708// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> 709// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> 710// CHECK: return %[[T3]] : vector<3x4xf32> 711 712func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { 713 %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> 714 return %0 : vector<3x4xf32> 715} 716 717// CHECK-LABEL: func @broadcast_stretch_at_end 718// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> 719// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> 720// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> 721// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> 722// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> 723// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> 724// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> 725// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> 726// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> 727// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> 728// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> 729// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> 730// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> 731// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> 732// CHECK: return %[[T15]] : vector<4x3xf32> 733 734func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { 735 %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> 736 return %0 : vector<4x3xf32> 737} 738 739// CHECK-LABEL: func @broadcast_stretch_in_middle 740// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> 741// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> 742// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> 743// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> 744// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 745// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> 746// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> 747// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> 748// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> 749// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 750// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> 751// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> 752// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> 753// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> 754// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 755// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> 756// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> 757// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> 758// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> 759// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> 760// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> 761// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> 762// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> 763// CHECK: return %[[T23]] : vector<4x3x2xf32> 764 765func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { 766 %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> 767 return %0 : vector<4x3x2xf32> 768} 769 770// CHECK-LABEL: func @genbool_1d 771// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> 772// CHECK: return %[[T0]] : vector<8xi1> 773 774func.func @genbool_1d() -> vector<8xi1> { 775 %0 = vector.constant_mask [4] : vector<8xi1> 776 return %0 : vector<8xi1> 777} 778 779// CHECK-LABEL: func @genbool_2d 780// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> 781// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1> 782// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> 783// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> 784// CHECK: return %[[T1]] : vector<4x4xi1> 785 786func.func @genbool_2d() -> vector<4x4xi1> { 787 %v = vector.constant_mask [2, 2] : vector<4x4xi1> 788 return %v: vector<4x4xi1> 789} 790 791// CHECK-LABEL: func @genbool_3d 792// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> 793// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1> 794// CHECK: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1> 795// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> 796// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> 797// CHECK: return %[[T1]] : vector<2x3x4xi1> 798 799func.func @genbool_3d() -> vector<2x3x4xi1> { 800 %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> 801 return %v: vector<2x3x4xi1> 802} 803 804// CHECK-LABEL: func @genbool_var_1d( 805// CHECK-SAME: %[[A:.*]]: index) 806// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> 807// CHECK: return %[[T0]] : vector<3xi1> 808 809func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> { 810 %0 = vector.create_mask %arg0 : vector<3xi1> 811 return %0 : vector<3xi1> 812} 813 814// CHECK-LABEL: func @genbool_var_2d( 815// CHECK-SAME: %[[A:.*0]]: index, 816// CHECK-SAME: %[[B:.*1]]: index) 817// CHECK: %[[C1:.*]] = arith.constant dense<false> : vector<3xi1> 818// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<2x3xi1> 819// CHECK: %[[c0:.*]] = arith.constant 0 : index 820// CHECK: %[[c1:.*]] = arith.constant 1 : index 821// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> 822// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index 823// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> 824// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> 825// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index 826// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> 827// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> 828// CHECK: return %[[T6]] : vector<2x3xi1> 829 830func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { 831 %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> 832 return %0 : vector<2x3xi1> 833} 834 835// CHECK-LABEL: func @genbool_var_3d( 836// CHECK-SAME: %[[A:.*0]]: index, 837// CHECK-SAME: %[[B:.*1]]: index, 838// CHECK-SAME: %[[C:.*2]]: index) 839// CHECK-DAG: %[[C1:.*]] = arith.constant dense<false> : vector<7xi1> 840// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<1x7xi1> 841// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x1x7xi1> 842// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index 843// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index 844// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> 845// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index 846// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> 847// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> 848// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index 849// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> 850// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> 851// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index 852// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> 853// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> 854// CHECK: return %[[T9]] : vector<2x1x7xi1> 855 856func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { 857 %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> 858 return %0 : vector<2x1x7xi1> 859} 860 861// CHECK-LABEL: @contract_one_sided_unit_reduction_dim 862// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) 863// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> 864// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> 865// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> 866// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> 867// CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32 868// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> 869// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> 870// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> 871// CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32 872// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> 873// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> 874// CHECK: return %[[S]] : vector<2xi32> 875 876func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { 877 %res = vector.contract { 878 indexing_maps = [ 879 affine_map<(d0, d1, d2) -> (d0, d2)>, 880 affine_map<(d0, d1, d2) -> (d1, d2)>, 881 affine_map<(d0, d1, d2) -> (d1)> 882 ], 883 iterator_types = ["reduction", "parallel", "reduction"], 884 kind = #vector.kind<add> 885 } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> 886 return %res : vector<2xi32> 887} 888 889#matmat_accesses_0 = [ 890 affine_map<(m, n, k) -> (m, k)>, 891 affine_map<(m, n, k) -> (k, n)>, 892 affine_map<(m, n, k) -> (m, n)> 893] 894#matmat_trait_0 = { 895 indexing_maps = #matmat_accesses_0, 896 iterator_types = ["parallel", "parallel", "reduction"] 897} 898 899// OUTERPRODUCT-LABEL: func @matmul_0 900// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 901// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 902// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 903// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 904// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 905// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 906// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 907// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 908func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) 909-> vector<2x3xf32> 910{ 911 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 912 : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> 913 return %0 : vector<2x3xf32> 914} 915 916// OUTERPRODUCT-LABEL: func @matmul_0_mixed 917// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, 918// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, 919// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 920// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 921// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16> 922// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16> 923// OUTERPRODUCT: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> 924// OUTERPRODUCT: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> 925// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] 926// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 927func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>) 928-> vector<2x3xf32> 929{ 930 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 931 : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> 932 return %0 : vector<2x3xf32> 933} 934 935#matmat_accesses_1 = [ 936 affine_map<(m, n, k) -> (m, k)>, 937 affine_map<(m, n, k) -> (n, k)>, 938 affine_map<(m, n, k) -> (m, n)> 939] 940#matmat_trait_1 = { 941 indexing_maps = #matmat_accesses_1, 942 iterator_types = ["parallel", "parallel", "reduction"] 943} 944 945// OUTERPRODUCT-LABEL: func @matmul_1 946// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 947// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, 948// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 949// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 950// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 951// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 952// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> 953// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 954// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 955func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) 956-> vector<2x3xf32> 957{ 958 %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 959 : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> 960 return %0 : vector<2x3xf32> 961} 962 963#matmat_accesses_2 = [ 964 affine_map<(m, n, k) -> (k, m)>, 965 affine_map<(m, n, k) -> (k, n)>, 966 affine_map<(m, n, k) -> (m, n)> 967] 968#matmat_trait_2 = { 969 indexing_maps = #matmat_accesses_2, 970 iterator_types = ["parallel", "parallel", "reduction"] 971} 972 973// OUTERPRODUCT-LABEL: func @matmul_2 974// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 975// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 976// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 977// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> 978// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 979// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 980// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 981func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) 982-> vector<2x3xf32> 983{ 984 %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 985 : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> 986 return %0 : vector<2x3xf32> 987} 988 989#matmat_accesses_3 = [ 990 affine_map<(m, n, k) -> (k, m)>, 991 affine_map<(m, n, k) -> (n, k)>, 992 affine_map<(m, n, k) -> (m, n)> 993] 994#matmat_trait_3 = { 995 indexing_maps = #matmat_accesses_3, 996 iterator_types = ["parallel", "parallel", "reduction"] 997} 998 999// OUTERPRODUCT-LABEL: func @matmul_3 1000// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 1001// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, 1002// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 1003// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 1004// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> 1005// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> 1006// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 1007// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> 1008func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) 1009-> vector<2x3xf32> 1010{ 1011 %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 1012 : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> 1013 return %0 : vector<2x3xf32> 1014} 1015 1016#matmat_accesses_4 = [ 1017 affine_map<(m, n, k) -> (m, k)>, 1018 affine_map<(m, n, k) -> (k, n)>, 1019 affine_map<(m, n, k) -> (n, m)> 1020] 1021#matmat_trait_4 = { 1022 indexing_maps = #matmat_accesses_4, 1023 iterator_types = ["parallel", "parallel", "reduction"] 1024} 1025 1026// OUTERPRODUCT-LABEL: func @matmul_4 1027// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1028// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1029// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1030// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1031// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1032// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1033// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1034// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1035func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1036-> vector<3x2xf32> 1037{ 1038 %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 1039 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1040 return %0 : vector<3x2xf32> 1041} 1042 1043#matmat_accesses_5 = [ 1044 affine_map<(m, n, k) -> (m, k)>, 1045 affine_map<(m, n, k) -> (k, n)>, 1046 affine_map<(m, n, k) -> (n, m)> 1047] 1048#matmat_trait_5 = { 1049 indexing_maps = #matmat_accesses_5, 1050 iterator_types = ["parallel", "parallel", "reduction"] 1051} 1052 1053// OUTERPRODUCT-LABEL: func @matmul_5 1054// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1055// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1056// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1057// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1058// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1059// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1060// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1061// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1062func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1063-> vector<3x2xf32> 1064{ 1065 %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 1066 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1067 return %0 : vector<3x2xf32> 1068} 1069 1070#matmat_accesses_6 = [ 1071 affine_map<(m, n, k) -> (m, k)>, 1072 affine_map<(m, n, k) -> (k, n)>, 1073 affine_map<(m, n, k) -> (n, m)> 1074] 1075#matmat_trait_6 = { 1076 indexing_maps = #matmat_accesses_6, 1077 iterator_types = ["parallel", "parallel", "reduction"] 1078} 1079 1080// OUTERPRODUCT-LABEL: func @matmul_6 1081// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1082// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1083// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1084// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1085// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1086// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1087// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1088// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1089func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1090-> vector<3x2xf32> 1091{ 1092 %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 1093 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1094 return %0 : vector<3x2xf32> 1095} 1096 1097#matmat_accesses_7 = [ 1098 affine_map<(m, n, k) -> (m, k)>, 1099 affine_map<(m, n, k) -> (k, n)>, 1100 affine_map<(m, n, k) -> (n, m)> 1101] 1102#matmat_trait_7 = { 1103 indexing_maps = #matmat_accesses_7, 1104 iterator_types = ["parallel", "parallel", "reduction"] 1105} 1106 1107// OUTERPRODUCT-LABEL: func @matmul_7 1108// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 1109// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 1110// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 1111// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 1112// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> 1113// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> 1114// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 1115// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> 1116func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) 1117-> vector<3x2xf32> 1118{ 1119 %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 1120 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 1121 return %0 : vector<3x2xf32> 1122} 1123 1124// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered 1125// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>, 1126// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, 1127// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32> 1128// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] 1129func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>) 1130-> vector<4x4xf32> 1131{ 1132 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 1133 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> 1134 return %0 : vector<4x4xf32> 1135} 1136 1137// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered 1138// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>, 1139// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, 1140// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32> 1141// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] 1142func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>) 1143-> vector<3x4xf32> 1144{ 1145 %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 1146 : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> 1147 return %0 : vector<3x4xf32> 1148} 1149 1150// PARALLEL-LABEL: func @parrallel_contract_lowering 1151// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> 1152// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> 1153// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> 1154// PARALLEL: return %[[F]] : vector<4xf32> 1155func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { 1156 %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> 1157 return %0 : vector<4xf32> 1158} 1159 1160// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast 1161// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> 1162// PARALLEL: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> 1163// PARALLEL: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> 1164// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> 1165// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> 1166// PARALLEL: return %[[F]] : vector<4xf32> 1167func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { 1168 %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> 1169 return %0 : vector<4xf32> 1170} 1171 1172// PARALLEL-LABEL: func @parrallel_contract_lowering 1173// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> 1174// PARALLEL: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> 1175// PARALLEL: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> 1176// PARALLEL: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32> 1177// PARALLEL: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32> 1178// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> 1179// PARALLEL: return %[[F]] : vector<4xf32> 1180func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { 1181 %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> 1182 return %0 : vector<4xf32> 1183} 1184 1185// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar 1186// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> 1187// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> 1188// PARALLEL: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 1189// PARALLEL: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 1190// PARALLEL: return %[[A]] : f32 1191func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { 1192 %0 = vector.contract { 1193 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 1194 affine_map<(d0, d1) -> (d0, d1)>, 1195 affine_map<(d0, d1) -> ()>], 1196 iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>} 1197 %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 1198 return %0 : f32 1199} 1200