1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion transformation utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 18 #include "mlir/Dialect/Affine/Analysis/Utils.h" 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/BuiltinOps.h" 27 #include "mlir/IR/Operation.h" 28 #include "llvm/ADT/DenseMap.h" 29 #include "llvm/ADT/SmallVector.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 #define DEBUG_TYPE "loop-fusion-utils" 34 35 using namespace mlir; 36 37 // Gathers all load and store memref accesses in 'opA' into 'values', where 38 // 'values[memref] == true' for each store operation. 39 static void getLoadAndStoreMemRefAccesses(Operation *opA, 40 DenseMap<Value, bool> &values) { 41 opA->walk([&](Operation *op) { 42 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 43 if (values.count(loadOp.getMemRef()) == 0) 44 values[loadOp.getMemRef()] = false; 45 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 46 values[storeOp.getMemRef()] = true; 47 } 48 }); 49 } 50 51 /// Returns true if 'op' is a load or store operation which access a memref 52 /// accessed 'values' and at least one of the access is a store operation. 53 /// Returns false otherwise. 54 static bool isDependentLoadOrStoreOp(Operation *op, 55 DenseMap<Value, bool> &values) { 56 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 57 return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()]; 58 } 59 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 60 return values.count(storeOp.getMemRef()) > 0; 61 } 62 return false; 63 } 64 65 // Returns the first operation in range ('opA', 'opB') which has a data 66 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 67 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 68 // Record memref values from all loads/store in loop nest rooted at 'opA'. 69 // Map from memref value to bool which is true if store, false otherwise. 70 DenseMap<Value, bool> values; 71 getLoadAndStoreMemRefAccesses(opA, values); 72 73 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 74 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 75 // and at least one of the accesses is a store). 76 Operation *firstDepOp = nullptr; 77 for (Block::iterator it = std::next(Block::iterator(opA)); 78 it != Block::iterator(opB); ++it) { 79 Operation *opX = &(*it); 80 opX->walk([&](Operation *op) { 81 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 82 firstDepOp = opX; 83 }); 84 if (firstDepOp) 85 break; 86 } 87 return firstDepOp; 88 } 89 90 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 91 // exists a data dependence from 'opX' to 'opB'. 92 // Returns 'nullptr' of no dependence exists. 93 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 94 // Record memref values from all loads/store in loop nest rooted at 'opB'. 95 // Map from memref value to bool which is true if store, false otherwise. 96 DenseMap<Value, bool> values; 97 getLoadAndStoreMemRefAccesses(opB, values); 98 99 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 100 // check if there is a data dependence from 'opX' to 'opB': 101 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 102 // is a store. 103 // *) 'opX' produces an SSA Value which is used by 'opB'. 104 Operation *lastDepOp = nullptr; 105 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 106 it != Block::reverse_iterator(opA); ++it) { 107 Operation *opX = &(*it); 108 opX->walk([&](Operation *op) { 109 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 110 if (isDependentLoadOrStoreOp(op, values)) { 111 lastDepOp = opX; 112 return WalkResult::interrupt(); 113 } 114 return WalkResult::advance(); 115 } 116 for (auto value : op->getResults()) { 117 for (Operation *user : value.getUsers()) { 118 SmallVector<AffineForOp, 4> loops; 119 // Check if any loop in loop nest surrounding 'user' is 'opB'. 120 getLoopIVs(*user, &loops); 121 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 122 lastDepOp = opX; 123 return WalkResult::interrupt(); 124 } 125 } 126 } 127 return WalkResult::advance(); 128 }); 129 if (lastDepOp) 130 break; 131 } 132 return lastDepOp; 133 } 134 135 // Computes and returns an insertion point operation, before which the 136 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 137 // dependences. Returns nullptr if no such insertion point is found. 138 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 139 AffineForOp dstForOp) { 140 bool isSrcForOpBeforeDstForOp = 141 srcForOp->isBeforeInBlock(dstForOp.getOperation()); 142 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 143 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 144 145 auto *firstDepOpA = 146 getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); 147 auto *lastDepOpB = 148 getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); 149 // Block: 150 // ... 151 // |-- opA 152 // | ... 153 // | lastDepOpB --| 154 // | ... | 155 // |-> firstDepOpA | 156 // ... | 157 // opB <--------- 158 // 159 // Valid insertion point range: (lastDepOpB, firstDepOpA) 160 // 161 if (firstDepOpA != nullptr) { 162 if (lastDepOpB != nullptr) { 163 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 164 // No valid insertion point exists which preserves dependences. 165 return nullptr; 166 } 167 // Return insertion point in valid range closest to 'opB'. 168 // TODO: Consider other insertion points in valid range. 169 return firstDepOpA; 170 } 171 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 172 // 'opB' insertion point. 173 return forOpB.getOperation(); 174 } 175 176 // Gathers all load and store ops in loop nest rooted at 'forOp' into 177 // 'loadAndStoreOps'. 178 static bool 179 gatherLoadsAndStores(AffineForOp forOp, 180 SmallVectorImpl<Operation *> &loadAndStoreOps) { 181 bool hasIfOp = false; 182 forOp.walk([&](Operation *op) { 183 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 184 loadAndStoreOps.push_back(op); 185 else if (isa<AffineIfOp>(op)) 186 hasIfOp = true; 187 }); 188 return !hasIfOp; 189 } 190 191 /// Returns the maximum loop depth at which we could fuse producer loop 192 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. 193 // TODO: Generalize this check for sibling and more generic fusion scenarios. 194 // TODO: Support forward slice fusion. 195 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, 196 ArrayRef<Operation *> dstOps) { 197 if (dstOps.empty()) 198 // Expected at least one memory operation. 199 // TODO: Revisit this case with a specific example. 200 return 0; 201 202 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so 203 // that they are not considered for analysis. 204 DenseSet<Value> producerConsumerMemrefs; 205 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); 206 SmallVector<Operation *, 4> targetDstOps; 207 for (Operation *dstOp : dstOps) { 208 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); 209 Value memref = loadOp ? loadOp.getMemRef() 210 : cast<AffineWriteOpInterface>(dstOp).getMemRef(); 211 if (producerConsumerMemrefs.count(memref) > 0) 212 targetDstOps.push_back(dstOp); 213 } 214 215 assert(!targetDstOps.empty() && 216 "No dependences between 'srcForOp' and 'dstForOp'?"); 217 218 // Compute the innermost common loop depth for loads and stores. 219 unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); 220 221 // Return common loop depth for loads if there are no store ops. 222 if (all_of(targetDstOps, 223 [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) 224 return loopDepth; 225 226 // Check dependences on all pairs of ops in 'targetDstOps' and store the 227 // minimum loop depth at which a dependence is satisfied. 228 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { 229 auto *srcOpInst = targetDstOps[i]; 230 MemRefAccess srcAccess(srcOpInst); 231 for (unsigned j = 0; j < e; ++j) { 232 auto *dstOpInst = targetDstOps[j]; 233 MemRefAccess dstAccess(dstOpInst); 234 235 unsigned numCommonLoops = 236 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 237 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 238 FlatAffineValueConstraints dependenceConstraints; 239 // TODO: Cache dependence analysis results, check cache here. 240 DependenceResult result = checkMemrefAccessDependence( 241 srcAccess, dstAccess, d, &dependenceConstraints, 242 /*dependenceComponents=*/nullptr); 243 if (hasDependence(result)) { 244 // Store minimum loop depth and break because we want the min 'd' at 245 // which there is a dependence. 246 loopDepth = std::min(loopDepth, d - 1); 247 break; 248 } 249 } 250 } 251 } 252 253 return loopDepth; 254 } 255 256 // TODO: Prevent fusion of loop nests with side-effecting operations. 257 // TODO: This pass performs some computation that is the same for all the depths 258 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes 259 // all the depths at once or only the legal maximal depth for maximal fusion. 260 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 261 unsigned dstLoopDepth, 262 ComputationSliceState *srcSlice, 263 FusionStrategy fusionStrategy) { 264 // Return 'failure' if 'dstLoopDepth == 0'. 265 if (dstLoopDepth == 0) { 266 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); 267 return FusionResult::FailPrecondition; 268 } 269 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 270 auto *block = srcForOp->getBlock(); 271 if (block != dstForOp->getBlock()) { 272 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); 273 return FusionResult::FailPrecondition; 274 } 275 276 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 277 // exists which would preserve dependences. 278 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 279 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); 280 return FusionResult::FailBlockDependence; 281 } 282 283 // Check if 'srcForOp' precedes 'dstForOp' in 'block'. 284 bool isSrcForOpBeforeDstForOp = 285 srcForOp->isBeforeInBlock(dstForOp.getOperation()); 286 // 'forOpA' executes before 'forOpB' in 'block'. 287 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 288 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 289 290 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 291 SmallVector<Operation *, 4> opsA; 292 if (!gatherLoadsAndStores(forOpA, opsA)) { 293 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 294 return FusionResult::FailPrecondition; 295 } 296 297 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 298 SmallVector<Operation *, 4> opsB; 299 if (!gatherLoadsAndStores(forOpB, opsB)) { 300 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); 301 return FusionResult::FailPrecondition; 302 } 303 304 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve 305 // loop dependences. 306 // TODO: Enable this check for sibling and more generic loop fusion 307 // strategies. 308 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { 309 // TODO: 'getMaxLoopDepth' does not support forward slice fusion. 310 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); 311 if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { 312 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); 313 return FusionResult::FailFusionDependence; 314 } 315 } 316 317 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 318 unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( 319 *srcForOp.getOperation(), *dstForOp.getOperation()); 320 321 // Filter out ops in 'opsA' to compute the slice union based on the 322 // assumptions made by the fusion strategy. 323 SmallVector<Operation *, 4> strategyOpsA; 324 switch (fusionStrategy.getStrategy()) { 325 case FusionStrategy::Generic: 326 // Generic fusion. Take into account all the memory operations to compute 327 // the slice union. 328 strategyOpsA.append(opsA.begin(), opsA.end()); 329 break; 330 case FusionStrategy::ProducerConsumer: 331 // Producer-consumer fusion (AffineLoopFusion pass) only takes into 332 // account stores in 'srcForOp' to compute the slice union. 333 for (Operation *op : opsA) { 334 if (isa<AffineWriteOpInterface>(op)) 335 strategyOpsA.push_back(op); 336 } 337 break; 338 case FusionStrategy::Sibling: 339 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads 340 // to 'memref' in 'srcForOp' to compute the slice union. 341 for (Operation *op : opsA) { 342 auto load = dyn_cast<AffineReadOpInterface>(op); 343 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) 344 strategyOpsA.push_back(op); 345 } 346 break; 347 } 348 349 // Compute union of computation slices computed between all pairs of ops 350 // from 'forOpA' and 'forOpB'. 351 SliceComputationResult sliceComputationResult = 352 mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, 353 isSrcForOpBeforeDstForOp, srcSlice); 354 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { 355 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 356 return FusionResult::FailPrecondition; 357 } 358 if (sliceComputationResult.value == 359 SliceComputationResult::IncorrectSliceFailure) { 360 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); 361 return FusionResult::FailIncorrectSlice; 362 } 363 364 return FusionResult::Success; 365 } 366 367 /// Patch the loop body of a forOp that is a single iteration reduction loop 368 /// into its containing block. 369 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, 370 bool siblingFusionUser) { 371 // Check if the reduction loop is a single iteration loop. 372 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 373 if (!tripCount || tripCount.getValue() != 1) 374 return failure(); 375 auto iterOperands = forOp.getIterOperands(); 376 auto *parentOp = forOp->getParentOp(); 377 if (!isa<AffineForOp>(parentOp)) 378 return failure(); 379 auto newOperands = forOp.getBody()->getTerminator()->getOperands(); 380 OpBuilder b(parentOp); 381 // Replace the parent loop and add iteroperands and results from the `forOp`. 382 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>(); 383 AffineForOp newLoop = replaceForOpWithNewYields( 384 b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs()); 385 386 // For sibling-fusion users, collect operations that use the results of the 387 // `forOp` outside the new parent loop that has absorbed all its iter args 388 // and operands. These operations will be moved later after the results 389 // have been replaced. 390 SetVector<Operation *> forwardSlice; 391 if (siblingFusionUser) { 392 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 393 SetVector<Operation *> tmpForwardSlice; 394 getForwardSlice(forOp.getResult(i), &tmpForwardSlice); 395 forwardSlice.set_union(tmpForwardSlice); 396 } 397 } 398 // Update the results of the `forOp` in the new loop. 399 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) { 400 forOp.getResult(i).replaceAllUsesWith( 401 newLoop.getResult(i + parentOp->getNumResults())); 402 } 403 // For sibling-fusion users, move operations that use the results of the 404 // `forOp` outside the new parent loop 405 if (siblingFusionUser) { 406 topologicalSort(forwardSlice); 407 for (Operation *op : llvm::reverse(forwardSlice)) 408 op->moveAfter(newLoop); 409 } 410 // Replace the induction variable. 411 auto iv = forOp.getInductionVar(); 412 iv.replaceAllUsesWith(newLoop.getInductionVar()); 413 // Replace the iter args. 414 auto forOpIterArgs = forOp.getRegionIterArgs(); 415 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back( 416 forOpIterArgs.size()))) { 417 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 418 } 419 // Move the loop body operations, except for its terminator, to the loop's 420 // containing block. 421 forOp.getBody()->back().erase(); 422 auto *parentBlock = forOp->getBlock(); 423 parentBlock->getOperations().splice(Block::iterator(forOp), 424 forOp.getBody()->getOperations()); 425 forOp.erase(); 426 parentForOp.erase(); 427 return success(); 428 } 429 430 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point 431 /// and source slice loop bounds specified in 'srcSlice'. 432 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 433 const ComputationSliceState &srcSlice, 434 bool isInnermostSiblingInsertion) { 435 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. 436 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); 437 BlockAndValueMapping mapper; 438 b.clone(*srcForOp, mapper); 439 440 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. 441 SmallVector<AffineForOp, 4> sliceLoops; 442 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { 443 auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); 444 if (!loopIV) 445 continue; 446 auto forOp = getForInductionVarOwner(loopIV); 447 sliceLoops.push_back(forOp); 448 if (AffineMap lbMap = srcSlice.lbs[i]) { 449 auto lbOperands = srcSlice.lbOperands[i]; 450 canonicalizeMapAndOperands(&lbMap, &lbOperands); 451 forOp.setLowerBound(lbOperands, lbMap); 452 } 453 if (AffineMap ubMap = srcSlice.ubs[i]) { 454 auto ubOperands = srcSlice.ubOperands[i]; 455 canonicalizeMapAndOperands(&ubMap, &ubOperands); 456 forOp.setUpperBound(ubOperands, ubMap); 457 } 458 } 459 460 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 461 auto srcIsUnitSlice = [&]() { 462 return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) && 463 (getSliceIterationCount(sliceTripCountMap) == 1)); 464 }; 465 // Fix up and if possible, eliminate single iteration loops. 466 for (AffineForOp forOp : sliceLoops) { 467 if (isLoopParallelAndContainsReduction(forOp) && 468 isInnermostSiblingInsertion && srcIsUnitSlice()) 469 // Patch reduction loop - only ones that are sibling-fused with the 470 // destination loop - into the parent loop. 471 (void)promoteSingleIterReductionLoop(forOp, true); 472 else 473 // Promote any single iteration slice loops. 474 (void)promoteIfSingleIteration(forOp); 475 } 476 } 477 478 /// Collect loop nest statistics (eg. loop trip count and operation count) 479 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 480 /// returns false otherwise. 481 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { 482 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 483 auto *childForOp = forOp.getOperation(); 484 auto *parentForOp = forOp->getParentOp(); 485 if (!llvm::isa<func::FuncOp>(parentForOp)) { 486 if (!isa<AffineForOp>(parentForOp)) { 487 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); 488 return WalkResult::interrupt(); 489 } 490 // Add mapping to 'forOp' from its parent AffineForOp. 491 stats->loopMap[parentForOp].push_back(forOp); 492 } 493 494 // Record the number of op operations in the body of 'forOp'. 495 unsigned count = 0; 496 stats->opCountMap[childForOp] = 0; 497 for (auto &op : *forOp.getBody()) { 498 if (!isa<AffineForOp, AffineIfOp>(op)) 499 ++count; 500 } 501 stats->opCountMap[childForOp] = count; 502 503 // Record trip count for 'forOp'. Set flag if trip count is not 504 // constant. 505 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 506 if (!maybeConstTripCount.hasValue()) { 507 // Currently only constant trip count loop nests are supported. 508 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); 509 return WalkResult::interrupt(); 510 } 511 512 stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); 513 return WalkResult::advance(); 514 }); 515 return !walkResult.wasInterrupted(); 516 } 517 518 // Computes the total cost of the loop nest rooted at 'forOp'. 519 // Currently, the total cost is computed by counting the total operation 520 // instance count (i.e. total number of operations in the loop bodyloop 521 // operation count * loop trip count) for the entire loop nest. 522 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 523 // specified in the map when computing the total op instance count. 524 // NOTEs: 1) This is used to compute the cost of computation slices, which are 525 // sliced along the iteration dimension, and thus reduce the trip count. 526 // If 'computeCostMap' is non-null, the total op count for forOps specified 527 // in the map is increased (not overridden) by adding the op count from the 528 // map to the existing op count for the for loop. This is done before 529 // multiplying by the loop's trip count, and is used to model the cost of 530 // inserting a sliced loop nest of known cost into the loop's body. 531 // 2) This is also used to compute the cost of fusing a slice of some loop nest 532 // within another loop. 533 static int64_t getComputeCostHelper( 534 Operation *forOp, LoopNestStats &stats, 535 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 536 DenseMap<Operation *, int64_t> *computeCostMap) { 537 // 'opCount' is the total number operations in one iteration of 'forOp' body, 538 // minus terminator op which is a no-op. 539 int64_t opCount = stats.opCountMap[forOp] - 1; 540 if (stats.loopMap.count(forOp) > 0) { 541 for (auto childForOp : stats.loopMap[forOp]) { 542 opCount += getComputeCostHelper(childForOp.getOperation(), stats, 543 tripCountOverrideMap, computeCostMap); 544 } 545 } 546 // Add in additional op instances from slice (if specified in map). 547 if (computeCostMap != nullptr) { 548 auto it = computeCostMap->find(forOp); 549 if (it != computeCostMap->end()) { 550 opCount += it->second; 551 } 552 } 553 // Override trip count (if specified in map). 554 int64_t tripCount = stats.tripCountMap[forOp]; 555 if (tripCountOverrideMap != nullptr) { 556 auto it = tripCountOverrideMap->find(forOp); 557 if (it != tripCountOverrideMap->end()) { 558 tripCount = it->second; 559 } 560 } 561 // Returns the total number of dynamic instances of operations in loop body. 562 return tripCount * opCount; 563 } 564 565 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 566 /// Currently, the total cost is computed by counting the total operation 567 /// instance count (i.e. total number of operations in the loop body * loop 568 /// trip count) for the entire loop nest. 569 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 570 return getComputeCostHelper(forOp.getOperation(), stats, 571 /*tripCountOverrideMap=*/nullptr, 572 /*computeCostMap=*/nullptr); 573 } 574 575 /// Computes and returns in 'computeCost', the total compute cost of fusing the 576 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 577 /// the total cost is computed by counting the total operation instance count 578 /// (i.e. total number of operations in the loop body * loop trip count) for 579 /// the entire loop nest. 580 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, 581 AffineForOp dstForOp, LoopNestStats &dstStats, 582 const ComputationSliceState &slice, 583 int64_t *computeCost) { 584 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 585 DenseMap<Operation *, int64_t> computeCostMap; 586 587 // Build trip count map for computation slice. 588 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 589 return false; 590 // Checks whether a store to load forwarding will happen. 591 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 592 assert(sliceIterationCount > 0); 593 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 594 auto *insertPointParent = slice.insertPoint->getParentOp(); 595 596 // The store and loads to this memref will disappear. 597 // TODO: Add load coalescing to memref data flow opt pass. 598 if (storeLoadFwdGuaranteed) { 599 // Subtract from operation count the loads/store we expect load/store 600 // forwarding to remove. 601 unsigned storeCount = 0; 602 llvm::SmallDenseSet<Value, 4> storeMemrefs; 603 srcForOp.walk([&](Operation *op) { 604 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 605 storeMemrefs.insert(storeOp.getMemRef()); 606 ++storeCount; 607 } 608 }); 609 // Subtract out any store ops in single-iteration src slice loop nest. 610 if (storeCount > 0) 611 computeCostMap[insertPointParent] = -storeCount; 612 // Subtract out any load users of 'storeMemrefs' nested below 613 // 'insertPointParent'. 614 for (auto value : storeMemrefs) { 615 for (auto *user : value.getUsers()) { 616 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 617 SmallVector<AffineForOp, 4> loops; 618 // Check if any loop in loop nest surrounding 'user' is 619 // 'insertPointParent'. 620 getLoopIVs(*user, &loops); 621 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 622 if (auto forOp = 623 dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 624 if (computeCostMap.count(forOp) == 0) 625 computeCostMap[forOp] = 0; 626 computeCostMap[forOp] -= 1; 627 } 628 } 629 } 630 } 631 } 632 } 633 634 // Compute op instance count for the src loop nest with iteration slicing. 635 int64_t sliceComputeCost = getComputeCostHelper( 636 srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); 637 638 // Compute cost of fusion for this depth. 639 computeCostMap[insertPointParent] = sliceComputeCost; 640 641 *computeCost = 642 getComputeCostHelper(dstForOp.getOperation(), dstStats, 643 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 644 return true; 645 } 646 647 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 648 /// producer-consumer dependence between write ops in 'srcOps' and read ops in 649 /// 'dstOps'. 650 void mlir::gatherProducerConsumerMemrefs( 651 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, 652 DenseSet<Value> &producerConsumerMemrefs) { 653 // Gather memrefs from stores in 'srcOps'. 654 DenseSet<Value> srcStoreMemRefs; 655 for (Operation *op : srcOps) 656 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) 657 srcStoreMemRefs.insert(storeOp.getMemRef()); 658 659 // Compute the intersection between memrefs from stores in 'srcOps' and 660 // memrefs from loads in 'dstOps'. 661 for (Operation *op : dstOps) 662 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) 663 if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) 664 producerConsumerMemrefs.insert(loadOp.getMemRef()); 665 } 666