1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "vector-unrolling"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 /// During unrolling from `originalShape` to `targetShape` return the offset for
27 /// the slice `index`.
28 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
29                                                ArrayRef<int64_t> targetShape,
30                                                int64_t index) {
31   SmallVector<int64_t, 4> dstSliceStrides =
32       computeStrides(originalShape, targetShape);
33   SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
34   SmallVector<int64_t, 4> elementOffsets =
35       computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
36   return elementOffsets;
37 }
38 
39 /// Compute the indices of the slice `index` for a tranfer op.
40 static SmallVector<Value>
41 sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
42                      ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
43                      AffineMap permutationMap, Location loc,
44                      OpBuilder &builder) {
45   MLIRContext *ctx = builder.getContext();
46   auto isBroadcast = [](AffineExpr expr) {
47     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
48       return constExpr.getValue() == 0;
49     return false;
50   };
51   SmallVector<int64_t, 4> elementOffsets =
52       getVectorOffset(originalShape, targetShape, index);
53   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
54   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
55   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
56     if (isBroadcast(dim.value()))
57       continue;
58     unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
59     auto expr = getAffineDimExpr(0, builder.getContext()) +
60                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
61     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
62     slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
63   }
64   return slicedIndices;
65 }
66 
67 // Clones `op` into a new operations that takes `operands` and returns
68 // `resultTypes`.
69 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
70                                               Operation *op,
71                                               ArrayRef<Value> operands,
72                                               ArrayRef<Type> resultTypes) {
73   return builder.create(loc, op->getName().getIdentifier(), operands,
74                         resultTypes, op->getAttrs());
75 }
76 
77 /// Return the target shape for unrolling for the given `op`. Return llvm::None
78 /// if the op shouldn't be or cannot be unrolled.
79 static Optional<SmallVector<int64_t, 4>>
80 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
81   if (options.filterConstraint && failed(options.filterConstraint(op)))
82     return llvm::None;
83   assert(options.nativeShape &&
84          "vector unrolling expects the native shape or native"
85          "shape call back function to be set");
86   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
87   if (!unrollableVectorOp)
88     return llvm::None;
89   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
90   if (!maybeUnrollShape)
91     return llvm::None;
92   Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
93   if (!targetShape)
94     return llvm::None;
95   auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
96   if (!maybeShapeRatio ||
97       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
98     return llvm::None;
99   return targetShape;
100 }
101 
102 namespace {
103 
104 struct UnrollTransferReadPattern
105     : public OpRewritePattern<vector::TransferReadOp> {
106   UnrollTransferReadPattern(MLIRContext *context,
107                             const vector::UnrollVectorOptions &options)
108       : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
109         options(options) {}
110   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
111                                 PatternRewriter &rewriter) const override {
112     // TODO: support 0-d corner case.
113     if (readOp.getTransferRank() == 0)
114       return failure();
115     if (readOp.getMask())
116       return failure();
117     auto targetShape = getTargetShape(options, readOp);
118     if (!targetShape)
119       return failure();
120     auto sourceVectorType = readOp.getVectorType();
121     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
122     Location loc = readOp.getLoc();
123     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
124     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
125     // Compute shape ratio of 'shape' and 'sizes'.
126     int64_t sliceCount = computeMaxLinearIndex(ratio);
127     // Prepare the result vector;
128     Value result = rewriter.create<arith::ConstantOp>(
129         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
130     auto targetType =
131         VectorType::get(*targetShape, sourceVectorType.getElementType());
132     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
133                                           readOp.getIndices().end());
134     for (int64_t i = 0; i < sliceCount; i++) {
135       SmallVector<Value, 4> indices =
136           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
137                                readOp.getPermutationMap(), loc, rewriter);
138       auto slicedRead = rewriter.create<vector::TransferReadOp>(
139           loc, targetType, readOp.getSource(), indices,
140           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
141           readOp.getInBoundsAttr());
142 
143       SmallVector<int64_t, 4> elementOffsets =
144           getVectorOffset(originalSize, *targetShape, i);
145       result = rewriter.create<vector::InsertStridedSliceOp>(
146           loc, slicedRead, result, elementOffsets, strides);
147     }
148     rewriter.replaceOp(readOp, result);
149     return success();
150   }
151 
152 private:
153   vector::UnrollVectorOptions options;
154 };
155 
156 struct UnrollTransferWritePattern
157     : public OpRewritePattern<vector::TransferWriteOp> {
158   UnrollTransferWritePattern(MLIRContext *context,
159                              const vector::UnrollVectorOptions &options)
160       : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
161         options(options) {}
162   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
163                                 PatternRewriter &rewriter) const override {
164     // TODO: support 0-d corner case.
165     if (writeOp.getTransferRank() == 0)
166       return failure();
167 
168     if (writeOp.getMask())
169       return failure();
170     auto targetShape = getTargetShape(options, writeOp);
171     if (!targetShape)
172       return failure();
173     auto sourceVectorType = writeOp.getVectorType();
174     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
175     Location loc = writeOp.getLoc();
176     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
177     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
178     // Compute shape ratio of 'shape' and 'sizes'.
179     int64_t sliceCount = computeMaxLinearIndex(ratio);
180     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
181                                           writeOp.getIndices().end());
182     Value resultTensor;
183     for (int64_t i = 0; i < sliceCount; i++) {
184       SmallVector<int64_t, 4> elementOffsets =
185           getVectorOffset(originalSize, *targetShape, i);
186       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
187           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
188 
189       SmallVector<Value, 4> indices =
190           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
191                                writeOp.getPermutationMap(), loc, rewriter);
192       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
193           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
194           indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
195       // For the tensor case update the destination for the next transfer write.
196       if (!slicedWrite->getResults().empty())
197         resultTensor = slicedWrite->getResult(0);
198     }
199     if (resultTensor)
200       rewriter.replaceOp(writeOp, resultTensor);
201     else
202       rewriter.eraseOp(writeOp);
203     return success();
204   }
205 
206 private:
207   vector::UnrollVectorOptions options;
208 };
209 
210 struct OffsetMapInfo {
211   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
212 
213   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
214 
215   static unsigned getHashValue(const SmallVector<int64_t> &v) {
216     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
217   }
218 
219   static bool isEqual(const SmallVector<int64_t> &lhs,
220                       const SmallVector<int64_t> &rhs) {
221     return lhs == rhs;
222   }
223 };
224 
225 struct UnrollContractionPattern
226     : public OpRewritePattern<vector::ContractionOp> {
227   UnrollContractionPattern(MLIRContext *context,
228                            const vector::UnrollVectorOptions &options)
229       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
230         options(options) {}
231 
232   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
233                                 PatternRewriter &rewriter) const override {
234     auto targetShape = getTargetShape(options, contractOp);
235     if (!targetShape)
236       return failure();
237     auto dstVecType = contractOp.getResultType().cast<VectorType>();
238     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
239     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
240 
241     // Compute shape ratio of 'shape' and 'sizes'.
242     int64_t sliceCount = computeMaxLinearIndex(ratio);
243     Location loc = contractOp.getLoc();
244     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
245     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
246     llvm::MapVector<
247         SmallVector<int64_t>, Value,
248         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
249         accCache;
250     for (int64_t i = 0; i < sliceCount; i++) {
251       SmallVector<int64_t, 4> offsets =
252           getVectorOffset(originalSize, *targetShape, i);
253       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
254 
255       // Helper to coompute the new shape of each operand and extract the slice.
256       auto extractOperand = [&](unsigned index, Value operand,
257                                 AffineMap permutationMap,
258                                 ArrayRef<int64_t> operandOffets) {
259         SmallVector<int64_t> operandShape = applyPermutationMap(
260             permutationMap, ArrayRef<int64_t>(*targetShape));
261         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
262         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
263             loc, operand, operandOffets, operandShape, operandStrides);
264       };
265 
266       // Extract the new lhs operand.
267       AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
268       SmallVector<int64_t> lhsOffets =
269           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
270       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
271       // If there is a mask associated to lhs, extract it as well.
272       if (slicesOperands.size() > 3)
273         extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
274                        lhsOffets);
275 
276       // Extract the new rhs operand.
277       AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
278       SmallVector<int64_t> rhsOffets =
279           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
280       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
281       // If there is a mask associated to rhs, extract it as well.
282       if (slicesOperands.size() > 4)
283         extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
284                        rhsOffets);
285 
286       AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
287       SmallVector<int64_t> accOffets =
288           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
289       // If a version of the accumulator has already been computed, use it
290       // otherwise extract the first version from the original operand.
291       auto accIt = accCache.find(accOffets);
292       if (accIt != accCache.end())
293         slicesOperands[2] = accIt->second;
294       else
295         extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
296 
297       SmallVector<int64_t> dstShape =
298           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
299       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
300       Operation *newOp = cloneOpWithOperandsAndTypes(
301           rewriter, loc, contractOp, slicesOperands, targetType);
302 
303       SmallVector<int64_t> dstOffets =
304           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
305       // Save the accumulated value untill all the loops are unrolled since
306       // reduction loop keep updating the accumulator.
307       accCache[dstOffets] = newOp->getResult(0);
308     }
309     // Assemble back the accumulator into a single vector.
310     Value result = rewriter.create<arith::ConstantOp>(
311         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
312     for (const auto &it : accCache) {
313       SmallVector<int64_t> dstStrides(it.first.size(), 1);
314       result = rewriter.create<vector::InsertStridedSliceOp>(
315           loc, it.second, result, it.first, dstStrides);
316     }
317     rewriter.replaceOp(contractOp, result);
318     return success();
319   }
320 
321 private:
322   vector::UnrollVectorOptions options;
323 };
324 
325 struct UnrollMultiReductionPattern
326     : public OpRewritePattern<vector::MultiDimReductionOp> {
327   UnrollMultiReductionPattern(MLIRContext *context,
328                               const vector::UnrollVectorOptions &options)
329       : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
330         options(options) {}
331 
332   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
333                                 PatternRewriter &rewriter) const override {
334     Optional<SmallVector<int64_t, 4>> targetShape =
335         getTargetShape(options, reductionOp);
336     if (!targetShape)
337       return failure();
338     SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
339     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
340     llvm::MapVector<
341         SmallVector<int64_t>, Value,
342         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
343         accCache;
344     // Compute shape ratio of 'shape' and 'sizes'.
345     int64_t sliceCount = computeMaxLinearIndex(ratio);
346     Location loc = reductionOp.getLoc();
347     for (int64_t i = 0; i < sliceCount; i++) {
348       SmallVector<int64_t, 4> offsets =
349           getVectorOffset(originalSize, *targetShape, i);
350 
351       SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
352       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
353           loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
354 
355       SmallVector<int64_t> dstShape;
356       SmallVector<int64_t> destOffset;
357       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
358         if (!reductionOp.isReducedDim(i)) {
359           destOffset.push_back(offsets[i]);
360           dstShape.push_back((*targetShape)[i]);
361         }
362       }
363       auto targetType = VectorType::get(
364           dstShape, reductionOp.getSourceVectorType().getElementType());
365       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
366                                                      slicedOperand, targetType);
367       Value result = newOp->getResult(0);
368       // Save the accumulated value until all the loops are unrolled since
369       // reduction loop keeps updating the accumulator.
370       auto accIt = accCache.find(destOffset);
371       if (accIt != accCache.end())
372         result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
373                                     result, accIt->second);
374       accCache[destOffset] = result;
375     }
376     // Assemble back the accumulator into a single vector.
377     Value result = rewriter.create<arith::ConstantOp>(
378         loc, reductionOp.getDestType(),
379         rewriter.getZeroAttr(reductionOp.getDestType()));
380     for (const auto &it : accCache) {
381       SmallVector<int64_t> dstStrides(it.first.size(), 1);
382       result = rewriter.create<vector::InsertStridedSliceOp>(
383           loc, it.second, result, it.first, dstStrides);
384     }
385     rewriter.replaceOp(reductionOp, result);
386     return success();
387   }
388 
389 private:
390   vector::UnrollVectorOptions options;
391 };
392 
393 struct UnrollElementwisePattern : public RewritePattern {
394   UnrollElementwisePattern(MLIRContext *context,
395                            const vector::UnrollVectorOptions &options)
396       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
397         options(options) {}
398   LogicalResult matchAndRewrite(Operation *op,
399                                 PatternRewriter &rewriter) const override {
400     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
401       return failure();
402     auto targetShape = getTargetShape(options, op);
403     if (!targetShape)
404       return failure();
405     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
406     SmallVector<int64_t, 4> originalSize =
407         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
408     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
409     int64_t sliceCount = computeMaxLinearIndex(ratio);
410     Location loc = op->getLoc();
411     // Prepare the result vector.
412     Value result = rewriter.create<arith::ConstantOp>(
413         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
414     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
415     VectorType newVecType =
416         VectorType::get(*targetShape, dstVecType.getElementType());
417     for (int64_t i = 0; i < sliceCount; i++) {
418       SmallVector<int64_t, 4> offsets =
419           getVectorOffset(originalSize, *targetShape, i);
420       SmallVector<Value, 4> extractOperands;
421       for (OpOperand &operand : op->getOpOperands()) {
422         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
423         if (!vecType) {
424           extractOperands.push_back(operand.get());
425           continue;
426         }
427         extractOperands.push_back(
428             rewriter.create<vector::ExtractStridedSliceOp>(
429                 loc, operand.get(), offsets, *targetShape, strides));
430       }
431       Operation *newOp = cloneOpWithOperandsAndTypes(
432           rewriter, loc, op, extractOperands, newVecType);
433       result = rewriter.create<vector::InsertStridedSliceOp>(
434           loc, newOp->getResult(0), result, offsets, strides);
435     }
436     rewriter.replaceOp(op, result);
437     return success();
438   }
439 
440 private:
441   vector::UnrollVectorOptions options;
442 };
443 
444 /// Canonicalize an extract_map using the result of a pointwise operation.
445 /// Transforms:
446 /// %v = arith.addf %a, %b : vector32xf32>
447 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
448 /// to:
449 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
450 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
451 /// %dv = arith.addf %da, %db : vector<1xf32>
452 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
453   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
454   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
455                                 PatternRewriter &rewriter) const override {
456     Operation *definedOp = extract.getVector().getDefiningOp();
457     if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
458         definedOp->getNumResults() != 1)
459       return failure();
460     Location loc = extract.getLoc();
461     SmallVector<Value, 4> extractOperands;
462     for (OpOperand &operand : definedOp->getOpOperands()) {
463       auto vecType = operand.get().getType().template dyn_cast<VectorType>();
464       if (!vecType) {
465         extractOperands.push_back(operand.get());
466         continue;
467       }
468       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
469           loc,
470           VectorType::get(extract.getResultType().getShape(),
471                           vecType.getElementType()),
472           operand.get(), extract.getIds()));
473     }
474     Operation *newOp = cloneOpWithOperandsAndTypes(
475         rewriter, loc, definedOp, extractOperands, extract.getResultType());
476     rewriter.replaceOp(extract, newOp->getResult(0));
477     return success();
478   }
479 };
480 
481 /// Canonicalize an extract_map using the result of a contract operation.
482 /// This propagate the extract_map to operands.
483 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
484   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
485   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
486                                 PatternRewriter &rewriter) const override {
487     Operation *definedOp = extract.getVector().getDefiningOp();
488     auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
489     if (!contract)
490       return failure();
491     Location loc = contract.getLoc();
492     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
493     AffineMap affineMap = contract.getIndexingMaps()[accIndex];
494     // Create a map of the dimensions distributed based on the acc affine map.
495     // Only parallel dimensions are being distributed, reduction dimensions are
496     // untouched.
497     DenseMap<int64_t, int64_t> map;
498     for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
499       map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
500     SmallVector<Value, 4> extractOperands;
501     for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) {
502       // For each operands calculate the new vector type after distribution.
503       Value operand = contract->getOperand(it.index());
504       auto vecType = operand.getType().cast<VectorType>();
505       SmallVector<int64_t> operandShape(vecType.getShape().begin(),
506                                         vecType.getShape().end());
507       for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
508         unsigned dim = it.value().getDimPosition(i);
509         auto distributedDim = map.find(dim);
510         // If the dimension is not in the map it means it is a reduction and
511         // doesn't get distributed.
512         if (distributedDim == map.end())
513           continue;
514         operandShape[i] = distributedDim->second;
515       }
516       VectorType newVecType =
517           VectorType::get(operandShape, vecType.getElementType());
518       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
519           loc, newVecType, operand, extract.getIds()));
520     }
521     Operation *newOp =
522         cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
523                                     extract.getResult().getType());
524     rewriter.replaceOp(extract, newOp->getResult(0));
525     return success();
526   }
527 };
528 
529 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
530 /// TransferRead.
531 /// Example:
532 /// ```
533 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
534 ///   memref<64x64x64xf32>, vector<64x4x32xf32>
535 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
536 /// ```
537 /// to:
538 /// ```
539 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
540 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
541 ///   memref<64x64x64xf32>, vector<2x4x1xf32>
542 /// ```
543 struct TransferReadExtractPattern
544     : public OpRewritePattern<vector::TransferReadOp> {
545   TransferReadExtractPattern(MLIRContext *context)
546       : OpRewritePattern<vector::TransferReadOp>(context) {}
547   LogicalResult matchAndRewrite(vector::TransferReadOp read,
548                                 PatternRewriter &rewriter) const override {
549     // TODO: support 0-d corner case.
550     if (read.getTransferRank() == 0)
551       return failure();
552 
553     if (!read.getResult().hasOneUse())
554       return failure();
555     auto extract =
556         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
557     if (!extract)
558       return failure();
559     if (read.getMask())
560       return failure();
561 
562     SmallVector<Value, 4> indices(read.getIndices().begin(),
563                                   read.getIndices().end());
564     AffineMap indexMap = extract.map().compose(read.getPermutationMap());
565     unsigned idCount = 0;
566     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
567     for (auto it :
568          llvm::zip(indexMap.getResults(), extract.map().getResults())) {
569       AffineExpr d0, d1;
570       bindDims(read.getContext(), d0, d1);
571       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
572       if (!indexExpr)
573         continue;
574       unsigned indexPos = indexExpr.getPosition();
575       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
576       auto scale = getAffineConstantExpr(
577           extract.getResultType().getDimSize(vectorPos), read.getContext());
578       indices[indexPos] = makeComposedAffineApply(
579           rewriter, read.getLoc(), d0 + scale * d1,
580           {indices[indexPos], extract.getIds()[idCount++]});
581     }
582     Value newRead = lb.create<vector::TransferReadOp>(
583         extract.getType(), read.getSource(), indices,
584         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
585         read.getInBoundsAttr());
586     Value dest = lb.create<arith::ConstantOp>(
587         read.getType(), rewriter.getZeroAttr(read.getType()));
588     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
589     rewriter.replaceOp(read, newRead);
590     return success();
591   }
592 };
593 
594 struct TransferWriteInsertPattern
595     : public OpRewritePattern<vector::TransferWriteOp> {
596   TransferWriteInsertPattern(MLIRContext *context)
597       : OpRewritePattern<vector::TransferWriteOp>(context) {}
598   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
599                                 PatternRewriter &rewriter) const override {
600     // TODO: support 0-d corner case.
601     if (write.getTransferRank() == 0)
602       return failure();
603 
604     auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
605     if (!insert)
606       return failure();
607     if (write.getMask())
608       return failure();
609     SmallVector<Value, 4> indices(write.getIndices().begin(),
610                                   write.getIndices().end());
611     AffineMap indexMap = insert.map().compose(write.getPermutationMap());
612     unsigned idCount = 0;
613     Location loc = write.getLoc();
614     for (auto it :
615          llvm::zip(indexMap.getResults(), insert.map().getResults())) {
616       AffineExpr d0, d1;
617       bindDims(write.getContext(), d0, d1);
618       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
619       if (!indexExpr)
620         continue;
621       unsigned indexPos = indexExpr.getPosition();
622       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
623       auto scale = getAffineConstantExpr(
624           insert.getSourceVectorType().getDimSize(vectorPos),
625           write.getContext());
626       indices[indexPos] = makeComposedAffineApply(
627           rewriter, loc, d0 + scale * d1,
628           {indices[indexPos], insert.getIds()[idCount++]});
629     }
630     rewriter.create<vector::TransferWriteOp>(
631         loc, insert.getVector(), write.getSource(), indices,
632         write.getPermutationMapAttr(), write.getInBoundsAttr());
633     rewriter.eraseOp(write);
634     return success();
635   }
636 };
637 
638 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
639   UnrollReductionPattern(MLIRContext *context,
640                          const vector::UnrollVectorOptions &options)
641       : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
642         options(options) {}
643 
644   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
645                                 PatternRewriter &rewriter) const override {
646     Optional<SmallVector<int64_t, 4>> targetShape =
647         getTargetShape(options, reductionOp);
648     if (!targetShape)
649       return failure();
650     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
651     int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
652 
653     // Create unrolled vector reduction.
654     Location loc = reductionOp.getLoc();
655     Value accumulator = nullptr;
656     for (int64_t i = 0; i < ratio; ++i) {
657       SmallVector<int64_t> offsets =
658           getVectorOffset(originalSize, *targetShape, i);
659       SmallVector<int64_t> strides(offsets.size(), 1);
660       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
661           loc, reductionOp.getVector(), offsets, *targetShape, strides);
662       Operation *newOp = cloneOpWithOperandsAndTypes(
663           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
664       Value result = newOp->getResult(0);
665 
666       if (!accumulator) {
667         // This is the first reduction.
668         accumulator = result;
669       } else {
670         // On subsequent reduction, combine with the accumulator.
671         accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
672                                          accumulator, result);
673       }
674     }
675 
676     rewriter.replaceOp(reductionOp, accumulator);
677     return success();
678   }
679 
680 private:
681   const vector::UnrollVectorOptions options;
682 };
683 
684 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
685   UnrollTranposePattern(MLIRContext *context,
686                         const vector::UnrollVectorOptions &options)
687       : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
688         options(options) {}
689   LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
690                                 PatternRewriter &rewriter) const override {
691     if (tranposeOp.getResultType().getRank() == 0)
692       return failure();
693     auto targetShape = getTargetShape(options, tranposeOp);
694     if (!targetShape)
695       return failure();
696     auto originalVectorType = tranposeOp.getResultType();
697     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
698     Location loc = tranposeOp.getLoc();
699     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
700     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
701     int64_t sliceCount = computeMaxLinearIndex(ratio);
702     // Prepare the result vector;
703     Value result = rewriter.create<arith::ConstantOp>(
704         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
705     SmallVector<int64_t> permutation;
706     tranposeOp.getTransp(permutation);
707     for (int64_t i = 0; i < sliceCount; i++) {
708       SmallVector<int64_t, 4> elementOffsets =
709           getVectorOffset(originalSize, *targetShape, i);
710       SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
711       SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
712       // Compute the source offsets and shape.
713       for (auto &indices : llvm::enumerate(permutation)) {
714         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
715         permutedShape[indices.value()] = (*targetShape)[indices.index()];
716       }
717       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
718           loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
719       Value tranposedSlice =
720           rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
721       result = rewriter.create<vector::InsertStridedSliceOp>(
722           loc, tranposedSlice, result, elementOffsets, strides);
723     }
724     rewriter.replaceOp(tranposeOp, result);
725     return success();
726   }
727 
728 private:
729   vector::UnrollVectorOptions options;
730 };
731 
732 } // namespace
733 
734 void mlir::vector::populateVectorUnrollPatterns(
735     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
736   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
737                UnrollContractionPattern, UnrollElementwisePattern,
738                UnrollReductionPattern, UnrollMultiReductionPattern,
739                UnrollTranposePattern>(patterns.getContext(), options);
740 }
741 
742 void mlir::vector::populatePropagateVectorDistributionPatterns(
743     RewritePatternSet &patterns) {
744   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
745                TransferReadExtractPattern, TransferWriteInsertPattern>(
746       patterns.getContext());
747 }
748