1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to pipeline data transfers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/LoopUtils.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/Transforms/Passes.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "affine-pipeline-data-transfer" 28 29 using namespace mlir; 30 31 namespace { 32 struct PipelineDataTransfer 33 : public AffinePipelineDataTransferBase<PipelineDataTransfer> { 34 void runOnOperation() override; 35 void runOnAffineForOp(AffineForOp forOp); 36 37 std::vector<AffineForOp> forOps; 38 }; 39 40 } // namespace 41 42 /// Creates a pass to pipeline explicit movement of data across levels of the 43 /// memory hierarchy. 44 std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() { 45 return std::make_unique<PipelineDataTransfer>(); 46 } 47 48 // Returns the position of the tag memref operand given a DMA operation. 49 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are 50 // added. TODO 51 static unsigned getTagMemRefPos(Operation &dmaOp) { 52 assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp))); 53 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) { 54 return dmaStartOp.getTagMemRefOperandIndex(); 55 } 56 // First operand for a dma finish operation. 57 return 0; 58 } 59 60 /// Doubles the buffer of the supplied memref on the specified 'affine.for' 61 /// operation by adding a leading dimension of size two to the memref. 62 /// Replaces all uses of the old memref by the new one while indexing the newly 63 /// added dimension by the loop IV of the specified 'affine.for' operation 64 /// modulo 2. Returns false if such a replacement cannot be performed. 65 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { 66 auto *forBody = forOp.getBody(); 67 OpBuilder bInner(forBody, forBody->begin()); 68 69 // Doubles the shape with a leading dimension extent of 2. 70 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { 71 // Add the leading dimension in the shape for the double buffer. 72 ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); 73 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); 74 newShape[0] = 2; 75 std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); 76 return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); 77 }; 78 79 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 80 auto newMemRefType = doubleShape(oldMemRefType); 81 82 // The double buffer is allocated right before 'forOp'. 83 OpBuilder bOuter(forOp); 84 // Put together alloc operands for any dynamic dimensions of the memref. 85 SmallVector<Value, 4> allocOperands; 86 for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) { 87 if (dim.value() == ShapedType::kDynamicSize) 88 allocOperands.push_back(bOuter.createOrFold<memref::DimOp>( 89 forOp.getLoc(), oldMemRef, dim.index())); 90 } 91 92 // Create and place the alloc right before the 'affine.for' operation. 93 Value newMemRef = bOuter.create<memref::AllocOp>( 94 forOp.getLoc(), newMemRefType, allocOperands); 95 96 // Create 'iv mod 2' value to index the leading dimension. 97 auto d0 = bInner.getAffineDimExpr(0); 98 int64_t step = forOp.getStep(); 99 auto modTwoMap = 100 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); 101 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, 102 forOp.getInductionVar()); 103 104 // replaceAllMemRefUsesWith will succeed unless the forOp body has 105 // non-dereferencing uses of the memref (dealloc's are fine though). 106 if (failed(replaceAllMemRefUsesWith( 107 oldMemRef, newMemRef, 108 /*extraIndices=*/{ivModTwoOp}, 109 /*indexRemap=*/AffineMap(), 110 /*extraOperands=*/{}, 111 /*symbolOperands=*/{}, 112 /*domOpFilter=*/&*forOp.getBody()->begin()))) { 113 LLVM_DEBUG( 114 forOp.emitError("memref replacement for double buffering failed")); 115 ivModTwoOp.erase(); 116 return false; 117 } 118 // Insert the dealloc op right after the for loop. 119 bOuter.setInsertionPointAfter(forOp); 120 bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef); 121 122 return true; 123 } 124 125 /// Returns success if the IR is in a valid state. 126 void PipelineDataTransfer::runOnOperation() { 127 // Do a post order walk so that inner loop DMAs are processed first. This is 128 // necessary since 'affine.for' operations nested within would otherwise 129 // become invalid (erased) when the outer loop is pipelined (the pipelined one 130 // gets deleted and replaced by a prologue, a new steady-state loop and an 131 // epilogue). 132 forOps.clear(); 133 getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); 134 for (auto forOp : forOps) 135 runOnAffineForOp(forOp); 136 } 137 138 // Check if tags of the dma start op and dma wait op match. 139 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) { 140 if (startOp.getTagMemRef() != waitOp.getTagMemRef()) 141 return false; 142 auto startIndices = startOp.getTagIndices(); 143 auto waitIndices = waitOp.getTagIndices(); 144 // Both of these have the same number of indices since they correspond to the 145 // same tag memref. 146 for (auto it = startIndices.begin(), wIt = waitIndices.begin(), 147 e = startIndices.end(); 148 it != e; ++it, ++wIt) { 149 // Keep it simple for now, just checking if indices match. 150 // TODO: this would in general need to check if there is no 151 // intervening write writing to the same tag location, i.e., memory last 152 // write/data flow analysis. This is however sufficient/powerful enough for 153 // now since the DMA generation pass or the input for it will always have 154 // start/wait with matching tags (same SSA operand indices). 155 if (*it != *wIt) 156 return false; 157 } 158 return true; 159 } 160 161 // Identify matching DMA start/finish operations to overlap computation with. 162 static void findMatchingStartFinishInsts( 163 AffineForOp forOp, 164 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) { 165 166 // Collect outgoing DMA operations - needed to check for dependences below. 167 SmallVector<AffineDmaStartOp, 4> outgoingDmaOps; 168 for (auto &op : *forOp.getBody()) { 169 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 170 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) 171 outgoingDmaOps.push_back(dmaStartOp); 172 } 173 174 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts; 175 for (auto &op : *forOp.getBody()) { 176 // Collect DMA finish operations. 177 if (isa<AffineDmaWaitOp>(op)) { 178 dmaFinishInsts.push_back(&op); 179 continue; 180 } 181 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 182 if (!dmaStartOp) 183 continue; 184 185 // Only DMAs incoming into higher memory spaces are pipelined for now. 186 // TODO: handle outgoing DMA pipelining. 187 if (!dmaStartOp.isDestMemorySpaceFaster()) 188 continue; 189 190 // Check for dependence with outgoing DMAs. Doing this conservatively. 191 // TODO: use the dependence analysis to check for 192 // dependences between an incoming and outgoing DMA in the same iteration. 193 auto *it = outgoingDmaOps.begin(); 194 for (; it != outgoingDmaOps.end(); ++it) { 195 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) 196 break; 197 } 198 if (it != outgoingDmaOps.end()) 199 continue; 200 201 // We only double buffer if the buffer is not live out of loop. 202 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); 203 bool escapingUses = false; 204 for (auto *user : memref.getUsers()) { 205 // We can double buffer regardless of dealloc's outside the loop. 206 if (isa<memref::DeallocOp>(user)) 207 continue; 208 if (!forOp.getBody()->findAncestorOpInBlock(*user)) { 209 LLVM_DEBUG(llvm::dbgs() 210 << "can't pipeline: buffer is live out of loop\n";); 211 escapingUses = true; 212 break; 213 } 214 } 215 if (!escapingUses) 216 dmaStartInsts.push_back(&op); 217 } 218 219 // For each start operation, we look for a matching finish operation. 220 for (auto *dmaStartOp : dmaStartInsts) { 221 for (auto *dmaFinishOp : dmaFinishInsts) { 222 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp), 223 cast<AffineDmaWaitOp>(dmaFinishOp))) { 224 startWaitPairs.push_back({dmaStartOp, dmaFinishOp}); 225 break; 226 } 227 } 228 } 229 } 230 231 /// Overlap DMA transfers with computation in this loop. If successful, 232 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are 233 /// inserted right before where it was. 234 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { 235 auto mayBeConstTripCount = getConstantTripCount(forOp); 236 if (!mayBeConstTripCount.hasValue()) { 237 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count")); 238 return; 239 } 240 241 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; 242 findMatchingStartFinishInsts(forOp, startWaitPairs); 243 244 if (startWaitPairs.empty()) { 245 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n")); 246 return; 247 } 248 249 // Double the buffers for the higher memory space memref's. 250 // Identify memref's to replace by scanning through all DMA start 251 // operations. A DMA start operation has two memref's - the one from the 252 // higher level of memory hierarchy is the one to double buffer. 253 // TODO: check whether double-buffering is even necessary. 254 // TODO: make this work with different layouts: assuming here that 255 // the dimension we are adding here for the double buffering is the outermost 256 // dimension. 257 for (auto &pair : startWaitPairs) { 258 auto *dmaStartOp = pair.first; 259 Value oldMemRef = dmaStartOp->getOperand( 260 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos()); 261 if (!doubleBuffer(oldMemRef, forOp)) { 262 // Normally, double buffering should not fail because we already checked 263 // that there are no uses outside. 264 LLVM_DEBUG(llvm::dbgs() 265 << "double buffering failed for" << dmaStartOp << "\n";); 266 // IR still valid and semantically correct. 267 return; 268 } 269 // If the old memref has no more uses, remove its 'dead' alloc if it was 270 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' 271 // operation could have been used on it if it was dynamically shaped in 272 // order to create the double buffer above.) 273 // '-canonicalize' does this in a more general way, but we'll anyway do the 274 // simple/common case so that the output / test cases looks clear. 275 if (auto *allocOp = oldMemRef.getDefiningOp()) { 276 if (oldMemRef.use_empty()) { 277 allocOp->erase(); 278 } else if (oldMemRef.hasOneUse()) { 279 if (auto dealloc = 280 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) { 281 dealloc.erase(); 282 allocOp->erase(); 283 } 284 } 285 } 286 } 287 288 // Double the buffers for tag memrefs. 289 for (auto &pair : startWaitPairs) { 290 auto *dmaFinishOp = pair.second; 291 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp)); 292 if (!doubleBuffer(oldTagMemRef, forOp)) { 293 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); 294 return; 295 } 296 // If the old tag has no uses or a single dealloc use, remove it. 297 // (canonicalization handles more complex cases). 298 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) { 299 if (oldTagMemRef.use_empty()) { 300 tagAllocOp->erase(); 301 } else if (oldTagMemRef.hasOneUse()) { 302 if (auto dealloc = 303 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) { 304 dealloc.erase(); 305 tagAllocOp->erase(); 306 } 307 } 308 } 309 } 310 311 // Double buffering would have invalidated all the old DMA start/wait insts. 312 startWaitPairs.clear(); 313 findMatchingStartFinishInsts(forOp, startWaitPairs); 314 315 // Store shift for operation for later lookup for AffineApplyOp's. 316 DenseMap<Operation *, unsigned> instShiftMap; 317 for (auto &pair : startWaitPairs) { 318 auto *dmaStartOp = pair.first; 319 assert(isa<AffineDmaStartOp>(dmaStartOp)); 320 instShiftMap[dmaStartOp] = 0; 321 // Set shifts for DMA start op's affine operand computation slices to 0. 322 SmallVector<AffineApplyOp, 4> sliceOps; 323 mlir::createAffineComputationSlice(dmaStartOp, &sliceOps); 324 if (!sliceOps.empty()) { 325 for (auto sliceOp : sliceOps) { 326 instShiftMap[sliceOp.getOperation()] = 0; 327 } 328 } else { 329 // If a slice wasn't created, the reachable affine.apply op's from its 330 // operands are the ones that go with it. 331 SmallVector<Operation *, 4> affineApplyInsts; 332 SmallVector<Value, 4> operands(dmaStartOp->getOperands()); 333 getReachableAffineApplyOps(operands, affineApplyInsts); 334 for (auto *op : affineApplyInsts) { 335 instShiftMap[op] = 0; 336 } 337 } 338 } 339 // Everything else (including compute ops and dma finish) are shifted by one. 340 for (auto &op : forOp.getBody()->without_terminator()) 341 if (instShiftMap.find(&op) == instShiftMap.end()) 342 instShiftMap[&op] = 1; 343 344 // Get shifts stored in map. 345 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size()); 346 unsigned s = 0; 347 for (auto &op : forOp.getBody()->without_terminator()) { 348 assert(instShiftMap.find(&op) != instShiftMap.end()); 349 shifts[s++] = instShiftMap[&op]; 350 351 // Tagging operations with shifts for debugging purposes. 352 LLVM_DEBUG({ 353 OpBuilder b(&op); 354 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); 355 }); 356 } 357 358 if (!isOpwiseShiftValid(forOp, shifts)) { 359 // Violates dependences. 360 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); 361 return; 362 } 363 364 if (failed(affineForOpBodySkew(forOp, shifts))) { 365 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); 366 return; 367 } 368 } 369