1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to pipeline data transfers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/LoopUtils.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/Transforms/Passes.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "affine-pipeline-data-transfer"
28 
29 using namespace mlir;
30 
31 namespace {
32 struct PipelineDataTransfer
33     : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
34   void runOnOperation() override;
35   void runOnAffineForOp(AffineForOp forOp);
36 
37   std::vector<AffineForOp> forOps;
38 };
39 
40 } // namespace
41 
42 /// Creates a pass to pipeline explicit movement of data across levels of the
43 /// memory hierarchy.
44 std::unique_ptr<OperationPass<func::FuncOp>>
45 mlir::createPipelineDataTransferPass() {
46   return std::make_unique<PipelineDataTransfer>();
47 }
48 
49 // Returns the position of the tag memref operand given a DMA operation.
50 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
51 // added.  TODO
52 static unsigned getTagMemRefPos(Operation &dmaOp) {
53   assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
54   if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
55     return dmaStartOp.getTagMemRefOperandIndex();
56   }
57   // First operand for a dma finish operation.
58   return 0;
59 }
60 
61 /// Doubles the buffer of the supplied memref on the specified 'affine.for'
62 /// operation by adding a leading dimension of size two to the memref.
63 /// Replaces all uses of the old memref by the new one while indexing the newly
64 /// added dimension by the loop IV of the specified 'affine.for' operation
65 /// modulo 2. Returns false if such a replacement cannot be performed.
66 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
67   auto *forBody = forOp.getBody();
68   OpBuilder bInner(forBody, forBody->begin());
69 
70   // Doubles the shape with a leading dimension extent of 2.
71   auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
72     // Add the leading dimension in the shape for the double buffer.
73     ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
74     SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
75     newShape[0] = 2;
76     std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
77     return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
78   };
79 
80   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
81   auto newMemRefType = doubleShape(oldMemRefType);
82 
83   // The double buffer is allocated right before 'forOp'.
84   OpBuilder bOuter(forOp);
85   // Put together alloc operands for any dynamic dimensions of the memref.
86   SmallVector<Value, 4> allocOperands;
87   for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
88     if (dim.value() == ShapedType::kDynamicSize)
89       allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
90           forOp.getLoc(), oldMemRef, dim.index()));
91   }
92 
93   // Create and place the alloc right before the 'affine.for' operation.
94   Value newMemRef = bOuter.create<memref::AllocOp>(
95       forOp.getLoc(), newMemRefType, allocOperands);
96 
97   // Create 'iv mod 2' value to index the leading dimension.
98   auto d0 = bInner.getAffineDimExpr(0);
99   int64_t step = forOp.getStep();
100   auto modTwoMap =
101       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
102   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
103                                                  forOp.getInductionVar());
104 
105   // replaceAllMemRefUsesWith will succeed unless the forOp body has
106   // non-dereferencing uses of the memref (dealloc's are fine though).
107   if (failed(replaceAllMemRefUsesWith(
108           oldMemRef, newMemRef,
109           /*extraIndices=*/{ivModTwoOp},
110           /*indexRemap=*/AffineMap(),
111           /*extraOperands=*/{},
112           /*symbolOperands=*/{},
113           /*domOpFilter=*/&*forOp.getBody()->begin()))) {
114     LLVM_DEBUG(
115         forOp.emitError("memref replacement for double buffering failed"));
116     ivModTwoOp.erase();
117     return false;
118   }
119   // Insert the dealloc op right after the for loop.
120   bOuter.setInsertionPointAfter(forOp);
121   bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
122 
123   return true;
124 }
125 
126 /// Returns success if the IR is in a valid state.
127 void PipelineDataTransfer::runOnOperation() {
128   // Do a post order walk so that inner loop DMAs are processed first. This is
129   // necessary since 'affine.for' operations nested within would otherwise
130   // become invalid (erased) when the outer loop is pipelined (the pipelined one
131   // gets deleted and replaced by a prologue, a new steady-state loop and an
132   // epilogue).
133   forOps.clear();
134   getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
135   for (auto forOp : forOps)
136     runOnAffineForOp(forOp);
137 }
138 
139 // Check if tags of the dma start op and dma wait op match.
140 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
141   if (startOp.getTagMemRef() != waitOp.getTagMemRef())
142     return false;
143   auto startIndices = startOp.getTagIndices();
144   auto waitIndices = waitOp.getTagIndices();
145   // Both of these have the same number of indices since they correspond to the
146   // same tag memref.
147   for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
148             e = startIndices.end();
149        it != e; ++it, ++wIt) {
150     // Keep it simple for now, just checking if indices match.
151     // TODO: this would in general need to check if there is no
152     // intervening write writing to the same tag location, i.e., memory last
153     // write/data flow analysis. This is however sufficient/powerful enough for
154     // now since the DMA generation pass or the input for it will always have
155     // start/wait with matching tags (same SSA operand indices).
156     if (*it != *wIt)
157       return false;
158   }
159   return true;
160 }
161 
162 // Identify matching DMA start/finish operations to overlap computation with.
163 static void findMatchingStartFinishInsts(
164     AffineForOp forOp,
165     SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
166 
167   // Collect outgoing DMA operations - needed to check for dependences below.
168   SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
169   for (auto &op : *forOp.getBody()) {
170     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
171     if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
172       outgoingDmaOps.push_back(dmaStartOp);
173   }
174 
175   SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
176   for (auto &op : *forOp.getBody()) {
177     // Collect DMA finish operations.
178     if (isa<AffineDmaWaitOp>(op)) {
179       dmaFinishInsts.push_back(&op);
180       continue;
181     }
182     auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
183     if (!dmaStartOp)
184       continue;
185 
186     // Only DMAs incoming into higher memory spaces are pipelined for now.
187     // TODO: handle outgoing DMA pipelining.
188     if (!dmaStartOp.isDestMemorySpaceFaster())
189       continue;
190 
191     // Check for dependence with outgoing DMAs. Doing this conservatively.
192     // TODO: use the dependence analysis to check for
193     // dependences between an incoming and outgoing DMA in the same iteration.
194     auto *it = outgoingDmaOps.begin();
195     for (; it != outgoingDmaOps.end(); ++it) {
196       if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
197         break;
198     }
199     if (it != outgoingDmaOps.end())
200       continue;
201 
202     // We only double buffer if the buffer is not live out of loop.
203     auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
204     bool escapingUses = false;
205     for (auto *user : memref.getUsers()) {
206       // We can double buffer regardless of dealloc's outside the loop.
207       if (isa<memref::DeallocOp>(user))
208         continue;
209       if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
210         LLVM_DEBUG(llvm::dbgs()
211                        << "can't pipeline: buffer is live out of loop\n";);
212         escapingUses = true;
213         break;
214       }
215     }
216     if (!escapingUses)
217       dmaStartInsts.push_back(&op);
218   }
219 
220   // For each start operation, we look for a matching finish operation.
221   for (auto *dmaStartOp : dmaStartInsts) {
222     for (auto *dmaFinishOp : dmaFinishInsts) {
223       if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
224                         cast<AffineDmaWaitOp>(dmaFinishOp))) {
225         startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
226         break;
227       }
228     }
229   }
230 }
231 
232 /// Overlap DMA transfers with computation in this loop. If successful,
233 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
234 /// inserted right before where it was.
235 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
236   auto mayBeConstTripCount = getConstantTripCount(forOp);
237   if (!mayBeConstTripCount.hasValue()) {
238     LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
239     return;
240   }
241 
242   SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
243   findMatchingStartFinishInsts(forOp, startWaitPairs);
244 
245   if (startWaitPairs.empty()) {
246     LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
247     return;
248   }
249 
250   // Double the buffers for the higher memory space memref's.
251   // Identify memref's to replace by scanning through all DMA start
252   // operations. A DMA start operation has two memref's - the one from the
253   // higher level of memory hierarchy is the one to double buffer.
254   // TODO: check whether double-buffering is even necessary.
255   // TODO: make this work with different layouts: assuming here that
256   // the dimension we are adding here for the double buffering is the outermost
257   // dimension.
258   for (auto &pair : startWaitPairs) {
259     auto *dmaStartOp = pair.first;
260     Value oldMemRef = dmaStartOp->getOperand(
261         cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
262     if (!doubleBuffer(oldMemRef, forOp)) {
263       // Normally, double buffering should not fail because we already checked
264       // that there are no uses outside.
265       LLVM_DEBUG(llvm::dbgs()
266                      << "double buffering failed for" << dmaStartOp << "\n";);
267       // IR still valid and semantically correct.
268       return;
269     }
270     // If the old memref has no more uses, remove its 'dead' alloc if it was
271     // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
272     // operation could have been used on it if it was dynamically shaped in
273     // order to create the double buffer above.)
274     // '-canonicalize' does this in a more general way, but we'll anyway do the
275     // simple/common case so that the output / test cases looks clear.
276     if (auto *allocOp = oldMemRef.getDefiningOp()) {
277       if (oldMemRef.use_empty()) {
278         allocOp->erase();
279       } else if (oldMemRef.hasOneUse()) {
280         if (auto dealloc =
281                 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
282           dealloc.erase();
283           allocOp->erase();
284         }
285       }
286     }
287   }
288 
289   // Double the buffers for tag memrefs.
290   for (auto &pair : startWaitPairs) {
291     auto *dmaFinishOp = pair.second;
292     Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
293     if (!doubleBuffer(oldTagMemRef, forOp)) {
294       LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
295       return;
296     }
297     // If the old tag has no uses or a single dealloc use, remove it.
298     // (canonicalization handles more complex cases).
299     if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
300       if (oldTagMemRef.use_empty()) {
301         tagAllocOp->erase();
302       } else if (oldTagMemRef.hasOneUse()) {
303         if (auto dealloc =
304                 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
305           dealloc.erase();
306           tagAllocOp->erase();
307         }
308       }
309     }
310   }
311 
312   // Double buffering would have invalidated all the old DMA start/wait insts.
313   startWaitPairs.clear();
314   findMatchingStartFinishInsts(forOp, startWaitPairs);
315 
316   // Store shift for operation for later lookup for AffineApplyOp's.
317   DenseMap<Operation *, unsigned> instShiftMap;
318   for (auto &pair : startWaitPairs) {
319     auto *dmaStartOp = pair.first;
320     assert(isa<AffineDmaStartOp>(dmaStartOp));
321     instShiftMap[dmaStartOp] = 0;
322     // Set shifts for DMA start op's affine operand computation slices to 0.
323     SmallVector<AffineApplyOp, 4> sliceOps;
324     mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
325     if (!sliceOps.empty()) {
326       for (auto sliceOp : sliceOps) {
327         instShiftMap[sliceOp.getOperation()] = 0;
328       }
329     } else {
330       // If a slice wasn't created, the reachable affine.apply op's from its
331       // operands are the ones that go with it.
332       SmallVector<Operation *, 4> affineApplyInsts;
333       SmallVector<Value, 4> operands(dmaStartOp->getOperands());
334       getReachableAffineApplyOps(operands, affineApplyInsts);
335       for (auto *op : affineApplyInsts) {
336         instShiftMap[op] = 0;
337       }
338     }
339   }
340   // Everything else (including compute ops and dma finish) are shifted by one.
341   for (auto &op : forOp.getBody()->without_terminator())
342     if (instShiftMap.find(&op) == instShiftMap.end())
343       instShiftMap[&op] = 1;
344 
345   // Get shifts stored in map.
346   SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
347   unsigned s = 0;
348   for (auto &op : forOp.getBody()->without_terminator()) {
349     assert(instShiftMap.find(&op) != instShiftMap.end());
350     shifts[s++] = instShiftMap[&op];
351 
352     // Tagging operations with shifts for debugging purposes.
353     LLVM_DEBUG({
354       OpBuilder b(&op);
355       op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
356     });
357   }
358 
359   if (!isOpwiseShiftValid(forOp, shifts)) {
360     // Violates dependences.
361     LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
362     return;
363   }
364 
365   if (failed(affineForOpBodySkew(forOp, shifts))) {
366     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
367     return;
368   }
369 }
370