1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/X86Vector/Transforms.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::x86vector;
27 using namespace mlir::x86vector::avx2;
28 using namespace mlir::x86vector::avx2::inline_asm;
29 using namespace mlir::x86vector::avx2::intrin;
30 
mm256BlendPsAsm(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)31 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
32     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
33   auto asmDialectAttr =
34       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
35   const auto *asmTp = "vblendps $0, $1, $2, {0}";
36   const auto *asmCstr =
37       "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
38   SmallVector<Value> asmVals{v1, v2};
39   auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
40   auto asmOp = b.create<LLVM::InlineAsmOp>(
41       v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
42       /*constraints=*/asmCstr, /*has_side_effects=*/false,
43       /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
44       /*operand_attrs=*/ArrayAttr());
45   return asmOp.getResult(0);
46 }
47 
mm256UnpackLoPs(ImplicitLocOpBuilder & b,Value v1,Value v2)48 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
49                                                      Value v1, Value v2) {
50   return b.create<vector::ShuffleOp>(
51       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
52 }
53 
mm256UnpackHiPs(ImplicitLocOpBuilder & b,Value v1,Value v2)54 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
55                                                      Value v1, Value v2) {
56   return b.create<vector::ShuffleOp>(
57       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
58 }
59 ///                            a  a   b   b  a  a   b   b
60 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
61 ///                                 0:127    |         128:255
62 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
mm256ShufflePs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)63 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
64                                                     Value v1, Value v2,
65                                                     uint8_t mask) {
66   uint8_t b01, b23, b45, b67;
67   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
68   SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
69                                    b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
70   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
71 }
72 
73 // imm[0:1] out of imm[0:3] is:
74 //    0             1           2             3
75 // a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
76 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
77 //             0             1           2             3
78 // imm[0:1] out of imm[4:7].
mm256Permute2f128Ps(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)79 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
80     ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
81   SmallVector<int64_t> shuffleMask;
82   auto appendToMask = [&](uint8_t control) {
83     if (control == 0)
84       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
85     else if (control == 1)
86       llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
87     else if (control == 2)
88       llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
89     else if (control == 3)
90       llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
91     else
92       llvm_unreachable("control > 3 : overflow");
93   };
94   uint8_t b03, b47;
95   MaskHelper::extractPermute(mask, b03, b47);
96   appendToMask(b03);
97   appendToMask(b47);
98   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
99 }
100 
101 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
mm256BlendPs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)102 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
103                                                   Value v1, Value v2,
104                                                   uint8_t mask) {
105   SmallVector<int64_t, 8> shuffleMask;
106   for (int i = 0; i < 8; ++i) {
107     bool isSet = mask & (1 << i);
108     shuffleMask.push_back(!isSet ? i : i + 8);
109   }
110   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
111 }
112 
113 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose4x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)114 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
115                                              MutableArrayRef<Value> vs) {
116 #ifndef NDEBUG
117   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
118   assert(vs.size() == 4 && "expects 4 vectors");
119   assert(llvm::all_of(ValueRange{vs}.getTypes(),
120                       [&](Type t) { return t == vt; }) &&
121          "expects all types to be vector<8xf32>");
122 #endif
123 
124   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
125   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
126   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
127   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
128   Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
129   Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
130   Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
131   Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
132   vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
133   vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
134   vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
135   vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
136 }
137 
138 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose8x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)139 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
140                                              MutableArrayRef<Value> vs) {
141   auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
142   (void)vt;
143   assert(vs.size() == 8 && "expects 8 vectors");
144   assert(llvm::all_of(ValueRange{vs}.getTypes(),
145                       [&](Type t) { return t == vt; }) &&
146          "expects all types to be vector<8xf32>");
147 
148   Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
149   Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
150   Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
151   Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
152   Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
153   Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
154   Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
155   Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156 
157   using inline_asm::mm256BlendPsAsm;
158   Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
159   Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
160   Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
161   Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
162 
163   Value s0 =
164       mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
165   Value s1 =
166       mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
167   Value s2 =
168       mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
169   Value s3 =
170       mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
171   Value s4 =
172       mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
173   Value s5 =
174       mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
175   Value s6 =
176       mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
177   Value s7 =
178       mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
179 
180   vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
181   vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
182   vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
183   vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
184   vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
185   vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
186   vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
187   vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
188 }
189 
190 /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
191 /// should be transposed with each other within the context of their 2D
192 /// transposition slice.
193 ///
194 /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
195 ///   Return true: dim0 and dim1 are transposed within the context of their 2D
196 ///   transposition slice ([1, 0]).
197 ///
198 /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
199 ///   Return true: dim0 and dim1 are transposed within the context of their 2D
200 ///   transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
201 ///   transposed within the full context of the transposition.
202 ///
203 /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
204 ///   Return false: dim0 and dim1 are *not* transposed within the context of
205 ///   their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
206 ///   and dim1 (1) are transposed within the full context of the of the
207 ///   transposition.
areDimsTransposedIn2DSlice(int64_t dim0,int64_t dim1,ArrayRef<int64_t> transp)208 static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
209                                        ArrayRef<int64_t> transp) {
210   // Perform a linear scan along the dimensions of the transposed pattern. If
211   // dim0 is found first, dim0 and dim1 are not transposed within the context of
212   // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
213   for (int64_t permDim : transp) {
214     if (permDim == dim0)
215       return false;
216     if (permDim == dim1)
217       return true;
218   }
219 
220   llvm_unreachable("Ill-formed transpose pattern");
221 }
222 
223 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
224 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
225 /// transpose cases and n-D cases that have been decomposed into 2-D
226 /// transposition slices. For example, a 3-D transpose:
227 ///
228 ///   %0 = vector.transpose %arg0, [2, 0, 1]
229 ///      : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
230 ///
231 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
232 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
233 ///
234 ///   %0 = vector.transpose %arg0, [2, 0, 1]
235 ///      : vector<1x4x8xf32> to vector<8x1x4xf32>
236 ///
237 /// This lowering will analyze the n-D vector.transpose and determine if it's a
238 /// supported 2-D transposition slice where any of the AVX2 patterns can be
239 /// applied.
240 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
241 public:
242   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
243 
TransposeOpLowering(LoweringOptions loweringOptions,MLIRContext * context,int benefit)244   TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
245                       int benefit)
246       : OpRewritePattern<vector::TransposeOp>(context, benefit),
247         loweringOptions(loweringOptions) {}
248 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const249   LogicalResult matchAndRewrite(vector::TransposeOp op,
250                                 PatternRewriter &rewriter) const override {
251     auto loc = op.getLoc();
252 
253     // Check if the source vector type is supported. AVX2 patterns can only be
254     // applied to f32 vector types with two dimensions greater than one.
255     VectorType srcType = op.getVectorType();
256     if (!srcType.getElementType().isF32())
257       return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
258 
259     SmallVector<int64_t> srcGtOneDims;
260     for (auto &en : llvm::enumerate(srcType.getShape()))
261       if (en.value() > 1)
262         srcGtOneDims.push_back(en.index());
263 
264     if (srcGtOneDims.size() != 2)
265       return rewriter.notifyMatchFailure(op, "Unsupported vector type");
266 
267     SmallVector<int64_t, 4> transp;
268     for (auto attr : op.getTransp())
269       transp.push_back(attr.cast<IntegerAttr>().getInt());
270 
271     // Check whether the two source vector dimensions that are greater than one
272     // must be transposed with each other so that we can apply one of the 2-D
273     // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
274     if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
275       return rewriter.notifyMatchFailure(
276           op, "Not applicable to this transpose permutation");
277 
278     // Retrieve the sizes of the two dimensions greater than one to be
279     // transposed.
280     auto srcShape = srcType.getShape();
281     int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
282 
283     auto applyRewrite = [&]() {
284       ImplicitLocOpBuilder ib(loc, rewriter);
285       SmallVector<Value> vs;
286 
287       // Reshape the n-D input vector with only two dimensions greater than one
288       // to a 2-D vector.
289       auto flattenedType =
290           VectorType::get({n * m}, op.getVectorType().getElementType());
291       auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
292       auto reshInput =
293           ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
294       reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
295 
296       // Extract 1-D vectors from the higher-order dimension of the input
297       // vector.
298       for (int64_t i = 0; i < m; ++i)
299         vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
300 
301       // Transpose set of 1-D vectors.
302       if (m == 4)
303         transpose4x8xf32(ib, vs);
304       if (m == 8)
305         transpose8x8xf32(ib, vs);
306 
307       // Insert transposed 1-D vectors into the higher-order dimension of the
308       // output vector.
309       Value res = ib.create<arith::ConstantOp>(reshInputType,
310                                                ib.getZeroAttr(reshInputType));
311       for (int64_t i = 0; i < m; ++i)
312         res = ib.create<vector::InsertOp>(vs[i], res, i);
313 
314       // The output vector still has the shape of the input vector (e.g., 4x8).
315       // We have to transpose their dimensions and retrieve its original rank
316       // (e.g., 1x8x1x4x1).
317       res = ib.create<vector::ShapeCastOp>(flattenedType, res);
318       res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
319       rewriter.replaceOp(op, res);
320       return success();
321     };
322 
323     if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
324       return applyRewrite();
325     if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
326       return applyRewrite();
327     return failure();
328   }
329 
330 private:
331   LoweringOptions loweringOptions;
332 };
333 
populateSpecializedTransposeLoweringPatterns(RewritePatternSet & patterns,LoweringOptions options,int benefit)334 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
335     RewritePatternSet &patterns, LoweringOptions options, int benefit) {
336   patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
337 }
338