1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements vector.transpose rewrites as AVX patterns for particular 10 // sizes of interest. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/VectorOps.h" 15 #include "mlir/Dialect/X86Vector/Transforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 using namespace mlir::x86vector; 23 using namespace mlir::x86vector::avx2; 24 25 Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, 26 Value v2) { 27 return b.create<vector::ShuffleOp>( 28 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); 29 } 30 31 Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, 32 Value v2) { 33 return b.create<vector::ShuffleOp>( 34 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); 35 } 36 /// a a b b a a b b 37 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): 38 /// 0:127 | 128:255 39 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 40 Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, 41 Value v2, int8_t mask) { 42 int8_t b01, b23, b45, b67; 43 MaskHelper::extractShuffle(mask, b01, b23, b45, b67); 44 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8, 45 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; 46 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 47 } 48 49 // imm[0:1] out of imm[0:3] is: 50 // 0 1 2 3 51 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | 52 // a[0:127] or a[128:255] or b[0:127] or b[128:255] 53 // 0 1 2 3 54 // imm[0:1] out of imm[4:7]. 55 Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b, 56 Value v1, Value v2, 57 int8_t mask) { 58 SmallVector<int64_t> shuffleMask; 59 auto appendToMask = [&](int8_t control) { 60 if (control == 0) 61 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3}); 62 else if (control == 1) 63 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7}); 64 else if (control == 2) 65 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11}); 66 else if (control == 3) 67 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15}); 68 else 69 llvm_unreachable("control > 3 : overflow"); 70 }; 71 int8_t b03, b47; 72 MaskHelper::extractPermute(mask, b03, b47); 73 appendToMask(b03); 74 appendToMask(b47); 75 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask); 76 } 77 78 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. 79 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, 80 MutableArrayRef<Value> vs) { 81 #ifndef NDEBUG 82 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 83 assert(vs.size() == 4 && "expects 4 vectors"); 84 assert(llvm::all_of(ValueRange{vs}.getTypes(), 85 [&](Type t) { return t == vt; }) && 86 "expects all types to be vector<8xf32>"); 87 #endif 88 89 Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 90 Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 91 Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 92 Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 93 Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); 94 Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); 95 Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); 96 Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); 97 vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>()); 98 vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>()); 99 vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>()); 100 vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>()); 101 } 102 103 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. 104 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, 105 MutableArrayRef<Value> vs) { 106 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); 107 (void)vt; 108 assert(vs.size() == 8 && "expects 8 vectors"); 109 assert(llvm::all_of(ValueRange{vs}.getTypes(), 110 [&](Type t) { return t == vt; }) && 111 "expects all types to be vector<8xf32>"); 112 113 Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); 114 Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); 115 Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); 116 Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); 117 Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]); 118 Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]); 119 Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]); 120 Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]); 121 Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); 122 Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); 123 Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); 124 Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); 125 Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>()); 126 Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>()); 127 Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>()); 128 Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>()); 129 vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>()); 130 vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>()); 131 vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>()); 132 vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>()); 133 vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>()); 134 vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>()); 135 vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>()); 136 vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>()); 137 } 138 139 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and 140 /// depending on the `TransposeLoweringOptions`. 141 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 142 public: 143 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 144 145 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, 146 int benefit) 147 : OpRewritePattern<vector::TransposeOp>(context, benefit), 148 loweringOptions(loweringOptions) {} 149 150 LogicalResult matchAndRewrite(vector::TransposeOp op, 151 PatternRewriter &rewriter) const override { 152 auto loc = op.getLoc(); 153 154 VectorType srcType = op.getVectorType(); 155 if (srcType.getRank() != 2) 156 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); 157 158 SmallVector<int64_t, 4> transp; 159 for (auto attr : op.transp()) 160 transp.push_back(attr.cast<IntegerAttr>().getInt()); 161 if (transp[0] != 1 && transp[1] != 0) 162 return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); 163 164 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 165 166 auto applyRewrite = [&]() { 167 ImplicitLocOpBuilder ib(loc, rewriter); 168 SmallVector<Value> vs; 169 for (int64_t i = 0; i < m; ++i) 170 vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i)); 171 if (m == 4) 172 transpose4x8xf32(ib, vs); 173 if (m == 8) 174 transpose8x8xf32(ib, vs); 175 auto flattenedType = 176 VectorType::get({n * m}, op.getVectorType().getElementType()); 177 auto transposedType = 178 VectorType::get({n, m}, op.getVectorType().getElementType()); 179 Value res = ib.create<arith::ConstantOp>( 180 op.getVectorType(), ib.getZeroAttr(op.getVectorType())); 181 // The transposed form is still 4x8 and needs to be reinterpreted as 8x4 182 // via shape_casts. 183 for (int64_t i = 0; i < m; ++i) 184 res = ib.create<vector::InsertOp>(vs[i], res, i); 185 if (m == 4) { 186 res = ib.create<vector::ShapeCastOp>(flattenedType, res); 187 res = ib.create<vector::ShapeCastOp>(transposedType, res); 188 } 189 190 rewriter.replaceOp(op, res); 191 return success(); 192 }; 193 194 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) 195 return applyRewrite(); 196 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) 197 return applyRewrite(); 198 return failure(); 199 } 200 201 private: 202 LoweringOptions loweringOptions; 203 }; 204 205 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 206 RewritePatternSet &patterns, LoweringOptions options, int benefit) { 207 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); 208 } 209