1a70aa7bbSRiver Riddle //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to pipeline data transfers.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle 
13a70aa7bbSRiver Riddle #include "PassDetail.h"
14a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
17a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
18a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
19a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h"
20ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
21a70aa7bbSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
22a70aa7bbSRiver Riddle #include "mlir/IR/Builders.h"
23a70aa7bbSRiver Riddle #include "mlir/Transforms/Passes.h"
24a70aa7bbSRiver Riddle #include "llvm/ADT/DenseMap.h"
25a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
26a70aa7bbSRiver Riddle 
27a70aa7bbSRiver Riddle #define DEBUG_TYPE "affine-pipeline-data-transfer"
28a70aa7bbSRiver Riddle 
29a70aa7bbSRiver Riddle using namespace mlir;
30a70aa7bbSRiver Riddle 
31a70aa7bbSRiver Riddle namespace {
32a70aa7bbSRiver Riddle struct PipelineDataTransfer
33a70aa7bbSRiver Riddle     : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
34a70aa7bbSRiver Riddle   void runOnOperation() override;
35a70aa7bbSRiver Riddle   void runOnAffineForOp(AffineForOp forOp);
36a70aa7bbSRiver Riddle 
37a70aa7bbSRiver Riddle   std::vector<AffineForOp> forOps;
38a70aa7bbSRiver Riddle };
39a70aa7bbSRiver Riddle 
40a70aa7bbSRiver Riddle } // namespace
41a70aa7bbSRiver Riddle 
42a70aa7bbSRiver Riddle /// Creates a pass to pipeline explicit movement of data across levels of the
43a70aa7bbSRiver Riddle /// memory hierarchy.
4458ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createPipelineDataTransferPass()4558ceae95SRiver Riddle mlir::createPipelineDataTransferPass() {
46a70aa7bbSRiver Riddle   return std::make_unique<PipelineDataTransfer>();
47a70aa7bbSRiver Riddle }
48a70aa7bbSRiver Riddle 
49a70aa7bbSRiver Riddle // Returns the position of the tag memref operand given a DMA operation.
50a70aa7bbSRiver Riddle // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
51a70aa7bbSRiver Riddle // added.  TODO
getTagMemRefPos(Operation & dmaOp)52a70aa7bbSRiver Riddle static unsigned getTagMemRefPos(Operation &dmaOp) {
53a70aa7bbSRiver Riddle   assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
54a70aa7bbSRiver Riddle   if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
55a70aa7bbSRiver Riddle     return dmaStartOp.getTagMemRefOperandIndex();
56a70aa7bbSRiver Riddle   }
57a70aa7bbSRiver Riddle   // First operand for a dma finish operation.
58a70aa7bbSRiver Riddle   return 0;
59a70aa7bbSRiver Riddle }
60a70aa7bbSRiver Riddle 
61a70aa7bbSRiver Riddle /// Doubles the buffer of the supplied memref on the specified 'affine.for'
62a70aa7bbSRiver Riddle /// operation by adding a leading dimension of size two to the memref.
63a70aa7bbSRiver Riddle /// Replaces all uses of the old memref by the new one while indexing the newly
64a70aa7bbSRiver Riddle /// added dimension by the loop IV of the specified 'affine.for' operation
65a70aa7bbSRiver Riddle /// modulo 2. Returns false if such a replacement cannot be performed.
doubleBuffer(Value oldMemRef,AffineForOp forOp)66a70aa7bbSRiver Riddle static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
67a70aa7bbSRiver Riddle   auto *forBody = forOp.getBody();
68a70aa7bbSRiver Riddle   OpBuilder bInner(forBody, forBody->begin());
69a70aa7bbSRiver Riddle 
70a70aa7bbSRiver Riddle   // Doubles the shape with a leading dimension extent of 2.
71a70aa7bbSRiver Riddle   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
72a70aa7bbSRiver Riddle     // Add the leading dimension in the shape for the double buffer.
73a70aa7bbSRiver Riddle     ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
74a70aa7bbSRiver Riddle     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
75a70aa7bbSRiver Riddle     newShape[0] = 2;
76a70aa7bbSRiver Riddle     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
77a70aa7bbSRiver Riddle     return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
78a70aa7bbSRiver Riddle   };
79a70aa7bbSRiver Riddle 
80a70aa7bbSRiver Riddle   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
81a70aa7bbSRiver Riddle   auto newMemRefType = doubleShape(oldMemRefType);
82a70aa7bbSRiver Riddle 
83a70aa7bbSRiver Riddle   // The double buffer is allocated right before 'forOp'.
84a70aa7bbSRiver Riddle   OpBuilder bOuter(forOp);
85a70aa7bbSRiver Riddle   // Put together alloc operands for any dynamic dimensions of the memref.
86a70aa7bbSRiver Riddle   SmallVector<Value, 4> allocOperands;
87a70aa7bbSRiver Riddle   for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
88a70aa7bbSRiver Riddle     if (dim.value() == ShapedType::kDynamicSize)
89a70aa7bbSRiver Riddle       allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
90a70aa7bbSRiver Riddle           forOp.getLoc(), oldMemRef, dim.index()));
91a70aa7bbSRiver Riddle   }
92a70aa7bbSRiver Riddle 
93a70aa7bbSRiver Riddle   // Create and place the alloc right before the 'affine.for' operation.
94a70aa7bbSRiver Riddle   Value newMemRef = bOuter.create<memref::AllocOp>(
95a70aa7bbSRiver Riddle       forOp.getLoc(), newMemRefType, allocOperands);
96a70aa7bbSRiver Riddle 
97a70aa7bbSRiver Riddle   // Create 'iv mod 2' value to index the leading dimension.
98a70aa7bbSRiver Riddle   auto d0 = bInner.getAffineDimExpr(0);
99a70aa7bbSRiver Riddle   int64_t step = forOp.getStep();
100a70aa7bbSRiver Riddle   auto modTwoMap =
101a70aa7bbSRiver Riddle       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
102a70aa7bbSRiver Riddle   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
103a70aa7bbSRiver Riddle                                                  forOp.getInductionVar());
104a70aa7bbSRiver Riddle 
105a70aa7bbSRiver Riddle   // replaceAllMemRefUsesWith will succeed unless the forOp body has
106a70aa7bbSRiver Riddle   // non-dereferencing uses of the memref (dealloc's are fine though).
107a70aa7bbSRiver Riddle   if (failed(replaceAllMemRefUsesWith(
108a70aa7bbSRiver Riddle           oldMemRef, newMemRef,
109a70aa7bbSRiver Riddle           /*extraIndices=*/{ivModTwoOp},
110a70aa7bbSRiver Riddle           /*indexRemap=*/AffineMap(),
111a70aa7bbSRiver Riddle           /*extraOperands=*/{},
112a70aa7bbSRiver Riddle           /*symbolOperands=*/{},
113a70aa7bbSRiver Riddle           /*domOpFilter=*/&*forOp.getBody()->begin()))) {
114a70aa7bbSRiver Riddle     LLVM_DEBUG(
115a70aa7bbSRiver Riddle         forOp.emitError("memref replacement for double buffering failed"));
116a70aa7bbSRiver Riddle     ivModTwoOp.erase();
117a70aa7bbSRiver Riddle     return false;
118a70aa7bbSRiver Riddle   }
119a70aa7bbSRiver Riddle   // Insert the dealloc op right after the for loop.
120a70aa7bbSRiver Riddle   bOuter.setInsertionPointAfter(forOp);
121a70aa7bbSRiver Riddle   bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
122a70aa7bbSRiver Riddle 
123a70aa7bbSRiver Riddle   return true;
124a70aa7bbSRiver Riddle }
125a70aa7bbSRiver Riddle 
126a70aa7bbSRiver Riddle /// Returns success if the IR is in a valid state.
runOnOperation()127a70aa7bbSRiver Riddle void PipelineDataTransfer::runOnOperation() {
128a70aa7bbSRiver Riddle   // Do a post order walk so that inner loop DMAs are processed first. This is
129a70aa7bbSRiver Riddle   // necessary since 'affine.for' operations nested within would otherwise
130a70aa7bbSRiver Riddle   // become invalid (erased) when the outer loop is pipelined (the pipelined one
131a70aa7bbSRiver Riddle   // gets deleted and replaced by a prologue, a new steady-state loop and an
132a70aa7bbSRiver Riddle   // epilogue).
133a70aa7bbSRiver Riddle   forOps.clear();
134a70aa7bbSRiver Riddle   getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
135a70aa7bbSRiver Riddle   for (auto forOp : forOps)
136a70aa7bbSRiver Riddle     runOnAffineForOp(forOp);
137a70aa7bbSRiver Riddle }
138a70aa7bbSRiver Riddle 
139a70aa7bbSRiver Riddle // Check if tags of the dma start op and dma wait op match.
checkTagMatch(AffineDmaStartOp startOp,AffineDmaWaitOp waitOp)140a70aa7bbSRiver Riddle static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
141a70aa7bbSRiver Riddle   if (startOp.getTagMemRef() != waitOp.getTagMemRef())
142a70aa7bbSRiver Riddle     return false;
143a70aa7bbSRiver Riddle   auto startIndices = startOp.getTagIndices();
144a70aa7bbSRiver Riddle   auto waitIndices = waitOp.getTagIndices();
145a70aa7bbSRiver Riddle   // Both of these have the same number of indices since they correspond to the
146a70aa7bbSRiver Riddle   // same tag memref.
147a70aa7bbSRiver Riddle   for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
148a70aa7bbSRiver Riddle             e = startIndices.end();
149a70aa7bbSRiver Riddle        it != e; ++it, ++wIt) {
150a70aa7bbSRiver Riddle     // Keep it simple for now, just checking if indices match.
151a70aa7bbSRiver Riddle     // TODO: this would in general need to check if there is no
152a70aa7bbSRiver Riddle     // intervening write writing to the same tag location, i.e., memory last
153a70aa7bbSRiver Riddle     // write/data flow analysis. This is however sufficient/powerful enough for
154a70aa7bbSRiver Riddle     // now since the DMA generation pass or the input for it will always have
155a70aa7bbSRiver Riddle     // start/wait with matching tags (same SSA operand indices).
156a70aa7bbSRiver Riddle     if (*it != *wIt)
157a70aa7bbSRiver Riddle       return false;
158a70aa7bbSRiver Riddle   }
159a70aa7bbSRiver Riddle   return true;
160a70aa7bbSRiver Riddle }
161a70aa7bbSRiver Riddle 
162a70aa7bbSRiver Riddle // Identify matching DMA start/finish operations to overlap computation with.
findMatchingStartFinishInsts(AffineForOp forOp,SmallVectorImpl<std::pair<Operation *,Operation * >> & startWaitPairs)163a70aa7bbSRiver Riddle static void findMatchingStartFinishInsts(
164a70aa7bbSRiver Riddle     AffineForOp forOp,
165a70aa7bbSRiver Riddle     SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
166a70aa7bbSRiver Riddle 
167a70aa7bbSRiver Riddle   // Collect outgoing DMA operations - needed to check for dependences below.
168a70aa7bbSRiver Riddle   SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
169a70aa7bbSRiver Riddle   for (auto &op : *forOp.getBody()) {
170a70aa7bbSRiver Riddle     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
171a70aa7bbSRiver Riddle     if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
172a70aa7bbSRiver Riddle       outgoingDmaOps.push_back(dmaStartOp);
173a70aa7bbSRiver Riddle   }
174a70aa7bbSRiver Riddle 
175a70aa7bbSRiver Riddle   SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
176a70aa7bbSRiver Riddle   for (auto &op : *forOp.getBody()) {
177a70aa7bbSRiver Riddle     // Collect DMA finish operations.
178a70aa7bbSRiver Riddle     if (isa<AffineDmaWaitOp>(op)) {
179a70aa7bbSRiver Riddle       dmaFinishInsts.push_back(&op);
180a70aa7bbSRiver Riddle       continue;
181a70aa7bbSRiver Riddle     }
182a70aa7bbSRiver Riddle     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
183a70aa7bbSRiver Riddle     if (!dmaStartOp)
184a70aa7bbSRiver Riddle       continue;
185a70aa7bbSRiver Riddle 
186a70aa7bbSRiver Riddle     // Only DMAs incoming into higher memory spaces are pipelined for now.
187a70aa7bbSRiver Riddle     // TODO: handle outgoing DMA pipelining.
188a70aa7bbSRiver Riddle     if (!dmaStartOp.isDestMemorySpaceFaster())
189a70aa7bbSRiver Riddle       continue;
190a70aa7bbSRiver Riddle 
191a70aa7bbSRiver Riddle     // Check for dependence with outgoing DMAs. Doing this conservatively.
192a70aa7bbSRiver Riddle     // TODO: use the dependence analysis to check for
193a70aa7bbSRiver Riddle     // dependences between an incoming and outgoing DMA in the same iteration.
194a70aa7bbSRiver Riddle     auto *it = outgoingDmaOps.begin();
195a70aa7bbSRiver Riddle     for (; it != outgoingDmaOps.end(); ++it) {
196a70aa7bbSRiver Riddle       if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
197a70aa7bbSRiver Riddle         break;
198a70aa7bbSRiver Riddle     }
199a70aa7bbSRiver Riddle     if (it != outgoingDmaOps.end())
200a70aa7bbSRiver Riddle       continue;
201a70aa7bbSRiver Riddle 
202a70aa7bbSRiver Riddle     // We only double buffer if the buffer is not live out of loop.
203a70aa7bbSRiver Riddle     auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
204a70aa7bbSRiver Riddle     bool escapingUses = false;
205a70aa7bbSRiver Riddle     for (auto *user : memref.getUsers()) {
206a70aa7bbSRiver Riddle       // We can double buffer regardless of dealloc's outside the loop.
207a70aa7bbSRiver Riddle       if (isa<memref::DeallocOp>(user))
208a70aa7bbSRiver Riddle         continue;
209a70aa7bbSRiver Riddle       if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
210a70aa7bbSRiver Riddle         LLVM_DEBUG(llvm::dbgs()
211a70aa7bbSRiver Riddle                        << "can't pipeline: buffer is live out of loop\n";);
212a70aa7bbSRiver Riddle         escapingUses = true;
213a70aa7bbSRiver Riddle         break;
214a70aa7bbSRiver Riddle       }
215a70aa7bbSRiver Riddle     }
216a70aa7bbSRiver Riddle     if (!escapingUses)
217a70aa7bbSRiver Riddle       dmaStartInsts.push_back(&op);
218a70aa7bbSRiver Riddle   }
219a70aa7bbSRiver Riddle 
220a70aa7bbSRiver Riddle   // For each start operation, we look for a matching finish operation.
221a70aa7bbSRiver Riddle   for (auto *dmaStartOp : dmaStartInsts) {
222a70aa7bbSRiver Riddle     for (auto *dmaFinishOp : dmaFinishInsts) {
223a70aa7bbSRiver Riddle       if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
224a70aa7bbSRiver Riddle                         cast<AffineDmaWaitOp>(dmaFinishOp))) {
225a70aa7bbSRiver Riddle         startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
226a70aa7bbSRiver Riddle         break;
227a70aa7bbSRiver Riddle       }
228a70aa7bbSRiver Riddle     }
229a70aa7bbSRiver Riddle   }
230a70aa7bbSRiver Riddle }
231a70aa7bbSRiver Riddle 
232a70aa7bbSRiver Riddle /// Overlap DMA transfers with computation in this loop. If successful,
233a70aa7bbSRiver Riddle /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
234a70aa7bbSRiver Riddle /// inserted right before where it was.
runOnAffineForOp(AffineForOp forOp)235a70aa7bbSRiver Riddle void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
236a70aa7bbSRiver Riddle   auto mayBeConstTripCount = getConstantTripCount(forOp);
237*037f0995SKazu Hirata   if (!mayBeConstTripCount) {
238a70aa7bbSRiver Riddle     LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
239a70aa7bbSRiver Riddle     return;
240a70aa7bbSRiver Riddle   }
241a70aa7bbSRiver Riddle 
242a70aa7bbSRiver Riddle   SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
243a70aa7bbSRiver Riddle   findMatchingStartFinishInsts(forOp, startWaitPairs);
244a70aa7bbSRiver Riddle 
245a70aa7bbSRiver Riddle   if (startWaitPairs.empty()) {
246a70aa7bbSRiver Riddle     LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
247a70aa7bbSRiver Riddle     return;
248a70aa7bbSRiver Riddle   }
249a70aa7bbSRiver Riddle 
250a70aa7bbSRiver Riddle   // Double the buffers for the higher memory space memref's.
251a70aa7bbSRiver Riddle   // Identify memref's to replace by scanning through all DMA start
252a70aa7bbSRiver Riddle   // operations. A DMA start operation has two memref's - the one from the
253a70aa7bbSRiver Riddle   // higher level of memory hierarchy is the one to double buffer.
254a70aa7bbSRiver Riddle   // TODO: check whether double-buffering is even necessary.
255a70aa7bbSRiver Riddle   // TODO: make this work with different layouts: assuming here that
256a70aa7bbSRiver Riddle   // the dimension we are adding here for the double buffering is the outermost
257a70aa7bbSRiver Riddle   // dimension.
258a70aa7bbSRiver Riddle   for (auto &pair : startWaitPairs) {
259a70aa7bbSRiver Riddle     auto *dmaStartOp = pair.first;
260a70aa7bbSRiver Riddle     Value oldMemRef = dmaStartOp->getOperand(
261a70aa7bbSRiver Riddle         cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
262a70aa7bbSRiver Riddle     if (!doubleBuffer(oldMemRef, forOp)) {
263a70aa7bbSRiver Riddle       // Normally, double buffering should not fail because we already checked
264a70aa7bbSRiver Riddle       // that there are no uses outside.
265a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs()
266a70aa7bbSRiver Riddle                      << "double buffering failed for" << dmaStartOp << "\n";);
267a70aa7bbSRiver Riddle       // IR still valid and semantically correct.
268a70aa7bbSRiver Riddle       return;
269a70aa7bbSRiver Riddle     }
270a70aa7bbSRiver Riddle     // If the old memref has no more uses, remove its 'dead' alloc if it was
271a70aa7bbSRiver Riddle     // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
272a70aa7bbSRiver Riddle     // operation could have been used on it if it was dynamically shaped in
273a70aa7bbSRiver Riddle     // order to create the double buffer above.)
274a70aa7bbSRiver Riddle     // '-canonicalize' does this in a more general way, but we'll anyway do the
275a70aa7bbSRiver Riddle     // simple/common case so that the output / test cases looks clear.
276a70aa7bbSRiver Riddle     if (auto *allocOp = oldMemRef.getDefiningOp()) {
277a70aa7bbSRiver Riddle       if (oldMemRef.use_empty()) {
278a70aa7bbSRiver Riddle         allocOp->erase();
279a70aa7bbSRiver Riddle       } else if (oldMemRef.hasOneUse()) {
280a70aa7bbSRiver Riddle         if (auto dealloc =
281a70aa7bbSRiver Riddle                 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
282a70aa7bbSRiver Riddle           dealloc.erase();
283a70aa7bbSRiver Riddle           allocOp->erase();
284a70aa7bbSRiver Riddle         }
285a70aa7bbSRiver Riddle       }
286a70aa7bbSRiver Riddle     }
287a70aa7bbSRiver Riddle   }
288a70aa7bbSRiver Riddle 
289a70aa7bbSRiver Riddle   // Double the buffers for tag memrefs.
290a70aa7bbSRiver Riddle   for (auto &pair : startWaitPairs) {
291a70aa7bbSRiver Riddle     auto *dmaFinishOp = pair.second;
292a70aa7bbSRiver Riddle     Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
293a70aa7bbSRiver Riddle     if (!doubleBuffer(oldTagMemRef, forOp)) {
294a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
295a70aa7bbSRiver Riddle       return;
296a70aa7bbSRiver Riddle     }
297a70aa7bbSRiver Riddle     // If the old tag has no uses or a single dealloc use, remove it.
298a70aa7bbSRiver Riddle     // (canonicalization handles more complex cases).
299a70aa7bbSRiver Riddle     if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
300a70aa7bbSRiver Riddle       if (oldTagMemRef.use_empty()) {
301a70aa7bbSRiver Riddle         tagAllocOp->erase();
302a70aa7bbSRiver Riddle       } else if (oldTagMemRef.hasOneUse()) {
303a70aa7bbSRiver Riddle         if (auto dealloc =
304a70aa7bbSRiver Riddle                 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
305a70aa7bbSRiver Riddle           dealloc.erase();
306a70aa7bbSRiver Riddle           tagAllocOp->erase();
307a70aa7bbSRiver Riddle         }
308a70aa7bbSRiver Riddle       }
309a70aa7bbSRiver Riddle     }
310a70aa7bbSRiver Riddle   }
311a70aa7bbSRiver Riddle 
312a70aa7bbSRiver Riddle   // Double buffering would have invalidated all the old DMA start/wait insts.
313a70aa7bbSRiver Riddle   startWaitPairs.clear();
314a70aa7bbSRiver Riddle   findMatchingStartFinishInsts(forOp, startWaitPairs);
315a70aa7bbSRiver Riddle 
316a70aa7bbSRiver Riddle   // Store shift for operation for later lookup for AffineApplyOp's.
317a70aa7bbSRiver Riddle   DenseMap<Operation *, unsigned> instShiftMap;
318a70aa7bbSRiver Riddle   for (auto &pair : startWaitPairs) {
319a70aa7bbSRiver Riddle     auto *dmaStartOp = pair.first;
320a70aa7bbSRiver Riddle     assert(isa<AffineDmaStartOp>(dmaStartOp));
321a70aa7bbSRiver Riddle     instShiftMap[dmaStartOp] = 0;
322a70aa7bbSRiver Riddle     // Set shifts for DMA start op's affine operand computation slices to 0.
323a70aa7bbSRiver Riddle     SmallVector<AffineApplyOp, 4> sliceOps;
324a70aa7bbSRiver Riddle     mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
325a70aa7bbSRiver Riddle     if (!sliceOps.empty()) {
326a70aa7bbSRiver Riddle       for (auto sliceOp : sliceOps) {
327a70aa7bbSRiver Riddle         instShiftMap[sliceOp.getOperation()] = 0;
328a70aa7bbSRiver Riddle       }
329a70aa7bbSRiver Riddle     } else {
330a70aa7bbSRiver Riddle       // If a slice wasn't created, the reachable affine.apply op's from its
331a70aa7bbSRiver Riddle       // operands are the ones that go with it.
332a70aa7bbSRiver Riddle       SmallVector<Operation *, 4> affineApplyInsts;
333a70aa7bbSRiver Riddle       SmallVector<Value, 4> operands(dmaStartOp->getOperands());
334a70aa7bbSRiver Riddle       getReachableAffineApplyOps(operands, affineApplyInsts);
335a70aa7bbSRiver Riddle       for (auto *op : affineApplyInsts) {
336a70aa7bbSRiver Riddle         instShiftMap[op] = 0;
337a70aa7bbSRiver Riddle       }
338a70aa7bbSRiver Riddle     }
339a70aa7bbSRiver Riddle   }
340a70aa7bbSRiver Riddle   // Everything else (including compute ops and dma finish) are shifted by one.
341a70aa7bbSRiver Riddle   for (auto &op : forOp.getBody()->without_terminator())
342a70aa7bbSRiver Riddle     if (instShiftMap.find(&op) == instShiftMap.end())
343a70aa7bbSRiver Riddle       instShiftMap[&op] = 1;
344a70aa7bbSRiver Riddle 
345a70aa7bbSRiver Riddle   // Get shifts stored in map.
346a70aa7bbSRiver Riddle   SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
347a70aa7bbSRiver Riddle   unsigned s = 0;
348a70aa7bbSRiver Riddle   for (auto &op : forOp.getBody()->without_terminator()) {
349a70aa7bbSRiver Riddle     assert(instShiftMap.find(&op) != instShiftMap.end());
350a70aa7bbSRiver Riddle     shifts[s++] = instShiftMap[&op];
351a70aa7bbSRiver Riddle 
352a70aa7bbSRiver Riddle     // Tagging operations with shifts for debugging purposes.
353a70aa7bbSRiver Riddle     LLVM_DEBUG({
354a70aa7bbSRiver Riddle       OpBuilder b(&op);
355a70aa7bbSRiver Riddle       op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
356a70aa7bbSRiver Riddle     });
357a70aa7bbSRiver Riddle   }
358a70aa7bbSRiver Riddle 
359a70aa7bbSRiver Riddle   if (!isOpwiseShiftValid(forOp, shifts)) {
360a70aa7bbSRiver Riddle     // Violates dependences.
361a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
362a70aa7bbSRiver Riddle     return;
363a70aa7bbSRiver Riddle   }
364a70aa7bbSRiver Riddle 
365a70aa7bbSRiver Riddle   if (failed(affineForOpBodySkew(forOp, shifts))) {
366a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
367a70aa7bbSRiver Riddle     return;
368a70aa7bbSRiver Riddle   }
369a70aa7bbSRiver Riddle }
370