1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Dominance.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "vector-transfer-opt"
26 
27 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
28 
29 using namespace mlir;
30 
31 /// Return the ancestor op in the region or nullptr if the region is not
32 /// an ancestor of the op.
33 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
34   for (; op != nullptr && op->getParentRegion() != region;
35        op = op->getParentOp())
36     ;
37   return op;
38 }
39 
40 namespace {
41 
42 class TransferOptimization {
43 public:
44   TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
45   void deadStoreOp(vector::TransferWriteOp);
46   void storeToLoadForwarding(vector::TransferReadOp);
47   void removeDeadOp() {
48     for (Operation *op : opToErase)
49       op->erase();
50     opToErase.clear();
51   }
52 
53 private:
54   bool isReachable(Operation *start, Operation *dest);
55   DominanceInfo dominators;
56   PostDominanceInfo postDominators;
57   std::vector<Operation *> opToErase;
58 };
59 
60 /// Return true if there is a path from start operation to dest operation,
61 /// otherwise return false. The operations have to be in the same region.
62 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
63   assert(start->getParentRegion() == dest->getParentRegion() &&
64          "This function only works for ops i the same region");
65   // Simple case where the start op dominate the destination.
66   if (dominators.dominates(start, dest))
67     return true;
68   Block *startBlock = start->getBlock();
69   Block *destBlock = dest->getBlock();
70   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
71                                     startBlock->succ_end());
72   SmallPtrSet<Block *, 32> visited;
73   while (!worklist.empty()) {
74     Block *bb = worklist.pop_back_val();
75     if (!visited.insert(bb).second)
76       continue;
77     if (dominators.dominates(bb, destBlock))
78       return true;
79     worklist.append(bb->succ_begin(), bb->succ_end());
80   }
81   return false;
82 }
83 
84 /// For transfer_write to overwrite fully another transfer_write must:
85 /// 1. Access the same memref with the same indices and vector type.
86 /// 2. Post-dominate the other transfer_write operation.
87 /// If several candidates are available, one must be post-dominated by all the
88 /// others since they are all post-dominating the same transfer_write. We only
89 /// consider the transfer_write post-dominated by all the other candidates as
90 /// this will be the first transfer_write executed after the potentially dead
91 /// transfer_write.
92 /// If we found such an overwriting transfer_write we know that the original
93 /// transfer_write is dead if all reads that can be reached from the potentially
94 /// dead transfer_write are dominated by the overwriting transfer_write.
95 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
96   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
97                     << "\n");
98   llvm::SmallVector<Operation *, 8> reads;
99   Operation *firstOverwriteCandidate = nullptr;
100   for (auto *user : write.getSource().getUsers()) {
101     if (user == write.getOperation())
102       continue;
103     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
104       // Check candidate that can override the store.
105       if (checkSameValueWAW(nextWrite, write) &&
106           postDominators.postDominates(nextWrite, write)) {
107         if (firstOverwriteCandidate == nullptr ||
108             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
109           firstOverwriteCandidate = nextWrite;
110         else
111           assert(
112               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
113       }
114     } else {
115       if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
116         // Don't need to consider disjoint reads.
117         if (vector::isDisjointTransferSet(
118                 cast<VectorTransferOpInterface>(write.getOperation()),
119                 cast<VectorTransferOpInterface>(read.getOperation())))
120           continue;
121       }
122       reads.push_back(user);
123     }
124   }
125   if (firstOverwriteCandidate == nullptr)
126     return;
127   Region *topRegion = firstOverwriteCandidate->getParentRegion();
128   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
129   assert(writeAncestor &&
130          "write op should be recursively part of the top region");
131 
132   for (Operation *read : reads) {
133     Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
134     // TODO: if the read and write have the same ancestor we could recurse in
135     // the region to know if the read is reachable with more precision.
136     if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
137       continue;
138     if (!dominators.dominates(firstOverwriteCandidate, read)) {
139       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
140                         << "\n");
141       return;
142     }
143   }
144   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
145                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
146   opToErase.push_back(write.getOperation());
147 }
148 
149 /// A transfer_write candidate to storeToLoad forwarding must:
150 /// 1. Access the same memref with the same indices and vector type as the
151 /// transfer_read.
152 /// 2. Dominate the transfer_read operation.
153 /// If several candidates are available, one must be dominated by all the others
154 /// since they are all dominating the same transfer_read. We only consider the
155 /// transfer_write dominated by all the other candidates as this will be the
156 /// last transfer_write executed before the transfer_read.
157 /// If we found such a candidate we can do the forwarding if all the other
158 /// potentially aliasing ops that may reach the transfer_read are post-dominated
159 /// by the transfer_write.
160 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
161   if (read.hasOutOfBoundsDim())
162     return;
163   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
164                     << "\n");
165   SmallVector<Operation *, 8> blockingWrites;
166   vector::TransferWriteOp lastwrite = nullptr;
167   for (Operation *user : read.getSource().getUsers()) {
168     if (isa<vector::TransferReadOp>(user))
169       continue;
170     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
171       // If there is a write, but we can prove that it is disjoint we can ignore
172       // the write.
173       if (vector::isDisjointTransferSet(
174               cast<VectorTransferOpInterface>(write.getOperation()),
175               cast<VectorTransferOpInterface>(read.getOperation())))
176         continue;
177       if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
178         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
179           lastwrite = write;
180         else
181           assert(dominators.dominates(write, lastwrite));
182         continue;
183       }
184     }
185     blockingWrites.push_back(user);
186   }
187 
188   if (lastwrite == nullptr)
189     return;
190 
191   Region *topRegion = lastwrite->getParentRegion();
192   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
193   assert(readAncestor &&
194          "read op should be recursively part of the top region");
195 
196   for (Operation *write : blockingWrites) {
197     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
198     // TODO: if the store and read have the same ancestor we could recurse in
199     // the region to know if the read is reachable with more precision.
200     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
201       continue;
202     if (!postDominators.postDominates(lastwrite, write)) {
203       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
204                         << *write << "\n");
205       return;
206     }
207   }
208 
209   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
210                     << " to: " << *read.getOperation() << "\n");
211   read.replaceAllUsesWith(lastwrite.getVector());
212   opToErase.push_back(read.getOperation());
213 }
214 
215 /// Drops unit dimensions from the input MemRefType.
216 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
217                                ArrayRef<int64_t> sizes,
218                                ArrayRef<int64_t> strides) {
219   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
220       0, inputType, offsets, sizes, strides);
221   return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
222 }
223 
224 /// Creates a rank-reducing memref.subview op that drops unit dims from its
225 /// input. Or just returns the input if it was already without unit dims.
226 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
227                                                  mlir::Location loc,
228                                                  Value input) {
229   MemRefType inputType = input.getType().cast<MemRefType>();
230   assert(inputType.hasStaticShape());
231   SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
232   SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
233   ArrayRef<int64_t> subViewSizes = inputType.getShape();
234   MemRefType resultType =
235       dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
236   if (canonicalizeStridedLayout(resultType) ==
237       canonicalizeStridedLayout(inputType))
238     return input;
239   return rewriter.create<memref::SubViewOp>(
240       loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
241 }
242 
243 /// Returns the number of dims that aren't unit dims.
244 static int getReducedRank(ArrayRef<int64_t> shape) {
245   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
246 }
247 
248 /// Returns true if all values are `arith.constant 0 : index`
249 static bool isZero(Value v) {
250   auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
251   return cst && cst.value() == 0;
252 }
253 
254 /// Rewrites vector.transfer_read ops where the source has unit dims, by
255 /// inserting a memref.subview dropping those unit dims.
256 class TransferReadDropUnitDimsPattern
257     : public OpRewritePattern<vector::TransferReadOp> {
258   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
259 
260   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
261                                 PatternRewriter &rewriter) const override {
262     auto loc = transferReadOp.getLoc();
263     Value vector = transferReadOp.getVector();
264     VectorType vectorType = vector.getType().cast<VectorType>();
265     Value source = transferReadOp.getSource();
266     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
267     // TODO: support tensor types.
268     if (!sourceType || !sourceType.hasStaticShape())
269       return failure();
270     if (sourceType.getNumElements() != vectorType.getNumElements())
271       return failure();
272     // TODO: generalize this pattern, relax the requirements here.
273     if (transferReadOp.hasOutOfBoundsDim())
274       return failure();
275     if (!transferReadOp.getPermutationMap().isMinorIdentity())
276       return failure();
277     int reducedRank = getReducedRank(sourceType.getShape());
278     if (reducedRank == sourceType.getRank())
279       return failure(); // The source shape can't be further reduced.
280     if (reducedRank != vectorType.getRank())
281       return failure(); // This pattern requires the vector shape to match the
282                         // reduced source shape.
283     if (llvm::any_of(transferReadOp.getIndices(),
284                      [](Value v) { return !isZero(v); }))
285       return failure();
286     Value reducedShapeSource =
287         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
288     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
289     SmallVector<Value> zeros(reducedRank, c0);
290     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
291     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
292         transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
293     return success();
294   }
295 };
296 
297 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
298 /// unit dims, by inserting a memref.subview dropping those unit dims.
299 class TransferWriteDropUnitDimsPattern
300     : public OpRewritePattern<vector::TransferWriteOp> {
301   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
302 
303   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
304                                 PatternRewriter &rewriter) const override {
305     auto loc = transferWriteOp.getLoc();
306     Value vector = transferWriteOp.getVector();
307     VectorType vectorType = vector.getType().cast<VectorType>();
308     Value source = transferWriteOp.getSource();
309     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
310     // TODO: support tensor type.
311     if (!sourceType || !sourceType.hasStaticShape())
312       return failure();
313     if (sourceType.getNumElements() != vectorType.getNumElements())
314       return failure();
315     // TODO: generalize this pattern, relax the requirements here.
316     if (transferWriteOp.hasOutOfBoundsDim())
317       return failure();
318     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
319       return failure();
320     int reducedRank = getReducedRank(sourceType.getShape());
321     if (reducedRank == sourceType.getRank())
322       return failure(); // The source shape can't be further reduced.
323     if (reducedRank != vectorType.getRank())
324       return failure(); // This pattern requires the vector shape to match the
325                         // reduced source shape.
326     if (llvm::any_of(transferWriteOp.getIndices(),
327                      [](Value v) { return !isZero(v); }))
328       return failure();
329     Value reducedShapeSource =
330         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
331     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
332     SmallVector<Value> zeros(reducedRank, c0);
333     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
334     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
335         transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
336     return success();
337   }
338 };
339 
340 /// Creates a memref.collapse_shape collapsing all of the dimensions of the
341 /// input into a 1D shape.
342 // TODO: move helper function
343 static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
344                                                   mlir::Location loc,
345                                                   Value input) {
346   Value rankReducedInput =
347       rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
348   ShapedType rankReducedInputType =
349       rankReducedInput.getType().cast<ShapedType>();
350   if (rankReducedInputType.getRank() == 1)
351     return rankReducedInput;
352   ReassociationIndices indices;
353   for (int i = 0; i < rankReducedInputType.getRank(); ++i)
354     indices.push_back(i);
355   return rewriter.create<memref::CollapseShapeOp>(
356       loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
357 }
358 
359 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
360 /// memref.collapse_shape on the source so that the resulting
361 /// vector.transfer_read has a 1D source. Requires the source shape to be
362 /// already reduced i.e. without unit dims.
363 class FlattenContiguousRowMajorTransferReadPattern
364     : public OpRewritePattern<vector::TransferReadOp> {
365   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
366 
367   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
368                                 PatternRewriter &rewriter) const override {
369     auto loc = transferReadOp.getLoc();
370     Value vector = transferReadOp.getVector();
371     VectorType vectorType = vector.getType().cast<VectorType>();
372     Value source = transferReadOp.getSource();
373     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
374     // Contiguity check is valid on tensors only.
375     if (!sourceType)
376       return failure();
377     if (vectorType.getRank() <= 1)
378       // Already 0D/1D, nothing to do.
379       return failure();
380     if (!isStaticShapeAndContiguousRowMajor(sourceType))
381       return failure();
382     if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
383       // This pattern requires the source to already be rank-reduced.
384       return failure();
385     if (sourceType.getNumElements() != vectorType.getNumElements())
386       return failure();
387     // TODO: generalize this pattern, relax the requirements here.
388     if (transferReadOp.hasOutOfBoundsDim())
389       return failure();
390     if (!transferReadOp.getPermutationMap().isMinorIdentity())
391       return failure();
392     if (transferReadOp.getMask())
393       return failure();
394     if (llvm::any_of(transferReadOp.getIndices(),
395                      [](Value v) { return !isZero(v); }))
396       return failure();
397     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
398     auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
399     VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
400                                               sourceType.getElementType());
401     Value source1d =
402         collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
403     Value read1d = rewriter.create<vector::TransferReadOp>(
404         loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
405     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
406         transferReadOp, vector.getType().cast<VectorType>(), read1d);
407     return success();
408   }
409 };
410 
411 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
412 /// memref.collapse_shape on the source so that the resulting
413 /// vector.transfer_write has a 1D source. Requires the source shape to be
414 /// already reduced i.e. without unit dims.
415 class FlattenContiguousRowMajorTransferWritePattern
416     : public OpRewritePattern<vector::TransferWriteOp> {
417   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
418 
419   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
420                                 PatternRewriter &rewriter) const override {
421     auto loc = transferWriteOp.getLoc();
422     Value vector = transferWriteOp.getVector();
423     VectorType vectorType = vector.getType().cast<VectorType>();
424     Value source = transferWriteOp.getSource();
425     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
426     // Contiguity check is valid on tensors only.
427     if (!sourceType)
428       return failure();
429     if (vectorType.getRank() <= 1)
430       // Already 0D/1D, nothing to do.
431       return failure();
432     if (!isStaticShapeAndContiguousRowMajor(sourceType))
433       return failure();
434     if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
435       // This pattern requires the source to already be rank-reduced.
436       return failure();
437     if (sourceType.getNumElements() != vectorType.getNumElements())
438       return failure();
439     // TODO: generalize this pattern, relax the requirements here.
440     if (transferWriteOp.hasOutOfBoundsDim())
441       return failure();
442     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
443       return failure();
444     if (transferWriteOp.getMask())
445       return failure();
446     if (llvm::any_of(transferWriteOp.getIndices(),
447                      [](Value v) { return !isZero(v); }))
448       return failure();
449     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
450     auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
451     VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
452                                               sourceType.getElementType());
453     Value source1d =
454         collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
455     Value vector1d =
456         rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
457     rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
458                                              ValueRange{c0}, identityMap1D);
459     rewriter.eraseOp(transferWriteOp);
460     return success();
461   }
462 };
463 
464 } // namespace
465 
466 void mlir::vector::transferOpflowOpt(Operation *rootOp) {
467   TransferOptimization opt(rootOp);
468   // Run store to load forwarding first since it can expose more dead store
469   // opportunity.
470   rootOp->walk([&](vector::TransferReadOp read) {
471     if (read.getShapedType().isa<MemRefType>())
472       opt.storeToLoadForwarding(read);
473   });
474   opt.removeDeadOp();
475   rootOp->walk([&](vector::TransferWriteOp write) {
476     if (write.getShapedType().isa<MemRefType>())
477       opt.deadStoreOp(write);
478   });
479   opt.removeDeadOp();
480 }
481 
482 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
483     RewritePatternSet &patterns) {
484   patterns
485       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
486           patterns.getContext());
487   populateShapeCastFoldingPatterns(patterns);
488 }
489 
490 void mlir::vector::populateFlattenVectorTransferPatterns(
491     RewritePatternSet &patterns) {
492   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
493                FlattenContiguousRowMajorTransferWritePattern>(
494       patterns.getContext());
495   populateShapeCastFoldingPatterns(patterns);
496 }
497