1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "vector-unrolling"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 /// During unrolling from `originalShape` to `targetShape` return the offset for
27 /// the slice `index`.
28 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
29                                                ArrayRef<int64_t> targetShape,
30                                                int64_t index) {
31   SmallVector<int64_t, 4> dstSliceStrides =
32       computeStrides(originalShape, targetShape);
33   SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
34   SmallVector<int64_t, 4> elementOffsets =
35       computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
36   return elementOffsets;
37 }
38 
39 /// Compute the indices of the slice `index` for a tranfer op.
40 static SmallVector<Value>
41 sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
42                      ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
43                      AffineMap permutationMap, Location loc,
44                      OpBuilder &builder) {
45   MLIRContext *ctx = builder.getContext();
46   auto isBroadcast = [](AffineExpr expr) {
47     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
48       return constExpr.getValue() == 0;
49     return false;
50   };
51   SmallVector<int64_t, 4> elementOffsets =
52       getVectorOffset(originalShape, targetShape, index);
53   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
54   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
55   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
56     if (isBroadcast(dim.value()))
57       continue;
58     unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
59     auto expr = getAffineDimExpr(0, builder.getContext()) +
60                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
61     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
62     slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
63   }
64   return slicedIndices;
65 }
66 
67 // Clones `op` into a new operations that takes `operands` and returns
68 // `resultTypes`.
69 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
70                                               Operation *op,
71                                               ArrayRef<Value> operands,
72                                               ArrayRef<Type> resultTypes) {
73   OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
74   return builder.createOperation(res);
75 }
76 
77 /// Return the target shape for unrolling for the given `op`. Return llvm::None
78 /// if the op shouldn't be or cannot be unrolled.
79 static Optional<SmallVector<int64_t, 4>>
80 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
81   if (options.filterConstraint && failed(options.filterConstraint(op)))
82     return llvm::None;
83   assert(options.nativeShape &&
84          "vector unrolling expects the native shape or native"
85          "shape call back function to be set");
86   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
87   if (!unrollableVectorOp)
88     return llvm::None;
89   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
90   if (!maybeUnrollShape)
91     return llvm::None;
92   Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
93   if (!targetShape)
94     return llvm::None;
95   auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
96   if (!maybeShapeRatio ||
97       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
98     return llvm::None;
99   return targetShape;
100 }
101 
102 namespace {
103 
104 struct UnrollTransferReadPattern
105     : public OpRewritePattern<vector::TransferReadOp> {
106   UnrollTransferReadPattern(MLIRContext *context,
107                             const vector::UnrollVectorOptions &options)
108       : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
109         options(options) {}
110   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
111                                 PatternRewriter &rewriter) const override {
112     // TODO: support 0-d corner case.
113     if (readOp.getTransferRank() == 0)
114       return failure();
115     if (readOp.mask())
116       return failure();
117     auto targetShape = getTargetShape(options, readOp);
118     if (!targetShape)
119       return failure();
120     auto sourceVectorType = readOp.getVectorType();
121     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
122     Location loc = readOp.getLoc();
123     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
124     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
125     // Compute shape ratio of 'shape' and 'sizes'.
126     int64_t sliceCount = computeMaxLinearIndex(ratio);
127     // Prepare the result vector;
128     Value result = rewriter.create<arith::ConstantOp>(
129         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
130     auto targetType =
131         VectorType::get(*targetShape, sourceVectorType.getElementType());
132     SmallVector<Value, 4> originalIndices(readOp.indices().begin(),
133                                           readOp.indices().end());
134     for (int64_t i = 0; i < sliceCount; i++) {
135       SmallVector<Value, 4> indices =
136           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
137                                readOp.permutation_map(), loc, rewriter);
138       auto slicedRead = rewriter.create<vector::TransferReadOp>(
139           loc, targetType, readOp.source(), indices,
140           readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
141           readOp.in_boundsAttr());
142 
143       SmallVector<int64_t, 4> elementOffsets =
144           getVectorOffset(originalSize, *targetShape, i);
145       result = rewriter.create<vector::InsertStridedSliceOp>(
146           loc, slicedRead, result, elementOffsets, strides);
147     }
148     rewriter.replaceOp(readOp, result);
149     return success();
150   }
151 
152 private:
153   vector::UnrollVectorOptions options;
154 };
155 
156 struct UnrollTransferWritePattern
157     : public OpRewritePattern<vector::TransferWriteOp> {
158   UnrollTransferWritePattern(MLIRContext *context,
159                              const vector::UnrollVectorOptions &options)
160       : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
161         options(options) {}
162   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
163                                 PatternRewriter &rewriter) const override {
164     // TODO: support 0-d corner case.
165     if (writeOp.getTransferRank() == 0)
166       return failure();
167 
168     if (writeOp.mask())
169       return failure();
170     auto targetShape = getTargetShape(options, writeOp);
171     if (!targetShape)
172       return failure();
173     auto sourceVectorType = writeOp.getVectorType();
174     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
175     Location loc = writeOp.getLoc();
176     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
177     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
178     // Compute shape ratio of 'shape' and 'sizes'.
179     int64_t sliceCount = computeMaxLinearIndex(ratio);
180     SmallVector<Value, 4> originalIndices(writeOp.indices().begin(),
181                                           writeOp.indices().end());
182     Value resultTensor;
183     for (int64_t i = 0; i < sliceCount; i++) {
184       SmallVector<int64_t, 4> elementOffsets =
185           getVectorOffset(originalSize, *targetShape, i);
186       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
187           loc, writeOp.vector(), elementOffsets, *targetShape, strides);
188 
189       SmallVector<Value, 4> indices =
190           sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
191                                writeOp.permutation_map(), loc, rewriter);
192       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
193           loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
194           indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
195       // For the tensor case update the destination for the next transfer write.
196       if (!slicedWrite->getResults().empty())
197         resultTensor = slicedWrite->getResult(0);
198     }
199     if (resultTensor)
200       rewriter.replaceOp(writeOp, resultTensor);
201     else
202       rewriter.eraseOp(writeOp);
203     return success();
204   }
205 
206 private:
207   vector::UnrollVectorOptions options;
208 };
209 
210 struct UnrollContractionPattern
211     : public OpRewritePattern<vector::ContractionOp> {
212   struct OffsetMapInfo {
213     static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
214 
215     static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
216 
217     static unsigned getHashValue(const SmallVector<int64_t> &v) {
218       return static_cast<unsigned>(
219           llvm::hash_combine_range(v.begin(), v.end()));
220     }
221 
222     static bool isEqual(const SmallVector<int64_t> &lhs,
223                         const SmallVector<int64_t> &rhs) {
224       return lhs == rhs;
225     }
226   };
227   UnrollContractionPattern(MLIRContext *context,
228                            const vector::UnrollVectorOptions &options)
229       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
230         options(options) {}
231 
232   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
233                                 PatternRewriter &rewriter) const override {
234     auto targetShape = getTargetShape(options, contractOp);
235     if (!targetShape)
236       return failure();
237     auto dstVecType = contractOp.getResultType().cast<VectorType>();
238     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
239     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
240 
241     // Compute shape ratio of 'shape' and 'sizes'.
242     int64_t sliceCount = computeMaxLinearIndex(ratio);
243     Location loc = contractOp.getLoc();
244     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
245     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
246     llvm::MapVector<
247         SmallVector<int64_t>, Value,
248         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
249         accCache;
250     for (int64_t i = 0; i < sliceCount; i++) {
251       SmallVector<int64_t, 4> offsets =
252           getVectorOffset(originalSize, *targetShape, i);
253       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
254 
255       // Helper to coompute the new shape of each operand and extract the slice.
256       auto extractOperand = [&](unsigned index, Value operand,
257                                 AffineMap permutationMap,
258                                 ArrayRef<int64_t> operandOffets) {
259         SmallVector<int64_t> operandShape = applyPermutationMap(
260             permutationMap, ArrayRef<int64_t>(*targetShape));
261         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
262         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
263             loc, operand, operandOffets, operandShape, operandStrides);
264       };
265 
266       // Extract the new lhs operand.
267       AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
268       SmallVector<int64_t> lhsOffets =
269           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
270       extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets);
271       // If there is a mask associated to lhs, extract it as well.
272       if (slicesOperands.size() > 3)
273         extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets);
274 
275       // Extract the new rhs operand.
276       AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
277       SmallVector<int64_t> rhsOffets =
278           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
279       extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets);
280       // If there is a mask associated to rhs, extract it as well.
281       if (slicesOperands.size() > 4)
282         extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets);
283 
284       AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
285       SmallVector<int64_t> accOffets =
286           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
287       // If a version of the accumulator has already been computed, use it
288       // otherwise extract the first version from the original operand.
289       auto accIt = accCache.find(accOffets);
290       if (accIt != accCache.end())
291         slicesOperands[2] = accIt->second;
292       else
293         extractOperand(2, contractOp.acc(), accPermutationMap, accOffets);
294 
295       SmallVector<int64_t> dstShape =
296           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
297       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
298       Operation *newOp = cloneOpWithOperandsAndTypes(
299           rewriter, loc, contractOp, slicesOperands, targetType);
300 
301       SmallVector<int64_t> dstOffets =
302           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
303       // Save the accumulated value untill all the loops are unrolled since
304       // reduction loop keep updating the accumulator.
305       accCache[dstOffets] = newOp->getResult(0);
306     }
307     // Assemble back the accumulator into a single vector.
308     Value result = rewriter.create<arith::ConstantOp>(
309         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
310     for (const auto &it : accCache) {
311       SmallVector<int64_t> dstStrides(it.first.size(), 1);
312       result = rewriter.create<vector::InsertStridedSliceOp>(
313           loc, it.second, result, it.first, dstStrides);
314     }
315     rewriter.replaceOp(contractOp, result);
316     return success();
317   }
318 
319 private:
320   vector::UnrollVectorOptions options;
321 };
322 
323 struct UnrollElementwisePattern : public RewritePattern {
324   UnrollElementwisePattern(MLIRContext *context,
325                            const vector::UnrollVectorOptions &options)
326       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
327         options(options) {}
328   LogicalResult matchAndRewrite(Operation *op,
329                                 PatternRewriter &rewriter) const override {
330     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
331       return failure();
332     auto targetShape = getTargetShape(options, op);
333     if (!targetShape)
334       return failure();
335     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
336     SmallVector<int64_t, 4> originalSize =
337         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
338     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
339     int64_t sliceCount = computeMaxLinearIndex(ratio);
340     Location loc = op->getLoc();
341     // Prepare the result vector.
342     Value result = rewriter.create<arith::ConstantOp>(
343         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
344     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
345     VectorType newVecType =
346         VectorType::get(*targetShape, dstVecType.getElementType());
347     for (int64_t i = 0; i < sliceCount; i++) {
348       SmallVector<int64_t, 4> offsets =
349           getVectorOffset(originalSize, *targetShape, i);
350       SmallVector<Value, 4> extractOperands;
351       for (OpOperand &operand : op->getOpOperands()) {
352         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
353         if (!vecType) {
354           extractOperands.push_back(operand.get());
355           continue;
356         }
357         extractOperands.push_back(
358             rewriter.create<vector::ExtractStridedSliceOp>(
359                 loc, operand.get(), offsets, *targetShape, strides));
360       }
361       Operation *newOp = cloneOpWithOperandsAndTypes(
362           rewriter, loc, op, extractOperands, newVecType);
363       result = rewriter.create<vector::InsertStridedSliceOp>(
364           loc, newOp->getResult(0), result, offsets, strides);
365     }
366     rewriter.replaceOp(op, result);
367     return success();
368   }
369 
370 private:
371   vector::UnrollVectorOptions options;
372 };
373 
374 /// Canonicalize an extract_map using the result of a pointwise operation.
375 /// Transforms:
376 /// %v = arith.addf %a, %b : vector32xf32>
377 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
378 /// to:
379 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
380 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
381 /// %dv = arith.addf %da, %db : vector<1xf32>
382 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
383   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
384   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
385                                 PatternRewriter &rewriter) const override {
386     Operation *definedOp = extract.vector().getDefiningOp();
387     if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
388         definedOp->getNumResults() != 1)
389       return failure();
390     Location loc = extract.getLoc();
391     SmallVector<Value, 4> extractOperands;
392     for (OpOperand &operand : definedOp->getOpOperands()) {
393       auto vecType = operand.get().getType().template dyn_cast<VectorType>();
394       if (!vecType) {
395         extractOperands.push_back(operand.get());
396         continue;
397       }
398       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
399           loc,
400           VectorType::get(extract.getResultType().getShape(),
401                           vecType.getElementType()),
402           operand.get(), extract.ids()));
403     }
404     Operation *newOp = cloneOpWithOperandsAndTypes(
405         rewriter, loc, definedOp, extractOperands, extract.getResultType());
406     rewriter.replaceOp(extract, newOp->getResult(0));
407     return success();
408   }
409 };
410 
411 /// Canonicalize an extract_map using the result of a contract operation.
412 /// This propagate the extract_map to operands.
413 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
414   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
415   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
416                                 PatternRewriter &rewriter) const override {
417     Operation *definedOp = extract.vector().getDefiningOp();
418     auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
419     if (!contract)
420       return failure();
421     Location loc = contract.getLoc();
422     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
423     AffineMap affineMap = contract.getIndexingMaps()[accIndex];
424     // Create a map of the dimensions distributed based on the acc affine map.
425     // Only parallel dimensions are being distributed, reduction dimensions are
426     // untouched.
427     DenseMap<int64_t, int64_t> map;
428     for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
429       map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
430     SmallVector<Value, 4> extractOperands;
431     for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) {
432       // For each operands calculate the new vector type after distribution.
433       Value operand = contract->getOperand(it.index());
434       auto vecType = operand.getType().cast<VectorType>();
435       SmallVector<int64_t> operandShape(vecType.getShape().begin(),
436                                         vecType.getShape().end());
437       for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
438         unsigned dim = it.value().getDimPosition(i);
439         auto distributedDim = map.find(dim);
440         // If the dimension is not in the map it means it is a reduction and
441         // doesn't get distributed.
442         if (distributedDim == map.end())
443           continue;
444         operandShape[i] = distributedDim->second;
445       }
446       VectorType newVecType =
447           VectorType::get(operandShape, vecType.getElementType());
448       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
449           loc, newVecType, operand, extract.ids()));
450     }
451     Operation *newOp =
452         cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
453                                     extract.getResult().getType());
454     rewriter.replaceOp(extract, newOp->getResult(0));
455     return success();
456   }
457 };
458 
459 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
460 /// TransferRead.
461 /// Example:
462 /// ```
463 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
464 ///   memref<64x64x64xf32>, vector<64x4x32xf32>
465 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
466 /// ```
467 /// to:
468 /// ```
469 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
470 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
471 ///   memref<64x64x64xf32>, vector<2x4x1xf32>
472 /// ```
473 struct TransferReadExtractPattern
474     : public OpRewritePattern<vector::TransferReadOp> {
475   TransferReadExtractPattern(MLIRContext *context)
476       : OpRewritePattern<vector::TransferReadOp>(context) {}
477   LogicalResult matchAndRewrite(vector::TransferReadOp read,
478                                 PatternRewriter &rewriter) const override {
479     // TODO: support 0-d corner case.
480     if (read.getTransferRank() == 0)
481       return failure();
482 
483     if (!read.getResult().hasOneUse())
484       return failure();
485     auto extract =
486         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
487     if (!extract)
488       return failure();
489     if (read.mask())
490       return failure();
491 
492     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
493     AffineMap indexMap = extract.map().compose(read.permutation_map());
494     unsigned idCount = 0;
495     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
496     for (auto it :
497          llvm::zip(indexMap.getResults(), extract.map().getResults())) {
498       AffineExpr d0, d1;
499       bindDims(read.getContext(), d0, d1);
500       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
501       if (!indexExpr)
502         continue;
503       unsigned indexPos = indexExpr.getPosition();
504       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
505       auto scale = getAffineConstantExpr(
506           extract.getResultType().getDimSize(vectorPos), read.getContext());
507       indices[indexPos] = makeComposedAffineApply(
508           rewriter, read.getLoc(), d0 + scale * d1,
509           {indices[indexPos], extract.ids()[idCount++]});
510     }
511     Value newRead = lb.create<vector::TransferReadOp>(
512         extract.getType(), read.source(), indices, read.permutation_mapAttr(),
513         read.padding(), read.mask(), read.in_boundsAttr());
514     Value dest = lb.create<arith::ConstantOp>(
515         read.getType(), rewriter.getZeroAttr(read.getType()));
516     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
517     rewriter.replaceOp(read, newRead);
518     return success();
519   }
520 };
521 
522 struct TransferWriteInsertPattern
523     : public OpRewritePattern<vector::TransferWriteOp> {
524   TransferWriteInsertPattern(MLIRContext *context)
525       : OpRewritePattern<vector::TransferWriteOp>(context) {}
526   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
527                                 PatternRewriter &rewriter) const override {
528     // TODO: support 0-d corner case.
529     if (write.getTransferRank() == 0)
530       return failure();
531 
532     auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
533     if (!insert)
534       return failure();
535     if (write.mask())
536       return failure();
537     SmallVector<Value, 4> indices(write.indices().begin(),
538                                   write.indices().end());
539     AffineMap indexMap = insert.map().compose(write.permutation_map());
540     unsigned idCount = 0;
541     Location loc = write.getLoc();
542     for (auto it :
543          llvm::zip(indexMap.getResults(), insert.map().getResults())) {
544       AffineExpr d0, d1;
545       bindDims(write.getContext(), d0, d1);
546       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
547       if (!indexExpr)
548         continue;
549       unsigned indexPos = indexExpr.getPosition();
550       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
551       auto scale = getAffineConstantExpr(
552           insert.getSourceVectorType().getDimSize(vectorPos),
553           write.getContext());
554       indices[indexPos] =
555           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
556                                   {indices[indexPos], insert.ids()[idCount++]});
557     }
558     rewriter.create<vector::TransferWriteOp>(
559         loc, insert.vector(), write.source(), indices,
560         write.permutation_mapAttr(), write.in_boundsAttr());
561     rewriter.eraseOp(write);
562     return success();
563   }
564 };
565 
566 } // namespace
567 
568 void mlir::vector::populateVectorUnrollPatterns(
569     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
570   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
571                UnrollContractionPattern, UnrollElementwisePattern>(
572       patterns.getContext(), options);
573 }
574 
575 void mlir::vector::populatePropagateVectorDistributionPatterns(
576     RewritePatternSet &patterns) {
577   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
578                TransferReadExtractPattern, TransferWriteInsertPattern>(
579       patterns.getContext());
580 }
581