1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "vector-unrolling"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 /// During unrolling from `originalShape` to `targetShape` return the offset for
27 /// the slice `index`.
28 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
29                                                ArrayRef<int64_t> targetShape,
30                                                int64_t index) {
31   SmallVector<int64_t, 4> dstSliceStrides =
32       computeStrides(originalShape, targetShape);
33   SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
34   SmallVector<int64_t, 4> elementOffsets =
35       computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
36   return elementOffsets;
37 }
38 
39 /// Compute the indices of the slice `index` for a tranfer op.
40 static SmallVector<Value>
41 sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
42                      ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
43                      AffineMap permutationMap, Location loc,
44                      OpBuilder &builder) {
45   MLIRContext *ctx = builder.getContext();
46   auto isBroadcast = [](AffineExpr expr) {
47     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
48       return constExpr.getValue() == 0;
49     return false;
50   };
51   SmallVector<int64_t, 4> elementOffsets =
52       getVectorOffset(originalShape, targetShape, index);
53   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
54   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
55   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
56     if (isBroadcast(dim.value()))
57       continue;
58     unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
59     auto expr = getAffineDimExpr(0, builder.getContext()) +
60                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
61     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
62     slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
63   }
64   return slicedIndices;
65 }
66 
67 // Clones `op` into a new operations that takes `operands` and returns
68 // `resultTypes`.
69 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
70                                               Operation *op,
71                                               ArrayRef<Value> operands,
72                                               ArrayRef<Type> resultTypes) {
73   OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
74   return builder.createOperation(res);
75 }
76 
77 /// Return the target shape for unrolling for the given `op`. Return llvm::None
78 /// if the op shouldn't be or cannot be unrolled.
79 static Optional<SmallVector<int64_t, 4>>
80 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
81   if (options.filterConstraint && failed(options.filterConstraint(op)))
82     return llvm::None;
83   assert(options.nativeShape &&
84          "vector unrolling expects the native shape or native"
85          "shape call back function to be set");
86   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
87   if (!unrollableVectorOp)
88     return llvm::None;
89   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
90   if (!maybeUnrollShape)
91     return llvm::None;
92   Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
93   if (!targetShape)
94     return llvm::None;
95   auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
96   if (!maybeShapeRatio ||
97       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
98     return llvm::None;
99   return targetShape;
100 }
101 
102 namespace {
103 
104 struct UnrollTransferReadPattern
105     : public OpRewritePattern<vector::TransferReadOp> {
106   UnrollTransferReadPattern(MLIRContext *context,
107                             const vector::UnrollVectorOptions &options)
108       : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
109         options(options) {}
110   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
111                                 PatternRewriter &rewriter) const override {
112     // TODO: support 0-d corner case.
113     if (readOp.getTransferRank() == 0)
114       return failure();
115     if (readOp.mask())
116       return failure();
117     auto targetShape = getTargetShape(options, readOp);
118     if (!targetShape)
119       return failure();
120     auto sourceVectorType = readOp.getVectorType();
121     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
122     Location loc = readOp.getLoc();
123     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
124     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
125     // Compute shape ratio of 'shape' and 'sizes'.
126     int64_t sliceCount = computeMaxLinearIndex(ratio);
127     // Prepare the result vector;
128     Value result = rewriter.create<arith::ConstantOp>(
129         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
130     auto targetType =
131         VectorType::get(*targetShape, sourceVectorType.getElementType());
132     SmallVector<Value, 4> originalIndices(readOp.indices().begin(),
133                                           readOp.indices().end());
134     for (int64_t i = 0; i < sliceCount; i++) {
135       SmallVector<Value, 4> indices =
136           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
137                                readOp.permutation_map(), loc, rewriter);
138       auto slicedRead = rewriter.create<vector::TransferReadOp>(
139           loc, targetType, readOp.source(), indices,
140           readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
141           readOp.in_boundsAttr());
142 
143       SmallVector<int64_t, 4> elementOffsets =
144           getVectorOffset(originalSize, *targetShape, i);
145       result = rewriter.create<vector::InsertStridedSliceOp>(
146           loc, slicedRead, result, elementOffsets, strides);
147     }
148     rewriter.replaceOp(readOp, result);
149     return success();
150   }
151 
152 private:
153   vector::UnrollVectorOptions options;
154 };
155 
156 struct UnrollTransferWritePattern
157     : public OpRewritePattern<vector::TransferWriteOp> {
158   UnrollTransferWritePattern(MLIRContext *context,
159                              const vector::UnrollVectorOptions &options)
160       : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
161         options(options) {}
162   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
163                                 PatternRewriter &rewriter) const override {
164     // TODO: support 0-d corner case.
165     if (writeOp.getTransferRank() == 0)
166       return failure();
167 
168     if (writeOp.mask())
169       return failure();
170     auto targetShape = getTargetShape(options, writeOp);
171     if (!targetShape)
172       return failure();
173     auto sourceVectorType = writeOp.getVectorType();
174     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
175     Location loc = writeOp.getLoc();
176     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
177     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
178     // Compute shape ratio of 'shape' and 'sizes'.
179     int64_t sliceCount = computeMaxLinearIndex(ratio);
180     SmallVector<Value, 4> originalIndices(writeOp.indices().begin(),
181                                           writeOp.indices().end());
182     Value resultTensor;
183     for (int64_t i = 0; i < sliceCount; i++) {
184       SmallVector<int64_t, 4> elementOffsets =
185           getVectorOffset(originalSize, *targetShape, i);
186       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
187           loc, writeOp.vector(), elementOffsets, *targetShape, strides);
188 
189       SmallVector<Value, 4> indices =
190           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
191                                writeOp.permutation_map(), loc, rewriter);
192       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
193           loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
194           indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
195       // For the tensor case update the destination for the next transfer write.
196       if (!slicedWrite->getResults().empty())
197         resultTensor = slicedWrite->getResult(0);
198     }
199     if (resultTensor)
200       rewriter.replaceOp(writeOp, resultTensor);
201     else
202       rewriter.eraseOp(writeOp);
203     return success();
204   }
205 
206 private:
207   vector::UnrollVectorOptions options;
208 };
209 
210 struct OffsetMapInfo {
211   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
212 
213   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
214 
215   static unsigned getHashValue(const SmallVector<int64_t> &v) {
216     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
217   }
218 
219   static bool isEqual(const SmallVector<int64_t> &lhs,
220                       const SmallVector<int64_t> &rhs) {
221     return lhs == rhs;
222   }
223 };
224 
225 struct UnrollContractionPattern
226     : public OpRewritePattern<vector::ContractionOp> {
227   UnrollContractionPattern(MLIRContext *context,
228                            const vector::UnrollVectorOptions &options)
229       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
230         options(options) {}
231 
232   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
233                                 PatternRewriter &rewriter) const override {
234     auto targetShape = getTargetShape(options, contractOp);
235     if (!targetShape)
236       return failure();
237     auto dstVecType = contractOp.getResultType().cast<VectorType>();
238     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
239     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
240 
241     // Compute shape ratio of 'shape' and 'sizes'.
242     int64_t sliceCount = computeMaxLinearIndex(ratio);
243     Location loc = contractOp.getLoc();
244     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
245     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
246     llvm::MapVector<
247         SmallVector<int64_t>, Value,
248         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
249         accCache;
250     for (int64_t i = 0; i < sliceCount; i++) {
251       SmallVector<int64_t, 4> offsets =
252           getVectorOffset(originalSize, *targetShape, i);
253       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
254 
255       // Helper to coompute the new shape of each operand and extract the slice.
256       auto extractOperand = [&](unsigned index, Value operand,
257                                 AffineMap permutationMap,
258                                 ArrayRef<int64_t> operandOffets) {
259         SmallVector<int64_t> operandShape = applyPermutationMap(
260             permutationMap, ArrayRef<int64_t>(*targetShape));
261         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
262         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
263             loc, operand, operandOffets, operandShape, operandStrides);
264       };
265 
266       // Extract the new lhs operand.
267       AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
268       SmallVector<int64_t> lhsOffets =
269           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
270       extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
271       // If there is a mask associated to lhs, extract it as well.
272       if (slicesOperands.size() > 3)
273         extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets);
274 
275       // Extract the new rhs operand.
276       AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
277       SmallVector<int64_t> rhsOffets =
278           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
279       extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
280       // If there is a mask associated to rhs, extract it as well.
281       if (slicesOperands.size() > 4)
282         extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets);
283 
284       AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
285       SmallVector<int64_t> accOffets =
286           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
287       // If a version of the accumulator has already been computed, use it
288       // otherwise extract the first version from the original operand.
289       auto accIt = accCache.find(accOffets);
290       if (accIt != accCache.end())
291         slicesOperands[2] = accIt->second;
292       else
293         extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
294 
295       SmallVector<int64_t> dstShape =
296           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
297       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
298       Operation *newOp = cloneOpWithOperandsAndTypes(
299           rewriter, loc, contractOp, slicesOperands, targetType);
300 
301       SmallVector<int64_t> dstOffets =
302           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
303       // Save the accumulated value untill all the loops are unrolled since
304       // reduction loop keep updating the accumulator.
305       accCache[dstOffets] = newOp->getResult(0);
306     }
307     // Assemble back the accumulator into a single vector.
308     Value result = rewriter.create<arith::ConstantOp>(
309         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
310     for (const auto &it : accCache) {
311       SmallVector<int64_t> dstStrides(it.first.size(), 1);
312       result = rewriter.create<vector::InsertStridedSliceOp>(
313           loc, it.second, result, it.first, dstStrides);
314     }
315     rewriter.replaceOp(contractOp, result);
316     return success();
317   }
318 
319 private:
320   vector::UnrollVectorOptions options;
321 };
322 
323 struct UnrollMultiReductionPattern
324     : public OpRewritePattern<vector::MultiDimReductionOp> {
325   UnrollMultiReductionPattern(MLIRContext *context,
326                               const vector::UnrollVectorOptions &options)
327       : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
328         options(options) {}
329 
330   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
331                                 PatternRewriter &rewriter) const override {
332     Optional<SmallVector<int64_t, 4>> targetShape =
333         getTargetShape(options, reductionOp);
334     if (!targetShape)
335       return failure();
336     SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
337     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
338     llvm::MapVector<
339         SmallVector<int64_t>, Value,
340         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
341         accCache;
342     // Compute shape ratio of 'shape' and 'sizes'.
343     int64_t sliceCount = computeMaxLinearIndex(ratio);
344     Location loc = reductionOp.getLoc();
345     for (int64_t i = 0; i < sliceCount; i++) {
346       SmallVector<int64_t, 4> offsets =
347           getVectorOffset(originalSize, *targetShape, i);
348 
349       SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
350       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
351           loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
352 
353       SmallVector<int64_t> dstShape;
354       SmallVector<int64_t> destOffset;
355       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
356         if (!reductionOp.isReducedDim(i)) {
357           destOffset.push_back(offsets[i]);
358           dstShape.push_back((*targetShape)[i]);
359         }
360       }
361       auto targetType = VectorType::get(
362           dstShape, reductionOp.getSourceVectorType().getElementType());
363       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
364                                                      slicedOperand, targetType);
365       Value result = newOp->getResult(0);
366       // Save the accumulated value until all the loops are unrolled since
367       // reduction loop keeps updating the accumulator.
368       auto accIt = accCache.find(destOffset);
369       if (accIt != accCache.end())
370         result = makeArithReduction(rewriter, loc, reductionOp.kind(), result,
371                                     accIt->second);
372       accCache[destOffset] = result;
373     }
374     // Assemble back the accumulator into a single vector.
375     Value result = rewriter.create<arith::ConstantOp>(
376         loc, reductionOp.getDestType(),
377         rewriter.getZeroAttr(reductionOp.getDestType()));
378     for (const auto &it : accCache) {
379       SmallVector<int64_t> dstStrides(it.first.size(), 1);
380       result = rewriter.create<vector::InsertStridedSliceOp>(
381           loc, it.second, result, it.first, dstStrides);
382     }
383     rewriter.replaceOp(reductionOp, result);
384     return success();
385   }
386 
387 private:
388   vector::UnrollVectorOptions options;
389 };
390 
391 struct UnrollElementwisePattern : public RewritePattern {
392   UnrollElementwisePattern(MLIRContext *context,
393                            const vector::UnrollVectorOptions &options)
394       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
395         options(options) {}
396   LogicalResult matchAndRewrite(Operation *op,
397                                 PatternRewriter &rewriter) const override {
398     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
399       return failure();
400     auto targetShape = getTargetShape(options, op);
401     if (!targetShape)
402       return failure();
403     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
404     SmallVector<int64_t, 4> originalSize =
405         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
406     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
407     int64_t sliceCount = computeMaxLinearIndex(ratio);
408     Location loc = op->getLoc();
409     // Prepare the result vector.
410     Value result = rewriter.create<arith::ConstantOp>(
411         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
412     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
413     VectorType newVecType =
414         VectorType::get(*targetShape, dstVecType.getElementType());
415     for (int64_t i = 0; i < sliceCount; i++) {
416       SmallVector<int64_t, 4> offsets =
417           getVectorOffset(originalSize, *targetShape, i);
418       SmallVector<Value, 4> extractOperands;
419       for (OpOperand &operand : op->getOpOperands()) {
420         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
421         if (!vecType) {
422           extractOperands.push_back(operand.get());
423           continue;
424         }
425         extractOperands.push_back(
426             rewriter.create<vector::ExtractStridedSliceOp>(
427                 loc, operand.get(), offsets, *targetShape, strides));
428       }
429       Operation *newOp = cloneOpWithOperandsAndTypes(
430           rewriter, loc, op, extractOperands, newVecType);
431       result = rewriter.create<vector::InsertStridedSliceOp>(
432           loc, newOp->getResult(0), result, offsets, strides);
433     }
434     rewriter.replaceOp(op, result);
435     return success();
436   }
437 
438 private:
439   vector::UnrollVectorOptions options;
440 };
441 
442 /// Canonicalize an extract_map using the result of a pointwise operation.
443 /// Transforms:
444 /// %v = arith.addf %a, %b : vector32xf32>
445 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
446 /// to:
447 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
448 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
449 /// %dv = arith.addf %da, %db : vector<1xf32>
450 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
451   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
452   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
453                                 PatternRewriter &rewriter) const override {
454     Operation *definedOp = extract.vector().getDefiningOp();
455     if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
456         definedOp->getNumResults() != 1)
457       return failure();
458     Location loc = extract.getLoc();
459     SmallVector<Value, 4> extractOperands;
460     for (OpOperand &operand : definedOp->getOpOperands()) {
461       auto vecType = operand.get().getType().template dyn_cast<VectorType>();
462       if (!vecType) {
463         extractOperands.push_back(operand.get());
464         continue;
465       }
466       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
467           loc,
468           VectorType::get(extract.getResultType().getShape(),
469                           vecType.getElementType()),
470           operand.get(), extract.ids()));
471     }
472     Operation *newOp = cloneOpWithOperandsAndTypes(
473         rewriter, loc, definedOp, extractOperands, extract.getResultType());
474     rewriter.replaceOp(extract, newOp->getResult(0));
475     return success();
476   }
477 };
478 
479 /// Canonicalize an extract_map using the result of a contract operation.
480 /// This propagate the extract_map to operands.
481 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
482   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
483   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
484                                 PatternRewriter &rewriter) const override {
485     Operation *definedOp = extract.vector().getDefiningOp();
486     auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
487     if (!contract)
488       return failure();
489     Location loc = contract.getLoc();
490     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
491     AffineMap affineMap = contract.getIndexingMaps()[accIndex];
492     // Create a map of the dimensions distributed based on the acc affine map.
493     // Only parallel dimensions are being distributed, reduction dimensions are
494     // untouched.
495     DenseMap<int64_t, int64_t> map;
496     for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
497       map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
498     SmallVector<Value, 4> extractOperands;
499     for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) {
500       // For each operands calculate the new vector type after distribution.
501       Value operand = contract->getOperand(it.index());
502       auto vecType = operand.getType().cast<VectorType>();
503       SmallVector<int64_t> operandShape(vecType.getShape().begin(),
504                                         vecType.getShape().end());
505       for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
506         unsigned dim = it.value().getDimPosition(i);
507         auto distributedDim = map.find(dim);
508         // If the dimension is not in the map it means it is a reduction and
509         // doesn't get distributed.
510         if (distributedDim == map.end())
511           continue;
512         operandShape[i] = distributedDim->second;
513       }
514       VectorType newVecType =
515           VectorType::get(operandShape, vecType.getElementType());
516       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
517           loc, newVecType, operand, extract.ids()));
518     }
519     Operation *newOp =
520         cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
521                                     extract.getResult().getType());
522     rewriter.replaceOp(extract, newOp->getResult(0));
523     return success();
524   }
525 };
526 
527 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
528 /// TransferRead.
529 /// Example:
530 /// ```
531 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
532 ///   memref<64x64x64xf32>, vector<64x4x32xf32>
533 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
534 /// ```
535 /// to:
536 /// ```
537 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
538 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
539 ///   memref<64x64x64xf32>, vector<2x4x1xf32>
540 /// ```
541 struct TransferReadExtractPattern
542     : public OpRewritePattern<vector::TransferReadOp> {
543   TransferReadExtractPattern(MLIRContext *context)
544       : OpRewritePattern<vector::TransferReadOp>(context) {}
545   LogicalResult matchAndRewrite(vector::TransferReadOp read,
546                                 PatternRewriter &rewriter) const override {
547     // TODO: support 0-d corner case.
548     if (read.getTransferRank() == 0)
549       return failure();
550 
551     if (!read.getResult().hasOneUse())
552       return failure();
553     auto extract =
554         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
555     if (!extract)
556       return failure();
557     if (read.mask())
558       return failure();
559 
560     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
561     AffineMap indexMap = extract.map().compose(read.permutation_map());
562     unsigned idCount = 0;
563     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
564     for (auto it :
565          llvm::zip(indexMap.getResults(), extract.map().getResults())) {
566       AffineExpr d0, d1;
567       bindDims(read.getContext(), d0, d1);
568       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
569       if (!indexExpr)
570         continue;
571       unsigned indexPos = indexExpr.getPosition();
572       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
573       auto scale = getAffineConstantExpr(
574           extract.getResultType().getDimSize(vectorPos), read.getContext());
575       indices[indexPos] = makeComposedAffineApply(
576           rewriter, read.getLoc(), d0 + scale * d1,
577           {indices[indexPos], extract.ids()[idCount++]});
578     }
579     Value newRead = lb.create<vector::TransferReadOp>(
580         extract.getType(), read.source(), indices, read.permutation_mapAttr(),
581         read.padding(), read.mask(), read.in_boundsAttr());
582     Value dest = lb.create<arith::ConstantOp>(
583         read.getType(), rewriter.getZeroAttr(read.getType()));
584     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
585     rewriter.replaceOp(read, newRead);
586     return success();
587   }
588 };
589 
590 struct TransferWriteInsertPattern
591     : public OpRewritePattern<vector::TransferWriteOp> {
592   TransferWriteInsertPattern(MLIRContext *context)
593       : OpRewritePattern<vector::TransferWriteOp>(context) {}
594   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
595                                 PatternRewriter &rewriter) const override {
596     // TODO: support 0-d corner case.
597     if (write.getTransferRank() == 0)
598       return failure();
599 
600     auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
601     if (!insert)
602       return failure();
603     if (write.mask())
604       return failure();
605     SmallVector<Value, 4> indices(write.indices().begin(),
606                                   write.indices().end());
607     AffineMap indexMap = insert.map().compose(write.permutation_map());
608     unsigned idCount = 0;
609     Location loc = write.getLoc();
610     for (auto it :
611          llvm::zip(indexMap.getResults(), insert.map().getResults())) {
612       AffineExpr d0, d1;
613       bindDims(write.getContext(), d0, d1);
614       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
615       if (!indexExpr)
616         continue;
617       unsigned indexPos = indexExpr.getPosition();
618       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
619       auto scale = getAffineConstantExpr(
620           insert.getSourceVectorType().getDimSize(vectorPos),
621           write.getContext());
622       indices[indexPos] =
623           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
624                                   {indices[indexPos], insert.ids()[idCount++]});
625     }
626     rewriter.create<vector::TransferWriteOp>(
627         loc, insert.vector(), write.source(), indices,
628         write.permutation_mapAttr(), write.in_boundsAttr());
629     rewriter.eraseOp(write);
630     return success();
631   }
632 };
633 
634 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
635   UnrollReductionPattern(MLIRContext *context,
636                          const vector::UnrollVectorOptions &options)
637       : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
638         options(options) {}
639 
640   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
641                                 PatternRewriter &rewriter) const override {
642     Optional<SmallVector<int64_t, 4>> targetShape =
643         getTargetShape(options, reductionOp);
644     if (!targetShape)
645       return failure();
646     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
647     int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
648 
649     // Create unrolled vector reduction.
650     Location loc = reductionOp.getLoc();
651     Value accumulator = nullptr;
652     for (int64_t i = 0; i < ratio; ++i) {
653       SmallVector<int64_t> offsets =
654           getVectorOffset(originalSize, *targetShape, i);
655       SmallVector<int64_t> strides(offsets.size(), 1);
656       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
657           loc, reductionOp.vector(), offsets, *targetShape, strides);
658       Operation *newOp = cloneOpWithOperandsAndTypes(
659           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
660       Value result = newOp->getResult(0);
661 
662       if (!accumulator) {
663         // This is the first reduction.
664         accumulator = result;
665       } else {
666         // On subsequent reduction, combine with the accumulator.
667         accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(),
668                                          accumulator, result);
669       }
670     }
671 
672     rewriter.replaceOp(reductionOp, accumulator);
673     return success();
674   }
675 
676 private:
677   const vector::UnrollVectorOptions options;
678 };
679 
680 } // namespace
681 
682 void mlir::vector::populateVectorUnrollPatterns(
683     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
684   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
685                UnrollContractionPattern, UnrollElementwisePattern,
686                UnrollReductionPattern, UnrollMultiReductionPattern>(
687       patterns.getContext(), options);
688 }
689 
690 void mlir::vector::populatePropagateVectorDistributionPatterns(
691     RewritePatternSet &patterns) {
692   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
693                TransferReadExtractPattern, TransferWriteInsertPattern>(
694       patterns.getContext());
695 }
696