1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "mlir/Support/MathExtras.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include <numeric>
22 
23 #define DEBUG_TYPE "vector-unrolling"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 
28 /// During unrolling from `originalShape` to `targetShape` return the offset for
29 /// the slice `index`.
getVectorOffset(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,int64_t index)30 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
31                                                ArrayRef<int64_t> targetShape,
32                                                int64_t index) {
33   SmallVector<int64_t, 4> dstSliceStrides =
34       computeStrides(originalShape, targetShape);
35   SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
36   SmallVector<int64_t, 4> elementOffsets =
37       computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
38   return elementOffsets;
39 }
40 
41 /// A functor that accomplishes the same thing as `getVectorOffset` but allows
42 /// for reordering the traversal of the dimensions. The order of traversal is
43 /// given in "for loop order" (outer to inner).
44 namespace {
45 class DecomposeShapeIterator {
46 private:
47   SmallVector<int64_t, 4> vectorShape;
48   SmallVector<int64_t> loopOrder;
49   SmallVector<int64_t> sliceStrides;
50   int64_t maxIndexVal{1};
51 
52 public:
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,ArrayRef<int64_t> loopOrder)53   DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
54                          ArrayRef<int64_t> targetShape,
55                          ArrayRef<int64_t> loopOrder)
56       : vectorShape(targetShape.begin(), targetShape.end()),
57         loopOrder(loopOrder.begin(), loopOrder.end()),
58         sliceStrides(originalShape.size()) {
59     assert(originalShape.size() == targetShape.size());
60     assert(loopOrder.size() == targetShape.size());
61 
62     // Compute the count for each dimension.
63     SmallVector<int64_t> sliceDimCounts(originalShape.size());
64     for (unsigned r = 0; r < originalShape.size(); ++r) {
65       sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
66       maxIndexVal *= sliceDimCounts[r];
67     }
68 
69     // Reversing "loop order" gives dimensions from fastest varying to slowest
70     // varying (smallest stride to largest stride).
71     int64_t accum = 1;
72     for (auto idx : llvm::reverse(loopOrder)) {
73       sliceStrides[idx] = accum;
74       accum *= sliceDimCounts[idx];
75     }
76   }
77 
78   // Turn the linear index into a d-tuple based on units of vectors of size
79   // `vectorShape`. The linear index is assumed to represent traversal of the
80   // dimensions based on `order`.
delinearize(int64_t index) const81   SmallVector<int64_t> delinearize(int64_t index) const {
82     // Traverse in for loop order (largest stride to smallest stride).
83     SmallVector<int64_t> vectorOffsets(sliceStrides.size());
84     for (auto idx : loopOrder) {
85       vectorOffsets[idx] = index / sliceStrides[idx];
86       index %= sliceStrides[idx];
87     }
88     return vectorOffsets;
89   }
90 
maxIndex() const91   int64_t maxIndex() const { return maxIndexVal; }
92 
93   /// Return the offset within d-tuple based on the ordering given by
94   /// `loopOrder`.
getVectorOffset(int64_t index) const95   SmallVector<int64_t> getVectorOffset(int64_t index) const {
96     SmallVector<int64_t> vectorOffsets = delinearize(index);
97     SmallVector<int64_t> elementOffsets =
98         computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
99     return elementOffsets;
100   }
101 };
102 } // namespace
103 
104 /// Compute the indices of the slice `index` for a tranfer op.
sliceTransferIndices(ArrayRef<int64_t> elementOffsets,ArrayRef<Value> indices,AffineMap permutationMap,Location loc,OpBuilder & builder)105 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
106                                                ArrayRef<Value> indices,
107                                                AffineMap permutationMap,
108                                                Location loc,
109                                                OpBuilder &builder) {
110   MLIRContext *ctx = builder.getContext();
111   auto isBroadcast = [](AffineExpr expr) {
112     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
113       return constExpr.getValue() == 0;
114     return false;
115   };
116   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
117   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
118   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
119     if (isBroadcast(dim.value()))
120       continue;
121     unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
122     auto expr = getAffineDimExpr(0, builder.getContext()) +
123                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
124     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
125     slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
126   }
127   return slicedIndices;
128 }
129 
130 // Clones `op` into a new operations that takes `operands` and returns
131 // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)132 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
133                                               Operation *op,
134                                               ArrayRef<Value> operands,
135                                               ArrayRef<Type> resultTypes) {
136   return builder.create(loc, op->getName().getIdentifier(), operands,
137                         resultTypes, op->getAttrs());
138 }
139 
140 /// Return the target shape for unrolling for the given `op`. Return llvm::None
141 /// if the op shouldn't be or cannot be unrolled.
142 static Optional<SmallVector<int64_t, 4>>
getTargetShape(const vector::UnrollVectorOptions & options,Operation * op)143 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
144   if (options.filterConstraint && failed(options.filterConstraint(op)))
145     return llvm::None;
146   assert(options.nativeShape &&
147          "vector unrolling expects the native shape or native"
148          "shape call back function to be set");
149   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
150   if (!unrollableVectorOp)
151     return llvm::None;
152   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
153   if (!maybeUnrollShape)
154     return llvm::None;
155   Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
156   if (!targetShape)
157     return llvm::None;
158   auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
159   if (!maybeShapeRatio ||
160       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
161     return llvm::None;
162   return targetShape;
163 }
164 
165 static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops,Operation * op,const vector::UnrollVectorOptions & options)166 getUnrollOrder(unsigned numLoops, Operation *op,
167                const vector::UnrollVectorOptions &options) {
168   SmallVector<int64_t> loopOrder =
169       llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
170   if (options.traversalOrderCallback != nullptr) {
171     Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
172     if (order) {
173       loopOrder = std::move(*order);
174     }
175   }
176   return loopOrder;
177 }
178 
179 namespace {
180 
181 struct UnrollTransferReadPattern
182     : public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern__anon77a0e1ec0411::UnrollTransferReadPattern183   UnrollTransferReadPattern(MLIRContext *context,
184                             const vector::UnrollVectorOptions &options)
185       : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
186         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferReadPattern187   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
188                                 PatternRewriter &rewriter) const override {
189     // TODO: support 0-d corner case.
190     if (readOp.getTransferRank() == 0)
191       return failure();
192     if (readOp.getMask())
193       return failure();
194     auto targetShape = getTargetShape(options, readOp);
195     if (!targetShape)
196       return failure();
197     auto sourceVectorType = readOp.getVectorType();
198     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
199     Location loc = readOp.getLoc();
200     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
201 
202     // Prepare the result vector;
203     Value result = rewriter.create<arith::ConstantOp>(
204         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
205     auto targetType =
206         VectorType::get(*targetShape, sourceVectorType.getElementType());
207     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
208                                           readOp.getIndices().end());
209 
210     SmallVector<int64_t> loopOrder =
211         getUnrollOrder(originalSize.size(), readOp, options);
212     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
213                                           loopOrder);
214     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
215       SmallVector<int64_t, 4> elementOffsets =
216           indexToOffsets.getVectorOffset(i);
217       SmallVector<Value, 4> indices =
218           sliceTransferIndices(elementOffsets, originalIndices,
219                                readOp.getPermutationMap(), loc, rewriter);
220       auto slicedRead = rewriter.create<vector::TransferReadOp>(
221           loc, targetType, readOp.getSource(), indices,
222           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
223           readOp.getInBoundsAttr());
224 
225       result = rewriter.create<vector::InsertStridedSliceOp>(
226           loc, slicedRead, result, elementOffsets, strides);
227     }
228     rewriter.replaceOp(readOp, result);
229     return success();
230   }
231 
232 private:
233   vector::UnrollVectorOptions options;
234 };
235 
236 struct UnrollTransferWritePattern
237     : public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern__anon77a0e1ec0411::UnrollTransferWritePattern238   UnrollTransferWritePattern(MLIRContext *context,
239                              const vector::UnrollVectorOptions &options)
240       : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
241         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferWritePattern242   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
243                                 PatternRewriter &rewriter) const override {
244     // TODO: support 0-d corner case.
245     if (writeOp.getTransferRank() == 0)
246       return failure();
247 
248     if (writeOp.getMask())
249       return failure();
250     auto targetShape = getTargetShape(options, writeOp);
251     if (!targetShape)
252       return failure();
253     auto sourceVectorType = writeOp.getVectorType();
254     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
255     Location loc = writeOp.getLoc();
256     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
257     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
258                                           writeOp.getIndices().end());
259 
260     SmallVector<int64_t> loopOrder =
261         getUnrollOrder(originalSize.size(), writeOp, options);
262     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
263                                           loopOrder);
264     Value resultTensor;
265     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
266       SmallVector<int64_t, 4> elementOffsets =
267           indexToOffsets.getVectorOffset(i);
268       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
269           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
270       SmallVector<Value, 4> indices =
271           sliceTransferIndices(elementOffsets, originalIndices,
272                                writeOp.getPermutationMap(), loc, rewriter);
273       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
274           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
275           indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
276       // For the tensor case update the destination for the next transfer write.
277       if (!slicedWrite->getResults().empty())
278         resultTensor = slicedWrite->getResult(0);
279     }
280     if (resultTensor)
281       rewriter.replaceOp(writeOp, resultTensor);
282     else
283       rewriter.eraseOp(writeOp);
284     return success();
285   }
286 
287 private:
288   vector::UnrollVectorOptions options;
289 };
290 
291 struct OffsetMapInfo {
getEmptyKey__anon77a0e1ec0411::OffsetMapInfo292   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
293 
getTombstoneKey__anon77a0e1ec0411::OffsetMapInfo294   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
295 
getHashValue__anon77a0e1ec0411::OffsetMapInfo296   static unsigned getHashValue(const SmallVector<int64_t> &v) {
297     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
298   }
299 
isEqual__anon77a0e1ec0411::OffsetMapInfo300   static bool isEqual(const SmallVector<int64_t> &lhs,
301                       const SmallVector<int64_t> &rhs) {
302     return lhs == rhs;
303   }
304 };
305 
306 struct UnrollContractionPattern
307     : public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern__anon77a0e1ec0411::UnrollContractionPattern308   UnrollContractionPattern(MLIRContext *context,
309                            const vector::UnrollVectorOptions &options)
310       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
311         options(options) {}
312 
matchAndRewrite__anon77a0e1ec0411::UnrollContractionPattern313   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
314                                 PatternRewriter &rewriter) const override {
315     auto targetShape = getTargetShape(options, contractOp);
316     if (!targetShape)
317       return failure();
318     auto dstVecType = contractOp.getResultType().cast<VectorType>();
319     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
320 
321     Location loc = contractOp.getLoc();
322     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323     AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
324     llvm::MapVector<
325         SmallVector<int64_t>, Value,
326         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
327         accCache;
328 
329     SmallVector<int64_t> loopOrder = getUnrollOrder(
330         contractOp.getIteratorTypes().size(), contractOp, options);
331     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
332                                           loopOrder);
333     const int64_t sliceCount = indexToOffsets.maxIndex();
334     for (int64_t i = 0; i < sliceCount; i++) {
335       SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
336       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
337 
338       // Helper to coompute the new shape of each operand and extract the slice.
339       auto extractOperand = [&](unsigned index, Value operand,
340                                 AffineMap permutationMap,
341                                 ArrayRef<int64_t> operandOffets) {
342         SmallVector<int64_t> operandShape = applyPermutationMap(
343             permutationMap, ArrayRef<int64_t>(*targetShape));
344         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
345         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
346             loc, operand, operandOffets, operandShape, operandStrides);
347       };
348 
349       // Extract the new lhs operand.
350       AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
351       SmallVector<int64_t> lhsOffets =
352           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
353       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
354       // If there is a mask associated to lhs, extract it as well.
355       if (slicesOperands.size() > 3)
356         extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
357                        lhsOffets);
358 
359       // Extract the new rhs operand.
360       AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
361       SmallVector<int64_t> rhsOffets =
362           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
363       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
364       // If there is a mask associated to rhs, extract it as well.
365       if (slicesOperands.size() > 4)
366         extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
367                        rhsOffets);
368 
369       AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
370       SmallVector<int64_t> accOffets =
371           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
372       // If a version of the accumulator has already been computed, use it
373       // otherwise extract the first version from the original operand.
374       auto accIt = accCache.find(accOffets);
375       if (accIt != accCache.end())
376         slicesOperands[2] = accIt->second;
377       else
378         extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
379 
380       SmallVector<int64_t> dstShape =
381           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
382       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
383       Operation *newOp = cloneOpWithOperandsAndTypes(
384           rewriter, loc, contractOp, slicesOperands, targetType);
385 
386       SmallVector<int64_t> dstOffets =
387           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
388       // Save the accumulated value untill all the loops are unrolled since
389       // reduction loop keep updating the accumulator.
390       accCache[dstOffets] = newOp->getResult(0);
391     }
392     // Assemble back the accumulator into a single vector.
393     Value result = rewriter.create<arith::ConstantOp>(
394         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
395     for (const auto &it : accCache) {
396       SmallVector<int64_t> dstStrides(it.first.size(), 1);
397       result = rewriter.create<vector::InsertStridedSliceOp>(
398           loc, it.second, result, it.first, dstStrides);
399     }
400     rewriter.replaceOp(contractOp, result);
401     return success();
402   }
403 
404 private:
405   vector::UnrollVectorOptions options;
406 };
407 
408 struct UnrollMultiReductionPattern
409     : public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern__anon77a0e1ec0411::UnrollMultiReductionPattern410   UnrollMultiReductionPattern(MLIRContext *context,
411                               const vector::UnrollVectorOptions &options)
412       : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
413         options(options) {}
414 
matchAndRewrite__anon77a0e1ec0411::UnrollMultiReductionPattern415   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
416                                 PatternRewriter &rewriter) const override {
417     Optional<SmallVector<int64_t, 4>> targetShape =
418         getTargetShape(options, reductionOp);
419     if (!targetShape)
420       return failure();
421     SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
422     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
423     llvm::MapVector<
424         SmallVector<int64_t>, Value,
425         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
426         accCache;
427     // Compute shape ratio of 'shape' and 'sizes'.
428     int64_t sliceCount = computeMaxLinearIndex(ratio);
429     Location loc = reductionOp.getLoc();
430     for (int64_t i = 0; i < sliceCount; i++) {
431       SmallVector<int64_t, 4> offsets =
432           getVectorOffset(originalSize, *targetShape, i);
433 
434       SmallVector<Value> operands;
435       SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
436       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
437           loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
438       operands.push_back(slicedOperand);
439       SmallVector<int64_t> dstShape;
440       SmallVector<int64_t> destOffset;
441       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
442         if (!reductionOp.isReducedDim(i)) {
443           destOffset.push_back(offsets[i]);
444           dstShape.push_back((*targetShape)[i]);
445         }
446       }
447       Value acc;
448       SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
449       // If a version of the accumulator has already been computed, use it
450       // otherwise extract the first version from the original operand.
451       auto accIt = accCache.find(destOffset);
452       if (accIt != accCache.end())
453         acc = accIt->second;
454       else
455         acc = rewriter.create<vector::ExtractStridedSliceOp>(
456             loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
457       operands.push_back(acc);
458       auto targetType = VectorType::get(
459           dstShape, reductionOp.getSourceVectorType().getElementType());
460       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
461                                                      operands, targetType);
462       Value result = newOp->getResult(0);
463       accCache[destOffset] = result;
464     }
465     // Assemble back the accumulator into a single vector.
466     Value result = rewriter.create<arith::ConstantOp>(
467         loc, reductionOp.getDestType(),
468         rewriter.getZeroAttr(reductionOp.getDestType()));
469     for (const auto &it : accCache) {
470       SmallVector<int64_t> dstStrides(it.first.size(), 1);
471       result = rewriter.create<vector::InsertStridedSliceOp>(
472           loc, it.second, result, it.first, dstStrides);
473     }
474     rewriter.replaceOp(reductionOp, result);
475     return success();
476   }
477 
478 private:
479   vector::UnrollVectorOptions options;
480 };
481 
482 struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern__anon77a0e1ec0411::UnrollElementwisePattern483   UnrollElementwisePattern(MLIRContext *context,
484                            const vector::UnrollVectorOptions &options)
485       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
486         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollElementwisePattern487   LogicalResult matchAndRewrite(Operation *op,
488                                 PatternRewriter &rewriter) const override {
489     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
490       return failure();
491     auto targetShape = getTargetShape(options, op);
492     if (!targetShape)
493       return failure();
494     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
495     SmallVector<int64_t, 4> originalSize =
496         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
497     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
498     int64_t sliceCount = computeMaxLinearIndex(ratio);
499     Location loc = op->getLoc();
500     // Prepare the result vector.
501     Value result = rewriter.create<arith::ConstantOp>(
502         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
503     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
504     VectorType newVecType =
505         VectorType::get(*targetShape, dstVecType.getElementType());
506     for (int64_t i = 0; i < sliceCount; i++) {
507       SmallVector<int64_t, 4> offsets =
508           getVectorOffset(originalSize, *targetShape, i);
509       SmallVector<Value, 4> extractOperands;
510       for (OpOperand &operand : op->getOpOperands()) {
511         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
512         if (!vecType) {
513           extractOperands.push_back(operand.get());
514           continue;
515         }
516         extractOperands.push_back(
517             rewriter.create<vector::ExtractStridedSliceOp>(
518                 loc, operand.get(), offsets, *targetShape, strides));
519       }
520       Operation *newOp = cloneOpWithOperandsAndTypes(
521           rewriter, loc, op, extractOperands, newVecType);
522       result = rewriter.create<vector::InsertStridedSliceOp>(
523           loc, newOp->getResult(0), result, offsets, strides);
524     }
525     rewriter.replaceOp(op, result);
526     return success();
527   }
528 
529 private:
530   vector::UnrollVectorOptions options;
531 };
532 
533 /// Canonicalize an extract_map using the result of a pointwise operation.
534 /// Transforms:
535 /// %v = arith.addf %a, %b : vector32xf32>
536 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
537 /// to:
538 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
539 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
540 /// %dv = arith.addf %da, %db : vector<1xf32>
541 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
542   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::PointwiseExtractPattern543   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
544                                 PatternRewriter &rewriter) const override {
545     Operation *definedOp = extract.getVector().getDefiningOp();
546     if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
547         definedOp->getNumResults() != 1)
548       return failure();
549     Location loc = extract.getLoc();
550     SmallVector<Value, 4> extractOperands;
551     for (OpOperand &operand : definedOp->getOpOperands()) {
552       auto vecType = operand.get().getType().template dyn_cast<VectorType>();
553       if (!vecType) {
554         extractOperands.push_back(operand.get());
555         continue;
556       }
557       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
558           loc,
559           VectorType::get(extract.getResultType().getShape(),
560                           vecType.getElementType()),
561           operand.get(), extract.getIds()));
562     }
563     Operation *newOp = cloneOpWithOperandsAndTypes(
564         rewriter, loc, definedOp, extractOperands, extract.getResultType());
565     rewriter.replaceOp(extract, newOp->getResult(0));
566     return success();
567   }
568 };
569 
570 /// Canonicalize an extract_map using the result of a contract operation.
571 /// This propagate the extract_map to operands.
572 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
573   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::ContractExtractPattern574   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
575                                 PatternRewriter &rewriter) const override {
576     Operation *definedOp = extract.getVector().getDefiningOp();
577     auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
578     if (!contract)
579       return failure();
580     Location loc = contract.getLoc();
581     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
582     AffineMap affineMap = contract.getIndexingMapsArray()[accIndex];
583     // Create a map of the dimensions distributed based on the acc affine map.
584     // Only parallel dimensions are being distributed, reduction dimensions are
585     // untouched.
586     DenseMap<int64_t, int64_t> map;
587     for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
588       map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
589     SmallVector<Value, 4> extractOperands;
590     for (const auto &it : llvm::enumerate(contract.getIndexingMapsArray())) {
591       // For each operands calculate the new vector type after distribution.
592       Value operand = contract->getOperand(it.index());
593       auto vecType = operand.getType().cast<VectorType>();
594       SmallVector<int64_t> operandShape(vecType.getShape().begin(),
595                                         vecType.getShape().end());
596       for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
597         unsigned dim = it.value().getDimPosition(i);
598         auto distributedDim = map.find(dim);
599         // If the dimension is not in the map it means it is a reduction and
600         // doesn't get distributed.
601         if (distributedDim == map.end())
602           continue;
603         operandShape[i] = distributedDim->second;
604       }
605       VectorType newVecType =
606           VectorType::get(operandShape, vecType.getElementType());
607       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
608           loc, newVecType, operand, extract.getIds()));
609     }
610     Operation *newOp =
611         cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
612                                     extract.getResult().getType());
613     rewriter.replaceOp(extract, newOp->getResult(0));
614     return success();
615   }
616 };
617 
618 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
619 /// TransferRead.
620 /// Example:
621 /// ```
622 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
623 ///   memref<64x64x64xf32>, vector<64x4x32xf32>
624 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
625 /// ```
626 /// to:
627 /// ```
628 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
629 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
630 ///   memref<64x64x64xf32>, vector<2x4x1xf32>
631 /// ```
632 struct TransferReadExtractPattern
633     : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPattern__anon77a0e1ec0411::TransferReadExtractPattern634   TransferReadExtractPattern(MLIRContext *context)
635       : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferReadExtractPattern636   LogicalResult matchAndRewrite(vector::TransferReadOp read,
637                                 PatternRewriter &rewriter) const override {
638     // TODO: support 0-d corner case.
639     if (read.getTransferRank() == 0)
640       return failure();
641 
642     if (!read.getResult().hasOneUse())
643       return failure();
644     auto extract =
645         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
646     if (!extract)
647       return failure();
648     if (read.getMask())
649       return failure();
650 
651     SmallVector<Value, 4> indices(read.getIndices().begin(),
652                                   read.getIndices().end());
653     AffineMap indexMap = extract.map().compose(read.getPermutationMap());
654     unsigned idCount = 0;
655     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
656     for (auto it :
657          llvm::zip(indexMap.getResults(), extract.map().getResults())) {
658       AffineExpr d0, d1;
659       bindDims(read.getContext(), d0, d1);
660       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
661       if (!indexExpr)
662         continue;
663       unsigned indexPos = indexExpr.getPosition();
664       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
665       auto scale = getAffineConstantExpr(
666           extract.getResultType().getDimSize(vectorPos), read.getContext());
667       indices[indexPos] = makeComposedAffineApply(
668           rewriter, read.getLoc(), d0 + scale * d1,
669           {indices[indexPos], extract.getIds()[idCount++]});
670     }
671     Value newRead = lb.create<vector::TransferReadOp>(
672         extract.getType(), read.getSource(), indices,
673         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
674         read.getInBoundsAttr());
675     Value dest = lb.create<arith::ConstantOp>(
676         read.getType(), rewriter.getZeroAttr(read.getType()));
677     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
678     rewriter.replaceOp(read, newRead);
679     return success();
680   }
681 };
682 
683 struct TransferWriteInsertPattern
684     : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPattern__anon77a0e1ec0411::TransferWriteInsertPattern685   TransferWriteInsertPattern(MLIRContext *context)
686       : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferWriteInsertPattern687   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
688                                 PatternRewriter &rewriter) const override {
689     // TODO: support 0-d corner case.
690     if (write.getTransferRank() == 0)
691       return failure();
692 
693     auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
694     if (!insert)
695       return failure();
696     if (write.getMask())
697       return failure();
698     SmallVector<Value, 4> indices(write.getIndices().begin(),
699                                   write.getIndices().end());
700     AffineMap indexMap = insert.map().compose(write.getPermutationMap());
701     unsigned idCount = 0;
702     Location loc = write.getLoc();
703     for (auto it :
704          llvm::zip(indexMap.getResults(), insert.map().getResults())) {
705       AffineExpr d0, d1;
706       bindDims(write.getContext(), d0, d1);
707       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
708       if (!indexExpr)
709         continue;
710       unsigned indexPos = indexExpr.getPosition();
711       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
712       auto scale = getAffineConstantExpr(
713           insert.getSourceVectorType().getDimSize(vectorPos),
714           write.getContext());
715       indices[indexPos] = makeComposedAffineApply(
716           rewriter, loc, d0 + scale * d1,
717           {indices[indexPos], insert.getIds()[idCount++]});
718     }
719     rewriter.create<vector::TransferWriteOp>(
720         loc, insert.getVector(), write.getSource(), indices,
721         write.getPermutationMapAttr(), write.getInBoundsAttr());
722     rewriter.eraseOp(write);
723     return success();
724   }
725 };
726 
727 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern__anon77a0e1ec0411::UnrollReductionPattern728   UnrollReductionPattern(MLIRContext *context,
729                          const vector::UnrollVectorOptions &options)
730       : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
731         options(options) {}
732 
matchAndRewrite__anon77a0e1ec0411::UnrollReductionPattern733   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
734                                 PatternRewriter &rewriter) const override {
735     Optional<SmallVector<int64_t, 4>> targetShape =
736         getTargetShape(options, reductionOp);
737     if (!targetShape)
738       return failure();
739     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
740     int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
741 
742     // Create unrolled vector reduction.
743     Location loc = reductionOp.getLoc();
744     Value accumulator = nullptr;
745     for (int64_t i = 0; i < ratio; ++i) {
746       SmallVector<int64_t> offsets =
747           getVectorOffset(originalSize, *targetShape, i);
748       SmallVector<int64_t> strides(offsets.size(), 1);
749       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
750           loc, reductionOp.getVector(), offsets, *targetShape, strides);
751       Operation *newOp = cloneOpWithOperandsAndTypes(
752           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
753       Value result = newOp->getResult(0);
754 
755       if (!accumulator) {
756         // This is the first reduction.
757         accumulator = result;
758       } else {
759         // On subsequent reduction, combine with the accumulator.
760         accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
761                                          accumulator, result);
762       }
763     }
764 
765     rewriter.replaceOp(reductionOp, accumulator);
766     return success();
767   }
768 
769 private:
770   const vector::UnrollVectorOptions options;
771 };
772 
773 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern__anon77a0e1ec0411::UnrollTranposePattern774   UnrollTranposePattern(MLIRContext *context,
775                         const vector::UnrollVectorOptions &options)
776       : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
777         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTranposePattern778   LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
779                                 PatternRewriter &rewriter) const override {
780     if (tranposeOp.getResultType().getRank() == 0)
781       return failure();
782     auto targetShape = getTargetShape(options, tranposeOp);
783     if (!targetShape)
784       return failure();
785     auto originalVectorType = tranposeOp.getResultType();
786     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
787     Location loc = tranposeOp.getLoc();
788     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
789     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
790     int64_t sliceCount = computeMaxLinearIndex(ratio);
791     // Prepare the result vector;
792     Value result = rewriter.create<arith::ConstantOp>(
793         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
794     SmallVector<int64_t> permutation;
795     tranposeOp.getTransp(permutation);
796     for (int64_t i = 0; i < sliceCount; i++) {
797       SmallVector<int64_t, 4> elementOffsets =
798           getVectorOffset(originalSize, *targetShape, i);
799       SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
800       SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
801       // Compute the source offsets and shape.
802       for (auto &indices : llvm::enumerate(permutation)) {
803         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
804         permutedShape[indices.value()] = (*targetShape)[indices.index()];
805       }
806       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
807           loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
808       Value tranposedSlice =
809           rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
810       result = rewriter.create<vector::InsertStridedSliceOp>(
811           loc, tranposedSlice, result, elementOffsets, strides);
812     }
813     rewriter.replaceOp(tranposeOp, result);
814     return success();
815   }
816 
817 private:
818   vector::UnrollVectorOptions options;
819 };
820 
821 } // namespace
822 
populateVectorUnrollPatterns(RewritePatternSet & patterns,const UnrollVectorOptions & options)823 void mlir::vector::populateVectorUnrollPatterns(
824     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
825   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
826                UnrollContractionPattern, UnrollElementwisePattern,
827                UnrollReductionPattern, UnrollMultiReductionPattern,
828                UnrollTranposePattern>(patterns.getContext(), options);
829 }
830 
populatePropagateVectorDistributionPatterns(RewritePatternSet & patterns)831 void mlir::vector::populatePropagateVectorDistributionPatterns(
832     RewritePatternSet &patterns) {
833   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
834                TransferReadExtractPattern, TransferWriteInsertPattern>(
835       patterns.getContext());
836 }
837