1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Interfaces/SideEffects.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Symbol Use Tracking 32 //===----------------------------------------------------------------------===// 33 34 /// Walk all of the used symbol callgraph nodes referenced with the given op. 35 static void walkReferencedSymbolNodes( 36 Operation *op, CallGraph &cg, 37 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 38 function_ref<void(CallGraphNode *, Operation *)> callback) { 39 auto symbolUses = SymbolTable::getSymbolUses(op); 40 assert(symbolUses && "expected uses to be valid"); 41 42 Operation *symbolTableOp = op->getParentOp(); 43 for (const SymbolTable::SymbolUse &use : *symbolUses) { 44 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 45 CallGraphNode *&node = refIt.first->second; 46 47 // If this is the first instance of this reference, try to resolve a 48 // callgraph node for it. 49 if (refIt.second) { 50 auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp, 51 use.getSymbolRef()); 52 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 53 if (!callableOp) 54 continue; 55 node = cg.lookupNode(callableOp.getCallableRegion()); 56 } 57 if (node) 58 callback(node, use.getUser()); 59 } 60 } 61 62 //===----------------------------------------------------------------------===// 63 // CGUseList 64 65 namespace { 66 /// This struct tracks the uses of callgraph nodes that can be dropped when 67 /// use_empty. It directly tracks and manages a use-list for all of the 68 /// call-graph nodes. This is necessary because many callgraph nodes are 69 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 70 /// class. 71 struct CGUseList { 72 /// This struct tracks the uses of callgraph nodes within a specific 73 /// operation. 74 struct CGUser { 75 /// Any nodes referenced in the top-level attribute list of this user. We 76 /// use a set here because the number of references does not matter. 77 DenseSet<CallGraphNode *> topLevelUses; 78 79 /// Uses of nodes referenced by nested operations. 80 DenseMap<CallGraphNode *, int> innerUses; 81 }; 82 83 CGUseList(Operation *op, CallGraph &cg); 84 85 /// Drop uses of nodes referred to by the given call operation that resides 86 /// within 'userNode'. 87 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 88 89 /// Remove the given node from the use list. 90 void eraseNode(CallGraphNode *node); 91 92 /// Returns true if the given callgraph node has no uses and can be pruned. 93 bool isDead(CallGraphNode *node) const; 94 95 /// Returns true if the given callgraph node has a single use and can be 96 /// discarded. 97 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 98 99 /// Recompute the uses held by the given callgraph node. 100 void recomputeUses(CallGraphNode *node, CallGraph &cg); 101 102 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 103 /// of 'lhs' into 'rhs'. 104 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 105 106 private: 107 /// Decrement the uses of discardable nodes referenced by the given user. 108 void decrementDiscardableUses(CGUser &uses); 109 110 /// A mapping between a discardable callgraph node (that is a symbol) and the 111 /// number of uses for this node. 112 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 113 /// A mapping between a callgraph node and the symbol callgraph nodes that it 114 /// uses. 115 DenseMap<CallGraphNode *, CGUser> nodeUses; 116 }; 117 } // end anonymous namespace 118 119 CGUseList::CGUseList(Operation *op, CallGraph &cg) { 120 /// A set of callgraph nodes that are always known to be live during inlining. 121 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 122 123 // Walk each of the symbol tables looking for discardable callgraph nodes. 124 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 125 for (Block &block : symbolTableOp->getRegion(0)) { 126 for (Operation &op : block) { 127 // If this is a callgraph operation, check to see if it is discardable. 128 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 129 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 130 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 131 if (symbol && (allUsesVisible || symbol.isPrivate()) && 132 symbol.canDiscardOnUseEmpty()) { 133 discardableSymNodeUses.try_emplace(node, 0); 134 } 135 continue; 136 } 137 } 138 // Otherwise, check for any referenced nodes. These will be always-live. 139 walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, 140 [](CallGraphNode *, Operation *) {}); 141 } 142 } 143 }; 144 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 145 walkFn); 146 147 // Drop the use information for any discardable nodes that are always live. 148 for (auto &it : alwaysLiveNodes) 149 discardableSymNodeUses.erase(it.second); 150 151 // Compute the uses for each of the callable nodes in the graph. 152 for (CallGraphNode *node : cg) 153 recomputeUses(node, cg); 154 } 155 156 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 157 CallGraph &cg) { 158 auto &userRefs = nodeUses[userNode].innerUses; 159 auto walkFn = [&](CallGraphNode *node, Operation *user) { 160 auto parentIt = userRefs.find(node); 161 if (parentIt == userRefs.end()) 162 return; 163 --parentIt->second; 164 --discardableSymNodeUses[node]; 165 }; 166 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 167 walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn); 168 } 169 170 void CGUseList::eraseNode(CallGraphNode *node) { 171 // Drop all child nodes. 172 for (auto &edge : *node) 173 if (edge.isChild()) 174 eraseNode(edge.getTarget()); 175 176 // Drop the uses held by this node and erase it. 177 auto useIt = nodeUses.find(node); 178 assert(useIt != nodeUses.end() && "expected node to be valid"); 179 decrementDiscardableUses(useIt->getSecond()); 180 nodeUses.erase(useIt); 181 discardableSymNodeUses.erase(node); 182 } 183 184 bool CGUseList::isDead(CallGraphNode *node) const { 185 // If the parent operation isn't a symbol, simply check normal SSA deadness. 186 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 187 if (!isa<SymbolOpInterface>(nodeOp)) 188 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 189 190 // Otherwise, check the number of symbol uses. 191 auto symbolIt = discardableSymNodeUses.find(node); 192 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 193 } 194 195 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 196 // If this isn't a symbol node, check for side-effects and SSA use count. 197 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 198 if (!isa<SymbolOpInterface>(nodeOp)) 199 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 200 201 // Otherwise, check the number of symbol uses. 202 auto symbolIt = discardableSymNodeUses.find(node); 203 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 204 } 205 206 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 207 Operation *parentOp = node->getCallableRegion()->getParentOp(); 208 CGUser &uses = nodeUses[node]; 209 decrementDiscardableUses(uses); 210 211 // Collect the new discardable uses within this node. 212 uses = CGUser(); 213 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 214 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 215 auto discardSymIt = discardableSymNodeUses.find(refNode); 216 if (discardSymIt == discardableSymNodeUses.end()) 217 return; 218 219 if (user != parentOp) 220 ++uses.innerUses[refNode]; 221 else if (!uses.topLevelUses.insert(refNode).second) 222 return; 223 ++discardSymIt->second; 224 }; 225 walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn); 226 } 227 228 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 229 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 230 for (auto &useIt : lhsUses.innerUses) { 231 rhsUses.innerUses[useIt.first] += useIt.second; 232 discardableSymNodeUses[useIt.first] += useIt.second; 233 } 234 } 235 236 void CGUseList::decrementDiscardableUses(CGUser &uses) { 237 for (CallGraphNode *node : uses.topLevelUses) 238 --discardableSymNodeUses[node]; 239 for (auto &it : uses.innerUses) 240 discardableSymNodeUses[it.first] -= it.second; 241 } 242 243 //===----------------------------------------------------------------------===// 244 // CallGraph traversal 245 //===----------------------------------------------------------------------===// 246 247 /// Run a given transformation over the SCCs of the callgraph in a bottom up 248 /// traversal. 249 static void runTransformOnCGSCCs( 250 const CallGraph &cg, 251 function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) { 252 std::vector<CallGraphNode *> currentSCCVec; 253 auto cgi = llvm::scc_begin(&cg); 254 while (!cgi.isAtEnd()) { 255 // Copy the current SCC and increment so that the transformer can modify the 256 // SCC without invalidating our iterator. 257 currentSCCVec = *cgi; 258 ++cgi; 259 sccTransformer(currentSCCVec); 260 } 261 } 262 263 namespace { 264 /// This struct represents a resolved call to a given callgraph node. Given that 265 /// the call does not actually contain a direct reference to the 266 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 267 /// explicitly. 268 struct ResolvedCall { 269 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 270 CallGraphNode *targetNode) 271 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 272 CallOpInterface call; 273 CallGraphNode *sourceNode, *targetNode; 274 }; 275 } // end anonymous namespace 276 277 /// Collect all of the callable operations within the given range of blocks. If 278 /// `traverseNestedCGNodes` is true, this will also collect call operations 279 /// inside of nested callgraph nodes. 280 static void collectCallOps(iterator_range<Region::iterator> blocks, 281 CallGraphNode *sourceNode, CallGraph &cg, 282 SmallVectorImpl<ResolvedCall> &calls, 283 bool traverseNestedCGNodes) { 284 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 285 auto addToWorklist = [&](CallGraphNode *node, 286 iterator_range<Region::iterator> blocks) { 287 for (Block &block : blocks) 288 worklist.emplace_back(&block, node); 289 }; 290 291 addToWorklist(sourceNode, blocks); 292 while (!worklist.empty()) { 293 Block *block; 294 std::tie(block, sourceNode) = worklist.pop_back_val(); 295 296 for (Operation &op : *block) { 297 if (auto call = dyn_cast<CallOpInterface>(op)) { 298 // TODO(riverriddle) Support inlining nested call references. 299 CallInterfaceCallable callable = call.getCallableForCallee(); 300 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 301 if (!symRef.isa<FlatSymbolRefAttr>()) 302 continue; 303 } 304 305 CallGraphNode *targetNode = cg.resolveCallable(call); 306 if (!targetNode->isExternal()) 307 calls.emplace_back(call, sourceNode, targetNode); 308 continue; 309 } 310 311 // If this is not a call, traverse the nested regions. If 312 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 313 // regions. 314 for (auto &nestedRegion : op.getRegions()) { 315 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 316 if (traverseNestedCGNodes || !nestedNode) 317 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 318 } 319 } 320 } 321 } 322 323 //===----------------------------------------------------------------------===// 324 // Inliner 325 //===----------------------------------------------------------------------===// 326 namespace { 327 /// This class provides a specialization of the main inlining interface. 328 struct Inliner : public InlinerInterface { 329 Inliner(MLIRContext *context, CallGraph &cg) 330 : InlinerInterface(context), cg(cg) {} 331 332 /// Process a set of blocks that have been inlined. This callback is invoked 333 /// *before* inlined terminator operations have been processed. 334 void 335 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 336 // Find the closest callgraph node from the first block. 337 CallGraphNode *node; 338 Region *region = inlinedBlocks.begin()->getParent(); 339 while (!(node = cg.lookupNode(region))) { 340 region = region->getParentRegion(); 341 assert(region && "expected valid parent node"); 342 } 343 344 collectCallOps(inlinedBlocks, node, cg, calls, 345 /*traverseNestedCGNodes=*/true); 346 } 347 348 /// The current set of call instructions to consider for inlining. 349 SmallVector<ResolvedCall, 8> calls; 350 351 /// The callgraph being operated on. 352 CallGraph &cg; 353 }; 354 } // namespace 355 356 /// Returns true if the given call should be inlined. 357 static bool shouldInline(ResolvedCall &resolvedCall) { 358 // Don't allow inlining terminator calls. We currently don't support this 359 // case. 360 if (resolvedCall.call.getOperation()->isKnownTerminator()) 361 return false; 362 363 // Don't allow inlining if the target is an ancestor of the call. This 364 // prevents inlining recursively. 365 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 366 resolvedCall.call.getParentRegion())) 367 return false; 368 369 // Otherwise, inline. 370 return true; 371 } 372 373 /// Delete the given node and remove it from the current scc and the callgraph. 374 static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg, 375 MutableArrayRef<CallGraphNode *> currentSCC) { 376 // Erase the parent operation and remove it from the various lists. 377 node->getCallableRegion()->getParentOp()->erase(); 378 cg.eraseNode(node); 379 380 // Replace this node in the currentSCC with the external node. 381 auto it = llvm::find(currentSCC, node); 382 if (it != currentSCC.end()) 383 *it = cg.getExternalNode(); 384 } 385 386 /// Attempt to inline calls within the given scc. This function returns 387 /// success if any calls were inlined, failure otherwise. 388 static LogicalResult 389 inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 390 MutableArrayRef<CallGraphNode *> currentSCC) { 391 CallGraph &cg = inliner.cg; 392 auto &calls = inliner.calls; 393 394 // Collect all of the direct calls within the nodes of the current SCC. We 395 // don't traverse nested callgraph nodes, because they are handled separately 396 // likely within a different SCC. 397 for (CallGraphNode *node : currentSCC) { 398 if (node->isExternal()) 399 continue; 400 401 // If this node is dead, just delete it now. 402 if (useList.isDead(node)) 403 deleteNode(node, useList, cg, currentSCC); 404 else 405 collectCallOps(*node->getCallableRegion(), node, cg, calls, 406 /*traverseNestedCGNodes=*/false); 407 } 408 if (calls.empty()) 409 return failure(); 410 411 // A set of dead nodes to remove after inlining. 412 SmallVector<CallGraphNode *, 1> deadNodes; 413 414 // Try to inline each of the call operations. Don't cache the end iterator 415 // here as more calls may be added during inlining. 416 bool inlinedAnyCalls = false; 417 for (unsigned i = 0; i != calls.size(); ++i) { 418 ResolvedCall it = calls[i]; 419 bool doInline = shouldInline(it); 420 LLVM_DEBUG({ 421 if (doInline) 422 llvm::dbgs() << "* Inlining call: "; 423 else 424 llvm::dbgs() << "* Not inlining call: "; 425 it.call.dump(); 426 }); 427 if (!doInline) 428 continue; 429 CallOpInterface call = it.call; 430 Region *targetRegion = it.targetNode->getCallableRegion(); 431 432 // If this is the last call to the target node and the node is discardable, 433 // then inline it in-place and delete the node if successful. 434 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 435 436 LogicalResult inlineResult = inlineCall( 437 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 438 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 439 if (failed(inlineResult)) 440 continue; 441 inlinedAnyCalls = true; 442 443 // If the inlining was successful, Merge the new uses into the source node. 444 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 445 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 446 447 // then erase the call. 448 call.erase(); 449 450 // If we inlined in place, mark the node for deletion. 451 if (inlineInPlace) { 452 useList.eraseNode(it.targetNode); 453 deadNodes.push_back(it.targetNode); 454 } 455 } 456 457 for (CallGraphNode *node : deadNodes) 458 deleteNode(node, useList, cg, currentSCC); 459 calls.clear(); 460 return success(inlinedAnyCalls); 461 } 462 463 /// Canonicalize the nodes within the given SCC with the given set of 464 /// canonicalization patterns. 465 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, 466 MutableArrayRef<CallGraphNode *> currentSCC, 467 MLIRContext *context, 468 const OwningRewritePatternList &canonPatterns) { 469 // Collect the sets of nodes to canonicalize. 470 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 471 for (auto *node : currentSCC) { 472 // Don't canonicalize the external node, it has no valid callable region. 473 if (node->isExternal()) 474 continue; 475 476 // Don't canonicalize nodes with children. Nodes with children 477 // require special handling as we may remove the node during 478 // canonicalization. In the future, we should be able to handle this 479 // case with proper node deletion tracking. 480 if (node->hasChildren()) 481 continue; 482 483 // We also won't apply canonicalizations for nodes that are not 484 // isolated. This avoids potentially mutating the regions of nodes defined 485 // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily' 486 // driver. 487 auto *region = node->getCallableRegion(); 488 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 489 continue; 490 nodesToCanonicalize.push_back(node); 491 } 492 if (nodesToCanonicalize.empty()) 493 return; 494 495 // Canonicalize each of the nodes within the SCC in parallel. 496 // NOTE: This is simple now, because we don't enable canonicalizing nodes 497 // within children. When we remove this restriction, this logic will need to 498 // be reworked. 499 ParallelDiagnosticHandler canonicalizationHandler(context); 500 llvm::parallel::for_each_n( 501 llvm::parallel::par, /*Begin=*/size_t(0), 502 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 503 // Set the order for this thread so that diagnostics will be properly 504 // ordered. 505 canonicalizationHandler.setOrderIDForThread(index); 506 507 // Apply the canonicalization patterns to this region. 508 auto *node = nodesToCanonicalize[index]; 509 applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); 510 511 // Make sure to reset the order ID for the diagnostic handler, as this 512 // thread may be used in a different context. 513 canonicalizationHandler.eraseOrderIDForThread(); 514 }); 515 516 // Recompute the uses held by each of the nodes. 517 for (CallGraphNode *node : nodesToCanonicalize) 518 useList.recomputeUses(node, cg); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // InlinerPass 523 //===----------------------------------------------------------------------===// 524 525 namespace { 526 struct InlinerPass : public InlinerBase<InlinerPass> { 527 void runOnOperation() override; 528 529 /// Attempt to inline calls within the given scc, and run canonicalizations 530 /// with the given patterns, until a fixed point is reached. This allows for 531 /// the inlining of newly devirtualized calls. 532 void inlineSCC(Inliner &inliner, CGUseList &useList, 533 MutableArrayRef<CallGraphNode *> currentSCC, 534 MLIRContext *context, 535 const OwningRewritePatternList &canonPatterns); 536 }; 537 } // end anonymous namespace 538 539 void InlinerPass::runOnOperation() { 540 CallGraph &cg = getAnalysis<CallGraph>(); 541 auto *context = &getContext(); 542 543 // The inliner should only be run on operations that define a symbol table, 544 // as the callgraph will need to resolve references. 545 Operation *op = getOperation(); 546 if (!op->hasTrait<OpTrait::SymbolTable>()) { 547 op->emitOpError() << " was scheduled to run under the inliner, but does " 548 "not define a symbol table"; 549 return signalPassFailure(); 550 } 551 552 // Collect a set of canonicalization patterns to use when simplifying 553 // callable regions within an SCC. 554 OwningRewritePatternList canonPatterns; 555 for (auto *op : context->getRegisteredOperations()) 556 op->getCanonicalizationPatterns(canonPatterns, context); 557 558 // Run the inline transform in post-order over the SCCs in the callgraph. 559 Inliner inliner(context, cg); 560 CGUseList useList(getOperation(), cg); 561 runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) { 562 inlineSCC(inliner, useList, scc, context, canonPatterns); 563 }); 564 } 565 566 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 567 MutableArrayRef<CallGraphNode *> currentSCC, 568 MLIRContext *context, 569 const OwningRewritePatternList &canonPatterns) { 570 // If we successfully inlined any calls, run some simplifications on the 571 // nodes of the scc. Continue attempting to inline until we reach a fixed 572 // point, or a maximum iteration count. We canonicalize here as it may 573 // devirtualize new calls, as well as give us a better cost model. 574 unsigned iterationCount = 0; 575 while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) { 576 // If we aren't allowing simplifications or the max iteration count was 577 // reached, then bail out early. 578 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 579 break; 580 canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns); 581 } 582 } 583 584 std::unique_ptr<Pass> mlir::createInlinerPass() { 585 return std::make_unique<InlinerPass>(); 586 } 587