1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/X86Vector/Transforms.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/FormatVariadic.h"
23
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::x86vector;
27 using namespace mlir::x86vector::avx2;
28 using namespace mlir::x86vector::avx2::inline_asm;
29 using namespace mlir::x86vector::avx2::intrin;
30
mm256BlendPsAsm(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)31 Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
32 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
33 auto asmDialectAttr =
34 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
35 const auto *asmTp = "vblendps $0, $1, $2, {0}";
36 const auto *asmCstr =
37 "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
38 SmallVector<Value> asmVals{v1, v2};
39 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
40 auto asmOp = b.create<LLVM::InlineAsmOp>(
41 v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
42 /*constraints=*/asmCstr, /*has_side_effects=*/false,
43 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
44 /*operand_attrs=*/ArrayAttr());
45 return asmOp.getResult(0);
46 }
47
mm256UnpackLoPs(ImplicitLocOpBuilder & b,Value v1,Value v2)48 Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
49 Value v1, Value v2) {
50 return b.create<vector::ShuffleOp>(
51 v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
52 }
53
mm256UnpackHiPs(ImplicitLocOpBuilder & b,Value v1,Value v2)54 Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
55 Value v1, Value v2) {
56 return b.create<vector::ShuffleOp>(
57 v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
58 }
59 /// a a b b a a b b
60 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
61 /// 0:127 | 128:255
62 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
mm256ShufflePs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)63 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
64 Value v1, Value v2,
65 uint8_t mask) {
66 uint8_t b01, b23, b45, b67;
67 MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
68 SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
69 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
70 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
71 }
72
73 // imm[0:1] out of imm[0:3] is:
74 // 0 1 2 3
75 // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
76 // a[0:127] or a[128:255] or b[0:127] or b[128:255]
77 // 0 1 2 3
78 // imm[0:1] out of imm[4:7].
mm256Permute2f128Ps(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)79 Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
80 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
81 SmallVector<int64_t> shuffleMask;
82 auto appendToMask = [&](uint8_t control) {
83 if (control == 0)
84 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
85 else if (control == 1)
86 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
87 else if (control == 2)
88 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
89 else if (control == 3)
90 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
91 else
92 llvm_unreachable("control > 3 : overflow");
93 };
94 uint8_t b03, b47;
95 MaskHelper::extractPermute(mask, b03, b47);
96 appendToMask(b03);
97 appendToMask(b47);
98 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
99 }
100
101 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
mm256BlendPs(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)102 Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
103 Value v1, Value v2,
104 uint8_t mask) {
105 SmallVector<int64_t, 8> shuffleMask;
106 for (int i = 0; i < 8; ++i) {
107 bool isSet = mask & (1 << i);
108 shuffleMask.push_back(!isSet ? i : i + 8);
109 }
110 return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
111 }
112
113 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose4x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)114 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
115 MutableArrayRef<Value> vs) {
116 #ifndef NDEBUG
117 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
118 assert(vs.size() == 4 && "expects 4 vectors");
119 assert(llvm::all_of(ValueRange{vs}.getTypes(),
120 [&](Type t) { return t == vt; }) &&
121 "expects all types to be vector<8xf32>");
122 #endif
123
124 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
125 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
126 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
127 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
128 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
129 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
130 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
131 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
132 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
133 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
134 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
135 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
136 }
137
138 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
transpose8x8xf32(ImplicitLocOpBuilder & ib,MutableArrayRef<Value> vs)139 void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
140 MutableArrayRef<Value> vs) {
141 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
142 (void)vt;
143 assert(vs.size() == 8 && "expects 8 vectors");
144 assert(llvm::all_of(ValueRange{vs}.getTypes(),
145 [&](Type t) { return t == vt; }) &&
146 "expects all types to be vector<8xf32>");
147
148 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
149 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
150 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
151 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
152 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
153 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
154 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
155 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156
157 using inline_asm::mm256BlendPsAsm;
158 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
159 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
160 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
161 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
162
163 Value s0 =
164 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
165 Value s1 =
166 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
167 Value s2 =
168 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
169 Value s3 =
170 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
171 Value s4 =
172 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
173 Value s5 =
174 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
175 Value s6 =
176 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
177 Value s7 =
178 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
179
180 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
181 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
182 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
183 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
184 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
185 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
186 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
187 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
188 }
189
190 /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
191 /// should be transposed with each other within the context of their 2D
192 /// transposition slice.
193 ///
194 /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
195 /// Return true: dim0 and dim1 are transposed within the context of their 2D
196 /// transposition slice ([1, 0]).
197 ///
198 /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
199 /// Return true: dim0 and dim1 are transposed within the context of their 2D
200 /// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
201 /// transposed within the full context of the transposition.
202 ///
203 /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
204 /// Return false: dim0 and dim1 are *not* transposed within the context of
205 /// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
206 /// and dim1 (1) are transposed within the full context of the of the
207 /// transposition.
areDimsTransposedIn2DSlice(int64_t dim0,int64_t dim1,ArrayRef<int64_t> transp)208 static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
209 ArrayRef<int64_t> transp) {
210 // Perform a linear scan along the dimensions of the transposed pattern. If
211 // dim0 is found first, dim0 and dim1 are not transposed within the context of
212 // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
213 for (int64_t permDim : transp) {
214 if (permDim == dim0)
215 return false;
216 if (permDim == dim1)
217 return true;
218 }
219
220 llvm_unreachable("Ill-formed transpose pattern");
221 }
222
223 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
224 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
225 /// transpose cases and n-D cases that have been decomposed into 2-D
226 /// transposition slices. For example, a 3-D transpose:
227 ///
228 /// %0 = vector.transpose %arg0, [2, 0, 1]
229 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
230 ///
231 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
232 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
233 ///
234 /// %0 = vector.transpose %arg0, [2, 0, 1]
235 /// : vector<1x4x8xf32> to vector<8x1x4xf32>
236 ///
237 /// This lowering will analyze the n-D vector.transpose and determine if it's a
238 /// supported 2-D transposition slice where any of the AVX2 patterns can be
239 /// applied.
240 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
241 public:
242 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
243
TransposeOpLowering(LoweringOptions loweringOptions,MLIRContext * context,int benefit)244 TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
245 int benefit)
246 : OpRewritePattern<vector::TransposeOp>(context, benefit),
247 loweringOptions(loweringOptions) {}
248
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const249 LogicalResult matchAndRewrite(vector::TransposeOp op,
250 PatternRewriter &rewriter) const override {
251 auto loc = op.getLoc();
252
253 // Check if the source vector type is supported. AVX2 patterns can only be
254 // applied to f32 vector types with two dimensions greater than one.
255 VectorType srcType = op.getVectorType();
256 if (!srcType.getElementType().isF32())
257 return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
258
259 SmallVector<int64_t> srcGtOneDims;
260 for (auto &en : llvm::enumerate(srcType.getShape()))
261 if (en.value() > 1)
262 srcGtOneDims.push_back(en.index());
263
264 if (srcGtOneDims.size() != 2)
265 return rewriter.notifyMatchFailure(op, "Unsupported vector type");
266
267 SmallVector<int64_t, 4> transp;
268 for (auto attr : op.getTransp())
269 transp.push_back(attr.cast<IntegerAttr>().getInt());
270
271 // Check whether the two source vector dimensions that are greater than one
272 // must be transposed with each other so that we can apply one of the 2-D
273 // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
274 if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
275 return rewriter.notifyMatchFailure(
276 op, "Not applicable to this transpose permutation");
277
278 // Retrieve the sizes of the two dimensions greater than one to be
279 // transposed.
280 auto srcShape = srcType.getShape();
281 int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
282
283 auto applyRewrite = [&]() {
284 ImplicitLocOpBuilder ib(loc, rewriter);
285 SmallVector<Value> vs;
286
287 // Reshape the n-D input vector with only two dimensions greater than one
288 // to a 2-D vector.
289 auto flattenedType =
290 VectorType::get({n * m}, op.getVectorType().getElementType());
291 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
292 auto reshInput =
293 ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
294 reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
295
296 // Extract 1-D vectors from the higher-order dimension of the input
297 // vector.
298 for (int64_t i = 0; i < m; ++i)
299 vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
300
301 // Transpose set of 1-D vectors.
302 if (m == 4)
303 transpose4x8xf32(ib, vs);
304 if (m == 8)
305 transpose8x8xf32(ib, vs);
306
307 // Insert transposed 1-D vectors into the higher-order dimension of the
308 // output vector.
309 Value res = ib.create<arith::ConstantOp>(reshInputType,
310 ib.getZeroAttr(reshInputType));
311 for (int64_t i = 0; i < m; ++i)
312 res = ib.create<vector::InsertOp>(vs[i], res, i);
313
314 // The output vector still has the shape of the input vector (e.g., 4x8).
315 // We have to transpose their dimensions and retrieve its original rank
316 // (e.g., 1x8x1x4x1).
317 res = ib.create<vector::ShapeCastOp>(flattenedType, res);
318 res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
319 rewriter.replaceOp(op, res);
320 return success();
321 };
322
323 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
324 return applyRewrite();
325 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
326 return applyRewrite();
327 return failure();
328 }
329
330 private:
331 LoweringOptions loweringOptions;
332 };
333
populateSpecializedTransposeLoweringPatterns(RewritePatternSet & patterns,LoweringOptions options,int benefit)334 void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
335 RewritePatternSet &patterns, LoweringOptions options, int benefit) {
336 patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
337 }
338