134ff8573SNicolas Vasilache //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
234ff8573SNicolas Vasilache //
334ff8573SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
434ff8573SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
534ff8573SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
634ff8573SNicolas Vasilache //
734ff8573SNicolas Vasilache //===----------------------------------------------------------------------===//
834ff8573SNicolas Vasilache //
934ff8573SNicolas Vasilache // This file implements vector.transpose rewrites as AVX patterns for particular
1034ff8573SNicolas Vasilache // sizes of interest.
1134ff8573SNicolas Vasilache //
1234ff8573SNicolas Vasilache //===----------------------------------------------------------------------===//
1334ff8573SNicolas Vasilache 
14*eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15b2729fdaSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1734ff8573SNicolas Vasilache #include "mlir/Dialect/X86Vector/Transforms.h"
1834ff8573SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
1934ff8573SNicolas Vasilache #include "mlir/IR/Matchers.h"
2034ff8573SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
21b2729fdaSNicolas Vasilache #include "llvm/Support/Format.h"
22b2729fdaSNicolas Vasilache #include "llvm/Support/FormatVariadic.h"
2334ff8573SNicolas Vasilache 
2434ff8573SNicolas Vasilache using namespace mlir;
2534ff8573SNicolas Vasilache using namespace mlir::vector;
2634ff8573SNicolas Vasilache using namespace mlir::x86vector;
2734ff8573SNicolas Vasilache using namespace mlir::x86vector::avx2;
28b2729fdaSNicolas Vasilache using namespace mlir::x86vector::avx2::inline_asm;
29b2729fdaSNicolas Vasilache using namespace mlir::x86vector::avx2::intrin;
3034ff8573SNicolas Vasilache 
mm256BlendPsAsm(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)31b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
32b2729fdaSNicolas Vasilache     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
33b2729fdaSNicolas Vasilache   auto asmDialectAttr =
34b2729fdaSNicolas Vasilache       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
3502b6fb21SMehdi Amini   const auto *asmTp = "vblendps $0, $1, $2, {0}";
3602b6fb21SMehdi Amini   const auto *asmCstr =
3702b6fb21SMehdi Amini       "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
38b2729fdaSNicolas Vasilache   SmallVector<Value> asmVals{v1, v2};
39b2729fdaSNicolas Vasilache   auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
40b2729fdaSNicolas Vasilache   auto asmOp = b.create<LLVM::InlineAsmOp>(
4142398b51SNicolas Vasilache       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
4242398b51SNicolas Vasilache       /*constraints=*/asmCstr, /*has_side_effects=*/false,
4342398b51SNicolas Vasilache       /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
4442398b51SNicolas Vasilache       /*operand_attrs=*/ArrayAttr());
45b2729fdaSNicolas Vasilache   return asmOp.getResult(0);
46b2729fdaSNicolas Vasilache }
47b2729fdaSNicolas Vasilache 
mm256UnpackLoPs(ImplicitLocOpBuilder & b,Value v1,Value v2)48b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
49b2729fdaSNicolas Vasilache                                                      Value v1, Value v2) {
5034ff8573SNicolas Vasilache   return b.create<vector::ShuffleOp>(
5134ff8573SNicolas Vasilache       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
5234ff8573SNicolas Vasilache }
5334ff8573SNicolas Vasilache 
mm256UnpackHiPs(ImplicitLocOpBuilder & b,Value v1,Value v2)54b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
55b2729fdaSNicolas Vasilache                                                      Value v1, Value v2) {
5634ff8573SNicolas Vasilache   return b.create<vector::ShuffleOp>(
5734ff8573SNicolas Vasilache       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
5834ff8573SNicolas Vasilache }
5934ff8573SNicolas Vasilache ///                            a  a   b   b  a  a   b   b
6034ff8573SNicolas Vasilache /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
6134ff8573SNicolas Vasilache ///                                 0:127    |         128:255
6234ff8573SNicolas Vasilache ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
mm256ShufflePs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)63b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
64b2729fdaSNicolas Vasilache                                                     Value v1, Value v2,
65b2729fdaSNicolas Vasilache                                                     uint8_t mask) {
66b2729fdaSNicolas Vasilache   uint8_t b01, b23, b45, b67;
6734ff8573SNicolas Vasilache   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
6834ff8573SNicolas Vasilache   SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
6934ff8573SNicolas Vasilache                                    b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
7034ff8573SNicolas Vasilache   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
7134ff8573SNicolas Vasilache }
7234ff8573SNicolas Vasilache 
7334ff8573SNicolas Vasilache // imm[0:1] out of imm[0:3] is:
7434ff8573SNicolas Vasilache //    0             1           2             3
7534ff8573SNicolas Vasilache // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
7634ff8573SNicolas Vasilache //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
7734ff8573SNicolas Vasilache //             0             1           2             3
7834ff8573SNicolas Vasilache // imm[0:1] out of imm[4:7].
mm256Permute2f128Ps(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)79b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
80b2729fdaSNicolas Vasilache     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
8134ff8573SNicolas Vasilache   SmallVector<int64_t> shuffleMask;
82b2729fdaSNicolas Vasilache   auto appendToMask = [&](uint8_t control) {
8334ff8573SNicolas Vasilache     if (control == 0)
8434ff8573SNicolas Vasilache       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
8534ff8573SNicolas Vasilache     else if (control == 1)
8634ff8573SNicolas Vasilache       llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
8734ff8573SNicolas Vasilache     else if (control == 2)
8834ff8573SNicolas Vasilache       llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
8934ff8573SNicolas Vasilache     else if (control == 3)
9034ff8573SNicolas Vasilache       llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
9134ff8573SNicolas Vasilache     else
9234ff8573SNicolas Vasilache       llvm_unreachable("control > 3 : overflow");
9334ff8573SNicolas Vasilache   };
94b2729fdaSNicolas Vasilache   uint8_t b03, b47;
9534ff8573SNicolas Vasilache   MaskHelper::extractPermute(mask, b03, b47);
9634ff8573SNicolas Vasilache   appendToMask(b03);
9734ff8573SNicolas Vasilache   appendToMask(b47);
9834ff8573SNicolas Vasilache   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
9934ff8573SNicolas Vasilache }
10034ff8573SNicolas Vasilache 
101b2729fdaSNicolas Vasilache /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
mm256BlendPs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)102b2729fdaSNicolas Vasilache Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
103b2729fdaSNicolas Vasilache                                                   Value v1, Value v2,
104b2729fdaSNicolas Vasilache                                                   uint8_t mask) {
105b2729fdaSNicolas Vasilache   SmallVector<int64_t, 8> shuffleMask;
106b2729fdaSNicolas Vasilache   for (int i = 0; i < 8; ++i) {
107b2729fdaSNicolas Vasilache     bool isSet = mask & (1 << i);
108b2729fdaSNicolas Vasilache     shuffleMask.push_back(!isSet ? i : i + 8);
109b2729fdaSNicolas Vasilache   }
110b2729fdaSNicolas Vasilache   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
111b2729fdaSNicolas Vasilache }
112b2729fdaSNicolas Vasilache 
11334ff8573SNicolas Vasilache /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose4x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)11434ff8573SNicolas Vasilache void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
11534ff8573SNicolas Vasilache                                              MutableArrayRef<Value> vs) {
11634ff8573SNicolas Vasilache #ifndef NDEBUG
117f04a1237SBenjamin Kramer   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
11834ff8573SNicolas Vasilache   assert(vs.size() == 4 && "expects 4 vectors");
11934ff8573SNicolas Vasilache   assert(llvm::all_of(ValueRange{vs}.getTypes(),
12034ff8573SNicolas Vasilache                       [&](Type t) { return t == vt; }) &&
12134ff8573SNicolas Vasilache          "expects all types to be vector<8xf32>");
12234ff8573SNicolas Vasilache #endif
12334ff8573SNicolas Vasilache 
12402b6fb21SMehdi Amini   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
12502b6fb21SMehdi Amini   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
12602b6fb21SMehdi Amini   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
12702b6fb21SMehdi Amini   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
12802b6fb21SMehdi Amini   Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
12902b6fb21SMehdi Amini   Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
13002b6fb21SMehdi Amini   Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
13102b6fb21SMehdi Amini   Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
13202b6fb21SMehdi Amini   vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
13302b6fb21SMehdi Amini   vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
13402b6fb21SMehdi Amini   vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
13502b6fb21SMehdi Amini   vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
13634ff8573SNicolas Vasilache }
13734ff8573SNicolas Vasilache 
13834ff8573SNicolas Vasilache /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose8x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)13934ff8573SNicolas Vasilache void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
14034ff8573SNicolas Vasilache                                              MutableArrayRef<Value> vs) {
14134ff8573SNicolas Vasilache   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
14234ff8573SNicolas Vasilache   (void)vt;
14334ff8573SNicolas Vasilache   assert(vs.size() == 8 && "expects 8 vectors");
14434ff8573SNicolas Vasilache   assert(llvm::all_of(ValueRange{vs}.getTypes(),
14534ff8573SNicolas Vasilache                       [&](Type t) { return t == vt; }) &&
14634ff8573SNicolas Vasilache          "expects all types to be vector<8xf32>");
14734ff8573SNicolas Vasilache 
14802b6fb21SMehdi Amini   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
14902b6fb21SMehdi Amini   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
15002b6fb21SMehdi Amini   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
15102b6fb21SMehdi Amini   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
15202b6fb21SMehdi Amini   Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
15302b6fb21SMehdi Amini   Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
15402b6fb21SMehdi Amini   Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
15502b6fb21SMehdi Amini   Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156b2729fdaSNicolas Vasilache 
157b2729fdaSNicolas Vasilache   using inline_asm::mm256BlendPsAsm;
15802b6fb21SMehdi Amini   Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
15902b6fb21SMehdi Amini   Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
16002b6fb21SMehdi Amini   Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
16102b6fb21SMehdi Amini   Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
162b2729fdaSNicolas Vasilache 
16302b6fb21SMehdi Amini   Value s0 =
16402b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
16502b6fb21SMehdi Amini   Value s1 =
16602b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
16702b6fb21SMehdi Amini   Value s2 =
16802b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
16902b6fb21SMehdi Amini   Value s3 =
17002b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
17102b6fb21SMehdi Amini   Value s4 =
17202b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
17302b6fb21SMehdi Amini   Value s5 =
17402b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
17502b6fb21SMehdi Amini   Value s6 =
17602b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
17702b6fb21SMehdi Amini   Value s7 =
17802b6fb21SMehdi Amini       mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
179b2729fdaSNicolas Vasilache 
18002b6fb21SMehdi Amini   vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
18102b6fb21SMehdi Amini   vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
18202b6fb21SMehdi Amini   vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
18302b6fb21SMehdi Amini   vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
18402b6fb21SMehdi Amini   vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
18502b6fb21SMehdi Amini   vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
18602b6fb21SMehdi Amini   vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
18702b6fb21SMehdi Amini   vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
18834ff8573SNicolas Vasilache }
18934ff8573SNicolas Vasilache 
190d7e0a084SDiego Caballero /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
191d7e0a084SDiego Caballero /// should be transposed with each other within the context of their 2D
192d7e0a084SDiego Caballero /// transposition slice.
193d7e0a084SDiego Caballero ///
194d7e0a084SDiego Caballero /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
195d7e0a084SDiego Caballero ///   Return true: dim0 and dim1 are transposed within the context of their 2D
196d7e0a084SDiego Caballero ///   transposition slice ([1, 0]).
197d7e0a084SDiego Caballero ///
198d7e0a084SDiego Caballero /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
199d7e0a084SDiego Caballero ///   Return true: dim0 and dim1 are transposed within the context of their 2D
200d7e0a084SDiego Caballero ///   transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
201d7e0a084SDiego Caballero ///   transposed within the full context of the transposition.
202d7e0a084SDiego Caballero ///
203d7e0a084SDiego Caballero /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
204d7e0a084SDiego Caballero ///   Return false: dim0 and dim1 are *not* transposed within the context of
205d7e0a084SDiego Caballero ///   their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
206d7e0a084SDiego Caballero ///   and dim1 (1) are transposed within the full context of the of the
207d7e0a084SDiego Caballero ///   transposition.
areDimsTransposedIn2DSlice(int64_t dim0,int64_t dim1,ArrayRef<int64_t> transp)208d7e0a084SDiego Caballero static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
209d7e0a084SDiego Caballero                                        ArrayRef<int64_t> transp) {
210d7e0a084SDiego Caballero   // Perform a linear scan along the dimensions of the transposed pattern. If
211d7e0a084SDiego Caballero   // dim0 is found first, dim0 and dim1 are not transposed within the context of
212d7e0a084SDiego Caballero   // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
213d7e0a084SDiego Caballero   for (int64_t permDim : transp) {
214d7e0a084SDiego Caballero     if (permDim == dim0)
215d7e0a084SDiego Caballero       return false;
216d7e0a084SDiego Caballero     if (permDim == dim1)
217d7e0a084SDiego Caballero       return true;
218d7e0a084SDiego Caballero   }
219d7e0a084SDiego Caballero 
220d7e0a084SDiego Caballero   llvm_unreachable("Ill-formed transpose pattern");
221d7e0a084SDiego Caballero }
222d7e0a084SDiego Caballero 
223d7e0a084SDiego Caballero /// Rewrite AVX2-specific vector.transpose, for the supported cases and
224d7e0a084SDiego Caballero /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
225d7e0a084SDiego Caballero /// transpose cases and n-D cases that have been decomposed into 2-D
226d7e0a084SDiego Caballero /// transposition slices. For example, a 3-D transpose:
227d7e0a084SDiego Caballero ///
228d7e0a084SDiego Caballero ///   %0 = vector.transpose %arg0, [2, 0, 1]
229d7e0a084SDiego Caballero ///      : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
230d7e0a084SDiego Caballero ///
231d7e0a084SDiego Caballero /// could be sliced into 2-D transposes by tiling two of its dimensions to one
232d7e0a084SDiego Caballero /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
233d7e0a084SDiego Caballero ///
234d7e0a084SDiego Caballero ///   %0 = vector.transpose %arg0, [2, 0, 1]
235d7e0a084SDiego Caballero ///      : vector<1x4x8xf32> to vector<8x1x4xf32>
236d7e0a084SDiego Caballero ///
237d7e0a084SDiego Caballero /// This lowering will analyze the n-D vector.transpose and determine if it's a
238d7e0a084SDiego Caballero /// supported 2-D transposition slice where any of the AVX2 patterns can be
239d7e0a084SDiego Caballero /// applied.
24034ff8573SNicolas Vasilache class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
24134ff8573SNicolas Vasilache public:
24234ff8573SNicolas Vasilache   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
24334ff8573SNicolas Vasilache 
TransposeOpLowering(LoweringOptions loweringOptions,MLIRContext * context,int benefit)24434ff8573SNicolas Vasilache   TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
24534ff8573SNicolas Vasilache                       int benefit)
24634ff8573SNicolas Vasilache       : OpRewritePattern<vector::TransposeOp>(context, benefit),
24734ff8573SNicolas Vasilache         loweringOptions(loweringOptions) {}
24834ff8573SNicolas Vasilache 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const24934ff8573SNicolas Vasilache   LogicalResult matchAndRewrite(vector::TransposeOp op,
25034ff8573SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
25134ff8573SNicolas Vasilache     auto loc = op.getLoc();
25234ff8573SNicolas Vasilache 
253d7e0a084SDiego Caballero     // Check if the source vector type is supported. AVX2 patterns can only be
254875bbce9SDiego Caballero     // applied to f32 vector types with two dimensions greater than one.
25534ff8573SNicolas Vasilache     VectorType srcType = op.getVectorType();
256875bbce9SDiego Caballero     if (!srcType.getElementType().isF32())
257875bbce9SDiego Caballero       return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
258875bbce9SDiego Caballero 
259d7e0a084SDiego Caballero     SmallVector<int64_t> srcGtOneDims;
260d7e0a084SDiego Caballero     for (auto &en : llvm::enumerate(srcType.getShape()))
261d7e0a084SDiego Caballero       if (en.value() > 1)
262d7e0a084SDiego Caballero         srcGtOneDims.push_back(en.index());
263d7e0a084SDiego Caballero 
264d7e0a084SDiego Caballero     if (srcGtOneDims.size() != 2)
265d7e0a084SDiego Caballero       return rewriter.notifyMatchFailure(op, "Unsupported vector type");
26634ff8573SNicolas Vasilache 
26734ff8573SNicolas Vasilache     SmallVector<int64_t, 4> transp;
2687c38fd60SJacques Pienaar     for (auto attr : op.getTransp())
26934ff8573SNicolas Vasilache       transp.push_back(attr.cast<IntegerAttr>().getInt());
27034ff8573SNicolas Vasilache 
271d7e0a084SDiego Caballero     // Check whether the two source vector dimensions that are greater than one
272d7e0a084SDiego Caballero     // must be transposed with each other so that we can apply one of the 2-D
273d7e0a084SDiego Caballero     // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
274d7e0a084SDiego Caballero     if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
275d7e0a084SDiego Caballero       return rewriter.notifyMatchFailure(
276d7e0a084SDiego Caballero           op, "Not applicable to this transpose permutation");
277d7e0a084SDiego Caballero 
278d7e0a084SDiego Caballero     // Retrieve the sizes of the two dimensions greater than one to be
279d7e0a084SDiego Caballero     // transposed.
280d7e0a084SDiego Caballero     auto srcShape = srcType.getShape();
281d7e0a084SDiego Caballero     int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
28234ff8573SNicolas Vasilache 
28334ff8573SNicolas Vasilache     auto applyRewrite = [&]() {
28434ff8573SNicolas Vasilache       ImplicitLocOpBuilder ib(loc, rewriter);
28534ff8573SNicolas Vasilache       SmallVector<Value> vs;
286d7e0a084SDiego Caballero 
287d7e0a084SDiego Caballero       // Reshape the n-D input vector with only two dimensions greater than one
288d7e0a084SDiego Caballero       // to a 2-D vector.
289d7e0a084SDiego Caballero       auto flattenedType =
290d7e0a084SDiego Caballero           VectorType::get({n * m}, op.getVectorType().getElementType());
291d7e0a084SDiego Caballero       auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
292d7e0a084SDiego Caballero       auto reshInput =
2937c38fd60SJacques Pienaar           ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
294d7e0a084SDiego Caballero       reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
295d7e0a084SDiego Caballero 
296d7e0a084SDiego Caballero       // Extract 1-D vectors from the higher-order dimension of the input
297d7e0a084SDiego Caballero       // vector.
29834ff8573SNicolas Vasilache       for (int64_t i = 0; i < m; ++i)
299d7e0a084SDiego Caballero         vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
300d7e0a084SDiego Caballero 
301d7e0a084SDiego Caballero       // Transpose set of 1-D vectors.
30234ff8573SNicolas Vasilache       if (m == 4)
30334ff8573SNicolas Vasilache         transpose4x8xf32(ib, vs);
30434ff8573SNicolas Vasilache       if (m == 8)
30534ff8573SNicolas Vasilache         transpose8x8xf32(ib, vs);
306d7e0a084SDiego Caballero 
307d7e0a084SDiego Caballero       // Insert transposed 1-D vectors into the higher-order dimension of the
308d7e0a084SDiego Caballero       // output vector.
309d7e0a084SDiego Caballero       Value res = ib.create<arith::ConstantOp>(reshInputType,
310d7e0a084SDiego Caballero                                                ib.getZeroAttr(reshInputType));
31134ff8573SNicolas Vasilache       for (int64_t i = 0; i < m; ++i)
31234ff8573SNicolas Vasilache         res = ib.create<vector::InsertOp>(vs[i], res, i);
31334ff8573SNicolas Vasilache 
314d7e0a084SDiego Caballero       // The output vector still has the shape of the input vector (e.g., 4x8).
315d7e0a084SDiego Caballero       // We have to transpose their dimensions and retrieve its original rank
316d7e0a084SDiego Caballero       // (e.g., 1x8x1x4x1).
317d7e0a084SDiego Caballero       res = ib.create<vector::ShapeCastOp>(flattenedType, res);
318d7e0a084SDiego Caballero       res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
31934ff8573SNicolas Vasilache       rewriter.replaceOp(op, res);
32034ff8573SNicolas Vasilache       return success();
32134ff8573SNicolas Vasilache     };
32234ff8573SNicolas Vasilache 
32334ff8573SNicolas Vasilache     if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
32434ff8573SNicolas Vasilache       return applyRewrite();
32534ff8573SNicolas Vasilache     if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
32634ff8573SNicolas Vasilache       return applyRewrite();
32734ff8573SNicolas Vasilache     return failure();
32834ff8573SNicolas Vasilache   }
32934ff8573SNicolas Vasilache 
33034ff8573SNicolas Vasilache private:
33134ff8573SNicolas Vasilache   LoweringOptions loweringOptions;
33234ff8573SNicolas Vasilache };
33334ff8573SNicolas Vasilache 
populateSpecializedTransposeLoweringPatterns(RewritePatternSet & patterns,LoweringOptions options,int benefit)33434ff8573SNicolas Vasilache void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
33534ff8573SNicolas Vasilache     RewritePatternSet &patterns, LoweringOptions options, int benefit) {
33634ff8573SNicolas Vasilache   patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
33734ff8573SNicolas Vasilache }
338