1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/Dialect/X86Vector/Transforms.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/Support/Format.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 using namespace mlir::x86vector;
26 using namespace mlir::x86vector::avx2;
27 using namespace mlir::x86vector::avx2::inline_asm;
28 using namespace mlir::x86vector::avx2::intrin;
29 
30 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
31     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
32   auto asmDialectAttr =
33       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
34   auto asmTp = "vblendps $0, $1, $2, {0}";
35   auto asmCstr = "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
36   SmallVector<Value> asmVals{v1, v2};
37   auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
38   auto asmOp = b.create<LLVM::InlineAsmOp>(
39       v1.getType(), asmVals, asmStr, asmCstr, false, false, asmDialectAttr);
40   return asmOp.getResult(0);
41 }
42 
43 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
44                                                      Value v1, Value v2) {
45   return b.create<vector::ShuffleOp>(
46       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
47 }
48 
49 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
50                                                      Value v1, Value v2) {
51   return b.create<vector::ShuffleOp>(
52       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
53 }
54 ///                            a  a   b   b  a  a   b   b
55 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
56 ///                                 0:127    |         128:255
57 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
58 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
59                                                     Value v1, Value v2,
60                                                     uint8_t mask) {
61   uint8_t b01, b23, b45, b67;
62   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
63   SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
64                                    b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
65   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
66 }
67 
68 // imm[0:1] out of imm[0:3] is:
69 //    0             1           2             3
70 // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
71 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
72 //             0             1           2             3
73 // imm[0:1] out of imm[4:7].
74 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
75     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
76   SmallVector<int64_t> shuffleMask;
77   auto appendToMask = [&](uint8_t control) {
78     if (control == 0)
79       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
80     else if (control == 1)
81       llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
82     else if (control == 2)
83       llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
84     else if (control == 3)
85       llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
86     else
87       llvm_unreachable("control > 3 : overflow");
88   };
89   uint8_t b03, b47;
90   MaskHelper::extractPermute(mask, b03, b47);
91   appendToMask(b03);
92   appendToMask(b47);
93   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
94 }
95 
96 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
97 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
98                                                   Value v1, Value v2,
99                                                   uint8_t mask) {
100   SmallVector<int64_t, 8> shuffleMask;
101   for (int i = 0; i < 8; ++i) {
102     bool isSet = mask & (1 << i);
103     shuffleMask.push_back(!isSet ? i : i + 8);
104   }
105   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
106 }
107 
108 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
109 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
110                                              MutableArrayRef<Value> vs) {
111 #ifndef NDEBUG
112   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
113   assert(vs.size() == 4 && "expects 4 vectors");
114   assert(llvm::all_of(ValueRange{vs}.getTypes(),
115                       [&](Type t) { return t == vt; }) &&
116          "expects all types to be vector<8xf32>");
117 #endif
118 
119   Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
120   Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
121   Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
122   Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
123   Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
124   Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
125   Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
126   Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
127   vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
128   vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
129   vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
130   vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
131 }
132 
133 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
134 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
135                                              MutableArrayRef<Value> vs) {
136   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
137   (void)vt;
138   assert(vs.size() == 8 && "expects 8 vectors");
139   assert(llvm::all_of(ValueRange{vs}.getTypes(),
140                       [&](Type t) { return t == vt; }) &&
141          "expects all types to be vector<8xf32>");
142 
143   Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
144   Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
145   Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
146   Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
147   Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
148   Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
149   Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
150   Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
151 
152   using inline_asm::mm256BlendPsAsm;
153   Value sh0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 3, 2>());
154   Value sh2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 3, 2>());
155   Value sh4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 3, 2>());
156   Value sh6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 3, 2>());
157 
158   Value S0 =
159       mm256BlendPsAsm(ib, T0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
160   Value S1 =
161       mm256BlendPsAsm(ib, T2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
162   Value S2 =
163       mm256BlendPsAsm(ib, T1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
164   Value S3 =
165       mm256BlendPsAsm(ib, T3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
166   Value S4 =
167       mm256BlendPsAsm(ib, T4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
168   Value S5 =
169       mm256BlendPsAsm(ib, T6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
170   Value S6 =
171       mm256BlendPsAsm(ib, T5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
172   Value S7 =
173       mm256BlendPsAsm(ib, T7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
174 
175   vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
176   vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
177   vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
178   vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
179   vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
180   vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
181   vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
182   vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
183 }
184 
185 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
186 /// depending on the `TransposeLoweringOptions`.
187 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
188 public:
189   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
190 
191   TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
192                       int benefit)
193       : OpRewritePattern<vector::TransposeOp>(context, benefit),
194         loweringOptions(loweringOptions) {}
195 
196   LogicalResult matchAndRewrite(vector::TransposeOp op,
197                                 PatternRewriter &rewriter) const override {
198     auto loc = op.getLoc();
199 
200     VectorType srcType = op.getVectorType();
201     if (srcType.getRank() != 2)
202       return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
203 
204     SmallVector<int64_t, 4> transp;
205     for (auto attr : op.transp())
206       transp.push_back(attr.cast<IntegerAttr>().getInt());
207     if (transp[0] != 1 && transp[1] != 0)
208       return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
209 
210     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
211 
212     auto applyRewrite = [&]() {
213       ImplicitLocOpBuilder ib(loc, rewriter);
214       SmallVector<Value> vs;
215       for (int64_t i = 0; i < m; ++i)
216         vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
217       if (m == 4)
218         transpose4x8xf32(ib, vs);
219       if (m == 8)
220         transpose8x8xf32(ib, vs);
221       auto flattenedType =
222           VectorType::get({n * m}, op.getVectorType().getElementType());
223       auto transposedType =
224           VectorType::get({n, m}, op.getVectorType().getElementType());
225       Value res = ib.create<arith::ConstantOp>(
226           op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
227       // The transposed form is still 4x8 and needs to be reinterpreted as 8x4
228       // via shape_casts.
229       for (int64_t i = 0; i < m; ++i)
230         res = ib.create<vector::InsertOp>(vs[i], res, i);
231       if (m == 4) {
232         res = ib.create<vector::ShapeCastOp>(flattenedType, res);
233         res = ib.create<vector::ShapeCastOp>(transposedType, res);
234       }
235 
236       rewriter.replaceOp(op, res);
237       return success();
238     };
239 
240     if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
241       return applyRewrite();
242     if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
243       return applyRewrite();
244     return failure();
245   }
246 
247 private:
248   LoweringOptions loweringOptions;
249 };
250 
251 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
252     RewritePatternSet &patterns, LoweringOptions options, int benefit) {
253   patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
254 }
255