1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/InliningUtils.h" 22 #include "mlir/Transforms/Passes.h" 23 #include "llvm/ADT/SCCIterator.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/Parallel.h" 26 27 #define DEBUG_TYPE "inlining" 28 29 using namespace mlir; 30 31 //===----------------------------------------------------------------------===// 32 // Symbol Use Tracking 33 //===----------------------------------------------------------------------===// 34 35 /// Walk all of the used symbol callgraph nodes referenced with the given op. 36 static void walkReferencedSymbolNodes( 37 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 38 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 39 function_ref<void(CallGraphNode *, Operation *)> callback) { 40 auto symbolUses = SymbolTable::getSymbolUses(op); 41 assert(symbolUses && "expected uses to be valid"); 42 43 Operation *symbolTableOp = op->getParentOp(); 44 for (const SymbolTable::SymbolUse &use : *symbolUses) { 45 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 46 CallGraphNode *&node = refIt.first->second; 47 48 // If this is the first instance of this reference, try to resolve a 49 // callgraph node for it. 50 if (refIt.second) { 51 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 52 use.getSymbolRef()); 53 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 54 if (!callableOp) 55 continue; 56 node = cg.lookupNode(callableOp.getCallableRegion()); 57 } 58 if (node) 59 callback(node, use.getUser()); 60 } 61 } 62 63 //===----------------------------------------------------------------------===// 64 // CGUseList 65 66 namespace { 67 /// This struct tracks the uses of callgraph nodes that can be dropped when 68 /// use_empty. It directly tracks and manages a use-list for all of the 69 /// call-graph nodes. This is necessary because many callgraph nodes are 70 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 71 /// class. 72 struct CGUseList { 73 /// This struct tracks the uses of callgraph nodes within a specific 74 /// operation. 75 struct CGUser { 76 /// Any nodes referenced in the top-level attribute list of this user. We 77 /// use a set here because the number of references does not matter. 78 DenseSet<CallGraphNode *> topLevelUses; 79 80 /// Uses of nodes referenced by nested operations. 81 DenseMap<CallGraphNode *, int> innerUses; 82 }; 83 84 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 85 86 /// Drop uses of nodes referred to by the given call operation that resides 87 /// within 'userNode'. 88 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 89 90 /// Remove the given node from the use list. 91 void eraseNode(CallGraphNode *node); 92 93 /// Returns true if the given callgraph node has no uses and can be pruned. 94 bool isDead(CallGraphNode *node) const; 95 96 /// Returns true if the given callgraph node has a single use and can be 97 /// discarded. 98 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 99 100 /// Recompute the uses held by the given callgraph node. 101 void recomputeUses(CallGraphNode *node, CallGraph &cg); 102 103 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 104 /// of 'lhs' into 'rhs'. 105 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 106 107 private: 108 /// Decrement the uses of discardable nodes referenced by the given user. 109 void decrementDiscardableUses(CGUser &uses); 110 111 /// A mapping between a discardable callgraph node (that is a symbol) and the 112 /// number of uses for this node. 113 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 114 115 /// A mapping between a callgraph node and the symbol callgraph nodes that it 116 /// uses. 117 DenseMap<CallGraphNode *, CGUser> nodeUses; 118 119 /// A symbol table to use when resolving call lookups. 120 SymbolTableCollection &symbolTable; 121 }; 122 } // end anonymous namespace 123 124 CGUseList::CGUseList(Operation *op, CallGraph &cg, 125 SymbolTableCollection &symbolTable) 126 : symbolTable(symbolTable) { 127 /// A set of callgraph nodes that are always known to be live during inlining. 128 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 129 130 // Walk each of the symbol tables looking for discardable callgraph nodes. 131 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 132 for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 133 // If this is a callgraph operation, check to see if it is discardable. 134 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 135 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 136 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 137 if (symbol && (allUsesVisible || symbol.isPrivate()) && 138 symbol.canDiscardOnUseEmpty()) { 139 discardableSymNodeUses.try_emplace(node, 0); 140 } 141 continue; 142 } 143 } 144 // Otherwise, check for any referenced nodes. These will be always-live. 145 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 146 [](CallGraphNode *, Operation *) {}); 147 } 148 }; 149 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 150 walkFn); 151 152 // Drop the use information for any discardable nodes that are always live. 153 for (auto &it : alwaysLiveNodes) 154 discardableSymNodeUses.erase(it.second); 155 156 // Compute the uses for each of the callable nodes in the graph. 157 for (CallGraphNode *node : cg) 158 recomputeUses(node, cg); 159 } 160 161 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 162 CallGraph &cg) { 163 auto &userRefs = nodeUses[userNode].innerUses; 164 auto walkFn = [&](CallGraphNode *node, Operation *user) { 165 auto parentIt = userRefs.find(node); 166 if (parentIt == userRefs.end()) 167 return; 168 --parentIt->second; 169 --discardableSymNodeUses[node]; 170 }; 171 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 172 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 173 } 174 175 void CGUseList::eraseNode(CallGraphNode *node) { 176 // Drop all child nodes. 177 for (auto &edge : *node) 178 if (edge.isChild()) 179 eraseNode(edge.getTarget()); 180 181 // Drop the uses held by this node and erase it. 182 auto useIt = nodeUses.find(node); 183 assert(useIt != nodeUses.end() && "expected node to be valid"); 184 decrementDiscardableUses(useIt->getSecond()); 185 nodeUses.erase(useIt); 186 discardableSymNodeUses.erase(node); 187 } 188 189 bool CGUseList::isDead(CallGraphNode *node) const { 190 // If the parent operation isn't a symbol, simply check normal SSA deadness. 191 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 192 if (!isa<SymbolOpInterface>(nodeOp)) 193 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 194 195 // Otherwise, check the number of symbol uses. 196 auto symbolIt = discardableSymNodeUses.find(node); 197 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 198 } 199 200 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 201 // If this isn't a symbol node, check for side-effects and SSA use count. 202 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 203 if (!isa<SymbolOpInterface>(nodeOp)) 204 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 205 206 // Otherwise, check the number of symbol uses. 207 auto symbolIt = discardableSymNodeUses.find(node); 208 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 209 } 210 211 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 212 Operation *parentOp = node->getCallableRegion()->getParentOp(); 213 CGUser &uses = nodeUses[node]; 214 decrementDiscardableUses(uses); 215 216 // Collect the new discardable uses within this node. 217 uses = CGUser(); 218 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 219 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 220 auto discardSymIt = discardableSymNodeUses.find(refNode); 221 if (discardSymIt == discardableSymNodeUses.end()) 222 return; 223 224 if (user != parentOp) 225 ++uses.innerUses[refNode]; 226 else if (!uses.topLevelUses.insert(refNode).second) 227 return; 228 ++discardSymIt->second; 229 }; 230 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 231 } 232 233 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 234 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 235 for (auto &useIt : lhsUses.innerUses) { 236 rhsUses.innerUses[useIt.first] += useIt.second; 237 discardableSymNodeUses[useIt.first] += useIt.second; 238 } 239 } 240 241 void CGUseList::decrementDiscardableUses(CGUser &uses) { 242 for (CallGraphNode *node : uses.topLevelUses) 243 --discardableSymNodeUses[node]; 244 for (auto &it : uses.innerUses) 245 discardableSymNodeUses[it.first] -= it.second; 246 } 247 248 //===----------------------------------------------------------------------===// 249 // CallGraph traversal 250 //===----------------------------------------------------------------------===// 251 252 namespace { 253 /// This class represents a specific callgraph SCC. 254 class CallGraphSCC { 255 public: 256 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 257 : parentIterator(parentIterator) {} 258 /// Return a range over the nodes within this SCC. 259 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 260 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 261 262 /// Reset the nodes of this SCC with those provided. 263 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 264 265 /// Remove the given node from this SCC. 266 void remove(CallGraphNode *node) { 267 auto it = llvm::find(nodes, node); 268 if (it != nodes.end()) { 269 nodes.erase(it); 270 parentIterator.ReplaceNode(node, nullptr); 271 } 272 } 273 274 private: 275 std::vector<CallGraphNode *> nodes; 276 llvm::scc_iterator<const CallGraph *> &parentIterator; 277 }; 278 } // end anonymous namespace 279 280 /// Run a given transformation over the SCCs of the callgraph in a bottom up 281 /// traversal. 282 static void 283 runTransformOnCGSCCs(const CallGraph &cg, 284 function_ref<void(CallGraphSCC &)> sccTransformer) { 285 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 286 CallGraphSCC currentSCC(cgi); 287 while (!cgi.isAtEnd()) { 288 // Copy the current SCC and increment so that the transformer can modify the 289 // SCC without invalidating our iterator. 290 currentSCC.reset(*cgi); 291 ++cgi; 292 sccTransformer(currentSCC); 293 } 294 } 295 296 namespace { 297 /// This struct represents a resolved call to a given callgraph node. Given that 298 /// the call does not actually contain a direct reference to the 299 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 300 /// explicitly. 301 struct ResolvedCall { 302 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 303 CallGraphNode *targetNode) 304 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 305 CallOpInterface call; 306 CallGraphNode *sourceNode, *targetNode; 307 }; 308 } // end anonymous namespace 309 310 /// Collect all of the callable operations within the given range of blocks. If 311 /// `traverseNestedCGNodes` is true, this will also collect call operations 312 /// inside of nested callgraph nodes. 313 static void collectCallOps(iterator_range<Region::iterator> blocks, 314 CallGraphNode *sourceNode, CallGraph &cg, 315 SymbolTableCollection &symbolTable, 316 SmallVectorImpl<ResolvedCall> &calls, 317 bool traverseNestedCGNodes) { 318 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 319 auto addToWorklist = [&](CallGraphNode *node, 320 iterator_range<Region::iterator> blocks) { 321 for (Block &block : blocks) 322 worklist.emplace_back(&block, node); 323 }; 324 325 addToWorklist(sourceNode, blocks); 326 while (!worklist.empty()) { 327 Block *block; 328 std::tie(block, sourceNode) = worklist.pop_back_val(); 329 330 for (Operation &op : *block) { 331 if (auto call = dyn_cast<CallOpInterface>(op)) { 332 // TODO: Support inlining nested call references. 333 CallInterfaceCallable callable = call.getCallableForCallee(); 334 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 335 if (!symRef.isa<FlatSymbolRefAttr>()) 336 continue; 337 } 338 339 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 340 if (!targetNode->isExternal()) 341 calls.emplace_back(call, sourceNode, targetNode); 342 continue; 343 } 344 345 // If this is not a call, traverse the nested regions. If 346 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 347 // regions. 348 for (auto &nestedRegion : op.getRegions()) { 349 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 350 if (traverseNestedCGNodes || !nestedNode) 351 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 352 } 353 } 354 } 355 } 356 357 //===----------------------------------------------------------------------===// 358 // Inliner 359 //===----------------------------------------------------------------------===// 360 namespace { 361 /// This class provides a specialization of the main inlining interface. 362 struct Inliner : public InlinerInterface { 363 Inliner(MLIRContext *context, CallGraph &cg, 364 SymbolTableCollection &symbolTable) 365 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 366 367 /// Process a set of blocks that have been inlined. This callback is invoked 368 /// *before* inlined terminator operations have been processed. 369 void 370 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 371 // Find the closest callgraph node from the first block. 372 CallGraphNode *node; 373 Region *region = inlinedBlocks.begin()->getParent(); 374 while (!(node = cg.lookupNode(region))) { 375 region = region->getParentRegion(); 376 assert(region && "expected valid parent node"); 377 } 378 379 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 380 /*traverseNestedCGNodes=*/true); 381 } 382 383 /// Mark the given callgraph node for deletion. 384 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 385 386 /// This method properly disposes of callables that became dead during 387 /// inlining. This should not be called while iterating over the SCCs. 388 void eraseDeadCallables() { 389 for (CallGraphNode *node : deadNodes) 390 node->getCallableRegion()->getParentOp()->erase(); 391 } 392 393 /// The set of callables known to be dead. 394 SmallPtrSet<CallGraphNode *, 8> deadNodes; 395 396 /// The current set of call instructions to consider for inlining. 397 SmallVector<ResolvedCall, 8> calls; 398 399 /// The callgraph being operated on. 400 CallGraph &cg; 401 402 /// A symbol table to use when resolving call lookups. 403 SymbolTableCollection &symbolTable; 404 }; 405 } // namespace 406 407 /// Returns true if the given call should be inlined. 408 static bool shouldInline(ResolvedCall &resolvedCall) { 409 // Don't allow inlining terminator calls. We currently don't support this 410 // case. 411 if (resolvedCall.call.getOperation()->isKnownTerminator()) 412 return false; 413 414 // Don't allow inlining if the target is an ancestor of the call. This 415 // prevents inlining recursively. 416 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 417 resolvedCall.call.getParentRegion())) 418 return false; 419 420 // Otherwise, inline. 421 return true; 422 } 423 424 /// Attempt to inline calls within the given scc. This function returns 425 /// success if any calls were inlined, failure otherwise. 426 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 427 CallGraphSCC ¤tSCC) { 428 CallGraph &cg = inliner.cg; 429 auto &calls = inliner.calls; 430 431 // A set of dead nodes to remove after inlining. 432 SmallVector<CallGraphNode *, 1> deadNodes; 433 434 // Collect all of the direct calls within the nodes of the current SCC. We 435 // don't traverse nested callgraph nodes, because they are handled separately 436 // likely within a different SCC. 437 for (CallGraphNode *node : currentSCC) { 438 if (node->isExternal()) 439 continue; 440 441 // Don't collect calls if the node is already dead. 442 if (useList.isDead(node)) { 443 deadNodes.push_back(node); 444 } else { 445 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable, 446 calls, /*traverseNestedCGNodes=*/false); 447 } 448 } 449 450 // Try to inline each of the call operations. Don't cache the end iterator 451 // here as more calls may be added during inlining. 452 bool inlinedAnyCalls = false; 453 for (unsigned i = 0; i != calls.size(); ++i) { 454 ResolvedCall it = calls[i]; 455 bool doInline = shouldInline(it); 456 CallOpInterface call = it.call; 457 LLVM_DEBUG({ 458 if (doInline) 459 llvm::dbgs() << "* Inlining call: " << call << "\n"; 460 else 461 llvm::dbgs() << "* Not inlining call: " << call << "\n"; 462 }); 463 if (!doInline) 464 continue; 465 Region *targetRegion = it.targetNode->getCallableRegion(); 466 467 // If this is the last call to the target node and the node is discardable, 468 // then inline it in-place and delete the node if successful. 469 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 470 471 LogicalResult inlineResult = inlineCall( 472 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 473 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 474 if (failed(inlineResult)) { 475 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 476 continue; 477 } 478 inlinedAnyCalls = true; 479 480 // If the inlining was successful, Merge the new uses into the source node. 481 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 482 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 483 484 // then erase the call. 485 call.erase(); 486 487 // If we inlined in place, mark the node for deletion. 488 if (inlineInPlace) { 489 useList.eraseNode(it.targetNode); 490 deadNodes.push_back(it.targetNode); 491 } 492 } 493 494 for (CallGraphNode *node : deadNodes) { 495 currentSCC.remove(node); 496 inliner.markForDeletion(node); 497 } 498 calls.clear(); 499 return success(inlinedAnyCalls); 500 } 501 502 /// Canonicalize the nodes within the given SCC with the given set of 503 /// canonicalization patterns. 504 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, 505 CallGraphSCC ¤tSCC, MLIRContext *context, 506 const FrozenRewritePatternList &canonPatterns) { 507 // Collect the sets of nodes to canonicalize. 508 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 509 for (auto *node : currentSCC) { 510 // Don't canonicalize the external node, it has no valid callable region. 511 if (node->isExternal()) 512 continue; 513 514 // Don't canonicalize nodes with children. Nodes with children 515 // require special handling as we may remove the node during 516 // canonicalization. In the future, we should be able to handle this 517 // case with proper node deletion tracking. 518 if (node->hasChildren()) 519 continue; 520 521 // We also won't apply canonicalizations for nodes that are not 522 // isolated. This avoids potentially mutating the regions of nodes defined 523 // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily' 524 // driver. 525 auto *region = node->getCallableRegion(); 526 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 527 continue; 528 nodesToCanonicalize.push_back(node); 529 } 530 if (nodesToCanonicalize.empty()) 531 return; 532 533 // Canonicalize each of the nodes within the SCC in parallel. 534 // NOTE: This is simple now, because we don't enable canonicalizing nodes 535 // within children. When we remove this restriction, this logic will need to 536 // be reworked. 537 if (context->isMultithreadingEnabled()) { 538 ParallelDiagnosticHandler canonicalizationHandler(context); 539 llvm::parallelForEachN( 540 /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 541 // Set the order for this thread so that diagnostics will be properly 542 // ordered. 543 canonicalizationHandler.setOrderIDForThread(index); 544 545 // Apply the canonicalization patterns to this region. 546 auto *node = nodesToCanonicalize[index]; 547 applyPatternsAndFoldGreedily(*node->getCallableRegion(), 548 canonPatterns); 549 550 // Make sure to reset the order ID for the diagnostic handler, as this 551 // thread may be used in a different context. 552 canonicalizationHandler.eraseOrderIDForThread(); 553 }); 554 } else { 555 for (CallGraphNode *node : nodesToCanonicalize) 556 applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); 557 } 558 559 // Recompute the uses held by each of the nodes. 560 for (CallGraphNode *node : nodesToCanonicalize) 561 useList.recomputeUses(node, cg); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // InlinerPass 566 //===----------------------------------------------------------------------===// 567 568 namespace { 569 struct InlinerPass : public InlinerBase<InlinerPass> { 570 void runOnOperation() override; 571 572 /// Attempt to inline calls within the given scc, and run canonicalizations 573 /// with the given patterns, until a fixed point is reached. This allows for 574 /// the inlining of newly devirtualized calls. 575 void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, 576 MLIRContext *context, 577 const FrozenRewritePatternList &canonPatterns); 578 }; 579 } // end anonymous namespace 580 581 void InlinerPass::runOnOperation() { 582 CallGraph &cg = getAnalysis<CallGraph>(); 583 auto *context = &getContext(); 584 585 // The inliner should only be run on operations that define a symbol table, 586 // as the callgraph will need to resolve references. 587 Operation *op = getOperation(); 588 if (!op->hasTrait<OpTrait::SymbolTable>()) { 589 op->emitOpError() << " was scheduled to run under the inliner, but does " 590 "not define a symbol table"; 591 return signalPassFailure(); 592 } 593 594 // Collect a set of canonicalization patterns to use when simplifying 595 // callable regions within an SCC. 596 OwningRewritePatternList canonPatterns; 597 for (auto *op : context->getRegisteredOperations()) 598 op->getCanonicalizationPatterns(canonPatterns, context); 599 FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns)); 600 601 // Run the inline transform in post-order over the SCCs in the callgraph. 602 SymbolTableCollection symbolTable; 603 Inliner inliner(context, cg, symbolTable); 604 CGUseList useList(getOperation(), cg, symbolTable); 605 runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 606 inlineSCC(inliner, useList, scc, context, frozenCanonPatterns); 607 }); 608 609 // After inlining, make sure to erase any callables proven to be dead. 610 inliner.eraseDeadCallables(); 611 } 612 613 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 614 CallGraphSCC ¤tSCC, MLIRContext *context, 615 const FrozenRewritePatternList &canonPatterns) { 616 // If we successfully inlined any calls, run some simplifications on the 617 // nodes of the scc. Continue attempting to inline until we reach a fixed 618 // point, or a maximum iteration count. We canonicalize here as it may 619 // devirtualize new calls, as well as give us a better cost model. 620 unsigned iterationCount = 0; 621 while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) { 622 // If we aren't allowing simplifications or the max iteration count was 623 // reached, then bail out early. 624 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 625 break; 626 canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns); 627 } 628 } 629 630 std::unique_ptr<Pass> mlir::createInlinerPass() { 631 return std::make_unique<InlinerPass>(); 632 } 633