1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/Dialect/X86Vector/Transforms.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "llvm/Support/Format.h" 21 #include "llvm/Support/FormatVariadic.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 using namespace mlir::x86vector; 26 using namespace mlir::x86vector::avx2; 27 using namespace mlir::x86vector::avx2::inline_asm; 28 using namespace mlir::x86vector::avx2::intrin; 29 30 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( 31 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 32 auto asmDialectAttr = 33 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); 34 auto asmTp = "vblendps $0, $1, $2, {0}"; 35 auto asmCstr = "=x,x,x"; // Careful: constraint parser is very brittle: no ws! 36 SmallVector<Value> asmVals{v1, v2}; 37 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); 38 auto asmOp = b.create<LLVM::InlineAsmOp>( 39 v1.getType(), asmVals, asmStr, asmCstr, false, false, asmDialectAttr); 40 return asmOp.getResult(0); 41 } 42 43 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, 44 Value v1, Value v2) { 45 return b.create<vector::ShuffleOp>( 46 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 47 } 48 49 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, 50 Value v1, Value v2) { 51 return b.create<vector::ShuffleOp>( 52 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 53 } 54 /// a a b b a a b b 55 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 56 /// 0:127 | 128:255 57 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 58 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, 59 Value v1, Value v2, 60 uint8_t mask) { 61 uint8_t b01, b23, b45, b67; 62 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 63 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8, 64 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 65 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 66 } 67 68 // imm[0:1] out of imm[0:3] is: 69 // 0 1 2 3 70 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 71 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 72 // 0 1 2 3 73 // imm[0:1] out of imm[4:7]. 74 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( 75 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 76 SmallVector<int64_t> shuffleMask; 77 auto appendToMask = [&](uint8_t control) { 78 if (control == 0) 79 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 80 else if (control == 1) 81 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 82 else if (control == 2) 83 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 84 else if (control == 3) 85 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 86 else 87 llvm_unreachable("control > 3 : overflow"); 88 }; 89 uint8_t b03, b47; 90 MaskHelper::extractPermute(mask, b03, b47); 91 appendToMask(b03); 92 appendToMask(b47); 93 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 94 } 95 96 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 97 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, 98 Value v1, Value v2, 99 uint8_t mask) { 100 SmallVector<int64_t, 8> shuffleMask; 101 for (int i = 0; i < 8; ++i) { 102 bool isSet = mask & (1 << i); 103 shuffleMask.push_back(!isSet ? i : i + 8); 104 } 105 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 106 } 107 108 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 109 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 110 MutableArrayRef<Value> vs) { 111 #ifndef NDEBUG 112 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 113 assert(vs.size() == 4 && "expects 4 vectors"); 114 assert(llvm::all_of(ValueRange{vs}.getTypes(), 115 [&](Type t) { return t == vt; }) && 116 "expects all types to be vector<8xf32>"); 117 #endif 118 119 Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 120 Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 121 Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 122 Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 123 Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); 124 Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); 125 Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); 126 Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); 127 vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>()); 128 vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>()); 129 vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>()); 130 vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>()); 131 } 132 133 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 134 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 135 MutableArrayRef<Value> vs) { 136 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 137 (void)vt; 138 assert(vs.size() == 8 && "expects 8 vectors"); 139 assert(llvm::all_of(ValueRange{vs}.getTypes(), 140 [&](Type t) { return t == vt; }) && 141 "expects all types to be vector<8xf32>"); 142 143 Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 144 Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 145 Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 146 Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 147 Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 148 Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 149 Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 150 Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 151 152 using inline_asm::mm256BlendPsAsm; 153 Value sh0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 3, 2>()); 154 Value sh2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 3, 2>()); 155 Value sh4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 3, 2>()); 156 Value sh6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 3, 2>()); 157 158 Value S0 = 159 mm256BlendPsAsm(ib, T0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 160 Value S1 = 161 mm256BlendPsAsm(ib, T2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 162 Value S2 = 163 mm256BlendPsAsm(ib, T1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 164 Value S3 = 165 mm256BlendPsAsm(ib, T3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 166 Value S4 = 167 mm256BlendPsAsm(ib, T4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 168 Value S5 = 169 mm256BlendPsAsm(ib, T6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 170 Value S6 = 171 mm256BlendPsAsm(ib, T5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 172 Value S7 = 173 mm256BlendPsAsm(ib, T7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 174 175 vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>()); 176 vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>()); 177 vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>()); 178 vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>()); 179 vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>()); 180 vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>()); 181 vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>()); 182 vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>()); 183 } 184 185 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and 186 /// depending on the `TransposeLoweringOptions`. 187 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 188 public: 189 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 190 191 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 192 int benefit) 193 : OpRewritePattern<vector::TransposeOp>(context, benefit), 194 loweringOptions(loweringOptions) {} 195 196 LogicalResult matchAndRewrite(vector::TransposeOp op, 197 PatternRewriter &rewriter) const override { 198 auto loc = op.getLoc(); 199 200 VectorType srcType = op.getVectorType(); 201 if (srcType.getRank() != 2) 202 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); 203 204 SmallVector<int64_t, 4> transp; 205 for (auto attr : op.transp()) 206 transp.push_back(attr.cast<IntegerAttr>().getInt()); 207 if (transp[0] != 1 && transp[1] != 0) 208 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); 209 210 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 211 212 auto applyRewrite = [&]() { 213 ImplicitLocOpBuilder ib(loc, rewriter); 214 SmallVector<Value> vs; 215 for (int64_t i = 0; i < m; ++i) 216 vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i)); 217 if (m == 4) 218 transpose4x8xf32(ib, vs); 219 if (m == 8) 220 transpose8x8xf32(ib, vs); 221 auto flattenedType = 222 VectorType::get({n * m}, op.getVectorType().getElementType()); 223 auto transposedType = 224 VectorType::get({n, m}, op.getVectorType().getElementType()); 225 Value res = ib.create<arith::ConstantOp>( 226 op.getVectorType(), ib.getZeroAttr(op.getVectorType())); 227 // The transposed form is still 4x8 and needs to be reinterpreted as 8x4 228 // via shape_casts. 229 for (int64_t i = 0; i < m; ++i) 230 res = ib.create<vector::InsertOp>(vs[i], res, i); 231 if (m == 4) { 232 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 233 res = ib.create<vector::ShapeCastOp>(transposedType, res); 234 } 235 236 rewriter.replaceOp(op, res); 237 return success(); 238 }; 239 240 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 241 return applyRewrite(); 242 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 243 return applyRewrite(); 244 return failure(); 245 } 246 247 private: 248 LoweringOptions loweringOptions; 249 }; 250 251 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 252 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 253 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 254 } 255