1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to pipeline data transfers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/LoopUtils.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/Transforms/Passes.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "affine-pipeline-data-transfer" 28 29 using namespace mlir; 30 31 namespace { 32 struct PipelineDataTransfer 33 : public AffinePipelineDataTransferBase<PipelineDataTransfer> { 34 void runOnOperation() override; 35 void runOnAffineForOp(AffineForOp forOp); 36 37 std::vector<AffineForOp> forOps; 38 }; 39 40 } // namespace 41 42 /// Creates a pass to pipeline explicit movement of data across levels of the 43 /// memory hierarchy. 44 std::unique_ptr<OperationPass<func::FuncOp>> 45 mlir::createPipelineDataTransferPass() { 46 return std::make_unique<PipelineDataTransfer>(); 47 } 48 49 // Returns the position of the tag memref operand given a DMA operation. 50 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are 51 // added. TODO 52 static unsigned getTagMemRefPos(Operation &dmaOp) { 53 assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp))); 54 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) { 55 return dmaStartOp.getTagMemRefOperandIndex(); 56 } 57 // First operand for a dma finish operation. 58 return 0; 59 } 60 61 /// Doubles the buffer of the supplied memref on the specified 'affine.for' 62 /// operation by adding a leading dimension of size two to the memref. 63 /// Replaces all uses of the old memref by the new one while indexing the newly 64 /// added dimension by the loop IV of the specified 'affine.for' operation 65 /// modulo 2. Returns false if such a replacement cannot be performed. 66 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { 67 auto *forBody = forOp.getBody(); 68 OpBuilder bInner(forBody, forBody->begin()); 69 70 // Doubles the shape with a leading dimension extent of 2. 71 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { 72 // Add the leading dimension in the shape for the double buffer. 73 ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); 74 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); 75 newShape[0] = 2; 76 std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); 77 return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); 78 }; 79 80 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 81 auto newMemRefType = doubleShape(oldMemRefType); 82 83 // The double buffer is allocated right before 'forOp'. 84 OpBuilder bOuter(forOp); 85 // Put together alloc operands for any dynamic dimensions of the memref. 86 SmallVector<Value, 4> allocOperands; 87 for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) { 88 if (dim.value() == ShapedType::kDynamicSize) 89 allocOperands.push_back(bOuter.createOrFold<memref::DimOp>( 90 forOp.getLoc(), oldMemRef, dim.index())); 91 } 92 93 // Create and place the alloc right before the 'affine.for' operation. 94 Value newMemRef = bOuter.create<memref::AllocOp>( 95 forOp.getLoc(), newMemRefType, allocOperands); 96 97 // Create 'iv mod 2' value to index the leading dimension. 98 auto d0 = bInner.getAffineDimExpr(0); 99 int64_t step = forOp.getStep(); 100 auto modTwoMap = 101 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); 102 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, 103 forOp.getInductionVar()); 104 105 // replaceAllMemRefUsesWith will succeed unless the forOp body has 106 // non-dereferencing uses of the memref (dealloc's are fine though). 107 if (failed(replaceAllMemRefUsesWith( 108 oldMemRef, newMemRef, 109 /*extraIndices=*/{ivModTwoOp}, 110 /*indexRemap=*/AffineMap(), 111 /*extraOperands=*/{}, 112 /*symbolOperands=*/{}, 113 /*domOpFilter=*/&*forOp.getBody()->begin()))) { 114 LLVM_DEBUG( 115 forOp.emitError("memref replacement for double buffering failed")); 116 ivModTwoOp.erase(); 117 return false; 118 } 119 // Insert the dealloc op right after the for loop. 120 bOuter.setInsertionPointAfter(forOp); 121 bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef); 122 123 return true; 124 } 125 126 /// Returns success if the IR is in a valid state. 127 void PipelineDataTransfer::runOnOperation() { 128 // Do a post order walk so that inner loop DMAs are processed first. This is 129 // necessary since 'affine.for' operations nested within would otherwise 130 // become invalid (erased) when the outer loop is pipelined (the pipelined one 131 // gets deleted and replaced by a prologue, a new steady-state loop and an 132 // epilogue). 133 forOps.clear(); 134 getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); 135 for (auto forOp : forOps) 136 runOnAffineForOp(forOp); 137 } 138 139 // Check if tags of the dma start op and dma wait op match. 140 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) { 141 if (startOp.getTagMemRef() != waitOp.getTagMemRef()) 142 return false; 143 auto startIndices = startOp.getTagIndices(); 144 auto waitIndices = waitOp.getTagIndices(); 145 // Both of these have the same number of indices since they correspond to the 146 // same tag memref. 147 for (auto it = startIndices.begin(), wIt = waitIndices.begin(), 148 e = startIndices.end(); 149 it != e; ++it, ++wIt) { 150 // Keep it simple for now, just checking if indices match. 151 // TODO: this would in general need to check if there is no 152 // intervening write writing to the same tag location, i.e., memory last 153 // write/data flow analysis. This is however sufficient/powerful enough for 154 // now since the DMA generation pass or the input for it will always have 155 // start/wait with matching tags (same SSA operand indices). 156 if (*it != *wIt) 157 return false; 158 } 159 return true; 160 } 161 162 // Identify matching DMA start/finish operations to overlap computation with. 163 static void findMatchingStartFinishInsts( 164 AffineForOp forOp, 165 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) { 166 167 // Collect outgoing DMA operations - needed to check for dependences below. 168 SmallVector<AffineDmaStartOp, 4> outgoingDmaOps; 169 for (auto &op : *forOp.getBody()) { 170 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 171 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) 172 outgoingDmaOps.push_back(dmaStartOp); 173 } 174 175 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts; 176 for (auto &op : *forOp.getBody()) { 177 // Collect DMA finish operations. 178 if (isa<AffineDmaWaitOp>(op)) { 179 dmaFinishInsts.push_back(&op); 180 continue; 181 } 182 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 183 if (!dmaStartOp) 184 continue; 185 186 // Only DMAs incoming into higher memory spaces are pipelined for now. 187 // TODO: handle outgoing DMA pipelining. 188 if (!dmaStartOp.isDestMemorySpaceFaster()) 189 continue; 190 191 // Check for dependence with outgoing DMAs. Doing this conservatively. 192 // TODO: use the dependence analysis to check for 193 // dependences between an incoming and outgoing DMA in the same iteration. 194 auto *it = outgoingDmaOps.begin(); 195 for (; it != outgoingDmaOps.end(); ++it) { 196 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) 197 break; 198 } 199 if (it != outgoingDmaOps.end()) 200 continue; 201 202 // We only double buffer if the buffer is not live out of loop. 203 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); 204 bool escapingUses = false; 205 for (auto *user : memref.getUsers()) { 206 // We can double buffer regardless of dealloc's outside the loop. 207 if (isa<memref::DeallocOp>(user)) 208 continue; 209 if (!forOp.getBody()->findAncestorOpInBlock(*user)) { 210 LLVM_DEBUG(llvm::dbgs() 211 << "can't pipeline: buffer is live out of loop\n";); 212 escapingUses = true; 213 break; 214 } 215 } 216 if (!escapingUses) 217 dmaStartInsts.push_back(&op); 218 } 219 220 // For each start operation, we look for a matching finish operation. 221 for (auto *dmaStartOp : dmaStartInsts) { 222 for (auto *dmaFinishOp : dmaFinishInsts) { 223 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp), 224 cast<AffineDmaWaitOp>(dmaFinishOp))) { 225 startWaitPairs.push_back({dmaStartOp, dmaFinishOp}); 226 break; 227 } 228 } 229 } 230 } 231 232 /// Overlap DMA transfers with computation in this loop. If successful, 233 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are 234 /// inserted right before where it was. 235 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { 236 auto mayBeConstTripCount = getConstantTripCount(forOp); 237 if (!mayBeConstTripCount.hasValue()) { 238 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count")); 239 return; 240 } 241 242 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; 243 findMatchingStartFinishInsts(forOp, startWaitPairs); 244 245 if (startWaitPairs.empty()) { 246 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n")); 247 return; 248 } 249 250 // Double the buffers for the higher memory space memref's. 251 // Identify memref's to replace by scanning through all DMA start 252 // operations. A DMA start operation has two memref's - the one from the 253 // higher level of memory hierarchy is the one to double buffer. 254 // TODO: check whether double-buffering is even necessary. 255 // TODO: make this work with different layouts: assuming here that 256 // the dimension we are adding here for the double buffering is the outermost 257 // dimension. 258 for (auto &pair : startWaitPairs) { 259 auto *dmaStartOp = pair.first; 260 Value oldMemRef = dmaStartOp->getOperand( 261 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos()); 262 if (!doubleBuffer(oldMemRef, forOp)) { 263 // Normally, double buffering should not fail because we already checked 264 // that there are no uses outside. 265 LLVM_DEBUG(llvm::dbgs() 266 << "double buffering failed for" << dmaStartOp << "\n";); 267 // IR still valid and semantically correct. 268 return; 269 } 270 // If the old memref has no more uses, remove its 'dead' alloc if it was 271 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' 272 // operation could have been used on it if it was dynamically shaped in 273 // order to create the double buffer above.) 274 // '-canonicalize' does this in a more general way, but we'll anyway do the 275 // simple/common case so that the output / test cases looks clear. 276 if (auto *allocOp = oldMemRef.getDefiningOp()) { 277 if (oldMemRef.use_empty()) { 278 allocOp->erase(); 279 } else if (oldMemRef.hasOneUse()) { 280 if (auto dealloc = 281 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) { 282 dealloc.erase(); 283 allocOp->erase(); 284 } 285 } 286 } 287 } 288 289 // Double the buffers for tag memrefs. 290 for (auto &pair : startWaitPairs) { 291 auto *dmaFinishOp = pair.second; 292 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp)); 293 if (!doubleBuffer(oldTagMemRef, forOp)) { 294 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); 295 return; 296 } 297 // If the old tag has no uses or a single dealloc use, remove it. 298 // (canonicalization handles more complex cases). 299 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) { 300 if (oldTagMemRef.use_empty()) { 301 tagAllocOp->erase(); 302 } else if (oldTagMemRef.hasOneUse()) { 303 if (auto dealloc = 304 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) { 305 dealloc.erase(); 306 tagAllocOp->erase(); 307 } 308 } 309 } 310 } 311 312 // Double buffering would have invalidated all the old DMA start/wait insts. 313 startWaitPairs.clear(); 314 findMatchingStartFinishInsts(forOp, startWaitPairs); 315 316 // Store shift for operation for later lookup for AffineApplyOp's. 317 DenseMap<Operation *, unsigned> instShiftMap; 318 for (auto &pair : startWaitPairs) { 319 auto *dmaStartOp = pair.first; 320 assert(isa<AffineDmaStartOp>(dmaStartOp)); 321 instShiftMap[dmaStartOp] = 0; 322 // Set shifts for DMA start op's affine operand computation slices to 0. 323 SmallVector<AffineApplyOp, 4> sliceOps; 324 mlir::createAffineComputationSlice(dmaStartOp, &sliceOps); 325 if (!sliceOps.empty()) { 326 for (auto sliceOp : sliceOps) { 327 instShiftMap[sliceOp.getOperation()] = 0; 328 } 329 } else { 330 // If a slice wasn't created, the reachable affine.apply op's from its 331 // operands are the ones that go with it. 332 SmallVector<Operation *, 4> affineApplyInsts; 333 SmallVector<Value, 4> operands(dmaStartOp->getOperands()); 334 getReachableAffineApplyOps(operands, affineApplyInsts); 335 for (auto *op : affineApplyInsts) { 336 instShiftMap[op] = 0; 337 } 338 } 339 } 340 // Everything else (including compute ops and dma finish) are shifted by one. 341 for (auto &op : forOp.getBody()->without_terminator()) 342 if (instShiftMap.find(&op) == instShiftMap.end()) 343 instShiftMap[&op] = 1; 344 345 // Get shifts stored in map. 346 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size()); 347 unsigned s = 0; 348 for (auto &op : forOp.getBody()->without_terminator()) { 349 assert(instShiftMap.find(&op) != instShiftMap.end()); 350 shifts[s++] = instShiftMap[&op]; 351 352 // Tagging operations with shifts for debugging purposes. 353 LLVM_DEBUG({ 354 OpBuilder b(&op); 355 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); 356 }); 357 } 358 359 if (!isOpwiseShiftValid(forOp, shifts)) { 360 // Violates dependences. 361 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); 362 return; 363 } 364 365 if (failed(affineForOpBodySkew(forOp, shifts))) { 366 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); 367 return; 368 } 369 } 370