1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Symbol Use Tracking 32 //===----------------------------------------------------------------------===// 33 34 /// Walk all of the used symbol callgraph nodes referenced with the given op. 35 static void walkReferencedSymbolNodes( 36 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 37 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 38 function_ref<void(CallGraphNode *, Operation *)> callback) { 39 auto symbolUses = SymbolTable::getSymbolUses(op); 40 assert(symbolUses && "expected uses to be valid"); 41 42 Operation *symbolTableOp = op->getParentOp(); 43 for (const SymbolTable::SymbolUse &use : *symbolUses) { 44 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 45 CallGraphNode *&node = refIt.first->second; 46 47 // If this is the first instance of this reference, try to resolve a 48 // callgraph node for it. 49 if (refIt.second) { 50 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 51 use.getSymbolRef()); 52 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 53 if (!callableOp) 54 continue; 55 node = cg.lookupNode(callableOp.getCallableRegion()); 56 } 57 if (node) 58 callback(node, use.getUser()); 59 } 60 } 61 62 //===----------------------------------------------------------------------===// 63 // CGUseList 64 65 namespace { 66 /// This struct tracks the uses of callgraph nodes that can be dropped when 67 /// use_empty. It directly tracks and manages a use-list for all of the 68 /// call-graph nodes. This is necessary because many callgraph nodes are 69 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 70 /// class. 71 struct CGUseList { 72 /// This struct tracks the uses of callgraph nodes within a specific 73 /// operation. 74 struct CGUser { 75 /// Any nodes referenced in the top-level attribute list of this user. We 76 /// use a set here because the number of references does not matter. 77 DenseSet<CallGraphNode *> topLevelUses; 78 79 /// Uses of nodes referenced by nested operations. 80 DenseMap<CallGraphNode *, int> innerUses; 81 }; 82 83 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 84 85 /// Drop uses of nodes referred to by the given call operation that resides 86 /// within 'userNode'. 87 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 88 89 /// Remove the given node from the use list. 90 void eraseNode(CallGraphNode *node); 91 92 /// Returns true if the given callgraph node has no uses and can be pruned. 93 bool isDead(CallGraphNode *node) const; 94 95 /// Returns true if the given callgraph node has a single use and can be 96 /// discarded. 97 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 98 99 /// Recompute the uses held by the given callgraph node. 100 void recomputeUses(CallGraphNode *node, CallGraph &cg); 101 102 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 103 /// of 'lhs' into 'rhs'. 104 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 105 106 private: 107 /// Decrement the uses of discardable nodes referenced by the given user. 108 void decrementDiscardableUses(CGUser &uses); 109 110 /// A mapping between a discardable callgraph node (that is a symbol) and the 111 /// number of uses for this node. 112 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 113 114 /// A mapping between a callgraph node and the symbol callgraph nodes that it 115 /// uses. 116 DenseMap<CallGraphNode *, CGUser> nodeUses; 117 118 /// A symbol table to use when resolving call lookups. 119 SymbolTableCollection &symbolTable; 120 }; 121 } // end anonymous namespace 122 123 CGUseList::CGUseList(Operation *op, CallGraph &cg, 124 SymbolTableCollection &symbolTable) 125 : symbolTable(symbolTable) { 126 /// A set of callgraph nodes that are always known to be live during inlining. 127 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 128 129 // Walk each of the symbol tables looking for discardable callgraph nodes. 130 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 131 for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 132 // If this is a callgraph operation, check to see if it is discardable. 133 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 134 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 135 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 136 if (symbol && (allUsesVisible || symbol.isPrivate()) && 137 symbol.canDiscardOnUseEmpty()) { 138 discardableSymNodeUses.try_emplace(node, 0); 139 } 140 continue; 141 } 142 } 143 // Otherwise, check for any referenced nodes. These will be always-live. 144 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 145 [](CallGraphNode *, Operation *) {}); 146 } 147 }; 148 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 149 walkFn); 150 151 // Drop the use information for any discardable nodes that are always live. 152 for (auto &it : alwaysLiveNodes) 153 discardableSymNodeUses.erase(it.second); 154 155 // Compute the uses for each of the callable nodes in the graph. 156 for (CallGraphNode *node : cg) 157 recomputeUses(node, cg); 158 } 159 160 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 161 CallGraph &cg) { 162 auto &userRefs = nodeUses[userNode].innerUses; 163 auto walkFn = [&](CallGraphNode *node, Operation *user) { 164 auto parentIt = userRefs.find(node); 165 if (parentIt == userRefs.end()) 166 return; 167 --parentIt->second; 168 --discardableSymNodeUses[node]; 169 }; 170 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 171 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 172 } 173 174 void CGUseList::eraseNode(CallGraphNode *node) { 175 // Drop all child nodes. 176 for (auto &edge : *node) 177 if (edge.isChild()) 178 eraseNode(edge.getTarget()); 179 180 // Drop the uses held by this node and erase it. 181 auto useIt = nodeUses.find(node); 182 assert(useIt != nodeUses.end() && "expected node to be valid"); 183 decrementDiscardableUses(useIt->getSecond()); 184 nodeUses.erase(useIt); 185 discardableSymNodeUses.erase(node); 186 } 187 188 bool CGUseList::isDead(CallGraphNode *node) const { 189 // If the parent operation isn't a symbol, simply check normal SSA deadness. 190 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 191 if (!isa<SymbolOpInterface>(nodeOp)) 192 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 193 194 // Otherwise, check the number of symbol uses. 195 auto symbolIt = discardableSymNodeUses.find(node); 196 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 197 } 198 199 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 200 // If this isn't a symbol node, check for side-effects and SSA use count. 201 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 202 if (!isa<SymbolOpInterface>(nodeOp)) 203 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 204 205 // Otherwise, check the number of symbol uses. 206 auto symbolIt = discardableSymNodeUses.find(node); 207 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 208 } 209 210 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 211 Operation *parentOp = node->getCallableRegion()->getParentOp(); 212 CGUser &uses = nodeUses[node]; 213 decrementDiscardableUses(uses); 214 215 // Collect the new discardable uses within this node. 216 uses = CGUser(); 217 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 218 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 219 auto discardSymIt = discardableSymNodeUses.find(refNode); 220 if (discardSymIt == discardableSymNodeUses.end()) 221 return; 222 223 if (user != parentOp) 224 ++uses.innerUses[refNode]; 225 else if (!uses.topLevelUses.insert(refNode).second) 226 return; 227 ++discardSymIt->second; 228 }; 229 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 230 } 231 232 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 233 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 234 for (auto &useIt : lhsUses.innerUses) { 235 rhsUses.innerUses[useIt.first] += useIt.second; 236 discardableSymNodeUses[useIt.first] += useIt.second; 237 } 238 } 239 240 void CGUseList::decrementDiscardableUses(CGUser &uses) { 241 for (CallGraphNode *node : uses.topLevelUses) 242 --discardableSymNodeUses[node]; 243 for (auto &it : uses.innerUses) 244 discardableSymNodeUses[it.first] -= it.second; 245 } 246 247 //===----------------------------------------------------------------------===// 248 // CallGraph traversal 249 //===----------------------------------------------------------------------===// 250 251 namespace { 252 /// This class represents a specific callgraph SCC. 253 class CallGraphSCC { 254 public: 255 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 256 : parentIterator(parentIterator) {} 257 /// Return a range over the nodes within this SCC. 258 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 259 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 260 261 /// Reset the nodes of this SCC with those provided. 262 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 263 264 /// Remove the given node from this SCC. 265 void remove(CallGraphNode *node) { 266 auto it = llvm::find(nodes, node); 267 if (it != nodes.end()) { 268 nodes.erase(it); 269 parentIterator.ReplaceNode(node, nullptr); 270 } 271 } 272 273 private: 274 std::vector<CallGraphNode *> nodes; 275 llvm::scc_iterator<const CallGraph *> &parentIterator; 276 }; 277 } // end anonymous namespace 278 279 /// Run a given transformation over the SCCs of the callgraph in a bottom up 280 /// traversal. 281 static void 282 runTransformOnCGSCCs(const CallGraph &cg, 283 function_ref<void(CallGraphSCC &)> sccTransformer) { 284 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 285 CallGraphSCC currentSCC(cgi); 286 while (!cgi.isAtEnd()) { 287 // Copy the current SCC and increment so that the transformer can modify the 288 // SCC without invalidating our iterator. 289 currentSCC.reset(*cgi); 290 ++cgi; 291 sccTransformer(currentSCC); 292 } 293 } 294 295 namespace { 296 /// This struct represents a resolved call to a given callgraph node. Given that 297 /// the call does not actually contain a direct reference to the 298 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 299 /// explicitly. 300 struct ResolvedCall { 301 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 302 CallGraphNode *targetNode) 303 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 304 CallOpInterface call; 305 CallGraphNode *sourceNode, *targetNode; 306 }; 307 } // end anonymous namespace 308 309 /// Collect all of the callable operations within the given range of blocks. If 310 /// `traverseNestedCGNodes` is true, this will also collect call operations 311 /// inside of nested callgraph nodes. 312 static void collectCallOps(iterator_range<Region::iterator> blocks, 313 CallGraphNode *sourceNode, CallGraph &cg, 314 SymbolTableCollection &symbolTable, 315 SmallVectorImpl<ResolvedCall> &calls, 316 bool traverseNestedCGNodes) { 317 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 318 auto addToWorklist = [&](CallGraphNode *node, 319 iterator_range<Region::iterator> blocks) { 320 for (Block &block : blocks) 321 worklist.emplace_back(&block, node); 322 }; 323 324 addToWorklist(sourceNode, blocks); 325 while (!worklist.empty()) { 326 Block *block; 327 std::tie(block, sourceNode) = worklist.pop_back_val(); 328 329 for (Operation &op : *block) { 330 if (auto call = dyn_cast<CallOpInterface>(op)) { 331 // TODO: Support inlining nested call references. 332 CallInterfaceCallable callable = call.getCallableForCallee(); 333 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 334 if (!symRef.isa<FlatSymbolRefAttr>()) 335 continue; 336 } 337 338 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 339 if (!targetNode->isExternal()) 340 calls.emplace_back(call, sourceNode, targetNode); 341 continue; 342 } 343 344 // If this is not a call, traverse the nested regions. If 345 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 346 // regions. 347 for (auto &nestedRegion : op.getRegions()) { 348 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 349 if (traverseNestedCGNodes || !nestedNode) 350 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 351 } 352 } 353 } 354 } 355 356 //===----------------------------------------------------------------------===// 357 // Inliner 358 //===----------------------------------------------------------------------===// 359 namespace { 360 /// This class provides a specialization of the main inlining interface. 361 struct Inliner : public InlinerInterface { 362 Inliner(MLIRContext *context, CallGraph &cg, 363 SymbolTableCollection &symbolTable) 364 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 365 366 /// Process a set of blocks that have been inlined. This callback is invoked 367 /// *before* inlined terminator operations have been processed. 368 void 369 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 370 // Find the closest callgraph node from the first block. 371 CallGraphNode *node; 372 Region *region = inlinedBlocks.begin()->getParent(); 373 while (!(node = cg.lookupNode(region))) { 374 region = region->getParentRegion(); 375 assert(region && "expected valid parent node"); 376 } 377 378 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 379 /*traverseNestedCGNodes=*/true); 380 } 381 382 /// Mark the given callgraph node for deletion. 383 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 384 385 /// This method properly disposes of callables that became dead during 386 /// inlining. This should not be called while iterating over the SCCs. 387 void eraseDeadCallables() { 388 for (CallGraphNode *node : deadNodes) 389 node->getCallableRegion()->getParentOp()->erase(); 390 } 391 392 /// The set of callables known to be dead. 393 SmallPtrSet<CallGraphNode *, 8> deadNodes; 394 395 /// The current set of call instructions to consider for inlining. 396 SmallVector<ResolvedCall, 8> calls; 397 398 /// The callgraph being operated on. 399 CallGraph &cg; 400 401 /// A symbol table to use when resolving call lookups. 402 SymbolTableCollection &symbolTable; 403 }; 404 } // namespace 405 406 /// Returns true if the given call should be inlined. 407 static bool shouldInline(ResolvedCall &resolvedCall) { 408 // Don't allow inlining terminator calls. We currently don't support this 409 // case. 410 if (resolvedCall.call.getOperation()->isKnownTerminator()) 411 return false; 412 413 // Don't allow inlining if the target is an ancestor of the call. This 414 // prevents inlining recursively. 415 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 416 resolvedCall.call.getParentRegion())) 417 return false; 418 419 // Otherwise, inline. 420 return true; 421 } 422 423 /// Attempt to inline calls within the given scc. This function returns 424 /// success if any calls were inlined, failure otherwise. 425 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 426 CallGraphSCC ¤tSCC) { 427 CallGraph &cg = inliner.cg; 428 auto &calls = inliner.calls; 429 430 // A set of dead nodes to remove after inlining. 431 SmallVector<CallGraphNode *, 1> deadNodes; 432 433 // Collect all of the direct calls within the nodes of the current SCC. We 434 // don't traverse nested callgraph nodes, because they are handled separately 435 // likely within a different SCC. 436 for (CallGraphNode *node : currentSCC) { 437 if (node->isExternal()) 438 continue; 439 440 // Don't collect calls if the node is already dead. 441 if (useList.isDead(node)) { 442 deadNodes.push_back(node); 443 } else { 444 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable, 445 calls, /*traverseNestedCGNodes=*/false); 446 } 447 } 448 449 // Try to inline each of the call operations. Don't cache the end iterator 450 // here as more calls may be added during inlining. 451 bool inlinedAnyCalls = false; 452 for (unsigned i = 0; i != calls.size(); ++i) { 453 ResolvedCall it = calls[i]; 454 bool doInline = shouldInline(it); 455 CallOpInterface call = it.call; 456 LLVM_DEBUG({ 457 if (doInline) 458 llvm::dbgs() << "* Inlining call: " << call << "\n"; 459 else 460 llvm::dbgs() << "* Not inlining call: " << call << "\n"; 461 }); 462 if (!doInline) 463 continue; 464 Region *targetRegion = it.targetNode->getCallableRegion(); 465 466 // If this is the last call to the target node and the node is discardable, 467 // then inline it in-place and delete the node if successful. 468 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 469 470 LogicalResult inlineResult = inlineCall( 471 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 472 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 473 if (failed(inlineResult)) { 474 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 475 continue; 476 } 477 inlinedAnyCalls = true; 478 479 // If the inlining was successful, Merge the new uses into the source node. 480 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 481 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 482 483 // then erase the call. 484 call.erase(); 485 486 // If we inlined in place, mark the node for deletion. 487 if (inlineInPlace) { 488 useList.eraseNode(it.targetNode); 489 deadNodes.push_back(it.targetNode); 490 } 491 } 492 493 for (CallGraphNode *node : deadNodes) { 494 currentSCC.remove(node); 495 inliner.markForDeletion(node); 496 } 497 calls.clear(); 498 return success(inlinedAnyCalls); 499 } 500 501 /// Canonicalize the nodes within the given SCC with the given set of 502 /// canonicalization patterns. 503 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, 504 CallGraphSCC ¤tSCC, MLIRContext *context, 505 const OwningRewritePatternList &canonPatterns) { 506 // Collect the sets of nodes to canonicalize. 507 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 508 for (auto *node : currentSCC) { 509 // Don't canonicalize the external node, it has no valid callable region. 510 if (node->isExternal()) 511 continue; 512 513 // Don't canonicalize nodes with children. Nodes with children 514 // require special handling as we may remove the node during 515 // canonicalization. In the future, we should be able to handle this 516 // case with proper node deletion tracking. 517 if (node->hasChildren()) 518 continue; 519 520 // We also won't apply canonicalizations for nodes that are not 521 // isolated. This avoids potentially mutating the regions of nodes defined 522 // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily' 523 // driver. 524 auto *region = node->getCallableRegion(); 525 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 526 continue; 527 nodesToCanonicalize.push_back(node); 528 } 529 if (nodesToCanonicalize.empty()) 530 return; 531 532 // Canonicalize each of the nodes within the SCC in parallel. 533 // NOTE: This is simple now, because we don't enable canonicalizing nodes 534 // within children. When we remove this restriction, this logic will need to 535 // be reworked. 536 if (context->isMultithreadingEnabled()) { 537 ParallelDiagnosticHandler canonicalizationHandler(context); 538 llvm::parallelForEachN( 539 /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 540 // Set the order for this thread so that diagnostics will be properly 541 // ordered. 542 canonicalizationHandler.setOrderIDForThread(index); 543 544 // Apply the canonicalization patterns to this region. 545 auto *node = nodesToCanonicalize[index]; 546 applyPatternsAndFoldGreedily(*node->getCallableRegion(), 547 canonPatterns); 548 549 // Make sure to reset the order ID for the diagnostic handler, as this 550 // thread may be used in a different context. 551 canonicalizationHandler.eraseOrderIDForThread(); 552 }); 553 } else { 554 for (CallGraphNode *node : nodesToCanonicalize) 555 applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); 556 } 557 558 // Recompute the uses held by each of the nodes. 559 for (CallGraphNode *node : nodesToCanonicalize) 560 useList.recomputeUses(node, cg); 561 } 562 563 //===----------------------------------------------------------------------===// 564 // InlinerPass 565 //===----------------------------------------------------------------------===// 566 567 namespace { 568 struct InlinerPass : public InlinerBase<InlinerPass> { 569 void runOnOperation() override; 570 571 /// Attempt to inline calls within the given scc, and run canonicalizations 572 /// with the given patterns, until a fixed point is reached. This allows for 573 /// the inlining of newly devirtualized calls. 574 void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, 575 MLIRContext *context, 576 const OwningRewritePatternList &canonPatterns); 577 }; 578 } // end anonymous namespace 579 580 void InlinerPass::runOnOperation() { 581 CallGraph &cg = getAnalysis<CallGraph>(); 582 auto *context = &getContext(); 583 584 // The inliner should only be run on operations that define a symbol table, 585 // as the callgraph will need to resolve references. 586 Operation *op = getOperation(); 587 if (!op->hasTrait<OpTrait::SymbolTable>()) { 588 op->emitOpError() << " was scheduled to run under the inliner, but does " 589 "not define a symbol table"; 590 return signalPassFailure(); 591 } 592 593 // Collect a set of canonicalization patterns to use when simplifying 594 // callable regions within an SCC. 595 OwningRewritePatternList canonPatterns; 596 for (auto *op : context->getRegisteredOperations()) 597 op->getCanonicalizationPatterns(canonPatterns, context); 598 599 // Run the inline transform in post-order over the SCCs in the callgraph. 600 SymbolTableCollection symbolTable; 601 Inliner inliner(context, cg, symbolTable); 602 CGUseList useList(getOperation(), cg, symbolTable); 603 runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 604 inlineSCC(inliner, useList, scc, context, canonPatterns); 605 }); 606 607 // After inlining, make sure to erase any callables proven to be dead. 608 inliner.eraseDeadCallables(); 609 } 610 611 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 612 CallGraphSCC ¤tSCC, MLIRContext *context, 613 const OwningRewritePatternList &canonPatterns) { 614 // If we successfully inlined any calls, run some simplifications on the 615 // nodes of the scc. Continue attempting to inline until we reach a fixed 616 // point, or a maximum iteration count. We canonicalize here as it may 617 // devirtualize new calls, as well as give us a better cost model. 618 unsigned iterationCount = 0; 619 while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) { 620 // If we aren't allowing simplifications or the max iteration count was 621 // reached, then bail out early. 622 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 623 break; 624 canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns); 625 } 626 } 627 628 std::unique_ptr<Pass> mlir::createInlinerPass() { 629 return std::make_unique<InlinerPass>(); 630 } 631