1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/X86Vector/Transforms.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "llvm/Support/Format.h" 21 #include "llvm/Support/FormatVariadic.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 using namespace mlir::x86vector; 26 using namespace mlir::x86vector::avx2; 27 using namespace mlir::x86vector::avx2::inline_asm; 28 using namespace mlir::x86vector::avx2::intrin; 29 30 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( 31 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 32 auto asmDialectAttr = 33 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); 34 const auto *asmTp = "vblendps $0, $1, $2, {0}"; 35 const auto *asmCstr = 36 "=x,x,x"; // Careful: constraint parser is very brittle: no ws! 37 SmallVector<Value> asmVals{v1, v2}; 38 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); 39 auto asmOp = b.create<LLVM::InlineAsmOp>( 40 v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr, 41 /*constraints=*/asmCstr, /*has_side_effects=*/false, 42 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, 43 /*operand_attrs=*/ArrayAttr()); 44 return asmOp.getResult(0); 45 } 46 47 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, 48 Value v1, Value v2) { 49 return b.create<vector::ShuffleOp>( 50 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 51 } 52 53 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, 54 Value v1, Value v2) { 55 return b.create<vector::ShuffleOp>( 56 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 57 } 58 /// a a b b a a b b 59 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 60 /// 0:127 | 128:255 61 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 62 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, 63 Value v1, Value v2, 64 uint8_t mask) { 65 uint8_t b01, b23, b45, b67; 66 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 67 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8, 68 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 69 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 70 } 71 72 // imm[0:1] out of imm[0:3] is: 73 // 0 1 2 3 74 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 75 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 76 // 0 1 2 3 77 // imm[0:1] out of imm[4:7]. 78 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( 79 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { 80 SmallVector<int64_t> shuffleMask; 81 auto appendToMask = [&](uint8_t control) { 82 if (control == 0) 83 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 84 else if (control == 1) 85 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 86 else if (control == 2) 87 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 88 else if (control == 3) 89 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 90 else 91 llvm_unreachable("control > 3 : overflow"); 92 }; 93 uint8_t b03, b47; 94 MaskHelper::extractPermute(mask, b03, b47); 95 appendToMask(b03); 96 appendToMask(b47); 97 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 98 } 99 100 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. 101 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, 102 Value v1, Value v2, 103 uint8_t mask) { 104 SmallVector<int64_t, 8> shuffleMask; 105 for (int i = 0; i < 8; ++i) { 106 bool isSet = mask & (1 << i); 107 shuffleMask.push_back(!isSet ? i : i + 8); 108 } 109 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 110 } 111 112 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 113 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 114 MutableArrayRef<Value> vs) { 115 #ifndef NDEBUG 116 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 117 assert(vs.size() == 4 && "expects 4 vectors"); 118 assert(llvm::all_of(ValueRange{vs}.getTypes(), 119 [&](Type t) { return t == vt; }) && 120 "expects all types to be vector<8xf32>"); 121 #endif 122 123 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 124 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 125 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 126 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 127 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>()); 128 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>()); 129 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>()); 130 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>()); 131 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>()); 132 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>()); 133 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>()); 134 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>()); 135 } 136 137 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 138 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 139 MutableArrayRef<Value> vs) { 140 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 141 (void)vt; 142 assert(vs.size() == 8 && "expects 8 vectors"); 143 assert(llvm::all_of(ValueRange{vs}.getTypes(), 144 [&](Type t) { return t == vt; }) && 145 "expects all types to be vector<8xf32>"); 146 147 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 148 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 149 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 150 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 151 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 152 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 153 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 154 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 155 156 using inline_asm::mm256BlendPsAsm; 157 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>()); 158 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>()); 159 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>()); 160 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>()); 161 162 Value s0 = 163 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 164 Value s1 = 165 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 166 Value s2 = 167 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 168 Value s3 = 169 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 170 Value s4 = 171 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 172 Value s5 = 173 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 174 Value s6 = 175 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); 176 Value s7 = 177 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); 178 179 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>()); 180 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>()); 181 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>()); 182 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>()); 183 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>()); 184 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>()); 185 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>()); 186 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); 187 } 188 189 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and 190 /// depending on the `TransposeLoweringOptions`. 191 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 192 public: 193 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 194 195 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 196 int benefit) 197 : OpRewritePattern<vector::TransposeOp>(context, benefit), 198 loweringOptions(loweringOptions) {} 199 200 LogicalResult matchAndRewrite(vector::TransposeOp op, 201 PatternRewriter &rewriter) const override { 202 auto loc = op.getLoc(); 203 204 VectorType srcType = op.getVectorType(); 205 if (srcType.getRank() != 2) 206 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); 207 208 SmallVector<int64_t, 4> transp; 209 for (auto attr : op.transp()) 210 transp.push_back(attr.cast<IntegerAttr>().getInt()); 211 if (transp[0] != 1 && transp[1] != 0) 212 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); 213 214 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 215 216 auto applyRewrite = [&]() { 217 ImplicitLocOpBuilder ib(loc, rewriter); 218 SmallVector<Value> vs; 219 for (int64_t i = 0; i < m; ++i) 220 vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i)); 221 if (m == 4) 222 transpose4x8xf32(ib, vs); 223 if (m == 8) 224 transpose8x8xf32(ib, vs); 225 auto flattenedType = 226 VectorType::get({n * m}, op.getVectorType().getElementType()); 227 auto transposedType = 228 VectorType::get({n, m}, op.getVectorType().getElementType()); 229 Value res = ib.create<arith::ConstantOp>( 230 op.getVectorType(), ib.getZeroAttr(op.getVectorType())); 231 // The transposed form is still 4x8 and needs to be reinterpreted as 8x4 232 // via shape_casts. 233 for (int64_t i = 0; i < m; ++i) 234 res = ib.create<vector::InsertOp>(vs[i], res, i); 235 if (m == 4) { 236 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 237 res = ib.create<vector::ShapeCastOp>(transposedType, res); 238 } 239 240 rewriter.replaceOp(op, res); 241 return success(); 242 }; 243 244 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 245 return applyRewrite(); 246 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 247 return applyRewrite(); 248 return failure(); 249 } 250 251 private: 252 LoweringOptions loweringOptions; 253 }; 254 255 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 256 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 257 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 258 } 259