1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "mlir/Support/MathExtras.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Debug.h"
22 #include <numeric>
23 
24 #define DEBUG_TYPE "vector-unrolling"
25 
26 using namespace mlir;
27 using namespace mlir::vector;
28 
29 /// During unrolling from `originalShape` to `targetShape` return the offset for
30 /// the slice `index`.
31 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
32                                                ArrayRef<int64_t> targetShape,
33                                                int64_t index) {
34   SmallVector<int64_t, 4> dstSliceStrides =
35       computeStrides(originalShape, targetShape);
36   SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
37   SmallVector<int64_t, 4> elementOffsets =
38       computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
39   return elementOffsets;
40 }
41 
42 /// A functor that accomplishes the same thing as `getVectorOffset` but allows
43 /// for reordering the traversal of the dimensions. The order of traversal is
44 /// given in "for loop order" (outer to inner).
45 namespace {
46 class DecomposeShapeIterator {
47 private:
48   SmallVector<int64_t, 4> vectorShape;
49   SmallVector<int64_t> loopOrder;
50   SmallVector<int64_t> sliceStrides;
51   int64_t maxIndexVal{1};
52 
53 public:
54   DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
55                          ArrayRef<int64_t> targetShape,
56                          ArrayRef<int64_t> loopOrder)
57       : vectorShape(targetShape.begin(), targetShape.end()),
58         loopOrder(loopOrder.begin(), loopOrder.end()),
59         sliceStrides(originalShape.size()) {
60     // Compute the count for each dimension.
61     SmallVector<int64_t> sliceDimCounts(originalShape.size());
62     for (unsigned r = 0; r < originalShape.size(); ++r) {
63       sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
64       maxIndexVal *= sliceDimCounts[r];
65     }
66 
67     // Reversing "loop order" gives dimensions from fastest varying to slowest
68     // varying (smallest stride to largest stride).
69     int64_t accum = 1;
70     for (auto idx : llvm::reverse(loopOrder)) {
71       sliceStrides[idx] = accum;
72       accum *= sliceDimCounts[idx];
73     }
74   }
75 
76   // Turn the linear index into a d-tuple based on units of vectors of size
77   // `vectorShape`. The linear index is assumed to represent traversal of the
78   // dimensions based on `order`.
79   SmallVector<int64_t> delinearize(int64_t index) const {
80     // Traverse in for loop order (largest stride to smallest stride).
81     SmallVector<int64_t, 4> vectorOffsets(sliceStrides.size());
82     for (auto idx : loopOrder) {
83       vectorOffsets[idx] = index / sliceStrides[idx];
84       index %= sliceStrides[idx];
85     }
86     return vectorOffsets;
87   }
88 
89   int64_t maxIndex() const { return maxIndexVal; }
90 
91   /// Return the offset within d-tuple based on the ordering given by
92   /// `loopOrder`.
93   SmallVector<int64_t> getVectorOffset(int64_t index) const {
94     SmallVector<int64_t> vectorOffsets = delinearize(index);
95     SmallVector<int64_t> elementOffsets =
96         computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
97     return elementOffsets;
98   }
99 };
100 } // namespace
101 
102 /// Compute the indices of the slice `index` for a tranfer op.
103 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
104                                                ArrayRef<Value> indices,
105                                                AffineMap permutationMap,
106                                                Location loc,
107                                                OpBuilder &builder) {
108   MLIRContext *ctx = builder.getContext();
109   auto isBroadcast = [](AffineExpr expr) {
110     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
111       return constExpr.getValue() == 0;
112     return false;
113   };
114   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
115   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
116   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
117     if (isBroadcast(dim.value()))
118       continue;
119     unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
120     auto expr = getAffineDimExpr(0, builder.getContext()) +
121                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
122     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
123     slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
124   }
125   return slicedIndices;
126 }
127 
128 // Clones `op` into a new operations that takes `operands` and returns
129 // `resultTypes`.
130 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
131                                               Operation *op,
132                                               ArrayRef<Value> operands,
133                                               ArrayRef<Type> resultTypes) {
134   return builder.create(loc, op->getName().getIdentifier(), operands,
135                         resultTypes, op->getAttrs());
136 }
137 
138 /// Return the target shape for unrolling for the given `op`. Return llvm::None
139 /// if the op shouldn't be or cannot be unrolled.
140 static Optional<SmallVector<int64_t, 4>>
141 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
142   if (options.filterConstraint && failed(options.filterConstraint(op)))
143     return llvm::None;
144   assert(options.nativeShape &&
145          "vector unrolling expects the native shape or native"
146          "shape call back function to be set");
147   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
148   if (!unrollableVectorOp)
149     return llvm::None;
150   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
151   if (!maybeUnrollShape)
152     return llvm::None;
153   Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
154   if (!targetShape)
155     return llvm::None;
156   auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
157   if (!maybeShapeRatio ||
158       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
159     return llvm::None;
160   return targetShape;
161 }
162 
163 static SmallVector<int64_t>
164 getUnrollOrder(unsigned numLoops, Operation *op,
165                const vector::UnrollVectorOptions &options) {
166   SmallVector<int64_t> loopOrder =
167       llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
168   if (options.traversalOrderCallback != nullptr) {
169     Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
170     if (order.hasValue()) {
171       loopOrder = std::move(*order);
172     }
173   }
174   return loopOrder;
175 }
176 
177 namespace {
178 
179 struct UnrollTransferReadPattern
180     : public OpRewritePattern<vector::TransferReadOp> {
181   UnrollTransferReadPattern(MLIRContext *context,
182                             const vector::UnrollVectorOptions &options)
183       : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
184         options(options) {}
185   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
186                                 PatternRewriter &rewriter) const override {
187     // TODO: support 0-d corner case.
188     if (readOp.getTransferRank() == 0)
189       return failure();
190     if (readOp.getMask())
191       return failure();
192     auto targetShape = getTargetShape(options, readOp);
193     if (!targetShape)
194       return failure();
195     auto sourceVectorType = readOp.getVectorType();
196     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
197     Location loc = readOp.getLoc();
198     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
199     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
200 
201     // Prepare the result vector;
202     Value result = rewriter.create<arith::ConstantOp>(
203         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
204     auto targetType =
205         VectorType::get(*targetShape, sourceVectorType.getElementType());
206     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
207                                           readOp.getIndices().end());
208 
209     SmallVector<int64_t> loopOrder =
210         getUnrollOrder(ratio.size(), readOp, options);
211     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
212                                           loopOrder);
213     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
214       SmallVector<int64_t, 4> elementOffsets =
215           indexToOffsets.getVectorOffset(i);
216       SmallVector<Value, 4> indices =
217           sliceTransferIndices(elementOffsets, originalIndices,
218                                readOp.getPermutationMap(), loc, rewriter);
219       auto slicedRead = rewriter.create<vector::TransferReadOp>(
220           loc, targetType, readOp.getSource(), indices,
221           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
222           readOp.getInBoundsAttr());
223 
224       result = rewriter.create<vector::InsertStridedSliceOp>(
225           loc, slicedRead, result, elementOffsets, strides);
226     }
227     rewriter.replaceOp(readOp, result);
228     return success();
229   }
230 
231 private:
232   vector::UnrollVectorOptions options;
233 };
234 
235 struct UnrollTransferWritePattern
236     : public OpRewritePattern<vector::TransferWriteOp> {
237   UnrollTransferWritePattern(MLIRContext *context,
238                              const vector::UnrollVectorOptions &options)
239       : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
240         options(options) {}
241   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
242                                 PatternRewriter &rewriter) const override {
243     // TODO: support 0-d corner case.
244     if (writeOp.getTransferRank() == 0)
245       return failure();
246 
247     if (writeOp.getMask())
248       return failure();
249     auto targetShape = getTargetShape(options, writeOp);
250     if (!targetShape)
251       return failure();
252     auto sourceVectorType = writeOp.getVectorType();
253     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
254     Location loc = writeOp.getLoc();
255     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
256     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
257                                           writeOp.getIndices().end());
258 
259     SmallVector<int64_t> loopOrder =
260         getUnrollOrder(originalIndices.size(), writeOp, options);
261     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
262                                           loopOrder);
263     Value resultTensor;
264     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
265       SmallVector<int64_t, 4> elementOffsets =
266           indexToOffsets.getVectorOffset(i);
267       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
268           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
269       SmallVector<Value, 4> indices =
270           sliceTransferIndices(elementOffsets, originalIndices,
271                                writeOp.getPermutationMap(), loc, rewriter);
272       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
273           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
274           indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
275       // For the tensor case update the destination for the next transfer write.
276       if (!slicedWrite->getResults().empty())
277         resultTensor = slicedWrite->getResult(0);
278     }
279     if (resultTensor)
280       rewriter.replaceOp(writeOp, resultTensor);
281     else
282       rewriter.eraseOp(writeOp);
283     return success();
284   }
285 
286 private:
287   vector::UnrollVectorOptions options;
288 };
289 
290 struct OffsetMapInfo {
291   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
292 
293   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
294 
295   static unsigned getHashValue(const SmallVector<int64_t> &v) {
296     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
297   }
298 
299   static bool isEqual(const SmallVector<int64_t> &lhs,
300                       const SmallVector<int64_t> &rhs) {
301     return lhs == rhs;
302   }
303 };
304 
305 struct UnrollContractionPattern
306     : public OpRewritePattern<vector::ContractionOp> {
307   UnrollContractionPattern(MLIRContext *context,
308                            const vector::UnrollVectorOptions &options)
309       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
310         options(options) {}
311 
312   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
313                                 PatternRewriter &rewriter) const override {
314     auto targetShape = getTargetShape(options, contractOp);
315     if (!targetShape)
316       return failure();
317     auto dstVecType = contractOp.getResultType().cast<VectorType>();
318     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
319     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
320 
321     Location loc = contractOp.getLoc();
322     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
324     llvm::MapVector<
325         SmallVector<int64_t>, Value,
326         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
327         accCache;
328 
329     SmallVector<int64_t> loopOrder = getUnrollOrder(
330         contractOp.getIndexingMaps().size(), contractOp, options);
331     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
332                                           loopOrder);
333     const int64_t sliceCount = indexToOffsets.maxIndex();
334     for (int64_t i = 0; i < sliceCount; i++) {
335       SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
336       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
337 
338       // Helper to coompute the new shape of each operand and extract the slice.
339       auto extractOperand = [&](unsigned index, Value operand,
340                                 AffineMap permutationMap,
341                                 ArrayRef<int64_t> operandOffets) {
342         SmallVector<int64_t> operandShape = applyPermutationMap(
343             permutationMap, ArrayRef<int64_t>(*targetShape));
344         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
345         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
346             loc, operand, operandOffets, operandShape, operandStrides);
347       };
348 
349       // Extract the new lhs operand.
350       AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
351       SmallVector<int64_t> lhsOffets =
352           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
353       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
354       // If there is a mask associated to lhs, extract it as well.
355       if (slicesOperands.size() > 3)
356         extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
357                        lhsOffets);
358 
359       // Extract the new rhs operand.
360       AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
361       SmallVector<int64_t> rhsOffets =
362           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
363       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
364       // If there is a mask associated to rhs, extract it as well.
365       if (slicesOperands.size() > 4)
366         extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
367                        rhsOffets);
368 
369       AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
370       SmallVector<int64_t> accOffets =
371           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
372       // If a version of the accumulator has already been computed, use it
373       // otherwise extract the first version from the original operand.
374       auto accIt = accCache.find(accOffets);
375       if (accIt != accCache.end())
376         slicesOperands[2] = accIt->second;
377       else
378         extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
379 
380       SmallVector<int64_t> dstShape =
381           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
382       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
383       Operation *newOp = cloneOpWithOperandsAndTypes(
384           rewriter, loc, contractOp, slicesOperands, targetType);
385 
386       SmallVector<int64_t> dstOffets =
387           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
388       // Save the accumulated value untill all the loops are unrolled since
389       // reduction loop keep updating the accumulator.
390       accCache[dstOffets] = newOp->getResult(0);
391     }
392     // Assemble back the accumulator into a single vector.
393     Value result = rewriter.create<arith::ConstantOp>(
394         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
395     for (const auto &it : accCache) {
396       SmallVector<int64_t> dstStrides(it.first.size(), 1);
397       result = rewriter.create<vector::InsertStridedSliceOp>(
398           loc, it.second, result, it.first, dstStrides);
399     }
400     rewriter.replaceOp(contractOp, result);
401     return success();
402   }
403 
404 private:
405   vector::UnrollVectorOptions options;
406 };
407 
408 struct UnrollMultiReductionPattern
409     : public OpRewritePattern<vector::MultiDimReductionOp> {
410   UnrollMultiReductionPattern(MLIRContext *context,
411                               const vector::UnrollVectorOptions &options)
412       : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
413         options(options) {}
414 
415   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
416                                 PatternRewriter &rewriter) const override {
417     Optional<SmallVector<int64_t, 4>> targetShape =
418         getTargetShape(options, reductionOp);
419     if (!targetShape)
420       return failure();
421     SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
422     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
423     llvm::MapVector<
424         SmallVector<int64_t>, Value,
425         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
426         accCache;
427     // Compute shape ratio of 'shape' and 'sizes'.
428     int64_t sliceCount = computeMaxLinearIndex(ratio);
429     Location loc = reductionOp.getLoc();
430     for (int64_t i = 0; i < sliceCount; i++) {
431       SmallVector<int64_t, 4> offsets =
432           getVectorOffset(originalSize, *targetShape, i);
433 
434       SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
435       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
436           loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
437 
438       SmallVector<int64_t> dstShape;
439       SmallVector<int64_t> destOffset;
440       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
441         if (!reductionOp.isReducedDim(i)) {
442           destOffset.push_back(offsets[i]);
443           dstShape.push_back((*targetShape)[i]);
444         }
445       }
446       auto targetType = VectorType::get(
447           dstShape, reductionOp.getSourceVectorType().getElementType());
448       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
449                                                      slicedOperand, targetType);
450       Value result = newOp->getResult(0);
451       // Save the accumulated value until all the loops are unrolled since
452       // reduction loop keeps updating the accumulator.
453       auto accIt = accCache.find(destOffset);
454       if (accIt != accCache.end())
455         result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
456                                     result, accIt->second);
457       accCache[destOffset] = result;
458     }
459     // Assemble back the accumulator into a single vector.
460     Value result = rewriter.create<arith::ConstantOp>(
461         loc, reductionOp.getDestType(),
462         rewriter.getZeroAttr(reductionOp.getDestType()));
463     for (const auto &it : accCache) {
464       SmallVector<int64_t> dstStrides(it.first.size(), 1);
465       result = rewriter.create<vector::InsertStridedSliceOp>(
466           loc, it.second, result, it.first, dstStrides);
467     }
468     rewriter.replaceOp(reductionOp, result);
469     return success();
470   }
471 
472 private:
473   vector::UnrollVectorOptions options;
474 };
475 
476 struct UnrollElementwisePattern : public RewritePattern {
477   UnrollElementwisePattern(MLIRContext *context,
478                            const vector::UnrollVectorOptions &options)
479       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
480         options(options) {}
481   LogicalResult matchAndRewrite(Operation *op,
482                                 PatternRewriter &rewriter) const override {
483     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
484       return failure();
485     auto targetShape = getTargetShape(options, op);
486     if (!targetShape)
487       return failure();
488     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
489     SmallVector<int64_t, 4> originalSize =
490         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
491     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
492     int64_t sliceCount = computeMaxLinearIndex(ratio);
493     Location loc = op->getLoc();
494     // Prepare the result vector.
495     Value result = rewriter.create<arith::ConstantOp>(
496         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
497     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
498     VectorType newVecType =
499         VectorType::get(*targetShape, dstVecType.getElementType());
500     for (int64_t i = 0; i < sliceCount; i++) {
501       SmallVector<int64_t, 4> offsets =
502           getVectorOffset(originalSize, *targetShape, i);
503       SmallVector<Value, 4> extractOperands;
504       for (OpOperand &operand : op->getOpOperands()) {
505         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
506         if (!vecType) {
507           extractOperands.push_back(operand.get());
508           continue;
509         }
510         extractOperands.push_back(
511             rewriter.create<vector::ExtractStridedSliceOp>(
512                 loc, operand.get(), offsets, *targetShape, strides));
513       }
514       Operation *newOp = cloneOpWithOperandsAndTypes(
515           rewriter, loc, op, extractOperands, newVecType);
516       result = rewriter.create<vector::InsertStridedSliceOp>(
517           loc, newOp->getResult(0), result, offsets, strides);
518     }
519     rewriter.replaceOp(op, result);
520     return success();
521   }
522 
523 private:
524   vector::UnrollVectorOptions options;
525 };
526 
527 /// Canonicalize an extract_map using the result of a pointwise operation.
528 /// Transforms:
529 /// %v = arith.addf %a, %b : vector32xf32>
530 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
531 /// to:
532 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
533 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
534 /// %dv = arith.addf %da, %db : vector<1xf32>
535 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
536   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
537   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
538                                 PatternRewriter &rewriter) const override {
539     Operation *definedOp = extract.getVector().getDefiningOp();
540     if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
541         definedOp->getNumResults() != 1)
542       return failure();
543     Location loc = extract.getLoc();
544     SmallVector<Value, 4> extractOperands;
545     for (OpOperand &operand : definedOp->getOpOperands()) {
546       auto vecType = operand.get().getType().template dyn_cast<VectorType>();
547       if (!vecType) {
548         extractOperands.push_back(operand.get());
549         continue;
550       }
551       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
552           loc,
553           VectorType::get(extract.getResultType().getShape(),
554                           vecType.getElementType()),
555           operand.get(), extract.getIds()));
556     }
557     Operation *newOp = cloneOpWithOperandsAndTypes(
558         rewriter, loc, definedOp, extractOperands, extract.getResultType());
559     rewriter.replaceOp(extract, newOp->getResult(0));
560     return success();
561   }
562 };
563 
564 /// Canonicalize an extract_map using the result of a contract operation.
565 /// This propagate the extract_map to operands.
566 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
567   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
568   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
569                                 PatternRewriter &rewriter) const override {
570     Operation *definedOp = extract.getVector().getDefiningOp();
571     auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
572     if (!contract)
573       return failure();
574     Location loc = contract.getLoc();
575     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
576     AffineMap affineMap = contract.getIndexingMaps()[accIndex];
577     // Create a map of the dimensions distributed based on the acc affine map.
578     // Only parallel dimensions are being distributed, reduction dimensions are
579     // untouched.
580     DenseMap<int64_t, int64_t> map;
581     for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
582       map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
583     SmallVector<Value, 4> extractOperands;
584     for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) {
585       // For each operands calculate the new vector type after distribution.
586       Value operand = contract->getOperand(it.index());
587       auto vecType = operand.getType().cast<VectorType>();
588       SmallVector<int64_t> operandShape(vecType.getShape().begin(),
589                                         vecType.getShape().end());
590       for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
591         unsigned dim = it.value().getDimPosition(i);
592         auto distributedDim = map.find(dim);
593         // If the dimension is not in the map it means it is a reduction and
594         // doesn't get distributed.
595         if (distributedDim == map.end())
596           continue;
597         operandShape[i] = distributedDim->second;
598       }
599       VectorType newVecType =
600           VectorType::get(operandShape, vecType.getElementType());
601       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
602           loc, newVecType, operand, extract.getIds()));
603     }
604     Operation *newOp =
605         cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
606                                     extract.getResult().getType());
607     rewriter.replaceOp(extract, newOp->getResult(0));
608     return success();
609   }
610 };
611 
612 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
613 /// TransferRead.
614 /// Example:
615 /// ```
616 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
617 ///   memref<64x64x64xf32>, vector<64x4x32xf32>
618 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
619 /// ```
620 /// to:
621 /// ```
622 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
623 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
624 ///   memref<64x64x64xf32>, vector<2x4x1xf32>
625 /// ```
626 struct TransferReadExtractPattern
627     : public OpRewritePattern<vector::TransferReadOp> {
628   TransferReadExtractPattern(MLIRContext *context)
629       : OpRewritePattern<vector::TransferReadOp>(context) {}
630   LogicalResult matchAndRewrite(vector::TransferReadOp read,
631                                 PatternRewriter &rewriter) const override {
632     // TODO: support 0-d corner case.
633     if (read.getTransferRank() == 0)
634       return failure();
635 
636     if (!read.getResult().hasOneUse())
637       return failure();
638     auto extract =
639         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
640     if (!extract)
641       return failure();
642     if (read.getMask())
643       return failure();
644 
645     SmallVector<Value, 4> indices(read.getIndices().begin(),
646                                   read.getIndices().end());
647     AffineMap indexMap = extract.map().compose(read.getPermutationMap());
648     unsigned idCount = 0;
649     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
650     for (auto it :
651          llvm::zip(indexMap.getResults(), extract.map().getResults())) {
652       AffineExpr d0, d1;
653       bindDims(read.getContext(), d0, d1);
654       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
655       if (!indexExpr)
656         continue;
657       unsigned indexPos = indexExpr.getPosition();
658       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
659       auto scale = getAffineConstantExpr(
660           extract.getResultType().getDimSize(vectorPos), read.getContext());
661       indices[indexPos] = makeComposedAffineApply(
662           rewriter, read.getLoc(), d0 + scale * d1,
663           {indices[indexPos], extract.getIds()[idCount++]});
664     }
665     Value newRead = lb.create<vector::TransferReadOp>(
666         extract.getType(), read.getSource(), indices,
667         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
668         read.getInBoundsAttr());
669     Value dest = lb.create<arith::ConstantOp>(
670         read.getType(), rewriter.getZeroAttr(read.getType()));
671     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
672     rewriter.replaceOp(read, newRead);
673     return success();
674   }
675 };
676 
677 struct TransferWriteInsertPattern
678     : public OpRewritePattern<vector::TransferWriteOp> {
679   TransferWriteInsertPattern(MLIRContext *context)
680       : OpRewritePattern<vector::TransferWriteOp>(context) {}
681   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
682                                 PatternRewriter &rewriter) const override {
683     // TODO: support 0-d corner case.
684     if (write.getTransferRank() == 0)
685       return failure();
686 
687     auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
688     if (!insert)
689       return failure();
690     if (write.getMask())
691       return failure();
692     SmallVector<Value, 4> indices(write.getIndices().begin(),
693                                   write.getIndices().end());
694     AffineMap indexMap = insert.map().compose(write.getPermutationMap());
695     unsigned idCount = 0;
696     Location loc = write.getLoc();
697     for (auto it :
698          llvm::zip(indexMap.getResults(), insert.map().getResults())) {
699       AffineExpr d0, d1;
700       bindDims(write.getContext(), d0, d1);
701       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
702       if (!indexExpr)
703         continue;
704       unsigned indexPos = indexExpr.getPosition();
705       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
706       auto scale = getAffineConstantExpr(
707           insert.getSourceVectorType().getDimSize(vectorPos),
708           write.getContext());
709       indices[indexPos] = makeComposedAffineApply(
710           rewriter, loc, d0 + scale * d1,
711           {indices[indexPos], insert.getIds()[idCount++]});
712     }
713     rewriter.create<vector::TransferWriteOp>(
714         loc, insert.getVector(), write.getSource(), indices,
715         write.getPermutationMapAttr(), write.getInBoundsAttr());
716     rewriter.eraseOp(write);
717     return success();
718   }
719 };
720 
721 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
722   UnrollReductionPattern(MLIRContext *context,
723                          const vector::UnrollVectorOptions &options)
724       : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
725         options(options) {}
726 
727   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
728                                 PatternRewriter &rewriter) const override {
729     Optional<SmallVector<int64_t, 4>> targetShape =
730         getTargetShape(options, reductionOp);
731     if (!targetShape)
732       return failure();
733     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
734     int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
735 
736     // Create unrolled vector reduction.
737     Location loc = reductionOp.getLoc();
738     Value accumulator = nullptr;
739     for (int64_t i = 0; i < ratio; ++i) {
740       SmallVector<int64_t> offsets =
741           getVectorOffset(originalSize, *targetShape, i);
742       SmallVector<int64_t> strides(offsets.size(), 1);
743       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
744           loc, reductionOp.getVector(), offsets, *targetShape, strides);
745       Operation *newOp = cloneOpWithOperandsAndTypes(
746           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
747       Value result = newOp->getResult(0);
748 
749       if (!accumulator) {
750         // This is the first reduction.
751         accumulator = result;
752       } else {
753         // On subsequent reduction, combine with the accumulator.
754         accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
755                                          accumulator, result);
756       }
757     }
758 
759     rewriter.replaceOp(reductionOp, accumulator);
760     return success();
761   }
762 
763 private:
764   const vector::UnrollVectorOptions options;
765 };
766 
767 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
768   UnrollTranposePattern(MLIRContext *context,
769                         const vector::UnrollVectorOptions &options)
770       : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
771         options(options) {}
772   LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
773                                 PatternRewriter &rewriter) const override {
774     if (tranposeOp.getResultType().getRank() == 0)
775       return failure();
776     auto targetShape = getTargetShape(options, tranposeOp);
777     if (!targetShape)
778       return failure();
779     auto originalVectorType = tranposeOp.getResultType();
780     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
781     Location loc = tranposeOp.getLoc();
782     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
783     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
784     int64_t sliceCount = computeMaxLinearIndex(ratio);
785     // Prepare the result vector;
786     Value result = rewriter.create<arith::ConstantOp>(
787         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
788     SmallVector<int64_t> permutation;
789     tranposeOp.getTransp(permutation);
790     for (int64_t i = 0; i < sliceCount; i++) {
791       SmallVector<int64_t, 4> elementOffsets =
792           getVectorOffset(originalSize, *targetShape, i);
793       SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
794       SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
795       // Compute the source offsets and shape.
796       for (auto &indices : llvm::enumerate(permutation)) {
797         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
798         permutedShape[indices.value()] = (*targetShape)[indices.index()];
799       }
800       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
801           loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
802       Value tranposedSlice =
803           rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
804       result = rewriter.create<vector::InsertStridedSliceOp>(
805           loc, tranposedSlice, result, elementOffsets, strides);
806     }
807     rewriter.replaceOp(tranposeOp, result);
808     return success();
809   }
810 
811 private:
812   vector::UnrollVectorOptions options;
813 };
814 
815 } // namespace
816 
817 void mlir::vector::populateVectorUnrollPatterns(
818     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
819   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
820                UnrollContractionPattern, UnrollElementwisePattern,
821                UnrollReductionPattern, UnrollMultiReductionPattern,
822                UnrollTranposePattern>(patterns.getContext(), options);
823 }
824 
825 void mlir::vector::populatePropagateVectorDistributionPatterns(
826     RewritePatternSet &patterns) {
827   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
828                TransferReadExtractPattern, TransferWriteInsertPattern>(
829       patterns.getContext());
830 }
831