1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Dominance.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "vector-transfer-opt"
25 
26 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
27 
28 using namespace mlir;
29 
30 /// Return the ancestor op in the region or nullptr if the region is not
31 /// an ancestor of the op.
32 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
33   for (; op != nullptr && op->getParentRegion() != region;
34        op = op->getParentOp())
35     ;
36   return op;
37 }
38 
39 namespace {
40 
41 class TransferOptimization {
42 public:
43   TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
44   void deadStoreOp(vector::TransferWriteOp);
45   void storeToLoadForwarding(vector::TransferReadOp);
46   void removeDeadOp() {
47     for (Operation *op : opToErase)
48       op->erase();
49     opToErase.clear();
50   }
51 
52 private:
53   bool isReachable(Operation *start, Operation *dest);
54   DominanceInfo dominators;
55   PostDominanceInfo postDominators;
56   std::vector<Operation *> opToErase;
57 };
58 
59 /// Return true if there is a path from start operation to dest operation,
60 /// otherwise return false. The operations have to be in the same region.
61 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
62   assert(start->getParentRegion() == dest->getParentRegion() &&
63          "This function only works for ops i the same region");
64   // Simple case where the start op dominate the destination.
65   if (dominators.dominates(start, dest))
66     return true;
67   Block *startBlock = start->getBlock();
68   Block *destBlock = dest->getBlock();
69   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
70                                     startBlock->succ_end());
71   SmallPtrSet<Block *, 32> visited;
72   while (!worklist.empty()) {
73     Block *bb = worklist.pop_back_val();
74     if (!visited.insert(bb).second)
75       continue;
76     if (dominators.dominates(bb, destBlock))
77       return true;
78     worklist.append(bb->succ_begin(), bb->succ_end());
79   }
80   return false;
81 }
82 
83 /// For transfer_write to overwrite fully another transfer_write must:
84 /// 1. Access the same memref with the same indices and vector type.
85 /// 2. Post-dominate the other transfer_write operation.
86 /// If several candidates are available, one must be post-dominated by all the
87 /// others since they are all post-dominating the same transfer_write. We only
88 /// consider the transfer_write post-dominated by all the other candidates as
89 /// this will be the first transfer_write executed after the potentially dead
90 /// transfer_write.
91 /// If we found such an overwriting transfer_write we know that the original
92 /// transfer_write is dead if all reads that can be reached from the potentially
93 /// dead transfer_write are dominated by the overwriting transfer_write.
94 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
95   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
96                     << "\n");
97   llvm::SmallVector<Operation *, 8> reads;
98   Operation *firstOverwriteCandidate = nullptr;
99   for (auto *user : write.source().getUsers()) {
100     if (user == write.getOperation())
101       continue;
102     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
103       // Check candidate that can override the store.
104       if (checkSameValueWAW(nextWrite, write) &&
105           postDominators.postDominates(nextWrite, write)) {
106         if (firstOverwriteCandidate == nullptr ||
107             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
108           firstOverwriteCandidate = nextWrite;
109         else
110           assert(
111               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
112       }
113     } else {
114       if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
115         // Don't need to consider disjoint reads.
116         if (vector::isDisjointTransferSet(
117                 cast<VectorTransferOpInterface>(write.getOperation()),
118                 cast<VectorTransferOpInterface>(read.getOperation())))
119           continue;
120       }
121       reads.push_back(user);
122     }
123   }
124   if (firstOverwriteCandidate == nullptr)
125     return;
126   Region *topRegion = firstOverwriteCandidate->getParentRegion();
127   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
128   assert(writeAncestor &&
129          "write op should be recursively part of the top region");
130 
131   for (Operation *read : reads) {
132     Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
133     // TODO: if the read and write have the same ancestor we could recurse in
134     // the region to know if the read is reachable with more precision.
135     if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
136       continue;
137     if (!dominators.dominates(firstOverwriteCandidate, read)) {
138       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
139                         << "\n");
140       return;
141     }
142   }
143   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
144                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
145   opToErase.push_back(write.getOperation());
146 }
147 
148 /// A transfer_write candidate to storeToLoad forwarding must:
149 /// 1. Access the same memref with the same indices and vector type as the
150 /// transfer_read.
151 /// 2. Dominate the transfer_read operation.
152 /// If several candidates are available, one must be dominated by all the others
153 /// since they are all dominating the same transfer_read. We only consider the
154 /// transfer_write dominated by all the other candidates as this will be the
155 /// last transfer_write executed before the transfer_read.
156 /// If we found such a candidate we can do the forwarding if all the other
157 /// potentially aliasing ops that may reach the transfer_read are post-dominated
158 /// by the transfer_write.
159 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
160   if (read.hasOutOfBoundsDim())
161     return;
162   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
163                     << "\n");
164   SmallVector<Operation *, 8> blockingWrites;
165   vector::TransferWriteOp lastwrite = nullptr;
166   for (Operation *user : read.source().getUsers()) {
167     if (isa<vector::TransferReadOp>(user))
168       continue;
169     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
170       // If there is a write, but we can prove that it is disjoint we can ignore
171       // the write.
172       if (vector::isDisjointTransferSet(
173               cast<VectorTransferOpInterface>(write.getOperation()),
174               cast<VectorTransferOpInterface>(read.getOperation())))
175         continue;
176       if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
177         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
178           lastwrite = write;
179         else
180           assert(dominators.dominates(write, lastwrite));
181         continue;
182       }
183     }
184     blockingWrites.push_back(user);
185   }
186 
187   if (lastwrite == nullptr)
188     return;
189 
190   Region *topRegion = lastwrite->getParentRegion();
191   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
192   assert(readAncestor &&
193          "read op should be recursively part of the top region");
194 
195   for (Operation *write : blockingWrites) {
196     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
197     // TODO: if the store and read have the same ancestor we could recurse in
198     // the region to know if the read is reachable with more precision.
199     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
200       continue;
201     if (!postDominators.postDominates(lastwrite, write)) {
202       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
203                         << *write << "\n");
204       return;
205     }
206   }
207 
208   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
209                     << " to: " << *read.getOperation() << "\n");
210   read.replaceAllUsesWith(lastwrite.vector());
211   opToErase.push_back(read.getOperation());
212 }
213 
214 /// Drops unit dimensions from the input MemRefType.
215 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
216                                ArrayRef<int64_t> sizes,
217                                ArrayRef<int64_t> strides) {
218   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
219       0, inputType, offsets, sizes, strides);
220   return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
221 }
222 
223 /// Creates a rank-reducing memref.subview op that drops unit dims from its
224 /// input. Or just returns the input if it was already without unit dims.
225 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
226                                                  mlir::Location loc,
227                                                  Value input) {
228   MemRefType inputType = input.getType().cast<MemRefType>();
229   assert(inputType.hasStaticShape());
230   SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
231   SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
232   ArrayRef<int64_t> subViewSizes = inputType.getShape();
233   MemRefType resultType =
234       dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
235   if (canonicalizeStridedLayout(resultType) ==
236       canonicalizeStridedLayout(inputType))
237     return input;
238   return rewriter.create<memref::SubViewOp>(
239       loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
240 }
241 
242 /// Returns the number of dims that aren't unit dims.
243 static int getReducedRank(ArrayRef<int64_t> shape) {
244   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
245 }
246 
247 /// Returns true if all values are `arith.constant 0 : index`
248 static bool isZero(Value v) {
249   auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
250   return cst && cst.value() == 0;
251 }
252 
253 /// Rewrites vector.transfer_read ops where the source has unit dims, by
254 /// inserting a memref.subview dropping those unit dims.
255 class TransferReadDropUnitDimsPattern
256     : public OpRewritePattern<vector::TransferReadOp> {
257   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
258 
259   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
260                                 PatternRewriter &rewriter) const override {
261     auto loc = transferReadOp.getLoc();
262     Value vector = transferReadOp.vector();
263     VectorType vectorType = vector.getType().cast<VectorType>();
264     Value source = transferReadOp.source();
265     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
266     // TODO: support tensor types.
267     if (!sourceType || !sourceType.hasStaticShape())
268       return failure();
269     if (sourceType.getNumElements() != vectorType.getNumElements())
270       return failure();
271     // TODO: generalize this pattern, relax the requirements here.
272     if (transferReadOp.hasOutOfBoundsDim())
273       return failure();
274     if (!transferReadOp.permutation_map().isMinorIdentity())
275       return failure();
276     int reducedRank = getReducedRank(sourceType.getShape());
277     if (reducedRank == sourceType.getRank())
278       return failure(); // The source shape can't be further reduced.
279     if (reducedRank != vectorType.getRank())
280       return failure(); // This pattern requires the vector shape to match the
281                         // reduced source shape.
282     if (llvm::any_of(transferReadOp.indices(),
283                      [](Value v) { return !isZero(v); }))
284       return failure();
285     Value reducedShapeSource =
286         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
287     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
288     SmallVector<Value> zeros(reducedRank, c0);
289     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
290     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
291         transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
292     return success();
293   }
294 };
295 
296 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
297 /// unit dims, by inserting a memref.subview dropping those unit dims.
298 class TransferWriteDropUnitDimsPattern
299     : public OpRewritePattern<vector::TransferWriteOp> {
300   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
301 
302   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
303                                 PatternRewriter &rewriter) const override {
304     auto loc = transferWriteOp.getLoc();
305     Value vector = transferWriteOp.vector();
306     VectorType vectorType = vector.getType().cast<VectorType>();
307     Value source = transferWriteOp.source();
308     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
309     // TODO: support tensor type.
310     if (!sourceType || !sourceType.hasStaticShape())
311       return failure();
312     if (sourceType.getNumElements() != vectorType.getNumElements())
313       return failure();
314     // TODO: generalize this pattern, relax the requirements here.
315     if (transferWriteOp.hasOutOfBoundsDim())
316       return failure();
317     if (!transferWriteOp.permutation_map().isMinorIdentity())
318       return failure();
319     int reducedRank = getReducedRank(sourceType.getShape());
320     if (reducedRank == sourceType.getRank())
321       return failure(); // The source shape can't be further reduced.
322     if (reducedRank != vectorType.getRank())
323       return failure(); // This pattern requires the vector shape to match the
324                         // reduced source shape.
325     if (llvm::any_of(transferWriteOp.indices(),
326                      [](Value v) { return !isZero(v); }))
327       return failure();
328     Value reducedShapeSource =
329         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
330     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
331     SmallVector<Value> zeros(reducedRank, c0);
332     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
333     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
334         transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
335     return success();
336   }
337 };
338 
339 /// Creates a memref.collapse_shape collapsing all of the dimensions of the
340 /// input into a 1D shape.
341 // TODO: move helper function
342 static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
343                                                   mlir::Location loc,
344                                                   Value input) {
345   Value rankReducedInput =
346       rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
347   ShapedType rankReducedInputType =
348       rankReducedInput.getType().cast<ShapedType>();
349   if (rankReducedInputType.getRank() == 1)
350     return rankReducedInput;
351   ReassociationIndices indices;
352   for (int i = 0; i < rankReducedInputType.getRank(); ++i)
353     indices.push_back(i);
354   return rewriter.create<memref::CollapseShapeOp>(
355       loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
356 }
357 
358 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
359 /// memref.collapse_shape on the source so that the resulting
360 /// vector.transfer_read has a 1D source. Requires the source shape to be
361 /// already reduced i.e. without unit dims.
362 class FlattenContiguousRowMajorTransferReadPattern
363     : public OpRewritePattern<vector::TransferReadOp> {
364   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
365 
366   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
367                                 PatternRewriter &rewriter) const override {
368     auto loc = transferReadOp.getLoc();
369     Value vector = transferReadOp.vector();
370     VectorType vectorType = vector.getType().cast<VectorType>();
371     Value source = transferReadOp.source();
372     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
373     // Contiguity check is valid on tensors only.
374     if (!sourceType)
375       return failure();
376     if (vectorType.getRank() <= 1)
377       // Already 0D/1D, nothing to do.
378       return failure();
379     if (!isStaticShapeAndContiguousRowMajor(sourceType))
380       return failure();
381     if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
382       // This pattern requires the source to already be rank-reduced.
383       return failure();
384     if (sourceType.getNumElements() != vectorType.getNumElements())
385       return failure();
386     // TODO: generalize this pattern, relax the requirements here.
387     if (transferReadOp.hasOutOfBoundsDim())
388       return failure();
389     if (!transferReadOp.permutation_map().isMinorIdentity())
390       return failure();
391     if (transferReadOp.mask())
392       return failure();
393     if (llvm::any_of(transferReadOp.indices(),
394                      [](Value v) { return !isZero(v); }))
395       return failure();
396     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
397     auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
398     VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
399                                               sourceType.getElementType());
400     Value source1d =
401         collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
402     Value read1d = rewriter.create<vector::TransferReadOp>(
403         loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
404     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
405         transferReadOp, vector.getType().cast<VectorType>(), read1d);
406     return success();
407   }
408 };
409 
410 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
411 /// memref.collapse_shape on the source so that the resulting
412 /// vector.transfer_write has a 1D source. Requires the source shape to be
413 /// already reduced i.e. without unit dims.
414 class FlattenContiguousRowMajorTransferWritePattern
415     : public OpRewritePattern<vector::TransferWriteOp> {
416   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
417 
418   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
419                                 PatternRewriter &rewriter) const override {
420     auto loc = transferWriteOp.getLoc();
421     Value vector = transferWriteOp.vector();
422     VectorType vectorType = vector.getType().cast<VectorType>();
423     Value source = transferWriteOp.source();
424     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
425     // Contiguity check is valid on tensors only.
426     if (!sourceType)
427       return failure();
428     if (vectorType.getRank() <= 1)
429       // Already 0D/1D, nothing to do.
430       return failure();
431     if (!isStaticShapeAndContiguousRowMajor(sourceType))
432       return failure();
433     if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
434       // This pattern requires the source to already be rank-reduced.
435       return failure();
436     if (sourceType.getNumElements() != vectorType.getNumElements())
437       return failure();
438     // TODO: generalize this pattern, relax the requirements here.
439     if (transferWriteOp.hasOutOfBoundsDim())
440       return failure();
441     if (!transferWriteOp.permutation_map().isMinorIdentity())
442       return failure();
443     if (transferWriteOp.mask())
444       return failure();
445     if (llvm::any_of(transferWriteOp.indices(),
446                      [](Value v) { return !isZero(v); }))
447       return failure();
448     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
449     auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
450     VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
451                                               sourceType.getElementType());
452     Value source1d =
453         collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
454     Value vector1d =
455         rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
456     rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
457                                              ValueRange{c0}, identityMap1D);
458     rewriter.eraseOp(transferWriteOp);
459     return success();
460   }
461 };
462 
463 } // namespace
464 
465 void mlir::vector::transferOpflowOpt(Operation *rootOp) {
466   TransferOptimization opt(rootOp);
467   // Run store to load forwarding first since it can expose more dead store
468   // opportunity.
469   rootOp->walk([&](vector::TransferReadOp read) {
470     if (read.getShapedType().isa<MemRefType>())
471       opt.storeToLoadForwarding(read);
472   });
473   opt.removeDeadOp();
474   rootOp->walk([&](vector::TransferWriteOp write) {
475     if (write.getShapedType().isa<MemRefType>())
476       opt.deadStoreOp(write);
477   });
478   opt.removeDeadOp();
479 }
480 
481 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
482     RewritePatternSet &patterns) {
483   patterns
484       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
485           patterns.getContext());
486   populateShapeCastFoldingPatterns(patterns);
487 }
488 
489 void mlir::vector::populateFlattenVectorTransferPatterns(
490     RewritePatternSet &patterns) {
491   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
492                FlattenContiguousRowMajorTransferWritePattern>(
493       patterns.getContext());
494   populateShapeCastFoldingPatterns(patterns);
495 }
496