1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/X86Vector/Transforms.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "llvm/Support/Format.h" 21 #include "llvm/Support/FormatVariadic.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 using namespace mlir::x86vector; 26 using namespace mlir::x86vector::avx2; 27 using namespace mlir::x86vector::avx2::inline_asm; 28 using namespace mlir::x86vector::avx2::intrin; 29 30 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( 31 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 32 auto asmDialectAttr = 33 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); 34 const auto *asmTp = "vblendps $0, $1, $2, {0}"; 35 const auto *asmCstr = 36 "=x,x,x"; // Careful: constraint parser is very brittle: no ws! 37 SmallVector<Value> asmVals{v1, v2}; 38 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); 39 auto asmOp = b.create<LLVM::InlineAsmOp>( 40 v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr, 41 /*constraints=*/asmCstr, /*has_side_effects=*/false, 42 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, 43 /*operand_attrs=*/ArrayAttr()); 44 return asmOp.getResult(0); 45 } 46 47 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, 48 Value v1, Value v2) { 49 return b.create<vector::ShuffleOp>( 50 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 51 } 52 53 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, 54 Value v1, Value v2) { 55 return b.create<vector::ShuffleOp>( 56 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 57 } 58 /// a a b b a a b b 59 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 60 /// 0:127 | 128:255 61 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 62 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, 63 Value v1, Value v2, 64 uint8_t mask) { 65 uint8_t b01, b23, b45, b67; 66 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 67 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8, 68 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 69 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 70 } 71 72 // imm[0:1] out of imm[0:3] is: 73 // 0 1 2 3 74 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 75 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 76 // 0 1 2 3 77 // imm[0:1] out of imm[4:7]. 78 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( 79 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 80 SmallVector<int64_t> shuffleMask; 81 auto appendToMask = [&](uint8_t control) { 82 if (control == 0) 83 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 84 else if (control == 1) 85 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 86 else if (control == 2) 87 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 88 else if (control == 3) 89 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 90 else 91 llvm_unreachable("control > 3 : overflow"); 92 }; 93 uint8_t b03, b47; 94 MaskHelper::extractPermute(mask, b03, b47); 95 appendToMask(b03); 96 appendToMask(b47); 97 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 98 } 99 100 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 101 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, 102 Value v1, Value v2, 103 uint8_t mask) { 104 SmallVector<int64_t, 8> shuffleMask; 105 for (int i = 0; i < 8; ++i) { 106 bool isSet = mask & (1 << i); 107 shuffleMask.push_back(!isSet ? i : i + 8); 108 } 109 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 110 } 111 112 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 113 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 114 MutableArrayRef<Value> vs) { 115 #ifndef NDEBUG 116 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 117 assert(vs.size() == 4 && "expects 4 vectors"); 118 assert(llvm::all_of(ValueRange{vs}.getTypes(), 119 [&](Type t) { return t == vt; }) && 120 "expects all types to be vector<8xf32>"); 121 #endif 122 123 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 124 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 125 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 126 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 127 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>()); 128 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>()); 129 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>()); 130 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>()); 131 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>()); 132 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>()); 133 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>()); 134 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>()); 135 } 136 137 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 138 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 139 MutableArrayRef<Value> vs) { 140 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 141 (void)vt; 142 assert(vs.size() == 8 && "expects 8 vectors"); 143 assert(llvm::all_of(ValueRange{vs}.getTypes(), 144 [&](Type t) { return t == vt; }) && 145 "expects all types to be vector<8xf32>"); 146 147 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 148 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 149 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 150 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 151 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 152 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 153 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 154 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 155 156 using inline_asm::mm256BlendPsAsm; 157 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>()); 158 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>()); 159 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>()); 160 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>()); 161 162 Value s0 = 163 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 164 Value s1 = 165 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 166 Value s2 = 167 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 168 Value s3 = 169 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 170 Value s4 = 171 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 172 Value s5 = 173 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 174 Value s6 = 175 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 176 Value s7 = 177 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 178 179 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>()); 180 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>()); 181 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>()); 182 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>()); 183 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>()); 184 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>()); 185 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>()); 186 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); 187 } 188 189 /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1' 190 /// should be transposed with each other within the context of their 2D 191 /// transposition slice. 192 /// 193 /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0] 194 /// Return true: dim0 and dim1 are transposed within the context of their 2D 195 /// transposition slice ([1, 0]). 196 /// 197 /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0] 198 /// Return true: dim0 and dim1 are transposed within the context of their 2D 199 /// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not* 200 /// transposed within the full context of the transposition. 201 /// 202 /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1] 203 /// Return false: dim0 and dim1 are *not* transposed within the context of 204 /// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0) 205 /// and dim1 (1) are transposed within the full context of the of the 206 /// transposition. 207 static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1, 208 ArrayRef<int64_t> transp) { 209 // Perform a linear scan along the dimensions of the transposed pattern. If 210 // dim0 is found first, dim0 and dim1 are not transposed within the context of 211 // their 2D slice. Otherwise, 'dim1' is found first and they are transposed. 212 for (int64_t permDim : transp) { 213 if (permDim == dim0) 214 return false; 215 if (permDim == dim1) 216 return true; 217 } 218 219 llvm_unreachable("Ill-formed transpose pattern"); 220 } 221 222 /// Rewrite AVX2-specific vector.transpose, for the supported cases and 223 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D 224 /// transpose cases and n-D cases that have been decomposed into 2-D 225 /// transposition slices. For example, a 3-D transpose: 226 /// 227 /// %0 = vector.transpose %arg0, [2, 0, 1] 228 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32> 229 /// 230 /// could be sliced into 2-D transposes by tiling two of its dimensions to one 231 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8): 232 /// 233 /// %0 = vector.transpose %arg0, [2, 0, 1] 234 /// : vector<1x4x8xf32> to vector<8x1x4xf32> 235 /// 236 /// This lowering will analyze the n-D vector.transpose and determine if it's a 237 /// supported 2-D transposition slice where any of the AVX2 patterns can be 238 /// applied. 239 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 240 public: 241 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 242 243 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 244 int benefit) 245 : OpRewritePattern<vector::TransposeOp>(context, benefit), 246 loweringOptions(loweringOptions) {} 247 248 LogicalResult matchAndRewrite(vector::TransposeOp op, 249 PatternRewriter &rewriter) const override { 250 auto loc = op.getLoc(); 251 252 // Check if the source vector type is supported. AVX2 patterns can only be 253 // applied to f32 vector types with two dimensions greater than one. 254 VectorType srcType = op.getVectorType(); 255 if (!srcType.getElementType().isF32()) 256 return rewriter.notifyMatchFailure(op, "Unsupported vector element type"); 257 258 SmallVector<int64_t> srcGtOneDims; 259 for (auto &en : llvm::enumerate(srcType.getShape())) 260 if (en.value() > 1) 261 srcGtOneDims.push_back(en.index()); 262 263 if (srcGtOneDims.size() != 2) 264 return rewriter.notifyMatchFailure(op, "Unsupported vector type"); 265 266 SmallVector<int64_t, 4> transp; 267 for (auto attr : op.transp()) 268 transp.push_back(attr.cast<IntegerAttr>().getInt()); 269 270 // Check whether the two source vector dimensions that are greater than one 271 // must be transposed with each other so that we can apply one of the 2-D 272 // AVX2 transpose pattens. Otherwise, these patterns are not applicable. 273 if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp)) 274 return rewriter.notifyMatchFailure( 275 op, "Not applicable to this transpose permutation"); 276 277 // Retrieve the sizes of the two dimensions greater than one to be 278 // transposed. 279 auto srcShape = srcType.getShape(); 280 int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]]; 281 282 auto applyRewrite = [&]() { 283 ImplicitLocOpBuilder ib(loc, rewriter); 284 SmallVector<Value> vs; 285 286 // Reshape the n-D input vector with only two dimensions greater than one 287 // to a 2-D vector. 288 auto flattenedType = 289 VectorType::get({n * m}, op.getVectorType().getElementType()); 290 auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); 291 auto reshInput = 292 ib.create<vector::ShapeCastOp>(flattenedType, op.vector()); 293 reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput); 294 295 // Extract 1-D vectors from the higher-order dimension of the input 296 // vector. 297 for (int64_t i = 0; i < m; ++i) 298 vs.push_back(ib.create<vector::ExtractOp>(reshInput, i)); 299 300 // Transpose set of 1-D vectors. 301 if (m == 4) 302 transpose4x8xf32(ib, vs); 303 if (m == 8) 304 transpose8x8xf32(ib, vs); 305 306 // Insert transposed 1-D vectors into the higher-order dimension of the 307 // output vector. 308 Value res = ib.create<arith::ConstantOp>(reshInputType, 309 ib.getZeroAttr(reshInputType)); 310 for (int64_t i = 0; i < m; ++i) 311 res = ib.create<vector::InsertOp>(vs[i], res, i); 312 313 // The output vector still has the shape of the input vector (e.g., 4x8). 314 // We have to transpose their dimensions and retrieve its original rank 315 // (e.g., 1x8x1x4x1). 316 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 317 res = ib.create<vector::ShapeCastOp>(op.getResultType(), res); 318 rewriter.replaceOp(op, res); 319 return success(); 320 }; 321 322 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 323 return applyRewrite(); 324 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 325 return applyRewrite(); 326 return failure(); 327 } 328 329 private: 330 LoweringOptions loweringOptions; 331 }; 332 333 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 334 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 335 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 336 } 337