134ff8573SNicolas Vasilache //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
234ff8573SNicolas Vasilache //
334ff8573SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
434ff8573SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
534ff8573SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
634ff8573SNicolas Vasilache //
734ff8573SNicolas Vasilache //===----------------------------------------------------------------------===//
834ff8573SNicolas Vasilache //
934ff8573SNicolas Vasilache // This file implements vector.transpose rewrites as AVX patterns for particular
1034ff8573SNicolas Vasilache // sizes of interest.
1134ff8573SNicolas Vasilache //
1234ff8573SNicolas Vasilache //===----------------------------------------------------------------------===//
1334ff8573SNicolas Vasilache
14*eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15b2729fdaSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1734ff8573SNicolas Vasilache #include "mlir/Dialect/X86Vector/Transforms.h"
1834ff8573SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
1934ff8573SNicolas Vasilache #include "mlir/IR/Matchers.h"
2034ff8573SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
21b2729fdaSNicolas Vasilache #include "llvm/Support/Format.h"
22b2729fdaSNicolas Vasilache #include "llvm/Support/FormatVariadic.h"
2334ff8573SNicolas Vasilache
2434ff8573SNicolas Vasilache using namespace mlir;
2534ff8573SNicolas Vasilache using namespace mlir::vector;
2634ff8573SNicolas Vasilache using namespace mlir::x86vector;
2734ff8573SNicolas Vasilache using namespace mlir::x86vector::avx2;
28b2729fdaSNicolas Vasilache using namespace mlir::x86vector::avx2::inline_asm;
29b2729fdaSNicolas Vasilache using namespace mlir::x86vector::avx2::intrin;
3034ff8573SNicolas Vasilache
mm256BlendPsAsm(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)31b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
32b2729fdaSNicolas Vasilache ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
33b2729fdaSNicolas Vasilache auto asmDialectAttr =
34b2729fdaSNicolas Vasilache LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
3502b6fb21SMehdi Amini const auto *asmTp = "vblendps $0, $1, $2, {0}";
3602b6fb21SMehdi Amini const auto *asmCstr =
3702b6fb21SMehdi Amini "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
38b2729fdaSNicolas Vasilache SmallVector<Value> asmVals{v1, v2};
39b2729fdaSNicolas Vasilache auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
40b2729fdaSNicolas Vasilache auto asmOp = b.create<LLVM::InlineAsmOp>(
4142398b51SNicolas Vasilache v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
4242398b51SNicolas Vasilache /*constraints=*/asmCstr, /*has_side_effects=*/false,
4342398b51SNicolas Vasilache /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
4442398b51SNicolas Vasilache /*operand_attrs=*/ArrayAttr());
45b2729fdaSNicolas Vasilache return asmOp.getResult(0);
46b2729fdaSNicolas Vasilache }
47b2729fdaSNicolas Vasilache
mm256UnpackLoPs(ImplicitLocOpBuilder & b,Value v1,Value v2)48b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
49b2729fdaSNicolas Vasilache Value v1, Value v2) {
5034ff8573SNicolas Vasilache return b.create<vector::ShuffleOp>(
5134ff8573SNicolas Vasilache v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
5234ff8573SNicolas Vasilache }
5334ff8573SNicolas Vasilache
mm256UnpackHiPs(ImplicitLocOpBuilder & b,Value v1,Value v2)54b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
55b2729fdaSNicolas Vasilache Value v1, Value v2) {
5634ff8573SNicolas Vasilache return b.create<vector::ShuffleOp>(
5734ff8573SNicolas Vasilache v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
5834ff8573SNicolas Vasilache }
5934ff8573SNicolas Vasilache /// a a b b a a b b
6034ff8573SNicolas Vasilache /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
6134ff8573SNicolas Vasilache /// 0:127 | 128:255
6234ff8573SNicolas Vasilache /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
mm256ShufflePs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)63b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
64b2729fdaSNicolas Vasilache Value v1, Value v2,
65b2729fdaSNicolas Vasilache uint8_t mask) {
66b2729fdaSNicolas Vasilache uint8_t b01, b23, b45, b67;
6734ff8573SNicolas Vasilache MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
6834ff8573SNicolas Vasilache SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
6934ff8573SNicolas Vasilache b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
7034ff8573SNicolas Vasilache return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
7134ff8573SNicolas Vasilache }
7234ff8573SNicolas Vasilache
7334ff8573SNicolas Vasilache // imm[0:1] out of imm[0:3] is:
7434ff8573SNicolas Vasilache // 0 1 2 3
7534ff8573SNicolas Vasilache // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
7634ff8573SNicolas Vasilache // a[0:127] or a[128:255] or b[0:127] or b[128:255]
7734ff8573SNicolas Vasilache // 0 1 2 3
7834ff8573SNicolas Vasilache // imm[0:1] out of imm[4:7].
mm256Permute2f128Ps(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)79b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
80b2729fdaSNicolas Vasilache ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
8134ff8573SNicolas Vasilache SmallVector<int64_t> shuffleMask;
82b2729fdaSNicolas Vasilache auto appendToMask = [&](uint8_t control) {
8334ff8573SNicolas Vasilache if (control == 0)
8434ff8573SNicolas Vasilache llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
8534ff8573SNicolas Vasilache else if (control == 1)
8634ff8573SNicolas Vasilache llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
8734ff8573SNicolas Vasilache else if (control == 2)
8834ff8573SNicolas Vasilache llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
8934ff8573SNicolas Vasilache else if (control == 3)
9034ff8573SNicolas Vasilache llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
9134ff8573SNicolas Vasilache else
9234ff8573SNicolas Vasilache llvm_unreachable("control > 3 : overflow");
9334ff8573SNicolas Vasilache };
94b2729fdaSNicolas Vasilache uint8_t b03, b47;
9534ff8573SNicolas Vasilache MaskHelper::extractPermute(mask, b03, b47);
9634ff8573SNicolas Vasilache appendToMask(b03);
9734ff8573SNicolas Vasilache appendToMask(b47);
9834ff8573SNicolas Vasilache return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
9934ff8573SNicolas Vasilache }
10034ff8573SNicolas Vasilache
101b2729fdaSNicolas Vasilache /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
mm256BlendPs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)102b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
103b2729fdaSNicolas Vasilache Value v1, Value v2,
104b2729fdaSNicolas Vasilache uint8_t mask) {
105b2729fdaSNicolas Vasilache SmallVector<int64_t, 8> shuffleMask;
106b2729fdaSNicolas Vasilache for (int i = 0; i < 8; ++i) {
107b2729fdaSNicolas Vasilache bool isSet = mask & (1 << i);
108b2729fdaSNicolas Vasilache shuffleMask.push_back(!isSet ? i : i + 8);
109b2729fdaSNicolas Vasilache }
110b2729fdaSNicolas Vasilache return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
111b2729fdaSNicolas Vasilache }
112b2729fdaSNicolas Vasilache
11334ff8573SNicolas Vasilache /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose4x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)11434ff8573SNicolas Vasilache void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
11534ff8573SNicolas Vasilache MutableArrayRef<Value> vs) {
11634ff8573SNicolas Vasilache #ifndef NDEBUG
117f04a1237SBenjamin Kramer auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
11834ff8573SNicolas Vasilache assert(vs.size() == 4 && "expects 4 vectors");
11934ff8573SNicolas Vasilache assert(llvm::all_of(ValueRange{vs}.getTypes(),
12034ff8573SNicolas Vasilache [&](Type t) { return t == vt; }) &&
12134ff8573SNicolas Vasilache "expects all types to be vector<8xf32>");
12234ff8573SNicolas Vasilache #endif
12334ff8573SNicolas Vasilache
12402b6fb21SMehdi Amini Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
12502b6fb21SMehdi Amini Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
12602b6fb21SMehdi Amini Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
12702b6fb21SMehdi Amini Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
12802b6fb21SMehdi Amini Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
12902b6fb21SMehdi Amini Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
13002b6fb21SMehdi Amini Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
13102b6fb21SMehdi Amini Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
13202b6fb21SMehdi Amini vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
13302b6fb21SMehdi Amini vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
13402b6fb21SMehdi Amini vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
13502b6fb21SMehdi Amini vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
13634ff8573SNicolas Vasilache }
13734ff8573SNicolas Vasilache
13834ff8573SNicolas Vasilache /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose8x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)13934ff8573SNicolas Vasilache void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
14034ff8573SNicolas Vasilache MutableArrayRef<Value> vs) {
14134ff8573SNicolas Vasilache auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
14234ff8573SNicolas Vasilache (void)vt;
14334ff8573SNicolas Vasilache assert(vs.size() == 8 && "expects 8 vectors");
14434ff8573SNicolas Vasilache assert(llvm::all_of(ValueRange{vs}.getTypes(),
14534ff8573SNicolas Vasilache [&](Type t) { return t == vt; }) &&
14634ff8573SNicolas Vasilache "expects all types to be vector<8xf32>");
14734ff8573SNicolas Vasilache
14802b6fb21SMehdi Amini Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
14902b6fb21SMehdi Amini Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
15002b6fb21SMehdi Amini Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
15102b6fb21SMehdi Amini Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
15202b6fb21SMehdi Amini Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
15302b6fb21SMehdi Amini Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
15402b6fb21SMehdi Amini Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
15502b6fb21SMehdi Amini Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156b2729fdaSNicolas Vasilache
157b2729fdaSNicolas Vasilache using inline_asm::mm256BlendPsAsm;
15802b6fb21SMehdi Amini Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
15902b6fb21SMehdi Amini Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
16002b6fb21SMehdi Amini Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
16102b6fb21SMehdi Amini Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
162b2729fdaSNicolas Vasilache
16302b6fb21SMehdi Amini Value s0 =
16402b6fb21SMehdi Amini mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
16502b6fb21SMehdi Amini Value s1 =
16602b6fb21SMehdi Amini mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
16702b6fb21SMehdi Amini Value s2 =
16802b6fb21SMehdi Amini mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
16902b6fb21SMehdi Amini Value s3 =
17002b6fb21SMehdi Amini mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
17102b6fb21SMehdi Amini Value s4 =
17202b6fb21SMehdi Amini mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
17302b6fb21SMehdi Amini Value s5 =
17402b6fb21SMehdi Amini mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
17502b6fb21SMehdi Amini Value s6 =
17602b6fb21SMehdi Amini mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
17702b6fb21SMehdi Amini Value s7 =
17802b6fb21SMehdi Amini mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
179b2729fdaSNicolas Vasilache
18002b6fb21SMehdi Amini vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
18102b6fb21SMehdi Amini vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
18202b6fb21SMehdi Amini vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
18302b6fb21SMehdi Amini vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
18402b6fb21SMehdi Amini vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
18502b6fb21SMehdi Amini vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
18602b6fb21SMehdi Amini vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
18702b6fb21SMehdi Amini vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
18834ff8573SNicolas Vasilache }
18934ff8573SNicolas Vasilache
190d7e0a084SDiego Caballero /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
191d7e0a084SDiego Caballero /// should be transposed with each other within the context of their 2D
192d7e0a084SDiego Caballero /// transposition slice.
193d7e0a084SDiego Caballero ///
194d7e0a084SDiego Caballero /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
195d7e0a084SDiego Caballero /// Return true: dim0 and dim1 are transposed within the context of their 2D
196d7e0a084SDiego Caballero /// transposition slice ([1, 0]).
197d7e0a084SDiego Caballero ///
198d7e0a084SDiego Caballero /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
199d7e0a084SDiego Caballero /// Return true: dim0 and dim1 are transposed within the context of their 2D
200d7e0a084SDiego Caballero /// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
201d7e0a084SDiego Caballero /// transposed within the full context of the transposition.
202d7e0a084SDiego Caballero ///
203d7e0a084SDiego Caballero /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
204d7e0a084SDiego Caballero /// Return false: dim0 and dim1 are *not* transposed within the context of
205d7e0a084SDiego Caballero /// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
206d7e0a084SDiego Caballero /// and dim1 (1) are transposed within the full context of the of the
207d7e0a084SDiego Caballero /// transposition.
areDimsTransposedIn2DSlice(int64_t dim0,int64_t dim1,ArrayRef<int64_t> transp)208d7e0a084SDiego Caballero static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
209d7e0a084SDiego Caballero ArrayRef<int64_t> transp) {
210d7e0a084SDiego Caballero // Perform a linear scan along the dimensions of the transposed pattern. If
211d7e0a084SDiego Caballero // dim0 is found first, dim0 and dim1 are not transposed within the context of
212d7e0a084SDiego Caballero // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
213d7e0a084SDiego Caballero for (int64_t permDim : transp) {
214d7e0a084SDiego Caballero if (permDim == dim0)
215d7e0a084SDiego Caballero return false;
216d7e0a084SDiego Caballero if (permDim == dim1)
217d7e0a084SDiego Caballero return true;
218d7e0a084SDiego Caballero }
219d7e0a084SDiego Caballero
220d7e0a084SDiego Caballero llvm_unreachable("Ill-formed transpose pattern");
221d7e0a084SDiego Caballero }
222d7e0a084SDiego Caballero
223d7e0a084SDiego Caballero /// Rewrite AVX2-specific vector.transpose, for the supported cases and
224d7e0a084SDiego Caballero /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
225d7e0a084SDiego Caballero /// transpose cases and n-D cases that have been decomposed into 2-D
226d7e0a084SDiego Caballero /// transposition slices. For example, a 3-D transpose:
227d7e0a084SDiego Caballero ///
228d7e0a084SDiego Caballero /// %0 = vector.transpose %arg0, [2, 0, 1]
229d7e0a084SDiego Caballero /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
230d7e0a084SDiego Caballero ///
231d7e0a084SDiego Caballero /// could be sliced into 2-D transposes by tiling two of its dimensions to one
232d7e0a084SDiego Caballero /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
233d7e0a084SDiego Caballero ///
234d7e0a084SDiego Caballero /// %0 = vector.transpose %arg0, [2, 0, 1]
235d7e0a084SDiego Caballero /// : vector<1x4x8xf32> to vector<8x1x4xf32>
236d7e0a084SDiego Caballero ///
237d7e0a084SDiego Caballero /// This lowering will analyze the n-D vector.transpose and determine if it's a
238d7e0a084SDiego Caballero /// supported 2-D transposition slice where any of the AVX2 patterns can be
239d7e0a084SDiego Caballero /// applied.
24034ff8573SNicolas Vasilache class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
24134ff8573SNicolas Vasilache public:
24234ff8573SNicolas Vasilache using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
24334ff8573SNicolas Vasilache
TransposeOpLowering(LoweringOptions loweringOptions,MLIRContext * context,int benefit)24434ff8573SNicolas Vasilache TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
24534ff8573SNicolas Vasilache int benefit)
24634ff8573SNicolas Vasilache : OpRewritePattern<vector::TransposeOp>(context, benefit),
24734ff8573SNicolas Vasilache loweringOptions(loweringOptions) {}
24834ff8573SNicolas Vasilache
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const24934ff8573SNicolas Vasilache LogicalResult matchAndRewrite(vector::TransposeOp op,
25034ff8573SNicolas Vasilache PatternRewriter &rewriter) const override {
25134ff8573SNicolas Vasilache auto loc = op.getLoc();
25234ff8573SNicolas Vasilache
253d7e0a084SDiego Caballero // Check if the source vector type is supported. AVX2 patterns can only be
254875bbce9SDiego Caballero // applied to f32 vector types with two dimensions greater than one.
25534ff8573SNicolas Vasilache VectorType srcType = op.getVectorType();
256875bbce9SDiego Caballero if (!srcType.getElementType().isF32())
257875bbce9SDiego Caballero return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
258875bbce9SDiego Caballero
259d7e0a084SDiego Caballero SmallVector<int64_t> srcGtOneDims;
260d7e0a084SDiego Caballero for (auto &en : llvm::enumerate(srcType.getShape()))
261d7e0a084SDiego Caballero if (en.value() > 1)
262d7e0a084SDiego Caballero srcGtOneDims.push_back(en.index());
263d7e0a084SDiego Caballero
264d7e0a084SDiego Caballero if (srcGtOneDims.size() != 2)
265d7e0a084SDiego Caballero return rewriter.notifyMatchFailure(op, "Unsupported vector type");
26634ff8573SNicolas Vasilache
26734ff8573SNicolas Vasilache SmallVector<int64_t, 4> transp;
2687c38fd60SJacques Pienaar for (auto attr : op.getTransp())
26934ff8573SNicolas Vasilache transp.push_back(attr.cast<IntegerAttr>().getInt());
27034ff8573SNicolas Vasilache
271d7e0a084SDiego Caballero // Check whether the two source vector dimensions that are greater than one
272d7e0a084SDiego Caballero // must be transposed with each other so that we can apply one of the 2-D
273d7e0a084SDiego Caballero // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
274d7e0a084SDiego Caballero if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
275d7e0a084SDiego Caballero return rewriter.notifyMatchFailure(
276d7e0a084SDiego Caballero op, "Not applicable to this transpose permutation");
277d7e0a084SDiego Caballero
278d7e0a084SDiego Caballero // Retrieve the sizes of the two dimensions greater than one to be
279d7e0a084SDiego Caballero // transposed.
280d7e0a084SDiego Caballero auto srcShape = srcType.getShape();
281d7e0a084SDiego Caballero int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
28234ff8573SNicolas Vasilache
28334ff8573SNicolas Vasilache auto applyRewrite = [&]() {
28434ff8573SNicolas Vasilache ImplicitLocOpBuilder ib(loc, rewriter);
28534ff8573SNicolas Vasilache SmallVector<Value> vs;
286d7e0a084SDiego Caballero
287d7e0a084SDiego Caballero // Reshape the n-D input vector with only two dimensions greater than one
288d7e0a084SDiego Caballero // to a 2-D vector.
289d7e0a084SDiego Caballero auto flattenedType =
290d7e0a084SDiego Caballero VectorType::get({n * m}, op.getVectorType().getElementType());
291d7e0a084SDiego Caballero auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
292d7e0a084SDiego Caballero auto reshInput =
2937c38fd60SJacques Pienaar ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
294d7e0a084SDiego Caballero reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
295d7e0a084SDiego Caballero
296d7e0a084SDiego Caballero // Extract 1-D vectors from the higher-order dimension of the input
297d7e0a084SDiego Caballero // vector.
29834ff8573SNicolas Vasilache for (int64_t i = 0; i < m; ++i)
299d7e0a084SDiego Caballero vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
300d7e0a084SDiego Caballero
301d7e0a084SDiego Caballero // Transpose set of 1-D vectors.
30234ff8573SNicolas Vasilache if (m == 4)
30334ff8573SNicolas Vasilache transpose4x8xf32(ib, vs);
30434ff8573SNicolas Vasilache if (m == 8)
30534ff8573SNicolas Vasilache transpose8x8xf32(ib, vs);
306d7e0a084SDiego Caballero
307d7e0a084SDiego Caballero // Insert transposed 1-D vectors into the higher-order dimension of the
308d7e0a084SDiego Caballero // output vector.
309d7e0a084SDiego Caballero Value res = ib.create<arith::ConstantOp>(reshInputType,
310d7e0a084SDiego Caballero ib.getZeroAttr(reshInputType));
31134ff8573SNicolas Vasilache for (int64_t i = 0; i < m; ++i)
31234ff8573SNicolas Vasilache res = ib.create<vector::InsertOp>(vs[i], res, i);
31334ff8573SNicolas Vasilache
314d7e0a084SDiego Caballero // The output vector still has the shape of the input vector (e.g., 4x8).
315d7e0a084SDiego Caballero // We have to transpose their dimensions and retrieve its original rank
316d7e0a084SDiego Caballero // (e.g., 1x8x1x4x1).
317d7e0a084SDiego Caballero res = ib.create<vector::ShapeCastOp>(flattenedType, res);
318d7e0a084SDiego Caballero res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
31934ff8573SNicolas Vasilache rewriter.replaceOp(op, res);
32034ff8573SNicolas Vasilache return success();
32134ff8573SNicolas Vasilache };
32234ff8573SNicolas Vasilache
32334ff8573SNicolas Vasilache if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
32434ff8573SNicolas Vasilache return applyRewrite();
32534ff8573SNicolas Vasilache if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
32634ff8573SNicolas Vasilache return applyRewrite();
32734ff8573SNicolas Vasilache return failure();
32834ff8573SNicolas Vasilache }
32934ff8573SNicolas Vasilache
33034ff8573SNicolas Vasilache private:
33134ff8573SNicolas Vasilache LoweringOptions loweringOptions;
33234ff8573SNicolas Vasilache };
33334ff8573SNicolas Vasilache
populateSpecializedTransposeLoweringPatterns(RewritePatternSet & patterns,LoweringOptions options,int benefit)33434ff8573SNicolas Vasilache void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
33534ff8573SNicolas Vasilache RewritePatternSet &patterns, LoweringOptions options, int benefit) {
33634ff8573SNicolas Vasilache patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
33734ff8573SNicolas Vasilache }
338