1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "mlir/Support/MathExtras.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include <numeric>
22
23 #define DEBUG_TYPE "vector-unrolling"
24
25 using namespace mlir;
26 using namespace mlir::vector;
27
28 /// During unrolling from `originalShape` to `targetShape` return the offset for
29 /// the slice `index`.
getVectorOffset(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,int64_t index)30 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
31 ArrayRef<int64_t> targetShape,
32 int64_t index) {
33 SmallVector<int64_t, 4> dstSliceStrides =
34 computeStrides(originalShape, targetShape);
35 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
36 SmallVector<int64_t, 4> elementOffsets =
37 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
38 return elementOffsets;
39 }
40
41 /// A functor that accomplishes the same thing as `getVectorOffset` but allows
42 /// for reordering the traversal of the dimensions. The order of traversal is
43 /// given in "for loop order" (outer to inner).
44 namespace {
45 class DecomposeShapeIterator {
46 private:
47 SmallVector<int64_t, 4> vectorShape;
48 SmallVector<int64_t> loopOrder;
49 SmallVector<int64_t> sliceStrides;
50 int64_t maxIndexVal{1};
51
52 public:
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,ArrayRef<int64_t> loopOrder)53 DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
54 ArrayRef<int64_t> targetShape,
55 ArrayRef<int64_t> loopOrder)
56 : vectorShape(targetShape.begin(), targetShape.end()),
57 loopOrder(loopOrder.begin(), loopOrder.end()),
58 sliceStrides(originalShape.size()) {
59 assert(originalShape.size() == targetShape.size());
60 assert(loopOrder.size() == targetShape.size());
61
62 // Compute the count for each dimension.
63 SmallVector<int64_t> sliceDimCounts(originalShape.size());
64 for (unsigned r = 0; r < originalShape.size(); ++r) {
65 sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
66 maxIndexVal *= sliceDimCounts[r];
67 }
68
69 // Reversing "loop order" gives dimensions from fastest varying to slowest
70 // varying (smallest stride to largest stride).
71 int64_t accum = 1;
72 for (auto idx : llvm::reverse(loopOrder)) {
73 sliceStrides[idx] = accum;
74 accum *= sliceDimCounts[idx];
75 }
76 }
77
78 // Turn the linear index into a d-tuple based on units of vectors of size
79 // `vectorShape`. The linear index is assumed to represent traversal of the
80 // dimensions based on `order`.
delinearize(int64_t index) const81 SmallVector<int64_t> delinearize(int64_t index) const {
82 // Traverse in for loop order (largest stride to smallest stride).
83 SmallVector<int64_t> vectorOffsets(sliceStrides.size());
84 for (auto idx : loopOrder) {
85 vectorOffsets[idx] = index / sliceStrides[idx];
86 index %= sliceStrides[idx];
87 }
88 return vectorOffsets;
89 }
90
maxIndex() const91 int64_t maxIndex() const { return maxIndexVal; }
92
93 /// Return the offset within d-tuple based on the ordering given by
94 /// `loopOrder`.
getVectorOffset(int64_t index) const95 SmallVector<int64_t> getVectorOffset(int64_t index) const {
96 SmallVector<int64_t> vectorOffsets = delinearize(index);
97 SmallVector<int64_t> elementOffsets =
98 computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
99 return elementOffsets;
100 }
101 };
102 } // namespace
103
104 /// Compute the indices of the slice `index` for a tranfer op.
sliceTransferIndices(ArrayRef<int64_t> elementOffsets,ArrayRef<Value> indices,AffineMap permutationMap,Location loc,OpBuilder & builder)105 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
106 ArrayRef<Value> indices,
107 AffineMap permutationMap,
108 Location loc,
109 OpBuilder &builder) {
110 MLIRContext *ctx = builder.getContext();
111 auto isBroadcast = [](AffineExpr expr) {
112 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
113 return constExpr.getValue() == 0;
114 return false;
115 };
116 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
117 SmallVector<Value> slicedIndices(indices.begin(), indices.end());
118 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
119 if (isBroadcast(dim.value()))
120 continue;
121 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
122 auto expr = getAffineDimExpr(0, builder.getContext()) +
123 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
124 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
125 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
126 }
127 return slicedIndices;
128 }
129
130 // Clones `op` into a new operations that takes `operands` and returns
131 // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)132 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
133 Operation *op,
134 ArrayRef<Value> operands,
135 ArrayRef<Type> resultTypes) {
136 return builder.create(loc, op->getName().getIdentifier(), operands,
137 resultTypes, op->getAttrs());
138 }
139
140 /// Return the target shape for unrolling for the given `op`. Return llvm::None
141 /// if the op shouldn't be or cannot be unrolled.
142 static Optional<SmallVector<int64_t, 4>>
getTargetShape(const vector::UnrollVectorOptions & options,Operation * op)143 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
144 if (options.filterConstraint && failed(options.filterConstraint(op)))
145 return llvm::None;
146 assert(options.nativeShape &&
147 "vector unrolling expects the native shape or native"
148 "shape call back function to be set");
149 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
150 if (!unrollableVectorOp)
151 return llvm::None;
152 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
153 if (!maybeUnrollShape)
154 return llvm::None;
155 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
156 if (!targetShape)
157 return llvm::None;
158 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
159 if (!maybeShapeRatio ||
160 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
161 return llvm::None;
162 return targetShape;
163 }
164
165 static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops,Operation * op,const vector::UnrollVectorOptions & options)166 getUnrollOrder(unsigned numLoops, Operation *op,
167 const vector::UnrollVectorOptions &options) {
168 SmallVector<int64_t> loopOrder =
169 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
170 if (options.traversalOrderCallback != nullptr) {
171 Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
172 if (order) {
173 loopOrder = std::move(*order);
174 }
175 }
176 return loopOrder;
177 }
178
179 namespace {
180
181 struct UnrollTransferReadPattern
182 : public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern__anon77a0e1ec0411::UnrollTransferReadPattern183 UnrollTransferReadPattern(MLIRContext *context,
184 const vector::UnrollVectorOptions &options)
185 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
186 options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferReadPattern187 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
188 PatternRewriter &rewriter) const override {
189 // TODO: support 0-d corner case.
190 if (readOp.getTransferRank() == 0)
191 return failure();
192 if (readOp.getMask())
193 return failure();
194 auto targetShape = getTargetShape(options, readOp);
195 if (!targetShape)
196 return failure();
197 auto sourceVectorType = readOp.getVectorType();
198 SmallVector<int64_t, 4> strides(targetShape->size(), 1);
199 Location loc = readOp.getLoc();
200 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
201
202 // Prepare the result vector;
203 Value result = rewriter.create<arith::ConstantOp>(
204 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
205 auto targetType =
206 VectorType::get(*targetShape, sourceVectorType.getElementType());
207 SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
208 readOp.getIndices().end());
209
210 SmallVector<int64_t> loopOrder =
211 getUnrollOrder(originalSize.size(), readOp, options);
212 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
213 loopOrder);
214 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
215 SmallVector<int64_t, 4> elementOffsets =
216 indexToOffsets.getVectorOffset(i);
217 SmallVector<Value, 4> indices =
218 sliceTransferIndices(elementOffsets, originalIndices,
219 readOp.getPermutationMap(), loc, rewriter);
220 auto slicedRead = rewriter.create<vector::TransferReadOp>(
221 loc, targetType, readOp.getSource(), indices,
222 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
223 readOp.getInBoundsAttr());
224
225 result = rewriter.create<vector::InsertStridedSliceOp>(
226 loc, slicedRead, result, elementOffsets, strides);
227 }
228 rewriter.replaceOp(readOp, result);
229 return success();
230 }
231
232 private:
233 vector::UnrollVectorOptions options;
234 };
235
236 struct UnrollTransferWritePattern
237 : public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern__anon77a0e1ec0411::UnrollTransferWritePattern238 UnrollTransferWritePattern(MLIRContext *context,
239 const vector::UnrollVectorOptions &options)
240 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
241 options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferWritePattern242 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
243 PatternRewriter &rewriter) const override {
244 // TODO: support 0-d corner case.
245 if (writeOp.getTransferRank() == 0)
246 return failure();
247
248 if (writeOp.getMask())
249 return failure();
250 auto targetShape = getTargetShape(options, writeOp);
251 if (!targetShape)
252 return failure();
253 auto sourceVectorType = writeOp.getVectorType();
254 SmallVector<int64_t, 4> strides(targetShape->size(), 1);
255 Location loc = writeOp.getLoc();
256 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
257 SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
258 writeOp.getIndices().end());
259
260 SmallVector<int64_t> loopOrder =
261 getUnrollOrder(originalSize.size(), writeOp, options);
262 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
263 loopOrder);
264 Value resultTensor;
265 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
266 SmallVector<int64_t, 4> elementOffsets =
267 indexToOffsets.getVectorOffset(i);
268 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
269 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
270 SmallVector<Value, 4> indices =
271 sliceTransferIndices(elementOffsets, originalIndices,
272 writeOp.getPermutationMap(), loc, rewriter);
273 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
274 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
275 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
276 // For the tensor case update the destination for the next transfer write.
277 if (!slicedWrite->getResults().empty())
278 resultTensor = slicedWrite->getResult(0);
279 }
280 if (resultTensor)
281 rewriter.replaceOp(writeOp, resultTensor);
282 else
283 rewriter.eraseOp(writeOp);
284 return success();
285 }
286
287 private:
288 vector::UnrollVectorOptions options;
289 };
290
291 struct OffsetMapInfo {
getEmptyKey__anon77a0e1ec0411::OffsetMapInfo292 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
293
getTombstoneKey__anon77a0e1ec0411::OffsetMapInfo294 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
295
getHashValue__anon77a0e1ec0411::OffsetMapInfo296 static unsigned getHashValue(const SmallVector<int64_t> &v) {
297 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
298 }
299
isEqual__anon77a0e1ec0411::OffsetMapInfo300 static bool isEqual(const SmallVector<int64_t> &lhs,
301 const SmallVector<int64_t> &rhs) {
302 return lhs == rhs;
303 }
304 };
305
306 struct UnrollContractionPattern
307 : public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern__anon77a0e1ec0411::UnrollContractionPattern308 UnrollContractionPattern(MLIRContext *context,
309 const vector::UnrollVectorOptions &options)
310 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
311 options(options) {}
312
matchAndRewrite__anon77a0e1ec0411::UnrollContractionPattern313 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
314 PatternRewriter &rewriter) const override {
315 auto targetShape = getTargetShape(options, contractOp);
316 if (!targetShape)
317 return failure();
318 auto dstVecType = contractOp.getResultType().cast<VectorType>();
319 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
320
321 Location loc = contractOp.getLoc();
322 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
324 llvm::MapVector<
325 SmallVector<int64_t>, Value,
326 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
327 accCache;
328
329 SmallVector<int64_t> loopOrder = getUnrollOrder(
330 contractOp.getIteratorTypes().size(), contractOp, options);
331 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
332 loopOrder);
333 const int64_t sliceCount = indexToOffsets.maxIndex();
334 for (int64_t i = 0; i < sliceCount; i++) {
335 SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
336 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
337
338 // Helper to coompute the new shape of each operand and extract the slice.
339 auto extractOperand = [&](unsigned index, Value operand,
340 AffineMap permutationMap,
341 ArrayRef<int64_t> operandOffets) {
342 SmallVector<int64_t> operandShape = applyPermutationMap(
343 permutationMap, ArrayRef<int64_t>(*targetShape));
344 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
345 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
346 loc, operand, operandOffets, operandShape, operandStrides);
347 };
348
349 // Extract the new lhs operand.
350 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
351 SmallVector<int64_t> lhsOffets =
352 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
353 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
354 // If there is a mask associated to lhs, extract it as well.
355 if (slicesOperands.size() > 3)
356 extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
357 lhsOffets);
358
359 // Extract the new rhs operand.
360 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
361 SmallVector<int64_t> rhsOffets =
362 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
363 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
364 // If there is a mask associated to rhs, extract it as well.
365 if (slicesOperands.size() > 4)
366 extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
367 rhsOffets);
368
369 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
370 SmallVector<int64_t> accOffets =
371 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
372 // If a version of the accumulator has already been computed, use it
373 // otherwise extract the first version from the original operand.
374 auto accIt = accCache.find(accOffets);
375 if (accIt != accCache.end())
376 slicesOperands[2] = accIt->second;
377 else
378 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
379
380 SmallVector<int64_t> dstShape =
381 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
382 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
383 Operation *newOp = cloneOpWithOperandsAndTypes(
384 rewriter, loc, contractOp, slicesOperands, targetType);
385
386 SmallVector<int64_t> dstOffets =
387 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
388 // Save the accumulated value untill all the loops are unrolled since
389 // reduction loop keep updating the accumulator.
390 accCache[dstOffets] = newOp->getResult(0);
391 }
392 // Assemble back the accumulator into a single vector.
393 Value result = rewriter.create<arith::ConstantOp>(
394 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
395 for (const auto &it : accCache) {
396 SmallVector<int64_t> dstStrides(it.first.size(), 1);
397 result = rewriter.create<vector::InsertStridedSliceOp>(
398 loc, it.second, result, it.first, dstStrides);
399 }
400 rewriter.replaceOp(contractOp, result);
401 return success();
402 }
403
404 private:
405 vector::UnrollVectorOptions options;
406 };
407
408 struct UnrollMultiReductionPattern
409 : public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern__anon77a0e1ec0411::UnrollMultiReductionPattern410 UnrollMultiReductionPattern(MLIRContext *context,
411 const vector::UnrollVectorOptions &options)
412 : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
413 options(options) {}
414
matchAndRewrite__anon77a0e1ec0411::UnrollMultiReductionPattern415 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
416 PatternRewriter &rewriter) const override {
417 Optional<SmallVector<int64_t, 4>> targetShape =
418 getTargetShape(options, reductionOp);
419 if (!targetShape)
420 return failure();
421 SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
422 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
423 llvm::MapVector<
424 SmallVector<int64_t>, Value,
425 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
426 accCache;
427 // Compute shape ratio of 'shape' and 'sizes'.
428 int64_t sliceCount = computeMaxLinearIndex(ratio);
429 Location loc = reductionOp.getLoc();
430 for (int64_t i = 0; i < sliceCount; i++) {
431 SmallVector<int64_t, 4> offsets =
432 getVectorOffset(originalSize, *targetShape, i);
433
434 SmallVector<Value> operands;
435 SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
436 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
437 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
438 operands.push_back(slicedOperand);
439 SmallVector<int64_t> dstShape;
440 SmallVector<int64_t> destOffset;
441 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
442 if (!reductionOp.isReducedDim(i)) {
443 destOffset.push_back(offsets[i]);
444 dstShape.push_back((*targetShape)[i]);
445 }
446 }
447 Value acc;
448 SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
449 // If a version of the accumulator has already been computed, use it
450 // otherwise extract the first version from the original operand.
451 auto accIt = accCache.find(destOffset);
452 if (accIt != accCache.end())
453 acc = accIt->second;
454 else
455 acc = rewriter.create<vector::ExtractStridedSliceOp>(
456 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
457 operands.push_back(acc);
458 auto targetType = VectorType::get(
459 dstShape, reductionOp.getSourceVectorType().getElementType());
460 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
461 operands, targetType);
462 Value result = newOp->getResult(0);
463 accCache[destOffset] = result;
464 }
465 // Assemble back the accumulator into a single vector.
466 Value result = rewriter.create<arith::ConstantOp>(
467 loc, reductionOp.getDestType(),
468 rewriter.getZeroAttr(reductionOp.getDestType()));
469 for (const auto &it : accCache) {
470 SmallVector<int64_t> dstStrides(it.first.size(), 1);
471 result = rewriter.create<vector::InsertStridedSliceOp>(
472 loc, it.second, result, it.first, dstStrides);
473 }
474 rewriter.replaceOp(reductionOp, result);
475 return success();
476 }
477
478 private:
479 vector::UnrollVectorOptions options;
480 };
481
482 struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern__anon77a0e1ec0411::UnrollElementwisePattern483 UnrollElementwisePattern(MLIRContext *context,
484 const vector::UnrollVectorOptions &options)
485 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
486 options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollElementwisePattern487 LogicalResult matchAndRewrite(Operation *op,
488 PatternRewriter &rewriter) const override {
489 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
490 return failure();
491 auto targetShape = getTargetShape(options, op);
492 if (!targetShape)
493 return failure();
494 auto dstVecType = op->getResult(0).getType().cast<VectorType>();
495 SmallVector<int64_t, 4> originalSize =
496 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
497 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
498 int64_t sliceCount = computeMaxLinearIndex(ratio);
499 Location loc = op->getLoc();
500 // Prepare the result vector.
501 Value result = rewriter.create<arith::ConstantOp>(
502 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
503 SmallVector<int64_t, 4> strides(targetShape->size(), 1);
504 VectorType newVecType =
505 VectorType::get(*targetShape, dstVecType.getElementType());
506 for (int64_t i = 0; i < sliceCount; i++) {
507 SmallVector<int64_t, 4> offsets =
508 getVectorOffset(originalSize, *targetShape, i);
509 SmallVector<Value, 4> extractOperands;
510 for (OpOperand &operand : op->getOpOperands()) {
511 auto vecType = operand.get().getType().template dyn_cast<VectorType>();
512 if (!vecType) {
513 extractOperands.push_back(operand.get());
514 continue;
515 }
516 extractOperands.push_back(
517 rewriter.create<vector::ExtractStridedSliceOp>(
518 loc, operand.get(), offsets, *targetShape, strides));
519 }
520 Operation *newOp = cloneOpWithOperandsAndTypes(
521 rewriter, loc, op, extractOperands, newVecType);
522 result = rewriter.create<vector::InsertStridedSliceOp>(
523 loc, newOp->getResult(0), result, offsets, strides);
524 }
525 rewriter.replaceOp(op, result);
526 return success();
527 }
528
529 private:
530 vector::UnrollVectorOptions options;
531 };
532
533 /// Canonicalize an extract_map using the result of a pointwise operation.
534 /// Transforms:
535 /// %v = arith.addf %a, %b : vector32xf32>
536 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
537 /// to:
538 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
539 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
540 /// %dv = arith.addf %da, %db : vector<1xf32>
541 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
542 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::PointwiseExtractPattern543 LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
544 PatternRewriter &rewriter) const override {
545 Operation *definedOp = extract.getVector().getDefiningOp();
546 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
547 definedOp->getNumResults() != 1)
548 return failure();
549 Location loc = extract.getLoc();
550 SmallVector<Value, 4> extractOperands;
551 for (OpOperand &operand : definedOp->getOpOperands()) {
552 auto vecType = operand.get().getType().template dyn_cast<VectorType>();
553 if (!vecType) {
554 extractOperands.push_back(operand.get());
555 continue;
556 }
557 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
558 loc,
559 VectorType::get(extract.getResultType().getShape(),
560 vecType.getElementType()),
561 operand.get(), extract.getIds()));
562 }
563 Operation *newOp = cloneOpWithOperandsAndTypes(
564 rewriter, loc, definedOp, extractOperands, extract.getResultType());
565 rewriter.replaceOp(extract, newOp->getResult(0));
566 return success();
567 }
568 };
569
570 /// Canonicalize an extract_map using the result of a contract operation.
571 /// This propagate the extract_map to operands.
572 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
573 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::ContractExtractPattern574 LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
575 PatternRewriter &rewriter) const override {
576 Operation *definedOp = extract.getVector().getDefiningOp();
577 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
578 if (!contract)
579 return failure();
580 Location loc = contract.getLoc();
581 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
582 AffineMap affineMap = contract.getIndexingMapsArray()[accIndex];
583 // Create a map of the dimensions distributed based on the acc affine map.
584 // Only parallel dimensions are being distributed, reduction dimensions are
585 // untouched.
586 DenseMap<int64_t, int64_t> map;
587 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
588 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
589 SmallVector<Value, 4> extractOperands;
590 for (const auto &it : llvm::enumerate(contract.getIndexingMapsArray())) {
591 // For each operands calculate the new vector type after distribution.
592 Value operand = contract->getOperand(it.index());
593 auto vecType = operand.getType().cast<VectorType>();
594 SmallVector<int64_t> operandShape(vecType.getShape().begin(),
595 vecType.getShape().end());
596 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
597 unsigned dim = it.value().getDimPosition(i);
598 auto distributedDim = map.find(dim);
599 // If the dimension is not in the map it means it is a reduction and
600 // doesn't get distributed.
601 if (distributedDim == map.end())
602 continue;
603 operandShape[i] = distributedDim->second;
604 }
605 VectorType newVecType =
606 VectorType::get(operandShape, vecType.getElementType());
607 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
608 loc, newVecType, operand, extract.getIds()));
609 }
610 Operation *newOp =
611 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
612 extract.getResult().getType());
613 rewriter.replaceOp(extract, newOp->getResult(0));
614 return success();
615 }
616 };
617
618 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
619 /// TransferRead.
620 /// Example:
621 /// ```
622 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
623 /// memref<64x64x64xf32>, vector<64x4x32xf32>
624 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
625 /// ```
626 /// to:
627 /// ```
628 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
629 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
630 /// memref<64x64x64xf32>, vector<2x4x1xf32>
631 /// ```
632 struct TransferReadExtractPattern
633 : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPattern__anon77a0e1ec0411::TransferReadExtractPattern634 TransferReadExtractPattern(MLIRContext *context)
635 : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferReadExtractPattern636 LogicalResult matchAndRewrite(vector::TransferReadOp read,
637 PatternRewriter &rewriter) const override {
638 // TODO: support 0-d corner case.
639 if (read.getTransferRank() == 0)
640 return failure();
641
642 if (!read.getResult().hasOneUse())
643 return failure();
644 auto extract =
645 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
646 if (!extract)
647 return failure();
648 if (read.getMask())
649 return failure();
650
651 SmallVector<Value, 4> indices(read.getIndices().begin(),
652 read.getIndices().end());
653 AffineMap indexMap = extract.map().compose(read.getPermutationMap());
654 unsigned idCount = 0;
655 ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
656 for (auto it :
657 llvm::zip(indexMap.getResults(), extract.map().getResults())) {
658 AffineExpr d0, d1;
659 bindDims(read.getContext(), d0, d1);
660 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
661 if (!indexExpr)
662 continue;
663 unsigned indexPos = indexExpr.getPosition();
664 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
665 auto scale = getAffineConstantExpr(
666 extract.getResultType().getDimSize(vectorPos), read.getContext());
667 indices[indexPos] = makeComposedAffineApply(
668 rewriter, read.getLoc(), d0 + scale * d1,
669 {indices[indexPos], extract.getIds()[idCount++]});
670 }
671 Value newRead = lb.create<vector::TransferReadOp>(
672 extract.getType(), read.getSource(), indices,
673 read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
674 read.getInBoundsAttr());
675 Value dest = lb.create<arith::ConstantOp>(
676 read.getType(), rewriter.getZeroAttr(read.getType()));
677 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
678 rewriter.replaceOp(read, newRead);
679 return success();
680 }
681 };
682
683 struct TransferWriteInsertPattern
684 : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPattern__anon77a0e1ec0411::TransferWriteInsertPattern685 TransferWriteInsertPattern(MLIRContext *context)
686 : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferWriteInsertPattern687 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
688 PatternRewriter &rewriter) const override {
689 // TODO: support 0-d corner case.
690 if (write.getTransferRank() == 0)
691 return failure();
692
693 auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
694 if (!insert)
695 return failure();
696 if (write.getMask())
697 return failure();
698 SmallVector<Value, 4> indices(write.getIndices().begin(),
699 write.getIndices().end());
700 AffineMap indexMap = insert.map().compose(write.getPermutationMap());
701 unsigned idCount = 0;
702 Location loc = write.getLoc();
703 for (auto it :
704 llvm::zip(indexMap.getResults(), insert.map().getResults())) {
705 AffineExpr d0, d1;
706 bindDims(write.getContext(), d0, d1);
707 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
708 if (!indexExpr)
709 continue;
710 unsigned indexPos = indexExpr.getPosition();
711 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
712 auto scale = getAffineConstantExpr(
713 insert.getSourceVectorType().getDimSize(vectorPos),
714 write.getContext());
715 indices[indexPos] = makeComposedAffineApply(
716 rewriter, loc, d0 + scale * d1,
717 {indices[indexPos], insert.getIds()[idCount++]});
718 }
719 rewriter.create<vector::TransferWriteOp>(
720 loc, insert.getVector(), write.getSource(), indices,
721 write.getPermutationMapAttr(), write.getInBoundsAttr());
722 rewriter.eraseOp(write);
723 return success();
724 }
725 };
726
727 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern__anon77a0e1ec0411::UnrollReductionPattern728 UnrollReductionPattern(MLIRContext *context,
729 const vector::UnrollVectorOptions &options)
730 : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
731 options(options) {}
732
matchAndRewrite__anon77a0e1ec0411::UnrollReductionPattern733 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
734 PatternRewriter &rewriter) const override {
735 Optional<SmallVector<int64_t, 4>> targetShape =
736 getTargetShape(options, reductionOp);
737 if (!targetShape)
738 return failure();
739 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
740 int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
741
742 // Create unrolled vector reduction.
743 Location loc = reductionOp.getLoc();
744 Value accumulator = nullptr;
745 for (int64_t i = 0; i < ratio; ++i) {
746 SmallVector<int64_t> offsets =
747 getVectorOffset(originalSize, *targetShape, i);
748 SmallVector<int64_t> strides(offsets.size(), 1);
749 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
750 loc, reductionOp.getVector(), offsets, *targetShape, strides);
751 Operation *newOp = cloneOpWithOperandsAndTypes(
752 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
753 Value result = newOp->getResult(0);
754
755 if (!accumulator) {
756 // This is the first reduction.
757 accumulator = result;
758 } else {
759 // On subsequent reduction, combine with the accumulator.
760 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
761 accumulator, result);
762 }
763 }
764
765 rewriter.replaceOp(reductionOp, accumulator);
766 return success();
767 }
768
769 private:
770 const vector::UnrollVectorOptions options;
771 };
772
773 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern__anon77a0e1ec0411::UnrollTranposePattern774 UnrollTranposePattern(MLIRContext *context,
775 const vector::UnrollVectorOptions &options)
776 : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
777 options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTranposePattern778 LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
779 PatternRewriter &rewriter) const override {
780 if (tranposeOp.getResultType().getRank() == 0)
781 return failure();
782 auto targetShape = getTargetShape(options, tranposeOp);
783 if (!targetShape)
784 return failure();
785 auto originalVectorType = tranposeOp.getResultType();
786 SmallVector<int64_t, 4> strides(targetShape->size(), 1);
787 Location loc = tranposeOp.getLoc();
788 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
789 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
790 int64_t sliceCount = computeMaxLinearIndex(ratio);
791 // Prepare the result vector;
792 Value result = rewriter.create<arith::ConstantOp>(
793 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
794 SmallVector<int64_t> permutation;
795 tranposeOp.getTransp(permutation);
796 for (int64_t i = 0; i < sliceCount; i++) {
797 SmallVector<int64_t, 4> elementOffsets =
798 getVectorOffset(originalSize, *targetShape, i);
799 SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
800 SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
801 // Compute the source offsets and shape.
802 for (auto &indices : llvm::enumerate(permutation)) {
803 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
804 permutedShape[indices.value()] = (*targetShape)[indices.index()];
805 }
806 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
807 loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
808 Value tranposedSlice =
809 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
810 result = rewriter.create<vector::InsertStridedSliceOp>(
811 loc, tranposedSlice, result, elementOffsets, strides);
812 }
813 rewriter.replaceOp(tranposeOp, result);
814 return success();
815 }
816
817 private:
818 vector::UnrollVectorOptions options;
819 };
820
821 } // namespace
822
populateVectorUnrollPatterns(RewritePatternSet & patterns,const UnrollVectorOptions & options)823 void mlir::vector::populateVectorUnrollPatterns(
824 RewritePatternSet &patterns, const UnrollVectorOptions &options) {
825 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
826 UnrollContractionPattern, UnrollElementwisePattern,
827 UnrollReductionPattern, UnrollMultiReductionPattern,
828 UnrollTranposePattern>(patterns.getContext(), options);
829 }
830
populatePropagateVectorDistributionPatterns(RewritePatternSet & patterns)831 void mlir::vector::populatePropagateVectorDistributionPatterns(
832 RewritePatternSet &patterns) {
833 patterns.add<PointwiseExtractPattern, ContractExtractPattern,
834 TransferReadExtractPattern, TransferWriteInsertPattern>(
835 patterns.getContext());
836 }
837