1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion transformation utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 18 #include "mlir/Dialect/Affine/Analysis/Utils.h" 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/BlockAndValueMapping.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/BuiltinOps.h" 26 #include "mlir/IR/Operation.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 32 #define DEBUG_TYPE "loop-fusion-utils" 33 34 using namespace mlir; 35 36 // Gathers all load and store memref accesses in 'opA' into 'values', where 37 // 'values[memref] == true' for each store operation. 38 static void getLoadAndStoreMemRefAccesses(Operation *opA, 39 DenseMap<Value, bool> &values) { 40 opA->walk([&](Operation *op) { 41 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 42 if (values.count(loadOp.getMemRef()) == 0) 43 values[loadOp.getMemRef()] = false; 44 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 45 values[storeOp.getMemRef()] = true; 46 } 47 }); 48 } 49 50 /// Returns true if 'op' is a load or store operation which access a memref 51 /// accessed 'values' and at least one of the access is a store operation. 52 /// Returns false otherwise. 53 static bool isDependentLoadOrStoreOp(Operation *op, 54 DenseMap<Value, bool> &values) { 55 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 56 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()]; 57 } 58 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 59 return values.count(storeOp.getMemRef()) > 0; 60 } 61 return false; 62 } 63 64 // Returns the first operation in range ('opA', 'opB') which has a data 65 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 66 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 67 // Record memref values from all loads/store in loop nest rooted at 'opA'. 68 // Map from memref value to bool which is true if store, false otherwise. 69 DenseMap<Value, bool> values; 70 getLoadAndStoreMemRefAccesses(opA, values); 71 72 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 73 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 74 // and at least one of the accesses is a store). 75 Operation *firstDepOp = nullptr; 76 for (Block::iterator it = std::next(Block::iterator(opA)); 77 it != Block::iterator(opB); ++it) { 78 Operation *opX = &(*it); 79 opX->walk([&](Operation *op) { 80 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 81 firstDepOp = opX; 82 }); 83 if (firstDepOp) 84 break; 85 } 86 return firstDepOp; 87 } 88 89 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 90 // exists a data dependence from 'opX' to 'opB'. 91 // Returns 'nullptr' of no dependence exists. 92 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 93 // Record memref values from all loads/store in loop nest rooted at 'opB'. 94 // Map from memref value to bool which is true if store, false otherwise. 95 DenseMap<Value, bool> values; 96 getLoadAndStoreMemRefAccesses(opB, values); 97 98 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 99 // check if there is a data dependence from 'opX' to 'opB': 100 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 101 // is a store. 102 // *) 'opX' produces an SSA Value which is used by 'opB'. 103 Operation *lastDepOp = nullptr; 104 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 105 it != Block::reverse_iterator(opA); ++it) { 106 Operation *opX = &(*it); 107 opX->walk([&](Operation *op) { 108 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 109 if (isDependentLoadOrStoreOp(op, values)) { 110 lastDepOp = opX; 111 return WalkResult::interrupt(); 112 } 113 return WalkResult::advance(); 114 } 115 for (auto value : op->getResults()) { 116 for (Operation *user : value.getUsers()) { 117 SmallVector<AffineForOp, 4> loops; 118 // Check if any loop in loop nest surrounding 'user' is 'opB'. 119 getLoopIVs(*user, &loops); 120 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 121 lastDepOp = opX; 122 return WalkResult::interrupt(); 123 } 124 } 125 } 126 return WalkResult::advance(); 127 }); 128 if (lastDepOp) 129 break; 130 } 131 return lastDepOp; 132 } 133 134 // Computes and returns an insertion point operation, before which the 135 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 136 // dependences. Returns nullptr if no such insertion point is found. 137 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 138 AffineForOp dstForOp) { 139 bool isSrcForOpBeforeDstForOp = 140 srcForOp->isBeforeInBlock(dstForOp.getOperation()); 141 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 142 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 143 144 auto *firstDepOpA = 145 getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); 146 auto *lastDepOpB = 147 getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); 148 // Block: 149 // ... 150 // |-- opA 151 // | ... 152 // | lastDepOpB --| 153 // | ... | 154 // |-> firstDepOpA | 155 // ... | 156 // opB <--------- 157 // 158 // Valid insertion point range: (lastDepOpB, firstDepOpA) 159 // 160 if (firstDepOpA != nullptr) { 161 if (lastDepOpB != nullptr) { 162 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 163 // No valid insertion point exists which preserves dependences. 164 return nullptr; 165 } 166 // Return insertion point in valid range closest to 'opB'. 167 // TODO: Consider other insertion points in valid range. 168 return firstDepOpA; 169 } 170 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 171 // 'opB' insertion point. 172 return forOpB.getOperation(); 173 } 174 175 // Gathers all load and store ops in loop nest rooted at 'forOp' into 176 // 'loadAndStoreOps'. 177 static bool 178 gatherLoadsAndStores(AffineForOp forOp, 179 SmallVectorImpl<Operation *> &loadAndStoreOps) { 180 bool hasIfOp = false; 181 forOp.walk([&](Operation *op) { 182 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 183 loadAndStoreOps.push_back(op); 184 else if (isa<AffineIfOp>(op)) 185 hasIfOp = true; 186 }); 187 return !hasIfOp; 188 } 189 190 /// Returns the maximum loop depth at which we could fuse producer loop 191 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. 192 // TODO: Generalize this check for sibling and more generic fusion scenarios. 193 // TODO: Support forward slice fusion. 194 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, 195 ArrayRef<Operation *> dstOps) { 196 if (dstOps.empty()) 197 // Expected at least one memory operation. 198 // TODO: Revisit this case with a specific example. 199 return 0; 200 201 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so 202 // that they are not considered for analysis. 203 DenseSet<Value> producerConsumerMemrefs; 204 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); 205 SmallVector<Operation *, 4> targetDstOps; 206 for (Operation *dstOp : dstOps) { 207 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); 208 Value memref = loadOp ? loadOp.getMemRef() 209 : cast<AffineWriteOpInterface>(dstOp).getMemRef(); 210 if (producerConsumerMemrefs.count(memref) > 0) 211 targetDstOps.push_back(dstOp); 212 } 213 214 assert(!targetDstOps.empty() && 215 "No dependences between 'srcForOp' and 'dstForOp'?"); 216 217 // Compute the innermost common loop depth for loads and stores. 218 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); 219 220 // Return common loop depth for loads if there are no store ops. 221 if (all_of(targetDstOps, 222 [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) 223 return loopDepth; 224 225 // Check dependences on all pairs of ops in 'targetDstOps' and store the 226 // minimum loop depth at which a dependence is satisfied. 227 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { 228 auto *srcOpInst = targetDstOps[i]; 229 MemRefAccess srcAccess(srcOpInst); 230 for (unsigned j = 0; j < e; ++j) { 231 auto *dstOpInst = targetDstOps[j]; 232 MemRefAccess dstAccess(dstOpInst); 233 234 unsigned numCommonLoops = 235 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 236 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 237 FlatAffineValueConstraints dependenceConstraints; 238 // TODO: Cache dependence analysis results, check cache here. 239 DependenceResult result = checkMemrefAccessDependence( 240 srcAccess, dstAccess, d, &dependenceConstraints, 241 /*dependenceComponents=*/nullptr); 242 if (hasDependence(result)) { 243 // Store minimum loop depth and break because we want the min 'd' at 244 // which there is a dependence. 245 loopDepth = std::min(loopDepth, d - 1); 246 break; 247 } 248 } 249 } 250 } 251 252 return loopDepth; 253 } 254 255 // TODO: Prevent fusion of loop nests with side-effecting operations. 256 // TODO: This pass performs some computation that is the same for all the depths 257 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes 258 // all the depths at once or only the legal maximal depth for maximal fusion. 259 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 260 unsigned dstLoopDepth, 261 ComputationSliceState *srcSlice, 262 FusionStrategy fusionStrategy) { 263 // Return 'failure' if 'dstLoopDepth == 0'. 264 if (dstLoopDepth == 0) { 265 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); 266 return FusionResult::FailPrecondition; 267 } 268 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 269 auto *block = srcForOp->getBlock(); 270 if (block != dstForOp->getBlock()) { 271 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); 272 return FusionResult::FailPrecondition; 273 } 274 275 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 276 // exists which would preserve dependences. 277 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 278 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); 279 return FusionResult::FailBlockDependence; 280 } 281 282 // Check if 'srcForOp' precedes 'dstForOp' in 'block'. 283 bool isSrcForOpBeforeDstForOp = 284 srcForOp->isBeforeInBlock(dstForOp.getOperation()); 285 // 'forOpA' executes before 'forOpB' in 'block'. 286 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 287 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 288 289 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 290 SmallVector<Operation *, 4> opsA; 291 if (!gatherLoadsAndStores(forOpA, opsA)) { 292 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 293 return FusionResult::FailPrecondition; 294 } 295 296 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 297 SmallVector<Operation *, 4> opsB; 298 if (!gatherLoadsAndStores(forOpB, opsB)) { 299 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 300 return FusionResult::FailPrecondition; 301 } 302 303 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve 304 // loop dependences. 305 // TODO: Enable this check for sibling and more generic loop fusion 306 // strategies. 307 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { 308 // TODO: 'getMaxLoopDepth' does not support forward slice fusion. 309 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); 310 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { 311 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); 312 return FusionResult::FailFusionDependence; 313 } 314 } 315 316 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 317 unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( 318 *srcForOp.getOperation(), *dstForOp.getOperation()); 319 320 // Filter out ops in 'opsA' to compute the slice union based on the 321 // assumptions made by the fusion strategy. 322 SmallVector<Operation *, 4> strategyOpsA; 323 switch (fusionStrategy.getStrategy()) { 324 case FusionStrategy::Generic: 325 // Generic fusion. Take into account all the memory operations to compute 326 // the slice union. 327 strategyOpsA.append(opsA.begin(), opsA.end()); 328 break; 329 case FusionStrategy::ProducerConsumer: 330 // Producer-consumer fusion (AffineLoopFusion pass) only takes into 331 // account stores in 'srcForOp' to compute the slice union. 332 for (Operation *op : opsA) { 333 if (isa<AffineWriteOpInterface>(op)) 334 strategyOpsA.push_back(op); 335 } 336 break; 337 case FusionStrategy::Sibling: 338 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads 339 // to 'memref' in 'srcForOp' to compute the slice union. 340 for (Operation *op : opsA) { 341 auto load = dyn_cast<AffineReadOpInterface>(op); 342 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) 343 strategyOpsA.push_back(op); 344 } 345 break; 346 } 347 348 // Compute union of computation slices computed between all pairs of ops 349 // from 'forOpA' and 'forOpB'. 350 SliceComputationResult sliceComputationResult = 351 mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, 352 isSrcForOpBeforeDstForOp, srcSlice); 353 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { 354 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 355 return FusionResult::FailPrecondition; 356 } 357 if (sliceComputationResult.value == 358 SliceComputationResult::IncorrectSliceFailure) { 359 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); 360 return FusionResult::FailIncorrectSlice; 361 } 362 363 return FusionResult::Success; 364 } 365 366 /// Patch the loop body of a forOp that is a single iteration reduction loop 367 /// into its containing block. 368 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, 369 bool siblingFusionUser) { 370 // Check if the reduction loop is a single iteration loop. 371 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 372 if (!tripCount || tripCount.getValue() != 1) 373 return failure(); 374 auto iterOperands = forOp.getIterOperands(); 375 auto *parentOp = forOp->getParentOp(); 376 if (!isa<AffineForOp>(parentOp)) 377 return failure(); 378 auto newOperands = forOp.getBody()->getTerminator()->getOperands(); 379 OpBuilder b(parentOp); 380 // Replace the parent loop and add iteroperands and results from the `forOp`. 381 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>(); 382 AffineForOp newLoop = replaceForOpWithNewYields( 383 b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs()); 384 385 // For sibling-fusion users, collect operations that use the results of the 386 // `forOp` outside the new parent loop that has absorbed all its iter args 387 // and operands. These operations will be moved later after the results 388 // have been replaced. 389 SetVector<Operation *> forwardSlice; 390 if (siblingFusionUser) { 391 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 392 SetVector<Operation *> tmpForwardSlice; 393 getForwardSlice(forOp.getResult(i), &tmpForwardSlice); 394 forwardSlice.set_union(tmpForwardSlice); 395 } 396 } 397 // Update the results of the `forOp` in the new loop. 398 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 399 forOp.getResult(i).replaceAllUsesWith( 400 newLoop.getResult(i + parentOp->getNumResults())); 401 } 402 // For sibling-fusion users, move operations that use the results of the 403 // `forOp` outside the new parent loop 404 if (siblingFusionUser) { 405 topologicalSort(forwardSlice); 406 for (Operation *op : llvm::reverse(forwardSlice)) 407 op->moveAfter(newLoop); 408 } 409 // Replace the induction variable. 410 auto iv = forOp.getInductionVar(); 411 iv.replaceAllUsesWith(newLoop.getInductionVar()); 412 // Replace the iter args. 413 auto forOpIterArgs = forOp.getRegionIterArgs(); 414 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back( 415 forOpIterArgs.size()))) { 416 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 417 } 418 // Move the loop body operations, except for its terminator, to the loop's 419 // containing block. 420 forOp.getBody()->back().erase(); 421 auto *parentBlock = forOp->getBlock(); 422 parentBlock->getOperations().splice(Block::iterator(forOp), 423 forOp.getBody()->getOperations()); 424 forOp.erase(); 425 parentForOp.erase(); 426 return success(); 427 } 428 429 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point 430 /// and source slice loop bounds specified in 'srcSlice'. 431 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 432 const ComputationSliceState &srcSlice, 433 bool isInnermostSiblingInsertion) { 434 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. 435 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); 436 BlockAndValueMapping mapper; 437 b.clone(*srcForOp, mapper); 438 439 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. 440 SmallVector<AffineForOp, 4> sliceLoops; 441 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { 442 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); 443 if (!loopIV) 444 continue; 445 auto forOp = getForInductionVarOwner(loopIV); 446 sliceLoops.push_back(forOp); 447 if (AffineMap lbMap = srcSlice.lbs[i]) { 448 auto lbOperands = srcSlice.lbOperands[i]; 449 canonicalizeMapAndOperands(&lbMap, &lbOperands); 450 forOp.setLowerBound(lbOperands, lbMap); 451 } 452 if (AffineMap ubMap = srcSlice.ubs[i]) { 453 auto ubOperands = srcSlice.ubOperands[i]; 454 canonicalizeMapAndOperands(&ubMap, &ubOperands); 455 forOp.setUpperBound(ubOperands, ubMap); 456 } 457 } 458 459 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 460 auto srcIsUnitSlice = [&]() { 461 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) && 462 (getSliceIterationCount(sliceTripCountMap) == 1)); 463 }; 464 // Fix up and if possible, eliminate single iteration loops. 465 for (AffineForOp forOp : sliceLoops) { 466 if (isLoopParallelAndContainsReduction(forOp) && 467 isInnermostSiblingInsertion && srcIsUnitSlice()) 468 // Patch reduction loop - only ones that are sibling-fused with the 469 // destination loop - into the parent loop. 470 (void)promoteSingleIterReductionLoop(forOp, true); 471 else 472 // Promote any single iteration slice loops. 473 (void)promoteIfSingleIteration(forOp); 474 } 475 } 476 477 /// Collect loop nest statistics (eg. loop trip count and operation count) 478 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 479 /// returns false otherwise. 480 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { 481 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 482 auto *childForOp = forOp.getOperation(); 483 auto *parentForOp = forOp->getParentOp(); 484 if (!llvm::isa<FuncOp>(parentForOp)) { 485 if (!isa<AffineForOp>(parentForOp)) { 486 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); 487 return WalkResult::interrupt(); 488 } 489 // Add mapping to 'forOp' from its parent AffineForOp. 490 stats->loopMap[parentForOp].push_back(forOp); 491 } 492 493 // Record the number of op operations in the body of 'forOp'. 494 unsigned count = 0; 495 stats->opCountMap[childForOp] = 0; 496 for (auto &op : *forOp.getBody()) { 497 if (!isa<AffineForOp, AffineIfOp>(op)) 498 ++count; 499 } 500 stats->opCountMap[childForOp] = count; 501 502 // Record trip count for 'forOp'. Set flag if trip count is not 503 // constant. 504 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 505 if (!maybeConstTripCount.hasValue()) { 506 // Currently only constant trip count loop nests are supported. 507 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); 508 return WalkResult::interrupt(); 509 } 510 511 stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); 512 return WalkResult::advance(); 513 }); 514 return !walkResult.wasInterrupted(); 515 } 516 517 // Computes the total cost of the loop nest rooted at 'forOp'. 518 // Currently, the total cost is computed by counting the total operation 519 // instance count (i.e. total number of operations in the loop bodyloop 520 // operation count * loop trip count) for the entire loop nest. 521 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 522 // specified in the map when computing the total op instance count. 523 // NOTEs: 1) This is used to compute the cost of computation slices, which are 524 // sliced along the iteration dimension, and thus reduce the trip count. 525 // If 'computeCostMap' is non-null, the total op count for forOps specified 526 // in the map is increased (not overridden) by adding the op count from the 527 // map to the existing op count for the for loop. This is done before 528 // multiplying by the loop's trip count, and is used to model the cost of 529 // inserting a sliced loop nest of known cost into the loop's body. 530 // 2) This is also used to compute the cost of fusing a slice of some loop nest 531 // within another loop. 532 static int64_t getComputeCostHelper( 533 Operation *forOp, LoopNestStats &stats, 534 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 535 DenseMap<Operation *, int64_t> *computeCostMap) { 536 // 'opCount' is the total number operations in one iteration of 'forOp' body, 537 // minus terminator op which is a no-op. 538 int64_t opCount = stats.opCountMap[forOp] - 1; 539 if (stats.loopMap.count(forOp) > 0) { 540 for (auto childForOp : stats.loopMap[forOp]) { 541 opCount += getComputeCostHelper(childForOp.getOperation(), stats, 542 tripCountOverrideMap, computeCostMap); 543 } 544 } 545 // Add in additional op instances from slice (if specified in map). 546 if (computeCostMap != nullptr) { 547 auto it = computeCostMap->find(forOp); 548 if (it != computeCostMap->end()) { 549 opCount += it->second; 550 } 551 } 552 // Override trip count (if specified in map). 553 int64_t tripCount = stats.tripCountMap[forOp]; 554 if (tripCountOverrideMap != nullptr) { 555 auto it = tripCountOverrideMap->find(forOp); 556 if (it != tripCountOverrideMap->end()) { 557 tripCount = it->second; 558 } 559 } 560 // Returns the total number of dynamic instances of operations in loop body. 561 return tripCount * opCount; 562 } 563 564 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 565 /// Currently, the total cost is computed by counting the total operation 566 /// instance count (i.e. total number of operations in the loop body * loop 567 /// trip count) for the entire loop nest. 568 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 569 return getComputeCostHelper(forOp.getOperation(), stats, 570 /*tripCountOverrideMap=*/nullptr, 571 /*computeCostMap=*/nullptr); 572 } 573 574 /// Computes and returns in 'computeCost', the total compute cost of fusing the 575 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 576 /// the total cost is computed by counting the total operation instance count 577 /// (i.e. total number of operations in the loop body * loop trip count) for 578 /// the entire loop nest. 579 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, 580 AffineForOp dstForOp, LoopNestStats &dstStats, 581 const ComputationSliceState &slice, 582 int64_t *computeCost) { 583 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 584 DenseMap<Operation *, int64_t> computeCostMap; 585 586 // Build trip count map for computation slice. 587 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 588 return false; 589 // Checks whether a store to load forwarding will happen. 590 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 591 assert(sliceIterationCount > 0); 592 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 593 auto *insertPointParent = slice.insertPoint->getParentOp(); 594 595 // The store and loads to this memref will disappear. 596 // TODO: Add load coalescing to memref data flow opt pass. 597 if (storeLoadFwdGuaranteed) { 598 // Subtract from operation count the loads/store we expect load/store 599 // forwarding to remove. 600 unsigned storeCount = 0; 601 llvm::SmallDenseSet<Value, 4> storeMemrefs; 602 srcForOp.walk([&](Operation *op) { 603 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 604 storeMemrefs.insert(storeOp.getMemRef()); 605 ++storeCount; 606 } 607 }); 608 // Subtract out any store ops in single-iteration src slice loop nest. 609 if (storeCount > 0) 610 computeCostMap[insertPointParent] = -storeCount; 611 // Subtract out any load users of 'storeMemrefs' nested below 612 // 'insertPointParent'. 613 for (auto value : storeMemrefs) { 614 for (auto *user : value.getUsers()) { 615 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 616 SmallVector<AffineForOp, 4> loops; 617 // Check if any loop in loop nest surrounding 'user' is 618 // 'insertPointParent'. 619 getLoopIVs(*user, &loops); 620 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 621 if (auto forOp = 622 dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 623 if (computeCostMap.count(forOp) == 0) 624 computeCostMap[forOp] = 0; 625 computeCostMap[forOp] -= 1; 626 } 627 } 628 } 629 } 630 } 631 } 632 633 // Compute op instance count for the src loop nest with iteration slicing. 634 int64_t sliceComputeCost = getComputeCostHelper( 635 srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); 636 637 // Compute cost of fusion for this depth. 638 computeCostMap[insertPointParent] = sliceComputeCost; 639 640 *computeCost = 641 getComputeCostHelper(dstForOp.getOperation(), dstStats, 642 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 643 return true; 644 } 645 646 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 647 /// producer-consumer dependence between write ops in 'srcOps' and read ops in 648 /// 'dstOps'. 649 void mlir::gatherProducerConsumerMemrefs( 650 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, 651 DenseSet<Value> &producerConsumerMemrefs) { 652 // Gather memrefs from stores in 'srcOps'. 653 DenseSet<Value> srcStoreMemRefs; 654 for (Operation *op : srcOps) 655 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) 656 srcStoreMemRefs.insert(storeOp.getMemRef()); 657 658 // Compute the intersection between memrefs from stores in 'srcOps' and 659 // memrefs from loads in 'dstOps'. 660 for (Operation *op : dstOps) 661 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) 662 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) 663 producerConsumerMemrefs.insert(loadOp.getMemRef()); 664 } 665