1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14
15 #include <type_traits>
16
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Interfaces/VectorInterfaces.h"
31
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/MapVector.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38
39 #define DEBUG_TYPE "vector-to-vector"
40
41 using namespace mlir;
42 using namespace mlir::vector;
43
44 // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47 int64_t idx = map.getDimPosition(i);
48 if (idx == index)
49 return i;
50 }
51 return None;
52 }
53
54 // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
56 int64_t index) {
57 SmallVector<Attribute, 4> results;
58 for (const auto &it : llvm::enumerate(iteratorTypes)) {
59 int64_t idx = it.index();
60 if (idx == index)
61 continue;
62 results.push_back(it.value());
63 }
64 return results;
65 }
66
67 // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)68 static AffineMap adjustMap(AffineMap map, int64_t index,
69 PatternRewriter &rewriter) {
70 auto *ctx = rewriter.getContext();
71 SmallVector<AffineExpr, 4> results;
72 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73 int64_t idx = map.getDimPosition(i);
74 if (idx == index)
75 continue;
76 // Re-insert remaining indices, but renamed when occurring
77 // after the removed index.
78 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
79 results.push_back(targetExpr);
80 }
81 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
82 }
83
84 // Helper method to possibly drop a dimension in a load.
85 // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)86 static Value reshapeLoad(Location loc, Value val, VectorType type,
87 int64_t index, int64_t pos,
88 PatternRewriter &rewriter) {
89 if (index == -1)
90 return val;
91 Type lowType = VectorType::Builder(type).dropDim(0);
92 // At extraction dimension?
93 if (index == 0) {
94 auto posAttr = rewriter.getI64ArrayAttr(pos);
95 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
96 }
97 // Unroll leading dimensions.
98 VectorType vType = lowType.cast<VectorType>();
99 Type resType = VectorType::Builder(type).dropDim(index);
100 auto resVectorType = resType.cast<VectorType>();
101 Value result = rewriter.create<arith::ConstantOp>(
102 loc, resVectorType, rewriter.getZeroAttr(resVectorType));
103 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
104 auto posAttr = rewriter.getI64ArrayAttr(d);
105 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
106 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
107 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
108 posAttr);
109 }
110 return result;
111 }
112
113 // Helper method to possibly drop a dimension in a store.
114 // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)115 static Value reshapeStore(Location loc, Value val, Value result,
116 VectorType type, int64_t index, int64_t pos,
117 PatternRewriter &rewriter) {
118 // Unmodified?
119 if (index == -1)
120 return val;
121 // At insertion dimension?
122 if (index == 0) {
123 auto posAttr = rewriter.getI64ArrayAttr(pos);
124 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
125 }
126 // Unroll leading dimensions.
127 Type lowType = VectorType::Builder(type).dropDim(0);
128 VectorType vType = lowType.cast<VectorType>();
129 Type insType = VectorType::Builder(vType).dropDim(0);
130 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
131 auto posAttr = rewriter.getI64ArrayAttr(d);
132 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
133 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
134 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
135 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
136 }
137 return result;
138 }
139
140 template <typename IntType>
extractVector(ArrayAttr arrayAttr)141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
142 return llvm::to_vector<4>(llvm::map_range(
143 arrayAttr.getAsRange<IntegerAttr>(),
144 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
145 }
146
147 /// Helper to create arithmetic operation associated with a kind of contraction.
createContractArithOp(Location loc,Value x,Value y,Value acc,vector::CombiningKind kind,PatternRewriter & rewriter,bool isInt)148 static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
149 Value acc,
150 vector::CombiningKind kind,
151 PatternRewriter &rewriter,
152 bool isInt) {
153 using vector::CombiningKind;
154 Value mul;
155 if (isInt) {
156 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
157 // Only valid for floating point types.
158 return Optional<Value>();
159 mul = rewriter.create<arith::MulIOp>(loc, x, y);
160 } else {
161 // Float case.
162 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
163 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
164 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
165 kind == CombiningKind::XOR)
166 // Only valid for integer types.
167 return Optional<Value>();
168 // Special case for fused multiply-add.
169 if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
170 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
171 }
172 mul = rewriter.create<arith::MulFOp>(loc, x, y);
173 }
174 if (!acc)
175 return Optional<Value>(mul);
176 return makeArithReduction(rewriter, loc, kind, mul, acc);
177 }
178
179 /// Return the positions of the reductions in the given map.
getReductionIndex(AffineMap map,ArrayAttr iteratorTypes)180 static SmallVector<int64_t> getReductionIndex(AffineMap map,
181 ArrayAttr iteratorTypes) {
182 SmallVector<int64_t> dimsIdx;
183 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
184 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
185 dimsIdx.push_back(i);
186 }
187 return dimsIdx;
188 }
189
190 /// Look for a given dimension in an affine map and return its position. Return
191 /// llvm::None if the dimension is not in the map results.
getDimPosition(AffineMap map,unsigned dim)192 static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
193 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
194 if (map.getDimPosition(i) == dim)
195 return i;
196 }
197 return llvm::None;
198 }
199
200 namespace {
201
202 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
203 //
204 // Example:
205 //
206 // The following MLIR with cancelling ShapeCastOps:
207 //
208 // %0 = source : vector<5x4x2xf32>
209 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
210 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
211 // %3 = user %2 : vector<5x4x2xf32>
212 //
213 // Should canonicalize to the following:
214 //
215 // %0 = source : vector<5x4x2xf32>
216 // %1 = user %0 : vector<5x4x2xf32>
217 //
218 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
219 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
220
matchAndRewrite__anon5c5a5b800211::ShapeCastOpFolder221 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
222 PatternRewriter &rewriter) const override {
223 // Check if 'shapeCastOp' has vector source/result type.
224 auto sourceVectorType =
225 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
226 auto resultVectorType =
227 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
228 if (!sourceVectorType || !resultVectorType)
229 return failure();
230
231 // Check if shape cast op source operand is also a shape cast op.
232 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
233 shapeCastOp.getSource().getDefiningOp());
234 if (!sourceShapeCastOp)
235 return failure();
236 auto operandSourceVectorType =
237 sourceShapeCastOp.getSource().getType().cast<VectorType>();
238 auto operandResultVectorType = sourceShapeCastOp.getType();
239
240 // Check if shape cast operations invert each other.
241 if (operandSourceVectorType != resultVectorType ||
242 operandResultVectorType != sourceVectorType)
243 return failure();
244
245 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
246 return success();
247 }
248 };
249
250 /// Progressive lowering of BroadcastOp.
251 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
252 public:
253 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
254
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const255 LogicalResult matchAndRewrite(vector::BroadcastOp op,
256 PatternRewriter &rewriter) const override {
257 auto loc = op.getLoc();
258 VectorType dstType = op.getVectorType();
259 VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
260 Type eltType = dstType.getElementType();
261
262 // Scalar to any vector can use splat.
263 if (!srcType) {
264 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
265 return success();
266 }
267
268 // Determine rank of source and destination.
269 int64_t srcRank = srcType.getRank();
270 int64_t dstRank = dstType.getRank();
271
272 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
273 if (srcRank <= 1 && dstRank == 1) {
274 Value ext;
275 if (srcRank == 0)
276 ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
277 else
278 ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
279 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
280 return success();
281 }
282
283 // Duplicate this rank.
284 // For example:
285 // %x = broadcast %y : k-D to n-D, k < n
286 // becomes:
287 // %b = broadcast %y : k-D to (n-1)-D
288 // %x = [%b,%b,%b,%b] : n-D
289 // becomes:
290 // %b = [%y,%y] : (n-1)-D
291 // %x = [%b,%b,%b,%b] : n-D
292 if (srcRank < dstRank) {
293 // Duplication.
294 VectorType resType =
295 VectorType::get(dstType.getShape().drop_front(), eltType);
296 Value bcst =
297 rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
298 Value result = rewriter.create<arith::ConstantOp>(
299 loc, dstType, rewriter.getZeroAttr(dstType));
300 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
301 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
302 rewriter.replaceOp(op, result);
303 return success();
304 }
305
306 // Find non-matching dimension, if any.
307 assert(srcRank == dstRank);
308 int64_t m = -1;
309 for (int64_t r = 0; r < dstRank; r++)
310 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
311 m = r;
312 break;
313 }
314
315 // All trailing dimensions are the same. Simply pass through.
316 if (m == -1) {
317 rewriter.replaceOp(op, op.getSource());
318 return success();
319 }
320
321 // Any non-matching dimension forces a stretch along this rank.
322 // For example:
323 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
324 // becomes:
325 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
326 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
327 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
328 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
329 // %x = [%a,%b,%c,%d]
330 // becomes:
331 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
332 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
333 // %a = [%u, %v]
334 // ..
335 // %x = [%a,%b,%c,%d]
336 VectorType resType =
337 VectorType::get(dstType.getShape().drop_front(), eltType);
338 Value result = rewriter.create<arith::ConstantOp>(
339 loc, dstType, rewriter.getZeroAttr(dstType));
340 if (m == 0) {
341 // Stetch at start.
342 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
343 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
344 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
345 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
346 } else {
347 // Stetch not at start.
348 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
349 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
350 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
351 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
352 }
353 }
354 rewriter.replaceOp(op, result);
355 return success();
356 }
357 };
358
359 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
360 /// transposed.
pruneNonTransposedDims(ArrayRef<int64_t> transpose,SmallVectorImpl<int64_t> & result)361 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
362 SmallVectorImpl<int64_t> &result) {
363 size_t numTransposedDims = transpose.size();
364 for (size_t transpDim : llvm::reverse(transpose)) {
365 if (transpDim != numTransposedDims - 1)
366 break;
367 numTransposedDims--;
368 }
369
370 result.append(transpose.begin(), transpose.begin() + numTransposedDims);
371 }
372
373 /// Progressive lowering of TransposeOp.
374 /// One:
375 /// %x = vector.transpose %y, [1, 0]
376 /// is replaced by:
377 /// %z = arith.constant dense<0.000000e+00>
378 /// %0 = vector.extract %y[0, 0]
379 /// %1 = vector.insert %0, %z [0, 0]
380 /// ..
381 /// %x = vector.insert .., .. [.., ..]
382 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
383 public:
384 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
385
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context)386 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
387 MLIRContext *context)
388 : OpRewritePattern<vector::TransposeOp>(context),
389 vectorTransformOptions(vectorTransformOptions) {}
390
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const391 LogicalResult matchAndRewrite(vector::TransposeOp op,
392 PatternRewriter &rewriter) const override {
393 auto loc = op.getLoc();
394
395 Value input = op.getVector();
396 VectorType inputType = op.getVectorType();
397 VectorType resType = op.getResultType();
398
399 // Set up convenience transposition table.
400 SmallVector<int64_t, 4> transp;
401 for (auto attr : op.getTransp())
402 transp.push_back(attr.cast<IntegerAttr>().getInt());
403
404 if (vectorTransformOptions.vectorTransposeLowering ==
405 vector::VectorTransposeLowering::Shuffle &&
406 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
407 return rewriter.notifyMatchFailure(
408 op, "Options specifies lowering to shuffle");
409
410 // Handle a true 2-D matrix transpose differently when requested.
411 if (vectorTransformOptions.vectorTransposeLowering ==
412 vector::VectorTransposeLowering::Flat &&
413 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
414 Type flattenedType =
415 VectorType::get(resType.getNumElements(), resType.getElementType());
416 auto matrix =
417 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
418 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
419 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
420 Value trans = rewriter.create<vector::FlatTransposeOp>(
421 loc, flattenedType, matrix, rows, columns);
422 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
423 return success();
424 }
425
426 // Generate unrolled extract/insert ops. We do not unroll the rightmost
427 // (i.e., highest-order) dimensions that are not transposed and leave them
428 // in vector form to improve performance. Therefore, we prune those
429 // dimensions from the shape/transpose data structures used to generate the
430 // extract/insert ops.
431 SmallVector<int64_t, 4> prunedTransp;
432 pruneNonTransposedDims(transp, prunedTransp);
433 size_t numPrunedDims = transp.size() - prunedTransp.size();
434 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
435 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
436 auto prunedInStrides = computeStrides(prunedInShape, ones);
437
438 // Generates the extract/insert operations for every scalar/vector element
439 // of the leftmost transposed dimensions. We traverse every transpose
440 // element using a linearized index that we delinearize to generate the
441 // appropriate indices for the extract/insert operations.
442 Value result = rewriter.create<arith::ConstantOp>(
443 loc, resType, rewriter.getZeroAttr(resType));
444 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
445
446 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
447 ++linearIdx) {
448 auto extractIdxs = delinearize(prunedInStrides, linearIdx);
449 SmallVector<int64_t, 4> insertIdxs(extractIdxs);
450 applyPermutationToVector(insertIdxs, prunedTransp);
451 Value extractOp =
452 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
453 result =
454 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
455 }
456
457 rewriter.replaceOp(op, result);
458 return success();
459 }
460
461 private:
462 /// Options to control the vector patterns.
463 vector::VectorTransformsOptions vectorTransformOptions;
464 };
465
466 /// Rewrite a 2-D vector.transpose as a sequence of:
467 /// vector.shape_cast 2D -> 1D
468 /// vector.shuffle
469 /// vector.shape_cast 1D -> 2D
470 class TransposeOp2DToShuffleLowering
471 : public OpRewritePattern<vector::TransposeOp> {
472 public:
473 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
474
TransposeOp2DToShuffleLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context)475 TransposeOp2DToShuffleLowering(
476 vector::VectorTransformsOptions vectorTransformOptions,
477 MLIRContext *context)
478 : OpRewritePattern<vector::TransposeOp>(context),
479 vectorTransformOptions(vectorTransformOptions) {}
480
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const481 LogicalResult matchAndRewrite(vector::TransposeOp op,
482 PatternRewriter &rewriter) const override {
483 auto loc = op.getLoc();
484
485 VectorType srcType = op.getVectorType();
486 if (srcType.getRank() != 2)
487 return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
488
489 SmallVector<int64_t, 4> transp;
490 for (auto attr : op.getTransp())
491 transp.push_back(attr.cast<IntegerAttr>().getInt());
492 if (transp[0] != 1 && transp[1] != 0)
493 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
494
495 if (vectorTransformOptions.vectorTransposeLowering !=
496 VectorTransposeLowering::Shuffle)
497 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
498
499 int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
500 Value casted = rewriter.create<vector::ShapeCastOp>(
501 loc, VectorType::get({m * n}, srcType.getElementType()),
502 op.getVector());
503 SmallVector<int64_t> mask;
504 mask.reserve(m * n);
505 for (int64_t j = 0; j < n; ++j)
506 for (int64_t i = 0; i < m; ++i)
507 mask.push_back(i * n + j);
508
509 Value shuffled =
510 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
511 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
512 shuffled);
513
514 return success();
515 }
516
517 private:
518 /// Options to control the vector patterns.
519 vector::VectorTransformsOptions vectorTransformOptions;
520 };
521
522 /// Progressive lowering of OuterProductOp.
523 /// One:
524 /// %x = vector.outerproduct %lhs, %rhs, %acc
525 /// is replaced by:
526 /// %z = zero-result
527 /// %0 = vector.extract %lhs[0]
528 /// %1 = vector.broadcast %0
529 /// %2 = vector.extract %acc[0]
530 /// %3 = vector.fma %1, %rhs, %2
531 /// %4 = vector.insert %3, %z[0]
532 /// ..
533 /// %x = vector.insert %.., %..[N-1]
534 ///
535 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
536 public:
537 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
538
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const539 LogicalResult matchAndRewrite(vector::OuterProductOp op,
540 PatternRewriter &rewriter) const override {
541 auto loc = op.getLoc();
542
543 VectorType lhsType = op.getOperandVectorTypeLHS();
544 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
545 VectorType resType = op.getVectorType();
546 Type eltType = resType.getElementType();
547 bool isInt = eltType.isa<IntegerType, IndexType>();
548 Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
549 vector::CombiningKind kind = op.getKind();
550
551 if (!rhsType) {
552 // Special case: AXPY operation.
553 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
554 Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
555 kind, rewriter, isInt);
556 if (!mult.has_value())
557 return failure();
558 rewriter.replaceOp(op, mult.value());
559 return success();
560 }
561
562 Value result = rewriter.create<arith::ConstantOp>(
563 loc, resType, rewriter.getZeroAttr(resType));
564 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
565 auto pos = rewriter.getI64ArrayAttr(d);
566 Value x =
567 rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
568 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
569 Value r = nullptr;
570 if (acc)
571 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
572 Optional<Value> m =
573 createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
574 if (!m.has_value())
575 return failure();
576 result = rewriter.create<vector::InsertOp>(loc, resType, m.value(),
577 result, pos);
578 }
579 rewriter.replaceOp(op, result);
580 return success();
581 }
582 };
583
584 /// Lower vector.contract with all size one reduction dimensions to
585 /// elementwise ops when possible.
586 struct ContractOpToElementwise
587 : public OpRewritePattern<vector::ContractionOp> {
588 using OpRewritePattern::OpRewritePattern;
589 using FilterConstraintType =
590 std::function<LogicalResult(vector::ContractionOp op)>;
defaultFilter__anon5c5a5b800211::ContractOpToElementwise591 static LogicalResult defaultFilter(vector::ContractionOp op) {
592 return success();
593 }
ContractOpToElementwise__anon5c5a5b800211::ContractOpToElementwise594 ContractOpToElementwise(
595 vector::VectorTransformsOptions vectorTransformOptions,
596 MLIRContext *context,
597 const FilterConstraintType &constraint = defaultFilter)
598 : OpRewritePattern<vector::ContractionOp>(context),
599 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
600
matchAndRewrite__anon5c5a5b800211::ContractOpToElementwise601 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
602 PatternRewriter &rewriter) const override {
603 // TODO: implement masks
604 if (llvm::size(contractOp.getMasks()) != 0)
605 return failure();
606
607 if (failed(filter(contractOp)))
608 return failure();
609
610 if (vectorTransformOptions.vectorContractLowering !=
611 vector::VectorContractLowering::ParallelArith)
612 return failure();
613 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
614 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
615 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
616 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
617 SmallVector<int64_t> lhsReductionDims =
618 getReductionIndex(lhsMap, contractOp.getIteratorTypes());
619 SmallVector<int64_t> rhsReductionDims =
620 getReductionIndex(rhsMap, contractOp.getIteratorTypes());
621 // All the reduction dimensions must be a size 1.
622 for (int64_t dim : lhsReductionDims) {
623 if (lhsShape[dim] != 1)
624 return failure();
625 }
626 for (int64_t dim : rhsReductionDims) {
627 if (rhsShape[dim] != 1)
628 return failure();
629 }
630 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
631 unsigned numParallelDims = accMap.getNumResults();
632 unsigned numLhsDimToBroadcast =
633 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
634 unsigned numRhsDimToBroadcast =
635 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
636 SmallVector<int64_t> lhsDims;
637 SmallVector<int64_t> lhsTranspose;
638 SmallVector<int64_t> rhsDims;
639 SmallVector<int64_t> rhsTranspose;
640 for (int64_t dim : lhsReductionDims)
641 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
642 for (int64_t dim : rhsReductionDims)
643 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
644 // Loop through the parallel dimensions to calculate the dimensions to
645 // broadcast and to permute in order to extract only parallel dimensions.
646 for (unsigned i = 0; i < numParallelDims; i++) {
647 llvm::Optional<unsigned> lhsDim =
648 getDimPosition(lhsMap, accMap.getDimPosition(i));
649 if (lhsDim) {
650 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
651 } else {
652 // If the parallel dimension doesn't exist we will have to broadcast it.
653 lhsDims.push_back(
654 contractOp.getResultType().cast<VectorType>().getDimSize(i));
655 lhsTranspose.push_back(lhsDims.size() - 1);
656 }
657 llvm::Optional<unsigned> rhsDim =
658 getDimPosition(rhsMap, accMap.getDimPosition(i));
659 if (rhsDim) {
660 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
661 } else {
662 // If the parallel dimension doesn't exist we will have to broadcast it.
663 rhsDims.push_back(
664 contractOp.getResultType().cast<VectorType>().getDimSize(i));
665 rhsTranspose.push_back(rhsDims.size() - 1);
666 }
667 }
668 Value newLhs = contractOp.getLhs();
669 Value newRhs = contractOp.getRhs();
670 Location loc = contractOp.getLoc();
671 if (!lhsDims.empty()) {
672 lhsDims.append(lhsShape.begin(), lhsShape.end());
673 auto expandedType =
674 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
675 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
676 }
677 if (!rhsDims.empty()) {
678 rhsDims.append(rhsShape.begin(), rhsShape.end());
679 auto expandedType =
680 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
681 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
682 }
683 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
684 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
685 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
686 SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
687 SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
688 newLhs = rewriter.create<vector::ExtractOp>(
689 loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
690 newRhs = rewriter.create<vector::ExtractOp>(
691 loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
692 Optional<Value> result =
693 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
694 contractOp.getKind(), rewriter, isInt);
695 rewriter.replaceOp(contractOp, {*result});
696 return success();
697 }
698
699 private:
700 /// Options to control the vector patterns.
701 vector::VectorTransformsOptions vectorTransformOptions;
702 FilterConstraintType filter;
703 };
704
705 /// Progressive lowering of ConstantMaskOp.
706 /// One:
707 /// %x = vector.constant_mask [a,b]
708 /// is replaced by:
709 /// %z = zero-result
710 /// %l = vector.constant_mask [b]
711 /// %4 = vector.insert %l, %z[0]
712 /// ..
713 /// %x = vector.insert %l, %..[a-1]
714 /// until a one-dimensional vector is reached. All these operations
715 /// will be folded at LLVM IR level.
716 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
717 public:
718 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
719
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const720 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
721 PatternRewriter &rewriter) const override {
722 auto loc = op.getLoc();
723 auto dstType = op.getType();
724 auto eltType = dstType.getElementType();
725 auto dimSizes = op.getMaskDimSizes();
726 int64_t rank = dstType.getRank();
727
728 if (rank == 0) {
729 assert(dimSizes.size() == 1 &&
730 "Expected exactly one dim size for a 0-D vector");
731 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
732 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
733 op, dstType,
734 DenseIntElementsAttr::get(
735 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
736 ArrayRef<bool>{value}));
737 return success();
738 }
739
740 // Scalable constant masks can only be lowered for the "none set" case.
741 if (dstType.cast<VectorType>().isScalable()) {
742 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
743 op, DenseElementsAttr::get(dstType, false));
744 return success();
745 }
746
747 int64_t trueDim = std::min(dstType.getDimSize(0),
748 dimSizes[0].cast<IntegerAttr>().getInt());
749
750 if (rank == 1) {
751 // Express constant 1-D case in explicit vector form:
752 // [T,..,T,F,..,F].
753 SmallVector<bool, 4> values(dstType.getDimSize(0));
754 for (int64_t d = 0; d < trueDim; d++)
755 values[d] = true;
756 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
757 op, dstType, rewriter.getBoolVectorAttr(values));
758 return success();
759 }
760
761 VectorType lowType =
762 VectorType::get(dstType.getShape().drop_front(), eltType);
763 SmallVector<int64_t, 4> newDimSizes;
764 for (int64_t r = 1; r < rank; r++)
765 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
766 Value trueVal = rewriter.create<vector::ConstantMaskOp>(
767 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
768 Value result = rewriter.create<arith::ConstantOp>(
769 loc, dstType, rewriter.getZeroAttr(dstType));
770 for (int64_t d = 0; d < trueDim; d++) {
771 auto pos = rewriter.getI64ArrayAttr(d);
772 result =
773 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
774 }
775 rewriter.replaceOp(op, result);
776 return success();
777 }
778 };
779
780 /// Progressive lowering of CreateMaskOp.
781 /// One:
782 /// %x = vector.create_mask %a, ... : vector<dx...>
783 /// is replaced by:
784 /// %l = vector.create_mask ... : vector<...> ; one lower rank
785 /// %0 = arith.cmpi "slt", %ci, %a |
786 /// %1 = select %0, %l, %zeroes |
787 /// %r = vector.insert %1, %pr [i] | d-times
788 /// %x = ....
789 /// until a one-dimensional vector is reached.
790 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
791 public:
792 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
793
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const794 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
795 PatternRewriter &rewriter) const override {
796 auto dstType = op.getResult().getType().cast<VectorType>();
797 int64_t rank = dstType.getRank();
798 if (rank <= 1)
799 return rewriter.notifyMatchFailure(
800 op, "0-D and 1-D vectors are handled separately");
801
802 auto loc = op.getLoc();
803 auto eltType = dstType.getElementType();
804 int64_t dim = dstType.getDimSize(0);
805 Value idx = op.getOperand(0);
806
807 VectorType lowType =
808 VectorType::get(dstType.getShape().drop_front(), eltType);
809 Value trueVal = rewriter.create<vector::CreateMaskOp>(
810 loc, lowType, op.getOperands().drop_front());
811 Value falseVal = rewriter.create<arith::ConstantOp>(
812 loc, lowType, rewriter.getZeroAttr(lowType));
813 Value result = rewriter.create<arith::ConstantOp>(
814 loc, dstType, rewriter.getZeroAttr(dstType));
815 for (int64_t d = 0; d < dim; d++) {
816 Value bnd =
817 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
818 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
819 bnd, idx);
820 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
821 auto pos = rewriter.getI64ArrayAttr(d);
822 result =
823 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
824 }
825 rewriter.replaceOp(op, result);
826 return success();
827 }
828 };
829
830 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
831 /// vectors progressively on the way to target llvm.matrix intrinsics.
832 /// This iterates over the most major dimension of the 2-D vector and performs
833 /// rewrites into:
834 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
835 class ShapeCastOp2DDownCastRewritePattern
836 : public OpRewritePattern<vector::ShapeCastOp> {
837 public:
838 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
839
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const840 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
841 PatternRewriter &rewriter) const override {
842 auto sourceVectorType = op.getSourceVectorType();
843 auto resultVectorType = op.getResultVectorType();
844 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
845 return failure();
846
847 auto loc = op.getLoc();
848 Value desc = rewriter.create<arith::ConstantOp>(
849 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
850 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
851 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
852 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
853 desc = rewriter.create<vector::InsertStridedSliceOp>(
854 loc, vec, desc,
855 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
856 }
857 rewriter.replaceOp(op, desc);
858 return success();
859 }
860 };
861
862 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
863 /// vectors progressively.
864 /// This iterates over the most major dimension of the 2-D vector and performs
865 /// rewrites into:
866 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
867 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
868 class ShapeCastOp2DUpCastRewritePattern
869 : public OpRewritePattern<vector::ShapeCastOp> {
870 public:
871 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
872
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const873 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
874 PatternRewriter &rewriter) const override {
875 auto sourceVectorType = op.getSourceVectorType();
876 auto resultVectorType = op.getResultVectorType();
877 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
878 return failure();
879
880 auto loc = op.getLoc();
881 Value desc = rewriter.create<arith::ConstantOp>(
882 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
883 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
884 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
885 Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
886 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
887 /*sizes=*/mostMinorVectorSize,
888 /*strides=*/1);
889 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
890 }
891 rewriter.replaceOp(op, desc);
892 return success();
893 }
894 };
895
896 // We typically should not lower general shape cast operations into data
897 // movement instructions, since the assumption is that these casts are
898 // optimized away during progressive lowering. For completeness, however,
899 // we fall back to a reference implementation that moves all elements
900 // into the right place if we get here.
901 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
902 public:
903 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
904
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const905 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
906 PatternRewriter &rewriter) const override {
907 Location loc = op.getLoc();
908 auto sourceVectorType = op.getSourceVectorType();
909 auto resultVectorType = op.getResultVectorType();
910
911 // Special case 2D/1D lowerings with better implementations.
912 // TODO: make is ND/1D to allow generic ND->1D->MD.
913 int64_t srcRank = sourceVectorType.getRank();
914 int64_t resRank = resultVectorType.getRank();
915 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
916 return failure();
917
918 // Generic ShapeCast lowering path goes all the way down to unrolled scalar
919 // extract/insert chains.
920 // TODO: consider evolving the semantics to only allow 1D source or dest and
921 // drop this potentially very expensive lowering.
922 // Compute number of elements involved in the reshape.
923 int64_t numElts = 1;
924 for (int64_t r = 0; r < srcRank; r++)
925 numElts *= sourceVectorType.getDimSize(r);
926 // Replace with data movement operations:
927 // x[0,0,0] = y[0,0]
928 // x[0,0,1] = y[0,1]
929 // x[0,1,0] = y[0,2]
930 // etc., incrementing the two index vectors "row-major"
931 // within the source and result shape.
932 SmallVector<int64_t, 4> srcIdx(srcRank);
933 SmallVector<int64_t, 4> resIdx(resRank);
934 Value result = rewriter.create<arith::ConstantOp>(
935 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
936 for (int64_t i = 0; i < numElts; i++) {
937 if (i != 0) {
938 incIdx(srcIdx, sourceVectorType, srcRank - 1);
939 incIdx(resIdx, resultVectorType, resRank - 1);
940 }
941 Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
942 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
943 }
944 rewriter.replaceOp(op, result);
945 return success();
946 }
947
948 private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)949 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
950 assert(0 <= r && r < tp.getRank());
951 if (++idx[r] == tp.getDimSize(r)) {
952 idx[r] = 0;
953 incIdx(idx, tp, r - 1);
954 }
955 }
956 };
957
958 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
959 /// Ex:
960 /// ```
961 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
962 /// %1 = vector.multi_reduction add, %0 [1]
963 /// : vector<8x32x16xf32> to vector<8x16xf32>
964 /// ```
965 /// Gets converted to:
966 /// ```
967 /// %1 = vector.contract {indexing_maps = [
968 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
969 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
970 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
971 /// iterator_types = ["parallel", "parallel", "reduction"],
972 /// kind = add} %0, %arg1, %cst_f0
973 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
974 /// ```
975 struct MultiReduceToContract
976 : public OpRewritePattern<vector::MultiDimReductionOp> {
977 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
978
matchAndRewrite__anon5c5a5b800211::MultiReduceToContract979 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
980 PatternRewriter &rewriter) const override {
981 if (reduceOp.getKind() != vector::CombiningKind::ADD)
982 return failure();
983 Operation *mulOp = reduceOp.getSource().getDefiningOp();
984 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
985 return failure();
986 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
987 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
988 SmallVector<AffineExpr> exprs;
989 SmallVector<StringRef> iteratorTypes;
990 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
991 if (!isReduceDim.value()) {
992 iteratorTypes.push_back(getParallelIteratorTypeName());
993 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
994 } else {
995 iteratorTypes.push_back(getReductionIteratorTypeName());
996 }
997 }
998 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
999 /*symCount=*/0, exprs, reduceOp.getContext());
1000 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
1001 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
1002 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
1003 rewriter.getStrArrayAttr(iteratorTypes));
1004 return success();
1005 }
1006 };
1007
1008 /// Merge TransposeOp into ContractionOp user.
1009 /// Ex:
1010 /// ```
1011 /// %0 = vector.transpose %arg0, [2, 0, 1]
1012 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
1013 /// %1 = vector.contract {indexing_maps = [
1014 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1015 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1016 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1017 /// iterator_types = ["parallel", "parallel", "reduction"],
1018 /// kind = add} %0, %arg1, %cst_f0
1019 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1020 /// ```
1021 /// Gets converted to:
1022 /// ```
1023 /// %1 = vector.contract {indexing_maps = [
1024 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
1025 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1026 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1027 /// iterator_types = ["parallel", "parallel", "reduction"],
1028 /// kind = add} %arg0, %arg1, %cst_f0
1029 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1030 /// ```
1031 struct CombineContractTranspose
1032 : public OpRewritePattern<vector::ContractionOp> {
1033 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1034
matchAndRewrite__anon5c5a5b800211::CombineContractTranspose1035 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1036 PatternRewriter &rewriter) const override {
1037 SmallVector<AffineMap, 4> maps =
1038 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1039 Value lhs = contractOp.getLhs();
1040 Value rhs = contractOp.getRhs();
1041 size_t index = 0;
1042 bool changed = false;
1043 for (Value *operand : {&lhs, &rhs}) {
1044 AffineMap &map = maps[index++];
1045 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
1046 if (!transposeOp)
1047 continue;
1048 SmallVector<int64_t> perm;
1049 transposeOp.getTransp(perm);
1050 AffineMap permutationMap = AffineMap::getPermutationMap(
1051 extractVector<unsigned>(transposeOp.getTransp()),
1052 contractOp.getContext());
1053 map = inversePermutation(permutationMap).compose(map);
1054 *operand = transposeOp.getVector();
1055 changed = true;
1056 }
1057 if (!changed)
1058 return failure();
1059 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1060 contractOp, lhs, rhs, contractOp.getAcc(),
1061 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
1062 return success();
1063 }
1064 };
1065
1066 /// Merge BroadcastOp into ContractionOp user.
1067 /// Ex:
1068 /// ```
1069 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
1070 /// %1 = vector.contract {indexing_maps = [
1071 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1072 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1073 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1074 /// iterator_types = ["parallel", "parallel", "reduction"],
1075 /// kind = add} %0, %arg1, %cst_f0
1076 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1077 /// ```
1078 /// Gets converted to:
1079 /// ```
1080 /// %1 = vector.contract {indexing_maps = [
1081 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
1082 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1083 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1084 /// iterator_types = ["parallel", "parallel", "reduction"],
1085 /// kind = add} %arg0, %arg1, %cst_f0
1086 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1087 /// ```
1088 struct CombineContractBroadcast
1089 : public OpRewritePattern<vector::ContractionOp> {
1090 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1091
matchAndRewrite__anon5c5a5b800211::CombineContractBroadcast1092 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1093 PatternRewriter &rewriter) const override {
1094 SmallVector<AffineMap, 4> maps =
1095 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1096 Value lhs = contractOp.getLhs();
1097 Value rhs = contractOp.getRhs();
1098 size_t index = 0;
1099 bool changed = false;
1100 for (Value *operand : {&lhs, &rhs}) {
1101 AffineMap &map = maps[index++];
1102 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
1103 if (!broadcast)
1104 continue;
1105 // contractionOp can only take vector as operands.
1106 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
1107 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
1108 continue;
1109 int64_t rankDiff =
1110 broadcast.getVectorType().getRank() - srcType.getRank();
1111 bool innerDimBroadcast = false;
1112 SmallVector<AffineExpr> originalDims;
1113 for (const auto &dim : llvm::enumerate(srcType.getShape())) {
1114 if (dim.value() !=
1115 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1116 innerDimBroadcast = true;
1117 break;
1118 }
1119 originalDims.push_back(
1120 rewriter.getAffineDimExpr(dim.index() + rankDiff));
1121 }
1122 // Contract doesn't support inner dimension broadcast. Once this is
1123 // relaxed we can remove this case.
1124 if (innerDimBroadcast)
1125 continue;
1126
1127 // It would be incorrect to fold a broadcast onto a reduction dimension
1128 // of non-unit size.
1129 bool nonUnitDimReductionBroadcast = false;
1130 for (int64_t i = 0; i < rankDiff; ++i) {
1131 if (broadcast.getVectorType().getDimSize(i) != 1 &&
1132 isReductionIterator(contractOp.getIteratorTypes()
1133 .getValue()[map.getDimPosition(i)])) {
1134 nonUnitDimReductionBroadcast = true;
1135 break;
1136 }
1137 }
1138 if (nonUnitDimReductionBroadcast)
1139 continue;
1140
1141 AffineMap broadcastMap =
1142 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1143 contractOp.getContext());
1144 map = broadcastMap.compose(map);
1145 *operand = broadcast.getSource();
1146 changed = true;
1147 }
1148
1149 if (!changed)
1150 return failure();
1151
1152 // Determine which dims are usused, now that the maps have been composed
1153 // with the broadcast maps.
1154 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
1155 // Compress unused dims.
1156 for (auto &m : maps)
1157 m = compressDims(m, unusedDimsBitVector);
1158 // Compute the combined iterators.
1159 SmallVector<Attribute, 4> iterators;
1160 for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
1161 if (!unusedDimsBitVector.test(i))
1162 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
1163 }
1164 // Check that compressing unused dims isn't removing all reduction dimension
1165 // pairs. For example, if the vector.contract had only one reduction
1166 // iterator and that was a unit-dimension created by a broadcast,
1167 // then we should bail here, otherwise we would create a contract without
1168 // a reduction dimension pair.
1169 bool hasReductionIteratorApplyingOnBothSides = false;
1170 for (unsigned i = 0; i < iterators.size(); ++i) {
1171 if (!isReductionIterator(iterators[i]))
1172 continue;
1173 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
1174 hasReductionIteratorApplyingOnBothSides = true;
1175 break;
1176 }
1177 }
1178 if (!hasReductionIteratorApplyingOnBothSides)
1179 return failure();
1180
1181 // If the compressed maps have a dimension that is not used by either LHS or
1182 // RHS then the ContractionOp verifier would fail.
1183 if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
1184 return failure();
1185 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1186 contractOp, lhs, rhs, contractOp.getAcc(),
1187 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
1188 return success();
1189 }
1190 };
1191
1192 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1193 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1194 /// casting ops are around these operations.
1195 /// Ex:
1196 /// ```
1197 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1198 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1199 /// ```
1200 /// Gets converted to:
1201 /// ```
1202 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1203 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1204 /// ```
1205 struct ReorderCastOpsOnBroadcast
1206 : public OpInterfaceRewritePattern<CastOpInterface> {
1207 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1208
matchAndRewrite__anon5c5a5b800211::ReorderCastOpsOnBroadcast1209 LogicalResult matchAndRewrite(CastOpInterface op,
1210 PatternRewriter &rewriter) const override {
1211 if (op->getNumOperands() != 1)
1212 return failure();
1213 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1214 if (!bcastOp)
1215 return failure();
1216
1217 Type castResTy = getElementTypeOrSelf(op->getResult(0));
1218 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1219 castResTy = VectorType::get(vecTy.getShape(), castResTy);
1220 auto *castOp =
1221 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1222 bcastOp.getSource(), castResTy, op->getAttrs());
1223 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1224 op, op->getResult(0).getType(), castOp->getResult(0));
1225 return success();
1226 }
1227 };
1228
1229 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
1230 /// transpose ops and contraction ops closer, which kicks in
1231 /// CombineContractTranspose pattern when elementwise ops are between these
1232 /// operations. Ex:
1233 /// ```
1234 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1235 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1236 /// %r = arith.addf %at, %bt : vector<2x4xf32>
1237 /// ```
1238 /// Gets converted to:
1239 /// ```
1240 /// %0 = arith.addf %a, %b : vector<4x2xf32>
1241 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
1242 /// ```
1243 struct ReorderElementwiseOpsOnTranspose final
1244 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1245 using OpTraitRewritePattern::OpTraitRewritePattern;
matchAndRewrite__anon5c5a5b800211::ReorderElementwiseOpsOnTranspose1246 LogicalResult matchAndRewrite(Operation *op,
1247 PatternRewriter &rewriter) const override {
1248 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1249 return failure();
1250
1251 // Make sure all operands are transpose/constant ops and collect their
1252 // transposition maps.
1253 SmallVector<ArrayAttr, 4> transposeMaps;
1254 transposeMaps.reserve(op->getNumOperands());
1255 // Record the initial type before transposition. We'll use its shape later.
1256 // Any type will do here as we will check all transpose maps are the same.
1257 VectorType srcType;
1258 for (Value operand : op->getOperands()) {
1259 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1260 if (transposeOp) {
1261 transposeMaps.push_back(transposeOp.getTransp());
1262 srcType = transposeOp.getVectorType();
1263 } else if (!matchPattern(operand, m_Constant())) {
1264 return failure();
1265 }
1266 }
1267 if (transposeMaps.empty())
1268 return failure();
1269 // This is an elementwise op, so all transposed operands should have the
1270 // same type. We need to additionally check that all transposes uses the
1271 // same map.
1272 if (!llvm::is_splat(transposeMaps))
1273 return rewriter.notifyMatchFailure(op, "different transpose map");
1274
1275 SmallVector<Value, 4> srcValues;
1276 srcValues.reserve(op->getNumOperands());
1277
1278 // If there are constant operands, we need to insert inverse transposes for
1279 // them. Calculate the inverse order first.
1280 auto order = extractVector<unsigned>(transposeMaps.front());
1281 SmallVector<int64_t> invOrder(order.size());
1282 for (int i = 0, e = order.size(); i < e; ++i)
1283 invOrder[order[i]] = i;
1284
1285 for (Value operand : op->getOperands()) {
1286 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1287 if (transposeOp) {
1288 srcValues.push_back(transposeOp.getVector());
1289 } else {
1290 // This is a constant. Create a reverse transpose op for it.
1291 auto vectorType = VectorType::get(
1292 srcType.getShape(),
1293 operand.getType().cast<VectorType>().getElementType());
1294 srcValues.push_back(rewriter.create<vector::TransposeOp>(
1295 operand.getLoc(), vectorType, operand,
1296 rewriter.getI64ArrayAttr(invOrder)));
1297 }
1298 }
1299
1300 auto vectorType = VectorType::get(
1301 srcType.getShape(),
1302 op->getResultTypes()[0].cast<VectorType>().getElementType());
1303 Operation *elementwiseOp =
1304 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1305 vectorType, op->getAttrs());
1306 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1307 op, op->getResultTypes()[0], elementwiseOp->getResult(0),
1308 transposeMaps.front());
1309 return success();
1310 }
1311 };
1312
1313 } // namespace
1314
1315 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1316 /// operands `x` and `y`.
createAdd(Location loc,Value x,Value y,bool isInt,PatternRewriter & rewriter)1317 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1318 PatternRewriter &rewriter) {
1319 if (isInt)
1320 return rewriter.create<arith::AddIOp>(loc, x, y);
1321 return rewriter.create<arith::AddFOp>(loc, x, y);
1322 }
1323
1324 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1325 /// operands `x and `y`.
createMul(Location loc,Value x,Value y,bool isInt,PatternRewriter & rewriter)1326 static Value createMul(Location loc, Value x, Value y, bool isInt,
1327 PatternRewriter &rewriter) {
1328 if (isInt)
1329 return rewriter.create<arith::MulIOp>(loc, x, y);
1330 return rewriter.create<arith::MulFOp>(loc, x, y);
1331 }
1332
1333 namespace mlir {
1334
1335 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1336 /// semantics to:
1337 /// ```
1338 /// %mta = maybe_transpose
1339 /// %mtb = maybe_transpose
1340 /// %flattened_a = vector.shape_cast %mta
1341 /// %flattened_b = vector.shape_cast %mtb
1342 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1343 /// %mtd = vector.shape_cast %flattened_d
1344 /// %d = maybe_untranspose %mtd
1345 /// %e = add %c, %d
1346 /// ```
1347 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1348 //
1349 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1350 /// vector.transpose operations are inserted if the vector.contract op is not a
1351 /// row-major matrix multiply.
1352 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rew) const1353 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1354 PatternRewriter &rew) const {
1355 // TODO: implement masks
1356 if (llvm::size(op.getMasks()) != 0)
1357 return failure();
1358 if (vectorTransformOptions.vectorContractLowering !=
1359 vector::VectorContractLowering::Matmul)
1360 return failure();
1361 if (failed(filter(op)))
1362 return failure();
1363
1364 auto iteratorTypes = op.getIteratorTypes().getValue();
1365 if (!isParallelIterator(iteratorTypes[0]) ||
1366 !isParallelIterator(iteratorTypes[1]) ||
1367 !isReductionIterator(iteratorTypes[2]))
1368 return failure();
1369
1370 Type elementType = op.getLhsType().getElementType();
1371 if (!elementType.isIntOrFloat())
1372 return failure();
1373
1374 Type dstElementType = op.getType();
1375 if (auto vecType = dstElementType.dyn_cast<VectorType>())
1376 dstElementType = vecType.getElementType();
1377 if (elementType != dstElementType)
1378 return failure();
1379
1380 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1381 // Bail out if the contraction cannot be put in this form.
1382 MLIRContext *ctx = op.getContext();
1383 Location loc = op.getLoc();
1384 AffineExpr m, n, k;
1385 bindDims(rew.getContext(), m, n, k);
1386 // LHS must be A(m, k) or A(k, m).
1387 Value lhs = op.getLhs();
1388 auto lhsMap = op.getIndexingMapsArray()[0];
1389 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1390 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1391 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1392 return failure();
1393
1394 // RHS must be B(k, n) or B(n, k).
1395 Value rhs = op.getRhs();
1396 auto rhsMap = op.getIndexingMapsArray()[1];
1397 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1398 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1399 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1400 return failure();
1401
1402 // At this point lhs and rhs are in row-major.
1403 VectorType lhsType = lhs.getType().cast<VectorType>();
1404 VectorType rhsType = rhs.getType().cast<VectorType>();
1405 int64_t lhsRows = lhsType.getDimSize(0);
1406 int64_t lhsColumns = lhsType.getDimSize(1);
1407 int64_t rhsColumns = rhsType.getDimSize(1);
1408
1409 Type flattenedLHSType =
1410 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1411 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1412
1413 Type flattenedRHSType =
1414 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1415 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1416
1417 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1418 rhsColumns);
1419 mul = rew.create<vector::ShapeCastOp>(
1420 loc,
1421 VectorType::get({lhsRows, rhsColumns},
1422 getElementTypeOrSelf(op.getAcc().getType())),
1423 mul);
1424
1425 // ACC must be C(m, n) or C(n, m).
1426 auto accMap = op.getIndexingMapsArray()[2];
1427 if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1428 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1429 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1430 llvm_unreachable("invalid contraction semantics");
1431
1432 Value res =
1433 elementType.isa<IntegerType>()
1434 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1435 : static_cast<Value>(
1436 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1437
1438 rew.replaceOp(op, res);
1439 return success();
1440 }
1441
1442 namespace {
1443 struct IteratorType {
IteratorTypemlir::__anon5c5a5b800311::IteratorType1444 IteratorType(StringRef strRef) : strRef(strRef) {}
isOfTypemlir::__anon5c5a5b800311::IteratorType1445 bool isOfType(Attribute attr) const {
1446 auto sAttr = attr.dyn_cast<StringAttr>();
1447 return sAttr && sAttr.getValue() == strRef;
1448 }
1449 StringRef strRef;
1450 };
1451 struct Par : public IteratorType {
Parmlir::__anon5c5a5b800311::Par1452 Par() : IteratorType(getParallelIteratorTypeName()) {}
1453 };
1454 struct Red : public IteratorType {
Redmlir::__anon5c5a5b800311::Red1455 Red() : IteratorType(getReductionIteratorTypeName()) {}
1456 };
1457
1458 /// Generate a vector implementation for matmat, matvec and tmatvec.
1459 /// This unrolls outer-products along the reduction dimension.
1460 struct UnrolledOuterProductGenerator
1461 : public StructuredGenerator<vector::ContractionOp> {
UnrolledOuterProductGeneratormlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1462 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1463 : StructuredGenerator<vector::ContractionOp>(builder, op),
1464 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1465 res(op.getAcc()), lhsType(op.getLhsType()) {}
1466
tmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1467 Value t(Value v) {
1468 static constexpr std::array<int64_t, 2> perm = {1, 0};
1469 return builder.create<vector::TransposeOp>(loc, v, perm);
1470 }
1471
promotemlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1472 Value promote(Value v, Type dstElementType) {
1473 Type elementType = v.getType();
1474 auto vecType = elementType.dyn_cast<VectorType>();
1475 if (vecType)
1476 elementType = vecType.getElementType();
1477 if (elementType == dstElementType)
1478 return v;
1479 Type promotedType = dstElementType;
1480 if (vecType)
1481 promotedType = VectorType::get(vecType.getShape(), promotedType);
1482 if (dstElementType.isa<FloatType>())
1483 return builder.create<arith::ExtFOp>(loc, promotedType, v);
1484 return builder.create<arith::ExtSIOp>(loc, promotedType, v);
1485 }
1486
outerProdmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1487 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1488 assert(reductionSize > 0);
1489 Type resElementType = res.getType().cast<VectorType>().getElementType();
1490 for (int64_t k = 0; k < reductionSize; ++k) {
1491 Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1492 Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1493 a = promote(a, resElementType);
1494 b = promote(b, resElementType);
1495 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1496 res, kind);
1497 }
1498 return res;
1499 }
1500
1501 /// Two outer parallel, one inner reduction (matmat flavor).
matmatmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1502 FailureOr<Value> matmat() {
1503 if (!iters({Par(), Par(), Red()}))
1504 return failure();
1505 // Set up the parallel/reduction structure in the right form.
1506 AffineExpr m, n, k;
1507 bindDims(builder.getContext(), m, n, k);
1508 // Classical row-major matmul: Just permute the lhs.
1509 if (layout({{m, k}, {k, n}, {m, n}}))
1510 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1511 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1512 if (layout({{m, k}, {n, k}, {m, n}})) {
1513 Value tlhs = t(lhs);
1514 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1515 }
1516 // No need to permute anything.
1517 if (layout({{k, m}, {k, n}, {m, n}}))
1518 return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1519 // Just permute the rhs.
1520 if (layout({{k, m}, {n, k}, {m, n}}))
1521 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1522 // Transposed output: swap RHS and LHS.
1523 // Classical row-major matmul: permute the lhs.
1524 if (layout({{m, k}, {k, n}, {n, m}}))
1525 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1526 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1527 if (layout({{m, k}, {n, k}, {n, m}})) {
1528 Value trhs = t(rhs);
1529 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1530 }
1531 if (layout({{k, m}, {k, n}, {n, m}}))
1532 return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1533 if (layout({{k, m}, {n, k}, {n, m}}))
1534 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1535 return failure();
1536 }
1537
1538 /// One outer parallel, one inner reduction (matvec flavor)
matvecmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1539 FailureOr<Value> matvec() {
1540 if (!iters({Par(), Red()}))
1541 return failure();
1542 AffineExpr m, k;
1543 bindDims(builder.getContext(), m, k);
1544
1545 // Case mat-vec: transpose.
1546 if (layout({{m, k}, {k}, {m}}))
1547 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1548 // Case mat-trans-vec: ready to go.
1549 if (layout({{k, m}, {k}, {m}}))
1550 return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1551 // Case vec-mat: swap and transpose.
1552 if (layout({{k}, {m, k}, {m}}))
1553 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1554 // Case vec-mat-trans: swap and ready to go.
1555 if (layout({{k}, {k, m}, {m}}))
1556 return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1557 return failure();
1558 }
1559
1560 //
1561 // One outer reduction, one inner parallel (tmatvec flavor)
1562 //
tmatvecmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1563 FailureOr<Value> tmatvec() {
1564 if (!iters({Red(), Par()}))
1565 return failure();
1566 AffineExpr k, m;
1567 bindDims(builder.getContext(), k, m);
1568
1569 // Case mat-vec: transpose.
1570 if (layout({{m, k}, {k}, {m}}))
1571 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1572 // Case mat-trans-vec: ready to go.
1573 if (layout({{k, m}, {k}, {m}}))
1574 return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1575 // Case vec-mat: swap and transpose.
1576 if (layout({{k}, {m, k}, {m}}))
1577 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1578 // Case vec-mat-trans: swap and ready to go.
1579 if (layout({{k}, {k, m}, {m}}))
1580 return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1581 return failure();
1582 }
1583
1584 private:
1585 vector::CombiningKind kind;
1586 Value lhs, rhs, res;
1587 VectorType lhsType;
1588 };
1589 } // namespace
1590
1591 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1592 /// semantics to a reduction_size-unrolled sequence:
1593 /// ```
1594 /// %at = vector.transpose %a, [1, 0]
1595 /// %bRow0 = vector.extract %b[0]
1596 /// %atRow0 = vector.extract %at[0]
1597 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
1598 /// ...
1599 /// %bRowK = vector.extract %b[K]
1600 /// %atRowK = vector.extract %at[K]
1601 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1602 /// ```
1603 ///
1604 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1605 /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1606 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1607 vector::ContractionOp op, PatternRewriter &rewriter) const {
1608 // TODO: implement masks
1609 if (llvm::size(op.getMasks()) != 0)
1610 return failure();
1611
1612 if (vectorTransformOptions.vectorContractLowering !=
1613 vector::VectorContractLowering::OuterProduct)
1614 return failure();
1615
1616 if (failed(filter(op)))
1617 return failure();
1618
1619 UnrolledOuterProductGenerator e(rewriter, op);
1620 FailureOr<Value> matmatRes = e.matmat();
1621 if (succeeded(matmatRes)) {
1622 rewriter.replaceOp(op, *matmatRes);
1623 return success();
1624 }
1625 FailureOr<Value> matvecRes = e.matvec();
1626 if (succeeded(matvecRes)) {
1627 rewriter.replaceOp(op, *matvecRes);
1628 return success();
1629 }
1630 FailureOr<Value> tmatvecRes = e.tmatvec();
1631 if (succeeded(tmatvecRes)) {
1632 rewriter.replaceOp(op, *tmatvecRes);
1633 return success();
1634 }
1635
1636 return failure();
1637 }
1638
1639 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1640 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1641 PatternRewriter &rewriter) const {
1642 // TODO: implement masks
1643 if (llvm::size(op.getMasks()) != 0)
1644 return failure();
1645
1646 if (failed(filter(op)))
1647 return failure();
1648
1649 if (vectorTransformOptions.vectorContractLowering !=
1650 vector::VectorContractLowering::Dot)
1651 return failure();
1652
1653 auto iteratorTypes = op.getIteratorTypes().getValue();
1654 static constexpr std::array<int64_t, 2> perm = {1, 0};
1655 Location loc = op.getLoc();
1656 Value lhs = op.getLhs(), rhs = op.getRhs();
1657
1658 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1659 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1660 AffineExpr m, n, k;
1661 bindDims(rewriter.getContext(), m, n, k);
1662 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1663 //
1664 // In the following we wish to make the reduction dimension innermost so we
1665 // can load vectors and just fmul + reduce into a scalar.
1666 //
1667 if (isParallelIterator(iteratorTypes[0]) &&
1668 isParallelIterator(iteratorTypes[1]) &&
1669 isReductionIterator(iteratorTypes[2])) {
1670 //
1671 // Two outer parallel, one inner reduction (matmat flavor).
1672 //
1673 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1674 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1675 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1676 // No need to permute anything.
1677 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1678 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1679 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1680 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1681 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1682 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1683 // This is the classical row-major matmul. Just permute the lhs.
1684 Value tmp = lhs;
1685 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1686 rhs = tmp;
1687 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1688 std::swap(lhs, rhs);
1689 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1690 Value tmp = lhs;
1691 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1692 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1693 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1694 Value tmp = rhs;
1695 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1696 lhs = tmp;
1697 } else {
1698 return failure();
1699 }
1700 } else if (isParallelIterator(iteratorTypes[0]) &&
1701 isReductionIterator(iteratorTypes[1])) {
1702 //
1703 // One outer parallel, one inner reduction (matvec flavor)
1704 //
1705 if (maps == infer({{m, n}, {n}, {m}})) {
1706 // No need to permute anything.
1707 } else if (maps == infer({{n, m}, {n}, {m}})) {
1708 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1709 } else if (maps == infer({{n}, {m, n}, {m}})) {
1710 std::swap(lhs, rhs);
1711 } else if (maps == infer({{n}, {n, m}, {m}})) {
1712 std::swap(lhs, rhs);
1713 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1714 } else {
1715 return failure();
1716 }
1717 } else {
1718 return failure();
1719 }
1720
1721 VectorType dstType = op.getResultType().cast<VectorType>();
1722 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1723 "Expected dst type of rank 1 or 2");
1724
1725 unsigned rank = dstType.getRank();
1726 unsigned dstRows = dstType.getShape()[0];
1727 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1728
1729 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1730 Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1731 rewriter.getZeroAttr(dstType));
1732 bool isInt = dstType.getElementType().isa<IntegerType>();
1733 for (unsigned r = 0; r < dstRows; ++r) {
1734 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1735 for (unsigned c = 0; c < dstColumns; ++c) {
1736 Value b = rank == 1
1737 ? rhs
1738 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1739 Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1740 Value reduced = rewriter.create<vector::ReductionOp>(
1741 op.getLoc(), vector::CombiningKind::ADD, m);
1742
1743 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1744 : SmallVector<int64_t, 2>{r, c};
1745 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1746 }
1747 }
1748 if (auto acc = op.getAcc())
1749 res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1750 rewriter.replaceOp(op, res);
1751 return success();
1752 }
1753
1754 /// Progressive lowering of ContractionOp.
1755 /// One:
1756 /// %x = vector.contract with at least one free/batch dimension
1757 /// is replaced by:
1758 /// %a = vector.contract with one less free/batch dimension
1759 /// %b = vector.contract with one less free/batch dimension
1760 /// ..
1761 /// %x = combine %a %b ..
1762 /// until a pure contraction is reached (no free/batch dimensions),
1763 /// which is replaced by a dot-product.
1764 ///
1765 /// This only kicks in when either VectorTransformsOptions is set
1766 /// to DOT or when other contraction patterns fail.
1767 //
1768 // TODO: break down into transpose/reshape/cast ops
1769 // when they become available to avoid code dup
1770 // TODO: investigate lowering order impact on performance
1771 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1772 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1773 PatternRewriter &rewriter) const {
1774 // TODO: implement masks.
1775 if (llvm::size(op.getMasks()) != 0)
1776 return failure();
1777
1778 if (failed(filter(op)))
1779 return failure();
1780
1781 // TODO: support mixed mode contract lowering.
1782 if (op.getLhsType().getElementType() !=
1783 getElementTypeOrSelf(op.getAccType()) ||
1784 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1785 return failure();
1786
1787 // TODO: implement benefits, cost models.
1788 MLIRContext *ctx = op.getContext();
1789 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1790 if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1791 return success();
1792 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1793 if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1794 return success();
1795 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1796 if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1797 return success();
1798 ContractOpToElementwise pat4(vectorTransformOptions, ctx);
1799 if (succeeded(pat4.matchAndRewrite(op, rewriter)))
1800 return success();
1801
1802 // Find first batch dimension in LHS/RHS, and lower when found.
1803 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1804 if (!batchDimMap.empty()) {
1805 int64_t lhsIndex = batchDimMap[0].first;
1806 int64_t rhsIndex = batchDimMap[0].second;
1807 auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
1808 if (failed(newOp))
1809 return failure();
1810 rewriter.replaceOp(op, newOp.value());
1811 return success();
1812 }
1813
1814 // Collect contracting dimensions.
1815 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1816 op.getContractingDimMap();
1817 DenseSet<int64_t> lhsContractingDimSet;
1818 DenseSet<int64_t> rhsContractingDimSet;
1819 for (auto &dimPair : contractingDimMap) {
1820 lhsContractingDimSet.insert(dimPair.first);
1821 rhsContractingDimSet.insert(dimPair.second);
1822 }
1823
1824 // Find first free dimension in LHS, and lower when found.
1825 VectorType lhsType = op.getLhsType();
1826 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1827 if (lhsContractingDimSet.count(lhsIndex) == 0) {
1828 auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
1829 if (failed(newOp))
1830 return failure();
1831 rewriter.replaceOp(op, newOp.value());
1832 return success();
1833 }
1834 }
1835
1836 // Find first free dimension in RHS, and lower when found.
1837 VectorType rhsType = op.getRhsType();
1838 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1839 if (rhsContractingDimSet.count(rhsIndex) == 0) {
1840 auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
1841 if (failed(newOp))
1842 return failure();
1843 rewriter.replaceOp(op, newOp.value());
1844 return success();
1845 }
1846 }
1847
1848 // Lower the first remaining reduction dimension.
1849 if (!contractingDimMap.empty()) {
1850 auto newOp = lowerReduction(op, rewriter);
1851 if (failed(newOp))
1852 return failure();
1853 rewriter.replaceOp(op, newOp.value());
1854 return success();
1855 }
1856
1857 return failure();
1858 }
1859
1860 // Lower one parallel dimension.
1861 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1862 // TODO: consider reusing existing contract unrolling
1863 FailureOr<Value>
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const1864 ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
1865 int64_t rhsIndex,
1866 PatternRewriter &rewriter) const {
1867 VectorType lhsType = op.getLhsType();
1868 VectorType rhsType = op.getRhsType();
1869 VectorType resType = op.getResultType().cast<VectorType>();
1870 // Find the iterator type index and result index.
1871 SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
1872 int64_t iterIndex = -1;
1873 int64_t dimSize = -1;
1874 if (lhsIndex >= 0) {
1875 iterIndex = iMap[0].getDimPosition(lhsIndex);
1876 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1877 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1878 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1879 << " to map to the same dimension";
1880 });
1881 dimSize = lhsType.getDimSize(lhsIndex);
1882 } else if (rhsIndex >= 0) {
1883 iterIndex = iMap[1].getDimPosition(rhsIndex);
1884 dimSize = rhsType.getDimSize(rhsIndex);
1885 }
1886 if (iterIndex < 0)
1887 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1888 diag << "expected either lhsIndex=" << lhsIndex
1889 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1890 });
1891 // value_or(-1) means that we tolerate a dimension not appearing
1892 // in the result map. That can't happen for actual parallel iterators, but
1893 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1894 // lowerParallel also for the case of unit-size reduction dims appearing only
1895 // on one of LHS or RHS, not both. At the moment, such cases are created by
1896 // CastAwayContractionLeadingOneDim, so we need to either support that or
1897 // modify that pattern.
1898 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1899 if (resIndex == -1 && dimSize != 1)
1900 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1901 diag << "expected the dimension for iterIndex=" << iterIndex
1902 << " to either appear in the result map, or to be a unit dimension";
1903 });
1904 // Construct new iterator types and affine map array attribute.
1905 std::array<AffineMap, 3> lowIndexingMaps = {
1906 adjustMap(iMap[0], iterIndex, rewriter),
1907 adjustMap(iMap[1], iterIndex, rewriter),
1908 adjustMap(iMap[2], iterIndex, rewriter)};
1909 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1910 auto lowIter =
1911 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1912 // Unroll into a series of lower dimensional vector.contract ops.
1913 Location loc = op.getLoc();
1914 Value result = rewriter.create<arith::ConstantOp>(
1915 loc, resType, rewriter.getZeroAttr(resType));
1916 for (int64_t d = 0; d < dimSize; ++d) {
1917 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1918 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1919 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1920 Value lowContract = rewriter.create<vector::ContractionOp>(
1921 loc, lhs, rhs, acc, lowAffine, lowIter);
1922 result =
1923 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1924 }
1925 return result;
1926 }
1927
1928 // Lower one reduction dimension.
1929 FailureOr<Value>
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const1930 ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1931 PatternRewriter &rewriter) const {
1932 auto loc = op.getLoc();
1933 VectorType lhsType = op.getLhsType();
1934 VectorType rhsType = op.getRhsType();
1935 Type resType = op.getResultType();
1936 if (resType.isa<VectorType>())
1937 return rewriter.notifyMatchFailure(op,
1938 "did not expect a VectorType result");
1939 bool isInt = resType.isa<IntegerType>();
1940 // Use iterator index 0.
1941 int64_t iterIndex = 0;
1942 SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
1943 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1944 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1945 if (!lookupLhs.has_value())
1946 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1947 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1948 });
1949 if (!lookupRhs.has_value())
1950 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1951 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1952 });
1953 int64_t lhsIndex = lookupLhs.value();
1954 int64_t rhsIndex = lookupRhs.value();
1955 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1956 if (dimSize != rhsType.getDimSize(rhsIndex))
1957 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1958 diag << "expect LHS dimension " << lhsIndex
1959 << " to have the same size as RHS dimension " << rhsIndex;
1960 });
1961 // Base case.
1962 if (lhsType.getRank() == 1) {
1963 if (rhsType.getRank() != 1)
1964 return rewriter.notifyMatchFailure(
1965 op, "When LHS has rank 1, expected also RHS to have rank 1");
1966 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1967 auto kind = vector::CombiningKind::ADD;
1968 if (auto acc = op.getAcc())
1969 return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1970 .getResult();
1971 return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
1972 }
1973 // Construct new iterator types and affine map array attribute.
1974 std::array<AffineMap, 3> lowIndexingMaps = {
1975 adjustMap(iMap[0], iterIndex, rewriter),
1976 adjustMap(iMap[1], iterIndex, rewriter),
1977 adjustMap(iMap[2], iterIndex, rewriter)};
1978 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1979 auto lowIter =
1980 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1981 // Unroll into a series of lower dimensional vector.contract ops.
1982 // By feeding the initial accumulator into the first contraction,
1983 // and the result of each contraction into the next, eventually
1984 // the sum of all reductions is computed.
1985 Value result = op.getAcc();
1986 for (int64_t d = 0; d < dimSize; ++d) {
1987 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1988 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1989 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1990 lowAffine, lowIter);
1991 }
1992 return result;
1993 }
1994
1995 } // namespace mlir
1996
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)1997 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1998 OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1999 ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2000 OpBuilder::InsertionGuard guard(builder);
2001 builder.setInsertionPointAfter(op);
2002 Location loc = op->getLoc();
2003 if (op->getNumResults() != 1)
2004 return {};
2005 Value result = op->getResult(0);
2006 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2007 if (!type || map.getNumResults() != multiplicity.size())
2008 return {};
2009 // For each dimension being distributed check that the size is a multiple of
2010 // the multiplicity. To handle more sizes we would need to support masking.
2011 unsigned multiplictyCount = 0;
2012 for (auto exp : map.getResults()) {
2013 auto affinExp = exp.dyn_cast<AffineDimExpr>();
2014 if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2015 type.getDimSize(affinExp.getPosition()) %
2016 multiplicity[multiplictyCount++] !=
2017 0)
2018 return {};
2019 }
2020 DistributeOps ops;
2021 ops.extract =
2022 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2023 ops.insert =
2024 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2025 return ops;
2026 }
2027
2028 /// Progressive lowering of transfer_read. This pattern supports lowering of
2029 /// `vector.transfer_read` to a combination of `vector.load` and
2030 /// `vector.broadcast` if all of the following hold:
2031 /// - Stride of most minor memref dimension must be 1.
2032 /// - Out-of-bounds masking is not required.
2033 /// - If the memref's element type is a vector type then it coincides with the
2034 /// result type.
2035 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
2036 struct TransferReadToVectorLoadLowering
2037 : public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLoweringTransferReadToVectorLoadLowering2038 TransferReadToVectorLoadLowering(MLIRContext *context,
2039 llvm::Optional<unsigned> maxRank)
2040 : OpRewritePattern<vector::TransferReadOp>(context),
2041 maxTransferRank(maxRank) {}
2042
matchAndRewriteTransferReadToVectorLoadLowering2043 LogicalResult matchAndRewrite(vector::TransferReadOp read,
2044 PatternRewriter &rewriter) const override {
2045 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
2046 return failure();
2047
2048 SmallVector<unsigned, 4> broadcastedDims;
2049 // Permutations are handled by VectorToSCF or
2050 // populateVectorTransferPermutationMapLoweringPatterns.
2051 // We let the 0-d corner case pass-through as it is supported.
2052 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
2053 &broadcastedDims))
2054 return failure();
2055
2056 auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
2057 if (!memRefType)
2058 return failure();
2059
2060 // Non-unit strides are handled by VectorToSCF.
2061 if (!vector::isLastMemrefDimUnitStride(memRefType))
2062 return failure();
2063
2064 // If there is broadcasting involved then we first load the unbroadcasted
2065 // vector, and then broadcast it with `vector.broadcast`.
2066 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
2067 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
2068 vectorShape.end());
2069 for (unsigned i : broadcastedDims)
2070 unbroadcastedVectorShape[i] = 1;
2071 VectorType unbroadcastedVectorType = VectorType::get(
2072 unbroadcastedVectorShape, read.getVectorType().getElementType());
2073
2074 // `vector.load` supports vector types as memref's elements only when the
2075 // resulting vector type is the same as the element type.
2076 auto memrefElTy = memRefType.getElementType();
2077 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
2078 return failure();
2079
2080 // Otherwise, element types of the memref and the vector must match.
2081 if (!memrefElTy.isa<VectorType>() &&
2082 memrefElTy != read.getVectorType().getElementType())
2083 return failure();
2084
2085 // Out-of-bounds dims are handled by MaterializeTransferMask.
2086 if (read.hasOutOfBoundsDim())
2087 return failure();
2088
2089 // Create vector load op.
2090 Operation *loadOp;
2091 if (read.getMask()) {
2092 Value fill = rewriter.create<vector::SplatOp>(
2093 read.getLoc(), unbroadcastedVectorType, read.getPadding());
2094 loadOp = rewriter.create<vector::MaskedLoadOp>(
2095 read.getLoc(), unbroadcastedVectorType, read.getSource(),
2096 read.getIndices(), read.getMask(), fill);
2097 } else {
2098 loadOp = rewriter.create<vector::LoadOp>(
2099 read.getLoc(), unbroadcastedVectorType, read.getSource(),
2100 read.getIndices());
2101 }
2102
2103 // Insert a broadcasting op if required.
2104 if (!broadcastedDims.empty()) {
2105 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2106 read, read.getVectorType(), loadOp->getResult(0));
2107 } else {
2108 rewriter.replaceOp(read, loadOp->getResult(0));
2109 }
2110
2111 return success();
2112 }
2113
2114 llvm::Optional<unsigned> maxTransferRank;
2115 };
2116
2117 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
2118 // TODO: we shouldn't cross the vector/scalar domains just for this
2119 // but atm we lack the infra to avoid it. Possible solutions include:
2120 // - go directly to LLVM + bitcast
2121 // - introduce a bitcast op and likely a new pointer dialect
2122 // - let memref.load/store additionally support the 0-d vector case
2123 // There are still deeper data layout issues lingering even in this
2124 // trivial case (for architectures for which this matters).
2125 struct VectorLoadToMemrefLoadLowering
2126 : public OpRewritePattern<vector::LoadOp> {
2127 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
2128
matchAndRewriteVectorLoadToMemrefLoadLowering2129 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
2130 PatternRewriter &rewriter) const override {
2131 auto vecType = loadOp.getVectorType();
2132 if (vecType.getNumElements() != 1)
2133 return failure();
2134 auto memrefLoad = rewriter.create<memref::LoadOp>(
2135 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
2136 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
2137 memrefLoad);
2138 return success();
2139 }
2140 };
2141
2142 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
2143 struct VectorStoreToMemrefStoreLowering
2144 : public OpRewritePattern<vector::StoreOp> {
2145 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
2146
matchAndRewriteVectorStoreToMemrefStoreLowering2147 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
2148 PatternRewriter &rewriter) const override {
2149 auto vecType = storeOp.getVectorType();
2150 if (vecType.getNumElements() != 1)
2151 return failure();
2152 Value extracted;
2153 if (vecType.getRank() == 0) {
2154 // TODO: Unifiy once ExtractOp supports 0-d vectors.
2155 extracted = rewriter.create<vector::ExtractElementOp>(
2156 storeOp.getLoc(), storeOp.getValueToStore());
2157 } else {
2158 SmallVector<int64_t> indices(vecType.getRank(), 0);
2159 extracted = rewriter.create<vector::ExtractOp>(
2160 storeOp.getLoc(), storeOp.getValueToStore(), indices);
2161 }
2162
2163 rewriter.replaceOpWithNewOp<memref::StoreOp>(
2164 storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
2165 return success();
2166 }
2167 };
2168
2169 /// Progressive lowering of transfer_write. This pattern supports lowering of
2170 /// `vector.transfer_write` to `vector.store` if all of the following hold:
2171 /// - Stride of most minor memref dimension must be 1.
2172 /// - Out-of-bounds masking is not required.
2173 /// - If the memref's element type is a vector type then it coincides with the
2174 /// type of the written value.
2175 /// - The permutation map is the minor identity map (neither permutation nor
2176 /// broadcasting is allowed).
2177 struct TransferWriteToVectorStoreLowering
2178 : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLoweringTransferWriteToVectorStoreLowering2179 TransferWriteToVectorStoreLowering(MLIRContext *context,
2180 llvm::Optional<unsigned> maxRank)
2181 : OpRewritePattern<vector::TransferWriteOp>(context),
2182 maxTransferRank(maxRank) {}
2183
matchAndRewriteTransferWriteToVectorStoreLowering2184 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2185 PatternRewriter &rewriter) const override {
2186 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
2187 return failure();
2188
2189 // Permutations are handled by VectorToSCF or
2190 // populateVectorTransferPermutationMapLoweringPatterns.
2191 if ( // pass-through for the 0-d corner case.
2192 !write.getPermutationMap().isMinorIdentity())
2193 return failure();
2194
2195 auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
2196 if (!memRefType)
2197 return failure();
2198
2199 // Non-unit strides are handled by VectorToSCF.
2200 if (!vector::isLastMemrefDimUnitStride(memRefType))
2201 return failure();
2202
2203 // `vector.store` supports vector types as memref's elements only when the
2204 // type of the vector value being written is the same as the element type.
2205 auto memrefElTy = memRefType.getElementType();
2206 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
2207 return failure();
2208
2209 // Otherwise, element types of the memref and the vector must match.
2210 if (!memrefElTy.isa<VectorType>() &&
2211 memrefElTy != write.getVectorType().getElementType())
2212 return failure();
2213
2214 // Out-of-bounds dims are handled by MaterializeTransferMask.
2215 if (write.hasOutOfBoundsDim())
2216 return failure();
2217 if (write.getMask()) {
2218 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
2219 write, write.getSource(), write.getIndices(), write.getMask(),
2220 write.getVector());
2221 } else {
2222 rewriter.replaceOpWithNewOp<vector::StoreOp>(
2223 write, write.getVector(), write.getSource(), write.getIndices());
2224 }
2225 return success();
2226 }
2227
2228 llvm::Optional<unsigned> maxTransferRank;
2229 };
2230
2231 // Returns the values in `arrayAttr` as an integer vector.
getIntValueVector(ArrayAttr arrayAttr)2232 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
2233 return llvm::to_vector<4>(
2234 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
2235 [](IntegerAttr attr) { return attr.getInt(); }));
2236 }
2237
2238 // Shuffles vector.bitcast op after vector.extract op.
2239 //
2240 // This transforms IR like:
2241 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
2242 // %1 = vector.extract %0[3] : vector<8xf16>
2243 // Into:
2244 // %0 = vector.extract %src[1] : vector<4xf32>
2245 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
2246 // %2 = vector.extract %1[1] : vector<2xf16>
2247 struct BubbleDownVectorBitCastForExtract
2248 : public OpRewritePattern<vector::ExtractOp> {
2249 using OpRewritePattern::OpRewritePattern;
2250
matchAndRewriteBubbleDownVectorBitCastForExtract2251 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2252 PatternRewriter &rewriter) const override {
2253 // Only support extracting scalars for now.
2254 if (extractOp.getVectorType().getRank() != 1)
2255 return failure();
2256
2257 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2258 if (!castOp)
2259 return failure();
2260
2261 VectorType castSrcType = castOp.getSourceVectorType();
2262 VectorType castDstType = castOp.getResultVectorType();
2263 assert(castSrcType.getRank() == castDstType.getRank());
2264
2265 // Fail to match if we only have one element in the cast op source.
2266 // This is to avoid infinite loop given that this pattern can generate
2267 // such cases.
2268 if (castSrcType.getNumElements() == 1)
2269 return failure();
2270
2271 // Only support casting to a larger number of elements or now.
2272 // E.g., vector<4xf32> -> vector<8xf16>.
2273 if (castSrcType.getNumElements() > castDstType.getNumElements())
2274 return failure();
2275
2276 unsigned expandRatio =
2277 castDstType.getNumElements() / castSrcType.getNumElements();
2278
2279 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
2280 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
2281 };
2282
2283 uint64_t index = getFirstIntValue(extractOp.getPosition());
2284
2285 // Get the single scalar (as a vector) in the source value that packs the
2286 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
2287 VectorType oneScalarType =
2288 VectorType::get({1}, castSrcType.getElementType());
2289 Value packedValue = rewriter.create<vector::ExtractOp>(
2290 extractOp.getLoc(), oneScalarType, castOp.getSource(),
2291 rewriter.getI64ArrayAttr(index / expandRatio));
2292
2293 // Cast it to a vector with the desired scalar's type.
2294 // E.g. f32 -> vector<2xf16>
2295 VectorType packedType =
2296 VectorType::get({expandRatio}, castDstType.getElementType());
2297 Value castedValue = rewriter.create<vector::BitCastOp>(
2298 extractOp.getLoc(), packedType, packedValue);
2299
2300 // Finally extract the desired scalar.
2301 rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2302 extractOp, extractOp.getType(), castedValue,
2303 rewriter.getI64ArrayAttr(index % expandRatio));
2304
2305 return success();
2306 }
2307 };
2308
2309 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2310 //
2311 // This transforms IR like:
2312 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2313 // %0 = vector.extract_strided_slice %cast {
2314 // offsets = [4], sizes = [4], strides = [1]
2315 // } : vector<8xf16> to vector<4xf16>
2316 // Into:
2317 // %0 = vector.extract_strided_slice %src {
2318 // offsets = [2], sizes = [2], strides = [1]
2319 // } : vector<4xf32> to vector<2xf32>
2320 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2321 struct BubbleDownBitCastForStridedSliceExtract
2322 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2323 using OpRewritePattern::OpRewritePattern;
2324
matchAndRewriteBubbleDownBitCastForStridedSliceExtract2325 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2326 PatternRewriter &rewriter) const override {
2327 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2328 if (!castOp)
2329 return failure();
2330
2331 VectorType castSrcType = castOp.getSourceVectorType();
2332 VectorType castDstType = castOp.getResultVectorType();
2333 assert(castSrcType.getRank() == castDstType.getRank());
2334
2335 int64_t castSrcLastDim = castSrcType.getShape().back();
2336 int64_t castDstLastDim = castDstType.getShape().back();
2337 // Require casting to more elements for now; other cases to be implemented.
2338 if (castSrcLastDim > castDstLastDim)
2339 return failure();
2340
2341 // Only accept all one strides for now.
2342 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2343 [](const APInt &val) { return !val.isOneValue(); }))
2344 return failure();
2345
2346 unsigned rank = extractOp.getVectorType().getRank();
2347 assert(castDstLastDim % castSrcLastDim == 0);
2348 int64_t expandRatio = castDstLastDim / castSrcLastDim;
2349
2350 // If we have a less number of offsets than the rank, then implicitly we
2351 // are selecting the full range for the last bitcasted dimension; other
2352 // dimensions aren't affected. Otherwise, we need to scale down the last
2353 // dimension's offset given we are extracting from less elements now.
2354 ArrayAttr newOffsets = extractOp.getOffsets();
2355 if (newOffsets.size() == rank) {
2356 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2357 if (offsets.back() % expandRatio != 0)
2358 return failure();
2359 offsets.back() = offsets.back() / expandRatio;
2360 newOffsets = rewriter.getI64ArrayAttr(offsets);
2361 }
2362
2363 // Similarly for sizes.
2364 ArrayAttr newSizes = extractOp.getSizes();
2365 if (newSizes.size() == rank) {
2366 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2367 if (sizes.back() % expandRatio != 0)
2368 return failure();
2369 sizes.back() = sizes.back() / expandRatio;
2370 newSizes = rewriter.getI64ArrayAttr(sizes);
2371 }
2372
2373 SmallVector<int64_t, 4> dims =
2374 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2375 dims.back() = dims.back() / expandRatio;
2376 VectorType newExtractType =
2377 VectorType::get(dims, castSrcType.getElementType());
2378
2379 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2380 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2381 newSizes, extractOp.getStrides());
2382
2383 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2384 extractOp, extractOp.getType(), newExtractOp);
2385
2386 return success();
2387 }
2388 };
2389
2390 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2391 //
2392 // This transforms IR like:
2393 // %0 = vector.insert_strided_slice %src, %dst {
2394 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2395 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2396 // Into:
2397 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2398 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2399 // %2 = vector.insert_strided_slice %src, %dst {
2400 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2401 struct BubbleUpBitCastForStridedSliceInsert
2402 : public OpRewritePattern<vector::BitCastOp> {
2403 using OpRewritePattern::OpRewritePattern;
matchAndRewriteBubbleUpBitCastForStridedSliceInsert2404 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2405 PatternRewriter &rewriter) const override {
2406 VectorType castSrcType = bitcastOp.getSourceVectorType();
2407 VectorType castDstType = bitcastOp.getResultVectorType();
2408 assert(castSrcType.getRank() == castDstType.getRank());
2409
2410 int64_t castSrcLastDim = castSrcType.getShape().back();
2411 int64_t castDstLastDim = castDstType.getShape().back();
2412 // Require casting to less elements for now; other cases to be implemented.
2413 if (castSrcLastDim < castDstLastDim)
2414 return failure();
2415
2416 assert(castSrcLastDim % castDstLastDim == 0);
2417 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2418
2419 auto insertOp =
2420 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2421 if (!insertOp)
2422 return failure();
2423
2424 // Only accept all one strides for now.
2425 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2426 [](const APInt &val) { return !val.isOneValue(); }))
2427 return failure();
2428
2429 unsigned rank = insertOp.getSourceVectorType().getRank();
2430 // Require insert op to have the same rank for the source and destination
2431 // vector; other cases to be implemented.
2432 if (rank != insertOp.getDestVectorType().getRank())
2433 return failure();
2434
2435 ArrayAttr newOffsets = insertOp.getOffsets();
2436 assert(newOffsets.size() == rank);
2437 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2438 if (offsets.back() % shrinkRatio != 0)
2439 return failure();
2440 offsets.back() = offsets.back() / shrinkRatio;
2441 newOffsets = rewriter.getI64ArrayAttr(offsets);
2442
2443 SmallVector<int64_t, 4> srcDims =
2444 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2445 srcDims.back() = srcDims.back() / shrinkRatio;
2446 VectorType newCastSrcType =
2447 VectorType::get(srcDims, castDstType.getElementType());
2448
2449 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2450 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2451
2452 SmallVector<int64_t, 4> dstDims =
2453 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2454 dstDims.back() = dstDims.back() / shrinkRatio;
2455 VectorType newCastDstType =
2456 VectorType::get(dstDims, castDstType.getElementType());
2457
2458 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2459 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2460
2461 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2462 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2463 insertOp.getStrides());
2464
2465 return success();
2466 }
2467 };
2468
2469 // Helper that returns a vector comparison that constructs a mask:
2470 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2471 //
2472 // If `dim == 0` then the result will be a 0-D vector.
2473 //
2474 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2475 // much more compact, IR for this operation, but LLVM eventually
2476 // generates more elaborate instructions for this intrinsic since it
2477 // is very conservative on the boundary conditions.
buildVectorComparison(PatternRewriter & rewriter,Operation * op,bool force32BitVectorIndices,int64_t dim,Value b,Value * off=nullptr)2478 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2479 bool force32BitVectorIndices, int64_t dim,
2480 Value b, Value *off = nullptr) {
2481 auto loc = op->getLoc();
2482 // If we can assume all indices fit in 32-bit, we perform the vector
2483 // comparison in 32-bit to get a higher degree of SIMD parallelism.
2484 // Otherwise we perform the vector comparison using 64-bit indices.
2485 Type idxType =
2486 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
2487 DenseIntElementsAttr indicesAttr;
2488 if (dim == 0 && force32BitVectorIndices) {
2489 indicesAttr = DenseIntElementsAttr::get(
2490 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2491 } else if (dim == 0) {
2492 indicesAttr = DenseIntElementsAttr::get(
2493 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2494 } else if (force32BitVectorIndices) {
2495 indicesAttr = rewriter.getI32VectorAttr(
2496 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2497 } else {
2498 indicesAttr = rewriter.getI64VectorAttr(
2499 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2500 }
2501 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2502 // Add in an offset if requested.
2503 if (off) {
2504 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2505 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2506 indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2507 }
2508 // Construct the vector comparison.
2509 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2510 Value bounds =
2511 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2512 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2513 bounds);
2514 }
2515
2516 template <typename ConcreteOp>
2517 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2518 public:
MaterializeTransferMaskMaterializeTransferMask2519 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2520 : mlir::OpRewritePattern<ConcreteOp>(context),
2521 force32BitVectorIndices(enableIndexOpt) {}
2522
matchAndRewriteMaterializeTransferMask2523 LogicalResult matchAndRewrite(ConcreteOp xferOp,
2524 PatternRewriter &rewriter) const override {
2525 if (!xferOp.hasOutOfBoundsDim())
2526 return failure();
2527
2528 if (xferOp.getVectorType().getRank() > 1 ||
2529 llvm::size(xferOp.getIndices()) == 0)
2530 return failure();
2531
2532 Location loc = xferOp->getLoc();
2533 VectorType vtp = xferOp.getVectorType();
2534
2535 // Create the in-bounds mask with all elements between [0 .. dim - offset)
2536 // set and [dim - offset .. vector_length) unset.
2537 //
2538 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2539 // dimensions here.
2540 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2541 Value off = xferOp.getIndices()[lastIndex];
2542 Value dim =
2543 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2544 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2545 Value mask = rewriter.create<vector::CreateMaskOp>(
2546 loc,
2547 VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2548 vtp.getNumScalableDims()),
2549 b);
2550 if (xferOp.getMask()) {
2551 // Intersect the in-bounds with the mask specified as an op parameter.
2552 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
2553 }
2554
2555 rewriter.updateRootInPlace(xferOp, [&]() {
2556 xferOp.getMaskMutable().assign(mask);
2557 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
2558 });
2559
2560 return success();
2561 }
2562
2563 private:
2564 const bool force32BitVectorIndices;
2565 };
2566
2567 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2568 class VectorCreateMaskOpConversion
2569 : public OpRewritePattern<vector::CreateMaskOp> {
2570 public:
VectorCreateMaskOpConversion(MLIRContext * context,bool enableIndexOpt)2571 explicit VectorCreateMaskOpConversion(MLIRContext *context,
2572 bool enableIndexOpt)
2573 : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2574 force32BitVectorIndices(enableIndexOpt) {}
2575
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const2576 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2577 PatternRewriter &rewriter) const override {
2578 auto dstType = op.getType();
2579 if (dstType.cast<VectorType>().isScalable())
2580 return failure();
2581 int64_t rank = dstType.getRank();
2582 if (rank > 1)
2583 return failure();
2584 rewriter.replaceOp(
2585 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
2586 rank == 0 ? 0 : dstType.getDimSize(0),
2587 op.getOperand(0)));
2588 return success();
2589 }
2590
2591 private:
2592 const bool force32BitVectorIndices;
2593 };
2594
2595 // Drop inner most contiguous unit dimensions from transfer_read operand.
2596 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2597 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2598
matchAndRewrite(vector::TransferReadOp readOp,PatternRewriter & rewriter) const2599 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2600 PatternRewriter &rewriter) const override {
2601 // TODO: support 0-d corner case.
2602 if (readOp.getTransferRank() == 0)
2603 return failure();
2604
2605 // TODO: support mask.
2606 if (readOp.getMask())
2607 return failure();
2608
2609 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2610 if (!srcType || !srcType.hasStaticShape())
2611 return failure();
2612
2613 if (!readOp.getPermutationMap().isMinorIdentity())
2614 return failure();
2615
2616 auto targetType = readOp.getVectorType();
2617 if (targetType.getRank() <= 1)
2618 return failure();
2619
2620 SmallVector<int64_t> srcStrides;
2621 int64_t srcOffset;
2622 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2623 return failure();
2624
2625 size_t dimsToDrop = 0;
2626 for (size_t i = 1; i < srcStrides.size(); ++i) {
2627 int dim = srcType.getRank() - i - 1;
2628 if (srcStrides[dim] == 1) {
2629 dimsToDrop++;
2630 } else {
2631 break;
2632 }
2633 }
2634 if (dimsToDrop == 0)
2635 return failure();
2636
2637 auto resultTargetVecType =
2638 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2639 targetType.getElementType());
2640
2641 MemRefType resultMemrefType;
2642 if (srcType.getLayout().getAffineMap().isIdentity()) {
2643 resultMemrefType = MemRefType::get(
2644 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2645 {}, srcType.getMemorySpaceAsInt());
2646 } else {
2647 AffineMap map = srcType.getLayout().getAffineMap();
2648 int numSymbols = map.getNumSymbols();
2649 for (size_t i = 0; i < dimsToDrop; ++i) {
2650 int dim = srcType.getRank() - i - 1;
2651 map = map.replace(rewriter.getAffineDimExpr(dim),
2652 rewriter.getAffineConstantExpr(0),
2653 map.getNumDims() - 1, numSymbols);
2654 }
2655 resultMemrefType = MemRefType::get(
2656 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2657 map, srcType.getMemorySpaceAsInt());
2658 }
2659
2660 auto loc = readOp.getLoc();
2661 SmallVector<int64_t> offsets(srcType.getRank(), 0);
2662 SmallVector<int64_t> strides(srcType.getRank(), 1);
2663
2664 ArrayAttr inBoundsAttr =
2665 readOp.getInBounds()
2666 ? rewriter.getArrayAttr(
2667 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2668 : ArrayAttr();
2669 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2670 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2671 strides);
2672 auto permMap = getTransferMinorIdentityMap(
2673 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2674 Value result = rewriter.create<vector::TransferReadOp>(
2675 loc, resultTargetVecType, rankedReducedView,
2676 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2677 readOp.getPadding(),
2678 // TODO: support mask.
2679 /*mask=*/Value(), inBoundsAttr);
2680 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2681 result);
2682 return success();
2683 }
2684 };
2685
2686 namespace {
2687
2688 /// This function checks to see if the vector combining kind
2689 /// is consistent with the integer or float element type.
isValidKind(bool isInt,vector::CombiningKind kind)2690 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2691 using vector::CombiningKind;
2692 enum class KindType { FLOAT, INT, INVALID };
2693 KindType type{KindType::INVALID};
2694 switch (kind) {
2695 case CombiningKind::MINF:
2696 case CombiningKind::MAXF:
2697 type = KindType::FLOAT;
2698 break;
2699 case CombiningKind::MINUI:
2700 case CombiningKind::MINSI:
2701 case CombiningKind::MAXUI:
2702 case CombiningKind::MAXSI:
2703 case CombiningKind::AND:
2704 case CombiningKind::OR:
2705 case CombiningKind::XOR:
2706 type = KindType::INT;
2707 break;
2708 case CombiningKind::ADD:
2709 case CombiningKind::MUL:
2710 type = isInt ? KindType::INT : KindType::FLOAT;
2711 break;
2712 }
2713 bool isValidIntKind = (type == KindType::INT) && isInt;
2714 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2715 return (isValidIntKind || isValidFloatKind);
2716 }
2717
2718 /// This function constructs the appropriate integer or float
2719 /// operation given the vector combining kind and operands. The
2720 /// supported int operations are : add, mul, min (signed/unsigned),
2721 /// max(signed/unsigned), and, or, xor. The supported float
2722 /// operations are : add, mul, min and max.
genOperator(Location loc,Value x,Value y,vector::CombiningKind kind,PatternRewriter & rewriter)2723 static Value genOperator(Location loc, Value x, Value y,
2724 vector::CombiningKind kind,
2725 PatternRewriter &rewriter) {
2726 using vector::CombiningKind;
2727
2728 auto elType = x.getType().cast<VectorType>().getElementType();
2729 bool isInt = elType.isIntOrIndex();
2730
2731 Value combinedResult{nullptr};
2732 switch (kind) {
2733 case CombiningKind::ADD:
2734 if (isInt)
2735 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2736 else
2737 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2738 break;
2739 case CombiningKind::MUL:
2740 if (isInt)
2741 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2742 else
2743 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2744 break;
2745 case CombiningKind::MINUI:
2746 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2747 break;
2748 case CombiningKind::MINSI:
2749 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2750 break;
2751 case CombiningKind::MAXUI:
2752 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2753 break;
2754 case CombiningKind::MAXSI:
2755 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2756 break;
2757 case CombiningKind::AND:
2758 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2759 break;
2760 case CombiningKind::OR:
2761 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2762 break;
2763 case CombiningKind::XOR:
2764 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2765 break;
2766 case CombiningKind::MINF:
2767 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2768 break;
2769 case CombiningKind::MAXF:
2770 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2771 break;
2772 }
2773 return combinedResult;
2774 }
2775
2776 /// Convert vector.scan op into arith ops and
2777 /// vector.insert_strided_slice/extract_strided_slice
2778 ///
2779 /// Ex:
2780 /// ```
2781 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2782 /// 1} :
2783 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2784 /// ```
2785 /// Gets converted to:
2786 /// ```
2787 /// %cst = arith.constant dense<0> : vector<2x3xi32>
2788 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2789 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2790 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2791 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2792 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2793 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2794 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2795 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2796 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2797 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2798 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2799 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2800 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2801 /// vector<2x3xi32>, vector<2xi32>
2802 /// ```
2803 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2804 using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2805
matchAndRewrite__anon5c5a5b801011::ScanToArithOps2806 LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2807 PatternRewriter &rewriter) const override {
2808 auto loc = scanOp.getLoc();
2809 VectorType destType = scanOp.getDestType();
2810 ArrayRef<int64_t> destShape = destType.getShape();
2811 auto elType = destType.getElementType();
2812 bool isInt = elType.isIntOrIndex();
2813 if (!isValidKind(isInt, scanOp.getKind()))
2814 return failure();
2815
2816 VectorType resType = VectorType::get(destShape, elType);
2817 Value result = rewriter.create<arith::ConstantOp>(
2818 loc, resType, rewriter.getZeroAttr(resType));
2819 int64_t reductionDim = scanOp.getReductionDim();
2820 bool inclusive = scanOp.getInclusive();
2821 int64_t destRank = destType.getRank();
2822 VectorType initialValueType = scanOp.getInitialValueType();
2823 int64_t initialValueRank = initialValueType.getRank();
2824
2825 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2826 reductionShape[reductionDim] = 1;
2827 VectorType reductionType = VectorType::get(reductionShape, elType);
2828 SmallVector<int64_t> offsets(destRank, 0);
2829 SmallVector<int64_t> strides(destRank, 1);
2830 SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2831 sizes[reductionDim] = 1;
2832 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2833 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2834
2835 Value lastOutput, lastInput;
2836 for (int i = 0; i < destShape[reductionDim]; i++) {
2837 offsets[reductionDim] = i;
2838 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2839 Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2840 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2841 scanStrides);
2842 Value output;
2843 if (i == 0) {
2844 if (inclusive) {
2845 output = input;
2846 } else {
2847 if (initialValueRank == 0) {
2848 // ShapeCastOp cannot handle 0-D vectors
2849 output = rewriter.create<vector::BroadcastOp>(
2850 loc, input.getType(), scanOp.getInitialValue());
2851 } else {
2852 output = rewriter.create<vector::ShapeCastOp>(
2853 loc, input.getType(), scanOp.getInitialValue());
2854 }
2855 }
2856 } else {
2857 Value y = inclusive ? input : lastInput;
2858 output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2859 assert(output != nullptr);
2860 }
2861 result = rewriter.create<vector::InsertStridedSliceOp>(
2862 loc, output, result, offsets, strides);
2863 lastOutput = output;
2864 lastInput = input;
2865 }
2866
2867 Value reduction;
2868 if (initialValueRank == 0) {
2869 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2870 reduction =
2871 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2872 } else {
2873 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2874 lastOutput);
2875 }
2876
2877 rewriter.replaceOp(scanOp, {result, reduction});
2878 return success();
2879 }
2880 };
2881
2882 } // namespace
2883
populateVectorMaskMaterializationPatterns(RewritePatternSet & patterns,bool force32BitVectorIndices)2884 void mlir::vector::populateVectorMaskMaterializationPatterns(
2885 RewritePatternSet &patterns, bool force32BitVectorIndices) {
2886 patterns.add<VectorCreateMaskOpConversion,
2887 MaterializeTransferMask<vector::TransferReadOp>,
2888 MaterializeTransferMask<vector::TransferWriteOp>>(
2889 patterns.getContext(), force32BitVectorIndices);
2890 }
2891
populateShapeCastFoldingPatterns(RewritePatternSet & patterns)2892 void mlir::vector::populateShapeCastFoldingPatterns(
2893 RewritePatternSet &patterns) {
2894 patterns.add<ShapeCastOpFolder>(patterns.getContext());
2895 }
2896
populateBubbleVectorBitCastOpPatterns(RewritePatternSet & patterns)2897 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2898 RewritePatternSet &patterns) {
2899 patterns.add<BubbleDownVectorBitCastForExtract,
2900 BubbleDownBitCastForStridedSliceExtract,
2901 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2902 }
2903
populateVectorBroadcastLoweringPatterns(RewritePatternSet & patterns)2904 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2905 RewritePatternSet &patterns) {
2906 patterns.add<BroadcastOpLowering>(patterns.getContext());
2907 }
2908
populateVectorMaskOpLoweringPatterns(RewritePatternSet & patterns)2909 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2910 RewritePatternSet &patterns) {
2911 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2912 patterns.getContext());
2913 }
2914
populateVectorShapeCastLoweringPatterns(RewritePatternSet & patterns)2915 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2916 RewritePatternSet &patterns) {
2917 patterns.add<ShapeCastOp2DDownCastRewritePattern,
2918 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2919 patterns.getContext());
2920 }
2921
populateVectorContractLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options)2922 void mlir::vector::populateVectorContractLoweringPatterns(
2923 RewritePatternSet &patterns, VectorTransformsOptions options) {
2924 patterns.add<OuterProductOpLowering>(patterns.getContext());
2925 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2926 ContractionOpToOuterProductOpLowering>(options,
2927 patterns.getContext());
2928 }
2929
populateVectorTransposeLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options)2930 void mlir::vector::populateVectorTransposeLoweringPatterns(
2931 RewritePatternSet &patterns, VectorTransformsOptions options) {
2932 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2933 options, patterns.getContext());
2934 }
2935
populateVectorReductionToContractPatterns(RewritePatternSet & patterns)2936 void mlir::vector::populateVectorReductionToContractPatterns(
2937 RewritePatternSet &patterns) {
2938 patterns.add<MultiReduceToContract, CombineContractBroadcast,
2939 CombineContractTranspose, ReorderCastOpsOnBroadcast,
2940 ReorderElementwiseOpsOnTranspose>(patterns.getContext());
2941 }
2942
2943 void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet & patterns)2944 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2945 RewritePatternSet &patterns) {
2946 patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2947 }
2948
populateVectorTransferLoweringPatterns(RewritePatternSet & patterns,llvm::Optional<unsigned> maxTransferRank)2949 void mlir::vector::populateVectorTransferLoweringPatterns(
2950 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2951 patterns.add<TransferReadToVectorLoadLowering,
2952 TransferWriteToVectorStoreLowering>(patterns.getContext(),
2953 maxTransferRank);
2954 patterns
2955 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2956 patterns.getContext());
2957 }
2958
populateVectorScanLoweringPatterns(RewritePatternSet & patterns)2959 void mlir::vector::populateVectorScanLoweringPatterns(
2960 RewritePatternSet &patterns) {
2961 patterns.add<ScanToArithOps>(patterns.getContext());
2962 }
2963