1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/X86Vector/Transforms.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/Support/Format.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 using namespace mlir::x86vector;
26 using namespace mlir::x86vector::avx2;
27 using namespace mlir::x86vector::avx2::inline_asm;
28 using namespace mlir::x86vector::avx2::intrin;
29 
30 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
31     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
32   auto asmDialectAttr =
33       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
34   const auto *asmTp = "vblendps $0, $1, $2, {0}";
35   const auto *asmCstr =
36       "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
37   SmallVector<Value> asmVals{v1, v2};
38   auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
39   auto asmOp = b.create<LLVM::InlineAsmOp>(
40       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
41       /*constraints=*/asmCstr, /*has_side_effects=*/false,
42       /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
43       /*operand_attrs=*/ArrayAttr());
44   return asmOp.getResult(0);
45 }
46 
47 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
48                                                      Value v1, Value v2) {
49   return b.create<vector::ShuffleOp>(
50       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
51 }
52 
53 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
54                                                      Value v1, Value v2) {
55   return b.create<vector::ShuffleOp>(
56       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
57 }
58 ///                            a  a   b   b  a  a   b   b
59 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
60 ///                                 0:127    |         128:255
61 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
62 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
63                                                     Value v1, Value v2,
64                                                     uint8_t mask) {
65   uint8_t b01, b23, b45, b67;
66   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
67   SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
68                                    b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
69   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
70 }
71 
72 // imm[0:1] out of imm[0:3] is:
73 //    0             1           2             3
74 // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
75 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
76 //             0             1           2             3
77 // imm[0:1] out of imm[4:7].
78 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
79     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
80   SmallVector<int64_t> shuffleMask;
81   auto appendToMask = [&](uint8_t control) {
82     if (control == 0)
83       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
84     else if (control == 1)
85       llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
86     else if (control == 2)
87       llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
88     else if (control == 3)
89       llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
90     else
91       llvm_unreachable("control > 3 : overflow");
92   };
93   uint8_t b03, b47;
94   MaskHelper::extractPermute(mask, b03, b47);
95   appendToMask(b03);
96   appendToMask(b47);
97   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
98 }
99 
100 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
101 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
102                                                   Value v1, Value v2,
103                                                   uint8_t mask) {
104   SmallVector<int64_t, 8> shuffleMask;
105   for (int i = 0; i < 8; ++i) {
106     bool isSet = mask & (1 << i);
107     shuffleMask.push_back(!isSet ? i : i + 8);
108   }
109   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
110 }
111 
112 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
113 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
114                                              MutableArrayRef<Value> vs) {
115 #ifndef NDEBUG
116   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
117   assert(vs.size() == 4 && "expects 4 vectors");
118   assert(llvm::all_of(ValueRange{vs}.getTypes(),
119                       [&](Type t) { return t == vt; }) &&
120          "expects all types to be vector<8xf32>");
121 #endif
122 
123   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
124   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
125   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
126   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
127   Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
128   Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
129   Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
130   Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
131   vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
132   vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
133   vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
134   vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
135 }
136 
137 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
138 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
139                                              MutableArrayRef<Value> vs) {
140   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
141   (void)vt;
142   assert(vs.size() == 8 && "expects 8 vectors");
143   assert(llvm::all_of(ValueRange{vs}.getTypes(),
144                       [&](Type t) { return t == vt; }) &&
145          "expects all types to be vector<8xf32>");
146 
147   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
148   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
149   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
150   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
151   Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
152   Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
153   Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
154   Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
155 
156   using inline_asm::mm256BlendPsAsm;
157   Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
158   Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
159   Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
160   Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
161 
162   Value s0 =
163       mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
164   Value s1 =
165       mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
166   Value s2 =
167       mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
168   Value s3 =
169       mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
170   Value s4 =
171       mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
172   Value s5 =
173       mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
174   Value s6 =
175       mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
176   Value s7 =
177       mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
178 
179   vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
180   vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
181   vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
182   vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
183   vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
184   vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
185   vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
186   vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
187 }
188 
189 /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
190 /// depending on the `TransposeLoweringOptions`.
191 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
192 public:
193   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
194 
195   TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
196                       int benefit)
197       : OpRewritePattern<vector::TransposeOp>(context, benefit),
198         loweringOptions(loweringOptions) {}
199 
200   LogicalResult matchAndRewrite(vector::TransposeOp op,
201                                 PatternRewriter &rewriter) const override {
202     auto loc = op.getLoc();
203 
204     VectorType srcType = op.getVectorType();
205     if (srcType.getRank() != 2)
206       return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
207 
208     SmallVector<int64_t, 4> transp;
209     for (auto attr : op.transp())
210       transp.push_back(attr.cast<IntegerAttr>().getInt());
211     if (transp[0] != 1 && transp[1] != 0)
212       return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
213 
214     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
215 
216     auto applyRewrite = [&]() {
217       ImplicitLocOpBuilder ib(loc, rewriter);
218       SmallVector<Value> vs;
219       for (int64_t i = 0; i < m; ++i)
220         vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
221       if (m == 4)
222         transpose4x8xf32(ib, vs);
223       if (m == 8)
224         transpose8x8xf32(ib, vs);
225       auto flattenedType =
226           VectorType::get({n * m}, op.getVectorType().getElementType());
227       auto transposedType =
228           VectorType::get({n, m}, op.getVectorType().getElementType());
229       Value res = ib.create<arith::ConstantOp>(
230           op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
231       // The transposed form is still 4x8 and needs to be reinterpreted as 8x4
232       // via shape_casts.
233       for (int64_t i = 0; i < m; ++i)
234         res = ib.create<vector::InsertOp>(vs[i], res, i);
235       if (m == 4) {
236         res = ib.create<vector::ShapeCastOp>(flattenedType, res);
237         res = ib.create<vector::ShapeCastOp>(transposedType, res);
238       }
239 
240       rewriter.replaceOp(op, res);
241       return success();
242     };
243 
244     if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
245       return applyRewrite();
246     if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
247       return applyRewrite();
248     return failure();
249   }
250 
251 private:
252   LoweringOptions loweringOptions;
253 };
254 
255 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
256     RewritePatternSet &patterns, LoweringOptions options, int benefit) {
257   patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
258 }
259