1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/X86Vector/Transforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 using namespace mlir::x86vector;
23 using namespace mlir::x86vector::avx2;
24 
25 Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1,
26                                              Value v2) {
27   return b.create<vector::ShuffleOp>(
28       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
29 }
30 
31 Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
32                                              Value v2) {
33   return b.create<vector::ShuffleOp>(
34       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
35 }
36 ///                            a  a   b   b  a  a   b   b
37 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
38 ///                                 0:127    |         128:255
39 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
40 Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
41                                             Value v2, int8_t mask) {
42   int8_t b01, b23, b45, b67;
43   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
44   SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
45                                    b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
46   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
47 }
48 
49 // imm[0:1] out of imm[0:3] is:
50 //    0             1           2             3
51 // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
52 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
53 //             0             1           2             3
54 // imm[0:1] out of imm[4:7].
55 Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
56                                                  Value v1, Value v2,
57                                                  int8_t mask) {
58   SmallVector<int64_t> shuffleMask;
59   auto appendToMask = [&](int8_t control) {
60     if (control == 0)
61       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
62     else if (control == 1)
63       llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
64     else if (control == 2)
65       llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
66     else if (control == 3)
67       llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
68     else
69       llvm_unreachable("control > 3 : overflow");
70   };
71   int8_t b03, b47;
72   MaskHelper::extractPermute(mask, b03, b47);
73   appendToMask(b03);
74   appendToMask(b47);
75   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
76 }
77 
78 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
79 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
80                                              MutableArrayRef<Value> vs) {
81 #ifndef NDEBUG
82   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
83   assert(vs.size() == 4 && "expects 4 vectors");
84   assert(llvm::all_of(ValueRange{vs}.getTypes(),
85                       [&](Type t) { return t == vt; }) &&
86          "expects all types to be vector<8xf32>");
87 #endif
88 
89   Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
90   Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
91   Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
92   Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
93   Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
94   Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
95   Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
96   Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
97   vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
98   vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
99   vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
100   vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
101 }
102 
103 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
104 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
105                                              MutableArrayRef<Value> vs) {
106   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
107   (void)vt;
108   assert(vs.size() == 8 && "expects 8 vectors");
109   assert(llvm::all_of(ValueRange{vs}.getTypes(),
110                       [&](Type t) { return t == vt; }) &&
111          "expects all types to be vector<8xf32>");
112 
113   Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
114   Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
115   Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
116   Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
117   Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
118   Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
119   Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
120   Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
121   Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
122   Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
123   Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
124   Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
125   Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>());
126   Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>());
127   Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>());
128   Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>());
129   vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
130   vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
131   vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
132   vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
133   vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
134   vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
135   vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
136   vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
137 }
138 
139 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
140 /// depending on the `TransposeLoweringOptions`.
141 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
142 public:
143   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
144 
145   TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
146                       int benefit)
147       : OpRewritePattern<vector::TransposeOp>(context, benefit),
148         loweringOptions(loweringOptions) {}
149 
150   LogicalResult matchAndRewrite(vector::TransposeOp op,
151                                 PatternRewriter &rewriter) const override {
152     auto loc = op.getLoc();
153 
154     VectorType srcType = op.getVectorType();
155     if (srcType.getRank() != 2)
156       return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
157 
158     SmallVector<int64_t, 4> transp;
159     for (auto attr : op.transp())
160       transp.push_back(attr.cast<IntegerAttr>().getInt());
161     if (transp[0] != 1 && transp[1] != 0)
162       return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
163 
164     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
165 
166     auto applyRewrite = [&]() {
167       ImplicitLocOpBuilder ib(loc, rewriter);
168       SmallVector<Value> vs;
169       for (int64_t i = 0; i < m; ++i)
170         vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
171       if (m == 4)
172         transpose4x8xf32(ib, vs);
173       if (m == 8)
174         transpose8x8xf32(ib, vs);
175       auto flattenedType =
176           VectorType::get({n * m}, op.getVectorType().getElementType());
177       auto transposedType =
178           VectorType::get({n, m}, op.getVectorType().getElementType());
179       Value res = ib.create<arith::ConstantOp>(
180           op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
181       // The transposed form is still 4x8 and needs to be reinterpreted as 8x4
182       // via shape_casts.
183       for (int64_t i = 0; i < m; ++i)
184         res = ib.create<vector::InsertOp>(vs[i], res, i);
185       if (m == 4) {
186         res = ib.create<vector::ShapeCastOp>(flattenedType, res);
187         res = ib.create<vector::ShapeCastOp>(transposedType, res);
188       }
189 
190       rewriter.replaceOp(op, res);
191       return success();
192     };
193 
194     if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
195       return applyRewrite();
196     if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
197       return applyRewrite();
198     return failure();
199   }
200 
201 private:
202   LoweringOptions loweringOptions;
203 };
204 
205 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
206     RewritePatternSet &patterns, LoweringOptions options, int benefit) {
207   patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
208 }
209