1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/X86Vector/Transforms.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "llvm/Support/Format.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 using namespace mlir; 25 using namespace mlir::vector; 26 using namespace mlir::x86vector; 27 using namespace mlir::x86vector::avx2; 28 using namespace mlir::x86vector::avx2::inline_asm; 29 using namespace mlir::x86vector::avx2::intrin; 30 31 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( 32 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 33 auto asmDialectAttr = 34 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); 35 const auto *asmTp = "vblendps $0, $1, $2, {0}"; 36 const auto *asmCstr = 37 "=x,x,x"; // Careful: constraint parser is very brittle: no ws! 38 SmallVector<Value> asmVals{v1, v2}; 39 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); 40 auto asmOp = b.create<LLVM::InlineAsmOp>( 41 v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr, 42 /*constraints=*/asmCstr, /*has_side_effects=*/false, 43 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, 44 /*operand_attrs=*/ArrayAttr()); 45 return asmOp.getResult(0); 46 } 47 48 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, 49 Value v1, Value v2) { 50 return b.create<vector::ShuffleOp>( 51 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 52 } 53 54 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, 55 Value v1, Value v2) { 56 return b.create<vector::ShuffleOp>( 57 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 58 } 59 /// a a b b a a b b 60 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 61 /// 0:127 | 128:255 62 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 63 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, 64 Value v1, Value v2, 65 uint8_t mask) { 66 uint8_t b01, b23, b45, b67; 67 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 68 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8, 69 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 70 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 71 } 72 73 // imm[0:1] out of imm[0:3] is: 74 // 0 1 2 3 75 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 76 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 77 // 0 1 2 3 78 // imm[0:1] out of imm[4:7]. 79 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( 80 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 81 SmallVector<int64_t> shuffleMask; 82 auto appendToMask = [&](uint8_t control) { 83 if (control == 0) 84 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 85 else if (control == 1) 86 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 87 else if (control == 2) 88 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 89 else if (control == 3) 90 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 91 else 92 llvm_unreachable("control > 3 : overflow"); 93 }; 94 uint8_t b03, b47; 95 MaskHelper::extractPermute(mask, b03, b47); 96 appendToMask(b03); 97 appendToMask(b47); 98 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 99 } 100 101 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 102 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, 103 Value v1, Value v2, 104 uint8_t mask) { 105 SmallVector<int64_t, 8> shuffleMask; 106 for (int i = 0; i < 8; ++i) { 107 bool isSet = mask & (1 << i); 108 shuffleMask.push_back(!isSet ? i : i + 8); 109 } 110 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 111 } 112 113 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 114 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 115 MutableArrayRef<Value> vs) { 116 #ifndef NDEBUG 117 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 118 assert(vs.size() == 4 && "expects 4 vectors"); 119 assert(llvm::all_of(ValueRange{vs}.getTypes(), 120 [&](Type t) { return t == vt; }) && 121 "expects all types to be vector<8xf32>"); 122 #endif 123 124 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 125 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 126 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 127 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 128 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>()); 129 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>()); 130 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>()); 131 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>()); 132 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>()); 133 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>()); 134 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>()); 135 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>()); 136 } 137 138 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 139 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 140 MutableArrayRef<Value> vs) { 141 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 142 (void)vt; 143 assert(vs.size() == 8 && "expects 8 vectors"); 144 assert(llvm::all_of(ValueRange{vs}.getTypes(), 145 [&](Type t) { return t == vt; }) && 146 "expects all types to be vector<8xf32>"); 147 148 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 149 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 150 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 151 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 152 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 153 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 154 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 155 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 156 157 using inline_asm::mm256BlendPsAsm; 158 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>()); 159 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>()); 160 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>()); 161 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>()); 162 163 Value s0 = 164 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 165 Value s1 = 166 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 167 Value s2 = 168 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 169 Value s3 = 170 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 171 Value s4 = 172 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 173 Value s5 = 174 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 175 Value s6 = 176 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 177 Value s7 = 178 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 179 180 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>()); 181 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>()); 182 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>()); 183 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>()); 184 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>()); 185 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>()); 186 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>()); 187 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); 188 } 189 190 /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1' 191 /// should be transposed with each other within the context of their 2D 192 /// transposition slice. 193 /// 194 /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0] 195 /// Return true: dim0 and dim1 are transposed within the context of their 2D 196 /// transposition slice ([1, 0]). 197 /// 198 /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0] 199 /// Return true: dim0 and dim1 are transposed within the context of their 2D 200 /// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not* 201 /// transposed within the full context of the transposition. 202 /// 203 /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1] 204 /// Return false: dim0 and dim1 are *not* transposed within the context of 205 /// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0) 206 /// and dim1 (1) are transposed within the full context of the of the 207 /// transposition. 208 static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1, 209 ArrayRef<int64_t> transp) { 210 // Perform a linear scan along the dimensions of the transposed pattern. If 211 // dim0 is found first, dim0 and dim1 are not transposed within the context of 212 // their 2D slice. Otherwise, 'dim1' is found first and they are transposed. 213 for (int64_t permDim : transp) { 214 if (permDim == dim0) 215 return false; 216 if (permDim == dim1) 217 return true; 218 } 219 220 llvm_unreachable("Ill-formed transpose pattern"); 221 } 222 223 /// Rewrite AVX2-specific vector.transpose, for the supported cases and 224 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D 225 /// transpose cases and n-D cases that have been decomposed into 2-D 226 /// transposition slices. For example, a 3-D transpose: 227 /// 228 /// %0 = vector.transpose %arg0, [2, 0, 1] 229 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32> 230 /// 231 /// could be sliced into 2-D transposes by tiling two of its dimensions to one 232 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8): 233 /// 234 /// %0 = vector.transpose %arg0, [2, 0, 1] 235 /// : vector<1x4x8xf32> to vector<8x1x4xf32> 236 /// 237 /// This lowering will analyze the n-D vector.transpose and determine if it's a 238 /// supported 2-D transposition slice where any of the AVX2 patterns can be 239 /// applied. 240 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 241 public: 242 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 243 244 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 245 int benefit) 246 : OpRewritePattern<vector::TransposeOp>(context, benefit), 247 loweringOptions(loweringOptions) {} 248 249 LogicalResult matchAndRewrite(vector::TransposeOp op, 250 PatternRewriter &rewriter) const override { 251 auto loc = op.getLoc(); 252 253 // Check if the source vector type is supported. AVX2 patterns can only be 254 // applied to f32 vector types with two dimensions greater than one. 255 VectorType srcType = op.getVectorType(); 256 if (!srcType.getElementType().isF32()) 257 return rewriter.notifyMatchFailure(op, "Unsupported vector element type"); 258 259 SmallVector<int64_t> srcGtOneDims; 260 for (auto &en : llvm::enumerate(srcType.getShape())) 261 if (en.value() > 1) 262 srcGtOneDims.push_back(en.index()); 263 264 if (srcGtOneDims.size() != 2) 265 return rewriter.notifyMatchFailure(op, "Unsupported vector type"); 266 267 SmallVector<int64_t, 4> transp; 268 for (auto attr : op.getTransp()) 269 transp.push_back(attr.cast<IntegerAttr>().getInt()); 270 271 // Check whether the two source vector dimensions that are greater than one 272 // must be transposed with each other so that we can apply one of the 2-D 273 // AVX2 transpose pattens. Otherwise, these patterns are not applicable. 274 if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp)) 275 return rewriter.notifyMatchFailure( 276 op, "Not applicable to this transpose permutation"); 277 278 // Retrieve the sizes of the two dimensions greater than one to be 279 // transposed. 280 auto srcShape = srcType.getShape(); 281 int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]]; 282 283 auto applyRewrite = [&]() { 284 ImplicitLocOpBuilder ib(loc, rewriter); 285 SmallVector<Value> vs; 286 287 // Reshape the n-D input vector with only two dimensions greater than one 288 // to a 2-D vector. 289 auto flattenedType = 290 VectorType::get({n * m}, op.getVectorType().getElementType()); 291 auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); 292 auto reshInput = 293 ib.create<vector::ShapeCastOp>(flattenedType, op.getVector()); 294 reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput); 295 296 // Extract 1-D vectors from the higher-order dimension of the input 297 // vector. 298 for (int64_t i = 0; i < m; ++i) 299 vs.push_back(ib.create<vector::ExtractOp>(reshInput, i)); 300 301 // Transpose set of 1-D vectors. 302 if (m == 4) 303 transpose4x8xf32(ib, vs); 304 if (m == 8) 305 transpose8x8xf32(ib, vs); 306 307 // Insert transposed 1-D vectors into the higher-order dimension of the 308 // output vector. 309 Value res = ib.create<arith::ConstantOp>(reshInputType, 310 ib.getZeroAttr(reshInputType)); 311 for (int64_t i = 0; i < m; ++i) 312 res = ib.create<vector::InsertOp>(vs[i], res, i); 313 314 // The output vector still has the shape of the input vector (e.g., 4x8). 315 // We have to transpose their dimensions and retrieve its original rank 316 // (e.g., 1x8x1x4x1). 317 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 318 res = ib.create<vector::ShapeCastOp>(op.getResultType(), res); 319 rewriter.replaceOp(op, res); 320 return success(); 321 }; 322 323 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 324 return applyRewrite(); 325 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 326 return applyRewrite(); 327 return failure(); 328 } 329 330 private: 331 LoweringOptions loweringOptions; 332 }; 333 334 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 335 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 336 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 337 } 338