1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/Dialect/X86Vector/Transforms.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "llvm/Support/Format.h" 21 #include "llvm/Support/FormatVariadic.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 using namespace mlir::x86vector; 26 using namespace mlir::x86vector::avx2; 27 using namespace mlir::x86vector::avx2::inline_asm; 28 using namespace mlir::x86vector::avx2::intrin; 29 30 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( 31 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 32 auto asmDialectAttr = 33 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); 34 const auto *asmTp = "vblendps $0, $1, $2, {0}"; 35 const auto *asmCstr = 36 "=x,x,x"; // Careful: constraint parser is very brittle: no ws! 37 SmallVector<Value> asmVals{v1, v2}; 38 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); 39 auto asmOp = b.create<LLVM::InlineAsmOp>( 40 v1.getType(), asmVals, asmStr, asmCstr, false, false, asmDialectAttr); 41 return asmOp.getResult(0); 42 } 43 44 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, 45 Value v1, Value v2) { 46 return b.create<vector::ShuffleOp>( 47 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 48 } 49 50 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, 51 Value v1, Value v2) { 52 return b.create<vector::ShuffleOp>( 53 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 54 } 55 /// a a b b a a b b 56 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 57 /// 0:127 | 128:255 58 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 59 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, 60 Value v1, Value v2, 61 uint8_t mask) { 62 uint8_t b01, b23, b45, b67; 63 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 64 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8, 65 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 66 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 67 } 68 69 // imm[0:1] out of imm[0:3] is: 70 // 0 1 2 3 71 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 72 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 73 // 0 1 2 3 74 // imm[0:1] out of imm[4:7]. 75 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( 76 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 77 SmallVector<int64_t> shuffleMask; 78 auto appendToMask = [&](uint8_t control) { 79 if (control == 0) 80 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 81 else if (control == 1) 82 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 83 else if (control == 2) 84 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 85 else if (control == 3) 86 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 87 else 88 llvm_unreachable("control > 3 : overflow"); 89 }; 90 uint8_t b03, b47; 91 MaskHelper::extractPermute(mask, b03, b47); 92 appendToMask(b03); 93 appendToMask(b47); 94 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 95 } 96 97 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 98 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, 99 Value v1, Value v2, 100 uint8_t mask) { 101 SmallVector<int64_t, 8> shuffleMask; 102 for (int i = 0; i < 8; ++i) { 103 bool isSet = mask & (1 << i); 104 shuffleMask.push_back(!isSet ? i : i + 8); 105 } 106 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 107 } 108 109 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 110 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 111 MutableArrayRef<Value> vs) { 112 #ifndef NDEBUG 113 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 114 assert(vs.size() == 4 && "expects 4 vectors"); 115 assert(llvm::all_of(ValueRange{vs}.getTypes(), 116 [&](Type t) { return t == vt; }) && 117 "expects all types to be vector<8xf32>"); 118 #endif 119 120 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 121 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 122 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 123 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 124 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>()); 125 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>()); 126 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>()); 127 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>()); 128 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>()); 129 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>()); 130 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>()); 131 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>()); 132 } 133 134 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 135 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 136 MutableArrayRef<Value> vs) { 137 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 138 (void)vt; 139 assert(vs.size() == 8 && "expects 8 vectors"); 140 assert(llvm::all_of(ValueRange{vs}.getTypes(), 141 [&](Type t) { return t == vt; }) && 142 "expects all types to be vector<8xf32>"); 143 144 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 145 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 146 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 147 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 148 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 149 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 150 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 151 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 152 153 using inline_asm::mm256BlendPsAsm; 154 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>()); 155 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>()); 156 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>()); 157 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>()); 158 159 Value s0 = 160 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 161 Value s1 = 162 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 163 Value s2 = 164 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 165 Value s3 = 166 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 167 Value s4 = 168 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 169 Value s5 = 170 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 171 Value s6 = 172 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 173 Value s7 = 174 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 175 176 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>()); 177 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>()); 178 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>()); 179 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>()); 180 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>()); 181 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>()); 182 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>()); 183 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); 184 } 185 186 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and 187 /// depending on the `TransposeLoweringOptions`. 188 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 189 public: 190 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 191 192 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 193 int benefit) 194 : OpRewritePattern<vector::TransposeOp>(context, benefit), 195 loweringOptions(loweringOptions) {} 196 197 LogicalResult matchAndRewrite(vector::TransposeOp op, 198 PatternRewriter &rewriter) const override { 199 auto loc = op.getLoc(); 200 201 VectorType srcType = op.getVectorType(); 202 if (srcType.getRank() != 2) 203 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); 204 205 SmallVector<int64_t, 4> transp; 206 for (auto attr : op.transp()) 207 transp.push_back(attr.cast<IntegerAttr>().getInt()); 208 if (transp[0] != 1 && transp[1] != 0) 209 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); 210 211 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 212 213 auto applyRewrite = [&]() { 214 ImplicitLocOpBuilder ib(loc, rewriter); 215 SmallVector<Value> vs; 216 for (int64_t i = 0; i < m; ++i) 217 vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i)); 218 if (m == 4) 219 transpose4x8xf32(ib, vs); 220 if (m == 8) 221 transpose8x8xf32(ib, vs); 222 auto flattenedType = 223 VectorType::get({n * m}, op.getVectorType().getElementType()); 224 auto transposedType = 225 VectorType::get({n, m}, op.getVectorType().getElementType()); 226 Value res = ib.create<arith::ConstantOp>( 227 op.getVectorType(), ib.getZeroAttr(op.getVectorType())); 228 // The transposed form is still 4x8 and needs to be reinterpreted as 8x4 229 // via shape_casts. 230 for (int64_t i = 0; i < m; ++i) 231 res = ib.create<vector::InsertOp>(vs[i], res, i); 232 if (m == 4) { 233 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 234 res = ib.create<vector::ShapeCastOp>(transposedType, res); 235 } 236 237 rewriter.replaceOp(op, res); 238 return success(); 239 }; 240 241 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 242 return applyRewrite(); 243 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 244 return applyRewrite(); 245 return failure(); 246 } 247 248 private: 249 LoweringOptions loweringOptions; 250 }; 251 252 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 253 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 254 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 255 } 256