1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to pipeline data transfers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/LoopUtils.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/Transforms/Passes.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "affine-pipeline-data-transfer"
28 
29 using namespace mlir;
30 
31 namespace {
32 struct PipelineDataTransfer
33     : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
34   void runOnOperation() override;
35   void runOnAffineForOp(AffineForOp forOp);
36 
37   std::vector<AffineForOp> forOps;
38 };
39 
40 } // namespace
41 
42 /// Creates a pass to pipeline explicit movement of data across levels of the
43 /// memory hierarchy.
44 std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() {
45   return std::make_unique<PipelineDataTransfer>();
46 }
47 
48 // Returns the position of the tag memref operand given a DMA operation.
49 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
50 // added.  TODO
51 static unsigned getTagMemRefPos(Operation &dmaOp) {
52   assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
53   if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
54     return dmaStartOp.getTagMemRefOperandIndex();
55   }
56   // First operand for a dma finish operation.
57   return 0;
58 }
59 
60 /// Doubles the buffer of the supplied memref on the specified 'affine.for'
61 /// operation by adding a leading dimension of size two to the memref.
62 /// Replaces all uses of the old memref by the new one while indexing the newly
63 /// added dimension by the loop IV of the specified 'affine.for' operation
64 /// modulo 2. Returns false if such a replacement cannot be performed.
65 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
66   auto *forBody = forOp.getBody();
67   OpBuilder bInner(forBody, forBody->begin());
68 
69   // Doubles the shape with a leading dimension extent of 2.
70   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
71     // Add the leading dimension in the shape for the double buffer.
72     ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
73     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
74     newShape[0] = 2;
75     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
76     return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
77   };
78 
79   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
80   auto newMemRefType = doubleShape(oldMemRefType);
81 
82   // The double buffer is allocated right before 'forOp'.
83   OpBuilder bOuter(forOp);
84   // Put together alloc operands for any dynamic dimensions of the memref.
85   SmallVector<Value, 4> allocOperands;
86   for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
87     if (dim.value() == ShapedType::kDynamicSize)
88       allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
89           forOp.getLoc(), oldMemRef, dim.index()));
90   }
91 
92   // Create and place the alloc right before the 'affine.for' operation.
93   Value newMemRef = bOuter.create<memref::AllocOp>(
94       forOp.getLoc(), newMemRefType, allocOperands);
95 
96   // Create 'iv mod 2' value to index the leading dimension.
97   auto d0 = bInner.getAffineDimExpr(0);
98   int64_t step = forOp.getStep();
99   auto modTwoMap =
100       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
101   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
102                                                  forOp.getInductionVar());
103 
104   // replaceAllMemRefUsesWith will succeed unless the forOp body has
105   // non-dereferencing uses of the memref (dealloc's are fine though).
106   if (failed(replaceAllMemRefUsesWith(
107           oldMemRef, newMemRef,
108           /*extraIndices=*/{ivModTwoOp},
109           /*indexRemap=*/AffineMap(),
110           /*extraOperands=*/{},
111           /*symbolOperands=*/{},
112           /*domOpFilter=*/&*forOp.getBody()->begin()))) {
113     LLVM_DEBUG(
114         forOp.emitError("memref replacement for double buffering failed"));
115     ivModTwoOp.erase();
116     return false;
117   }
118   // Insert the dealloc op right after the for loop.
119   bOuter.setInsertionPointAfter(forOp);
120   bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
121 
122   return true;
123 }
124 
125 /// Returns success if the IR is in a valid state.
126 void PipelineDataTransfer::runOnOperation() {
127   // Do a post order walk so that inner loop DMAs are processed first. This is
128   // necessary since 'affine.for' operations nested within would otherwise
129   // become invalid (erased) when the outer loop is pipelined (the pipelined one
130   // gets deleted and replaced by a prologue, a new steady-state loop and an
131   // epilogue).
132   forOps.clear();
133   getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
134   for (auto forOp : forOps)
135     runOnAffineForOp(forOp);
136 }
137 
138 // Check if tags of the dma start op and dma wait op match.
139 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
140   if (startOp.getTagMemRef() != waitOp.getTagMemRef())
141     return false;
142   auto startIndices = startOp.getTagIndices();
143   auto waitIndices = waitOp.getTagIndices();
144   // Both of these have the same number of indices since they correspond to the
145   // same tag memref.
146   for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
147             e = startIndices.end();
148        it != e; ++it, ++wIt) {
149     // Keep it simple for now, just checking if indices match.
150     // TODO: this would in general need to check if there is no
151     // intervening write writing to the same tag location, i.e., memory last
152     // write/data flow analysis. This is however sufficient/powerful enough for
153     // now since the DMA generation pass or the input for it will always have
154     // start/wait with matching tags (same SSA operand indices).
155     if (*it != *wIt)
156       return false;
157   }
158   return true;
159 }
160 
161 // Identify matching DMA start/finish operations to overlap computation with.
162 static void findMatchingStartFinishInsts(
163     AffineForOp forOp,
164     SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
165 
166   // Collect outgoing DMA operations - needed to check for dependences below.
167   SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
168   for (auto &op : *forOp.getBody()) {
169     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
170     if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
171       outgoingDmaOps.push_back(dmaStartOp);
172   }
173 
174   SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
175   for (auto &op : *forOp.getBody()) {
176     // Collect DMA finish operations.
177     if (isa<AffineDmaWaitOp>(op)) {
178       dmaFinishInsts.push_back(&op);
179       continue;
180     }
181     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
182     if (!dmaStartOp)
183       continue;
184 
185     // Only DMAs incoming into higher memory spaces are pipelined for now.
186     // TODO: handle outgoing DMA pipelining.
187     if (!dmaStartOp.isDestMemorySpaceFaster())
188       continue;
189 
190     // Check for dependence with outgoing DMAs. Doing this conservatively.
191     // TODO: use the dependence analysis to check for
192     // dependences between an incoming and outgoing DMA in the same iteration.
193     auto *it = outgoingDmaOps.begin();
194     for (; it != outgoingDmaOps.end(); ++it) {
195       if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
196         break;
197     }
198     if (it != outgoingDmaOps.end())
199       continue;
200 
201     // We only double buffer if the buffer is not live out of loop.
202     auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
203     bool escapingUses = false;
204     for (auto *user : memref.getUsers()) {
205       // We can double buffer regardless of dealloc's outside the loop.
206       if (isa<memref::DeallocOp>(user))
207         continue;
208       if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
209         LLVM_DEBUG(llvm::dbgs()
210                        << "can't pipeline: buffer is live out of loop\n";);
211         escapingUses = true;
212         break;
213       }
214     }
215     if (!escapingUses)
216       dmaStartInsts.push_back(&op);
217   }
218 
219   // For each start operation, we look for a matching finish operation.
220   for (auto *dmaStartOp : dmaStartInsts) {
221     for (auto *dmaFinishOp : dmaFinishInsts) {
222       if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
223                         cast<AffineDmaWaitOp>(dmaFinishOp))) {
224         startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
225         break;
226       }
227     }
228   }
229 }
230 
231 /// Overlap DMA transfers with computation in this loop. If successful,
232 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
233 /// inserted right before where it was.
234 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
235   auto mayBeConstTripCount = getConstantTripCount(forOp);
236   if (!mayBeConstTripCount.hasValue()) {
237     LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
238     return;
239   }
240 
241   SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
242   findMatchingStartFinishInsts(forOp, startWaitPairs);
243 
244   if (startWaitPairs.empty()) {
245     LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
246     return;
247   }
248 
249   // Double the buffers for the higher memory space memref's.
250   // Identify memref's to replace by scanning through all DMA start
251   // operations. A DMA start operation has two memref's - the one from the
252   // higher level of memory hierarchy is the one to double buffer.
253   // TODO: check whether double-buffering is even necessary.
254   // TODO: make this work with different layouts: assuming here that
255   // the dimension we are adding here for the double buffering is the outermost
256   // dimension.
257   for (auto &pair : startWaitPairs) {
258     auto *dmaStartOp = pair.first;
259     Value oldMemRef = dmaStartOp->getOperand(
260         cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
261     if (!doubleBuffer(oldMemRef, forOp)) {
262       // Normally, double buffering should not fail because we already checked
263       // that there are no uses outside.
264       LLVM_DEBUG(llvm::dbgs()
265                      << "double buffering failed for" << dmaStartOp << "\n";);
266       // IR still valid and semantically correct.
267       return;
268     }
269     // If the old memref has no more uses, remove its 'dead' alloc if it was
270     // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
271     // operation could have been used on it if it was dynamically shaped in
272     // order to create the double buffer above.)
273     // '-canonicalize' does this in a more general way, but we'll anyway do the
274     // simple/common case so that the output / test cases looks clear.
275     if (auto *allocOp = oldMemRef.getDefiningOp()) {
276       if (oldMemRef.use_empty()) {
277         allocOp->erase();
278       } else if (oldMemRef.hasOneUse()) {
279         if (auto dealloc =
280                 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
281           dealloc.erase();
282           allocOp->erase();
283         }
284       }
285     }
286   }
287 
288   // Double the buffers for tag memrefs.
289   for (auto &pair : startWaitPairs) {
290     auto *dmaFinishOp = pair.second;
291     Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
292     if (!doubleBuffer(oldTagMemRef, forOp)) {
293       LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
294       return;
295     }
296     // If the old tag has no uses or a single dealloc use, remove it.
297     // (canonicalization handles more complex cases).
298     if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
299       if (oldTagMemRef.use_empty()) {
300         tagAllocOp->erase();
301       } else if (oldTagMemRef.hasOneUse()) {
302         if (auto dealloc =
303                 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
304           dealloc.erase();
305           tagAllocOp->erase();
306         }
307       }
308     }
309   }
310 
311   // Double buffering would have invalidated all the old DMA start/wait insts.
312   startWaitPairs.clear();
313   findMatchingStartFinishInsts(forOp, startWaitPairs);
314 
315   // Store shift for operation for later lookup for AffineApplyOp's.
316   DenseMap<Operation *, unsigned> instShiftMap;
317   for (auto &pair : startWaitPairs) {
318     auto *dmaStartOp = pair.first;
319     assert(isa<AffineDmaStartOp>(dmaStartOp));
320     instShiftMap[dmaStartOp] = 0;
321     // Set shifts for DMA start op's affine operand computation slices to 0.
322     SmallVector<AffineApplyOp, 4> sliceOps;
323     mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
324     if (!sliceOps.empty()) {
325       for (auto sliceOp : sliceOps) {
326         instShiftMap[sliceOp.getOperation()] = 0;
327       }
328     } else {
329       // If a slice wasn't created, the reachable affine.apply op's from its
330       // operands are the ones that go with it.
331       SmallVector<Operation *, 4> affineApplyInsts;
332       SmallVector<Value, 4> operands(dmaStartOp->getOperands());
333       getReachableAffineApplyOps(operands, affineApplyInsts);
334       for (auto *op : affineApplyInsts) {
335         instShiftMap[op] = 0;
336       }
337     }
338   }
339   // Everything else (including compute ops and dma finish) are shifted by one.
340   for (auto &op : forOp.getBody()->without_terminator())
341     if (instShiftMap.find(&op) == instShiftMap.end())
342       instShiftMap[&op] = 1;
343 
344   // Get shifts stored in map.
345   SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
346   unsigned s = 0;
347   for (auto &op : forOp.getBody()->without_terminator()) {
348     assert(instShiftMap.find(&op) != instShiftMap.end());
349     shifts[s++] = instShiftMap[&op];
350 
351     // Tagging operations with shifts for debugging purposes.
352     LLVM_DEBUG({
353       OpBuilder b(&op);
354       op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
355     });
356   }
357 
358   if (!isOpwiseShiftValid(forOp, shifts)) {
359     // Violates dependences.
360     LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
361     return;
362   }
363 
364   if (failed(affineForOpBodySkew(forOp, shifts))) {
365     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
366     return;
367   }
368 }
369