1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Analysis/CallGraph.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Interfaces/SideEffects.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 static llvm::cl::opt<bool> disableCanonicalization( 31 "mlir-disable-inline-simplify", 32 llvm::cl::desc("Disable running simplifications during inlining"), 33 llvm::cl::ReallyHidden, llvm::cl::init(false)); 34 35 static llvm::cl::opt<unsigned> maxInliningIterations( 36 "mlir-max-inline-iterations", 37 llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), 38 llvm::cl::ReallyHidden, llvm::cl::init(4)); 39 40 //===----------------------------------------------------------------------===// 41 // Symbol Use Tracking 42 //===----------------------------------------------------------------------===// 43 44 /// Returns true if this operation can be discarded if it is a symbol and has no 45 /// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden 46 /// from above. 47 static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) { 48 if (!SymbolTable::isSymbol(op)) 49 return false; 50 51 // TODO: This is essentially the same logic from SymbolDCE. Remove this when 52 // we have a 'Symbol' interface. 53 // Private symbols are always initially considered dead. 54 SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op); 55 if (visibility == mlir::SymbolTable::Visibility::Private) 56 return true; 57 // We only include nested visibility here if all uses are visible. 58 if (allUsesVisible && visibility == SymbolTable::Visibility::Nested) 59 return true; 60 // Otherwise, public symbols are never removable. 61 return false; 62 } 63 64 /// Walk all of the symbol table operations nested with 'op' along with a 65 /// boolean signifying if the symbols within can be treated as if all uses are 66 /// visible. The provided callback is invoked with the symbol table operation, 67 /// and a boolean signaling if all of the uses within the symbol table are 68 /// visible. 69 static void walkSymbolTables(Operation *op, bool allSymUsesVisible, 70 function_ref<void(Operation *, bool)> callback) { 71 if (op->hasTrait<OpTrait::SymbolTable>()) { 72 allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) || 73 SymbolTable::getSymbolVisibility(op) == 74 SymbolTable::Visibility::Private; 75 callback(op, allSymUsesVisible); 76 } else { 77 // Otherwise if 'op' is not a symbol table, any nested symbols are 78 // guaranteed to be hidden. 79 allSymUsesVisible = true; 80 } 81 82 for (Region ®ion : op->getRegions()) 83 for (Block &block : region) 84 for (Operation &nested : block) 85 walkSymbolTables(&nested, allSymUsesVisible, callback); 86 } 87 88 /// Walk all of the used symbol callgraph nodes referenced with the given op. 89 static void walkReferencedSymbolNodes( 90 Operation *op, CallGraph &cg, 91 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 92 function_ref<void(CallGraphNode *, Operation *)> callback) { 93 auto symbolUses = SymbolTable::getSymbolUses(op); 94 assert(symbolUses && "expected uses to be valid"); 95 96 Operation *symbolTableOp = op->getParentOp(); 97 for (const SymbolTable::SymbolUse &use : *symbolUses) { 98 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 99 CallGraphNode *&node = refIt.first->second; 100 101 // If this is the first instance of this reference, try to resolve a 102 // callgraph node for it. 103 if (refIt.second) { 104 auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp, 105 use.getSymbolRef()); 106 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 107 if (!callableOp) 108 continue; 109 node = cg.lookupNode(callableOp.getCallableRegion()); 110 } 111 if (node) 112 callback(node, use.getUser()); 113 } 114 } 115 116 //===----------------------------------------------------------------------===// 117 // CGUseList 118 119 namespace { 120 /// This struct tracks the uses of callgraph nodes that can be dropped when 121 /// use_empty. It directly tracks and manages a use-list for all of the 122 /// call-graph nodes. This is necessary because many callgraph nodes are 123 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 124 /// class. 125 struct CGUseList { 126 /// This struct tracks the uses of callgraph nodes within a specific 127 /// operation. 128 struct CGUser { 129 /// Any nodes referenced in the top-level attribute list of this user. We 130 /// use a set here because the number of references does not matter. 131 DenseSet<CallGraphNode *> topLevelUses; 132 133 /// Uses of nodes referenced by nested operations. 134 DenseMap<CallGraphNode *, int> innerUses; 135 }; 136 137 CGUseList(Operation *op, CallGraph &cg); 138 139 /// Drop uses of nodes referred to by the given call operation that resides 140 /// within 'userNode'. 141 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 142 143 /// Remove the given node from the use list. 144 void eraseNode(CallGraphNode *node); 145 146 /// Returns true if the given callgraph node has no uses and can be pruned. 147 bool isDead(CallGraphNode *node) const; 148 149 /// Returns true if the given callgraph node has a single use and can be 150 /// discarded. 151 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 152 153 /// Recompute the uses held by the given callgraph node. 154 void recomputeUses(CallGraphNode *node, CallGraph &cg); 155 156 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 157 /// of 'lhs' into 'rhs'. 158 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 159 160 private: 161 /// Decrement the uses of discardable nodes referenced by the given user. 162 void decrementDiscardableUses(CGUser &uses); 163 164 /// A mapping between a discardable callgraph node (that is a symbol) and the 165 /// number of uses for this node. 166 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 167 /// A mapping between a callgraph node and the symbol callgraph nodes that it 168 /// uses. 169 DenseMap<CallGraphNode *, CGUser> nodeUses; 170 }; 171 } // end anonymous namespace 172 173 CGUseList::CGUseList(Operation *op, CallGraph &cg) { 174 /// A set of callgraph nodes that are always known to be live during inlining. 175 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 176 177 // Walk each of the symbol tables looking for discardable callgraph nodes. 178 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 179 for (Block &block : symbolTableOp->getRegion(0)) { 180 for (Operation &op : block) { 181 // If this is a callgraph operation, check to see if it is discardable. 182 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 183 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 184 if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible)) 185 discardableSymNodeUses.try_emplace(node, 0); 186 continue; 187 } 188 } 189 // Otherwise, check for any referenced nodes. These will be always-live. 190 walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, 191 [](CallGraphNode *, Operation *) {}); 192 } 193 } 194 }; 195 walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), walkFn); 196 197 // Drop the use information for any discardable nodes that are always live. 198 for (auto &it : alwaysLiveNodes) 199 discardableSymNodeUses.erase(it.second); 200 201 // Compute the uses for each of the callable nodes in the graph. 202 for (CallGraphNode *node : cg) 203 recomputeUses(node, cg); 204 } 205 206 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 207 CallGraph &cg) { 208 auto &userRefs = nodeUses[userNode].innerUses; 209 auto walkFn = [&](CallGraphNode *node, Operation *user) { 210 auto parentIt = userRefs.find(node); 211 if (parentIt == userRefs.end()) 212 return; 213 --parentIt->second; 214 --discardableSymNodeUses[node]; 215 }; 216 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 217 walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn); 218 } 219 220 void CGUseList::eraseNode(CallGraphNode *node) { 221 // Drop all child nodes. 222 for (auto &edge : *node) 223 if (edge.isChild()) 224 eraseNode(edge.getTarget()); 225 226 // Drop the uses held by this node and erase it. 227 auto useIt = nodeUses.find(node); 228 assert(useIt != nodeUses.end() && "expected node to be valid"); 229 decrementDiscardableUses(useIt->getSecond()); 230 nodeUses.erase(useIt); 231 discardableSymNodeUses.erase(node); 232 } 233 234 bool CGUseList::isDead(CallGraphNode *node) const { 235 // If the parent operation isn't a symbol, simply check normal SSA deadness. 236 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 237 if (!SymbolTable::isSymbol(nodeOp)) 238 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 239 240 // Otherwise, check the number of symbol uses. 241 auto symbolIt = discardableSymNodeUses.find(node); 242 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 243 } 244 245 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 246 // If this isn't a symbol node, check for side-effects and SSA use count. 247 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 248 if (!SymbolTable::isSymbol(nodeOp)) 249 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 250 251 // Otherwise, check the number of symbol uses. 252 auto symbolIt = discardableSymNodeUses.find(node); 253 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 254 } 255 256 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 257 Operation *parentOp = node->getCallableRegion()->getParentOp(); 258 CGUser &uses = nodeUses[node]; 259 decrementDiscardableUses(uses); 260 261 // Collect the new discardable uses within this node. 262 uses = CGUser(); 263 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 264 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 265 auto discardSymIt = discardableSymNodeUses.find(refNode); 266 if (discardSymIt == discardableSymNodeUses.end()) 267 return; 268 269 if (user != parentOp) 270 ++uses.innerUses[refNode]; 271 else if (!uses.topLevelUses.insert(refNode).second) 272 return; 273 ++discardSymIt->second; 274 }; 275 walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn); 276 } 277 278 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 279 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 280 for (auto &useIt : lhsUses.innerUses) { 281 rhsUses.innerUses[useIt.first] += useIt.second; 282 discardableSymNodeUses[useIt.first] += useIt.second; 283 } 284 } 285 286 void CGUseList::decrementDiscardableUses(CGUser &uses) { 287 for (CallGraphNode *node : uses.topLevelUses) 288 --discardableSymNodeUses[node]; 289 for (auto &it : uses.innerUses) 290 discardableSymNodeUses[it.first] -= it.second; 291 } 292 293 //===----------------------------------------------------------------------===// 294 // CallGraph traversal 295 //===----------------------------------------------------------------------===// 296 297 /// Run a given transformation over the SCCs of the callgraph in a bottom up 298 /// traversal. 299 static void runTransformOnCGSCCs( 300 const CallGraph &cg, 301 function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) { 302 std::vector<CallGraphNode *> currentSCCVec; 303 auto cgi = llvm::scc_begin(&cg); 304 while (!cgi.isAtEnd()) { 305 // Copy the current SCC and increment so that the transformer can modify the 306 // SCC without invalidating our iterator. 307 currentSCCVec = *cgi; 308 ++cgi; 309 sccTransformer(currentSCCVec); 310 } 311 } 312 313 namespace { 314 /// This struct represents a resolved call to a given callgraph node. Given that 315 /// the call does not actually contain a direct reference to the 316 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 317 /// explicitly. 318 struct ResolvedCall { 319 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 320 CallGraphNode *targetNode) 321 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 322 CallOpInterface call; 323 CallGraphNode *sourceNode, *targetNode; 324 }; 325 } // end anonymous namespace 326 327 /// Collect all of the callable operations within the given range of blocks. If 328 /// `traverseNestedCGNodes` is true, this will also collect call operations 329 /// inside of nested callgraph nodes. 330 static void collectCallOps(iterator_range<Region::iterator> blocks, 331 CallGraphNode *sourceNode, CallGraph &cg, 332 SmallVectorImpl<ResolvedCall> &calls, 333 bool traverseNestedCGNodes) { 334 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 335 auto addToWorklist = [&](CallGraphNode *node, 336 iterator_range<Region::iterator> blocks) { 337 for (Block &block : blocks) 338 worklist.emplace_back(&block, node); 339 }; 340 341 addToWorklist(sourceNode, blocks); 342 while (!worklist.empty()) { 343 Block *block; 344 std::tie(block, sourceNode) = worklist.pop_back_val(); 345 346 for (Operation &op : *block) { 347 if (auto call = dyn_cast<CallOpInterface>(op)) { 348 // TODO(riverriddle) Support inlining nested call references. 349 CallInterfaceCallable callable = call.getCallableForCallee(); 350 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 351 if (!symRef.isa<FlatSymbolRefAttr>()) 352 continue; 353 } 354 355 CallGraphNode *targetNode = cg.resolveCallable(call); 356 if (!targetNode->isExternal()) 357 calls.emplace_back(call, sourceNode, targetNode); 358 continue; 359 } 360 361 // If this is not a call, traverse the nested regions. If 362 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 363 // regions. 364 for (auto &nestedRegion : op.getRegions()) { 365 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 366 if (traverseNestedCGNodes || !nestedNode) 367 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 368 } 369 } 370 } 371 } 372 373 //===----------------------------------------------------------------------===// 374 // Inliner 375 //===----------------------------------------------------------------------===// 376 namespace { 377 /// This class provides a specialization of the main inlining interface. 378 struct Inliner : public InlinerInterface { 379 Inliner(MLIRContext *context, CallGraph &cg) 380 : InlinerInterface(context), cg(cg) {} 381 382 /// Process a set of blocks that have been inlined. This callback is invoked 383 /// *before* inlined terminator operations have been processed. 384 void 385 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 386 // Find the closest callgraph node from the first block. 387 CallGraphNode *node; 388 Region *region = inlinedBlocks.begin()->getParent(); 389 while (!(node = cg.lookupNode(region))) { 390 region = region->getParentRegion(); 391 assert(region && "expected valid parent node"); 392 } 393 394 collectCallOps(inlinedBlocks, node, cg, calls, 395 /*traverseNestedCGNodes=*/true); 396 } 397 398 /// The current set of call instructions to consider for inlining. 399 SmallVector<ResolvedCall, 8> calls; 400 401 /// The callgraph being operated on. 402 CallGraph &cg; 403 }; 404 } // namespace 405 406 /// Returns true if the given call should be inlined. 407 static bool shouldInline(ResolvedCall &resolvedCall) { 408 // Don't allow inlining terminator calls. We currently don't support this 409 // case. 410 if (resolvedCall.call.getOperation()->isKnownTerminator()) 411 return false; 412 413 // Don't allow inlining if the target is an ancestor of the call. This 414 // prevents inlining recursively. 415 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 416 resolvedCall.call.getParentRegion())) 417 return false; 418 419 // Otherwise, inline. 420 return true; 421 } 422 423 /// Delete the given node and remove it from the current scc and the callgraph. 424 static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg, 425 MutableArrayRef<CallGraphNode *> currentSCC) { 426 // Erase the parent operation and remove it from the various lists. 427 node->getCallableRegion()->getParentOp()->erase(); 428 cg.eraseNode(node); 429 430 // Replace this node in the currentSCC with the external node. 431 auto it = llvm::find(currentSCC, node); 432 if (it != currentSCC.end()) 433 *it = cg.getExternalNode(); 434 } 435 436 /// Attempt to inline calls within the given scc. This function returns 437 /// success if any calls were inlined, failure otherwise. 438 static LogicalResult 439 inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 440 MutableArrayRef<CallGraphNode *> currentSCC) { 441 CallGraph &cg = inliner.cg; 442 auto &calls = inliner.calls; 443 444 // Collect all of the direct calls within the nodes of the current SCC. We 445 // don't traverse nested callgraph nodes, because they are handled separately 446 // likely within a different SCC. 447 for (CallGraphNode *node : currentSCC) { 448 if (node->isExternal()) 449 continue; 450 451 // If this node is dead, just delete it now. 452 if (useList.isDead(node)) 453 deleteNode(node, useList, cg, currentSCC); 454 else 455 collectCallOps(*node->getCallableRegion(), node, cg, calls, 456 /*traverseNestedCGNodes=*/false); 457 } 458 if (calls.empty()) 459 return failure(); 460 461 // A set of dead nodes to remove after inlining. 462 SmallVector<CallGraphNode *, 1> deadNodes; 463 464 // Try to inline each of the call operations. Don't cache the end iterator 465 // here as more calls may be added during inlining. 466 bool inlinedAnyCalls = false; 467 for (unsigned i = 0; i != calls.size(); ++i) { 468 ResolvedCall &it = calls[i]; 469 LLVM_DEBUG({ 470 llvm::dbgs() << "* Considering inlining call: "; 471 it.call.dump(); 472 }); 473 if (!shouldInline(it)) 474 continue; 475 CallOpInterface call = it.call; 476 Region *targetRegion = it.targetNode->getCallableRegion(); 477 478 // If this is the last call to the target node and the node is discardable, 479 // then inline it in-place and delete the node if successful. 480 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 481 482 LogicalResult inlineResult = inlineCall( 483 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 484 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 485 if (failed(inlineResult)) 486 continue; 487 inlinedAnyCalls = true; 488 489 // If the inlining was successful, Merge the new uses into the source node. 490 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 491 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 492 493 // then erase the call. 494 call.erase(); 495 496 // If we inlined in place, mark the node for deletion. 497 if (inlineInPlace) { 498 useList.eraseNode(it.targetNode); 499 deadNodes.push_back(it.targetNode); 500 } 501 } 502 503 for (CallGraphNode *node : deadNodes) 504 deleteNode(node, useList, cg, currentSCC); 505 calls.clear(); 506 return success(inlinedAnyCalls); 507 } 508 509 /// Canonicalize the nodes within the given SCC with the given set of 510 /// canonicalization patterns. 511 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, 512 MutableArrayRef<CallGraphNode *> currentSCC, 513 MLIRContext *context, 514 const OwningRewritePatternList &canonPatterns) { 515 // Collect the sets of nodes to canonicalize. 516 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 517 for (auto *node : currentSCC) { 518 // Don't canonicalize the external node, it has no valid callable region. 519 if (node->isExternal()) 520 continue; 521 522 // Don't canonicalize nodes with children. Nodes with children 523 // require special handling as we may remove the node during 524 // canonicalization. In the future, we should be able to handle this 525 // case with proper node deletion tracking. 526 if (node->hasChildren()) 527 continue; 528 529 // We also won't apply canonicalizations for nodes that are not 530 // isolated. This avoids potentially mutating the regions of nodes defined 531 // above, this is also a stipulation of the 'applyPatternsGreedily' driver. 532 auto *region = node->getCallableRegion(); 533 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 534 continue; 535 nodesToCanonicalize.push_back(node); 536 } 537 if (nodesToCanonicalize.empty()) 538 return; 539 540 // Canonicalize each of the nodes within the SCC in parallel. 541 // NOTE: This is simple now, because we don't enable canonicalizing nodes 542 // within children. When we remove this restriction, this logic will need to 543 // be reworked. 544 ParallelDiagnosticHandler canonicalizationHandler(context); 545 llvm::parallel::for_each_n( 546 llvm::parallel::par, /*Begin=*/size_t(0), 547 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 548 // Set the order for this thread so that diagnostics will be properly 549 // ordered. 550 canonicalizationHandler.setOrderIDForThread(index); 551 552 // Apply the canonicalization patterns to this region. 553 auto *node = nodesToCanonicalize[index]; 554 applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); 555 556 // Make sure to reset the order ID for the diagnostic handler, as this 557 // thread may be used in a different context. 558 canonicalizationHandler.eraseOrderIDForThread(); 559 }); 560 561 // Recompute the uses held by each of the nodes. 562 for (CallGraphNode *node : nodesToCanonicalize) 563 useList.recomputeUses(node, cg); 564 } 565 566 /// Attempt to inline calls within the given scc, and run canonicalizations with 567 /// the given patterns, until a fixed point is reached. This allows for the 568 /// inlining of newly devirtualized calls. 569 static void inlineSCC(Inliner &inliner, CGUseList &useList, 570 MutableArrayRef<CallGraphNode *> currentSCC, 571 MLIRContext *context, 572 const OwningRewritePatternList &canonPatterns) { 573 // If we successfully inlined any calls, run some simplifications on the 574 // nodes of the scc. Continue attempting to inline until we reach a fixed 575 // point, or a maximum iteration count. We canonicalize here as it may 576 // devirtualize new calls, as well as give us a better cost model. 577 unsigned iterationCount = 0; 578 while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) { 579 // If we aren't allowing simplifications or the max iteration count was 580 // reached, then bail out early. 581 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 582 break; 583 canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns); 584 } 585 } 586 587 //===----------------------------------------------------------------------===// 588 // InlinerPass 589 //===----------------------------------------------------------------------===// 590 591 namespace { 592 struct InlinerPass : public OperationPass<InlinerPass> { 593 void runOnOperation() override { 594 CallGraph &cg = getAnalysis<CallGraph>(); 595 auto *context = &getContext(); 596 597 // The inliner should only be run on operations that define a symbol table, 598 // as the callgraph will need to resolve references. 599 Operation *op = getOperation(); 600 if (!op->hasTrait<OpTrait::SymbolTable>()) { 601 op->emitOpError() << " was scheduled to run under the inliner, but does " 602 "not define a symbol table"; 603 return signalPassFailure(); 604 } 605 606 // Collect a set of canonicalization patterns to use when simplifying 607 // callable regions within an SCC. 608 OwningRewritePatternList canonPatterns; 609 for (auto *op : context->getRegisteredOperations()) 610 op->getCanonicalizationPatterns(canonPatterns, context); 611 612 // Run the inline transform in post-order over the SCCs in the callgraph. 613 Inliner inliner(context, cg); 614 CGUseList useList(getOperation(), cg); 615 runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) { 616 inlineSCC(inliner, useList, scc, context, canonPatterns); 617 }); 618 } 619 }; 620 } // end anonymous namespace 621 622 std::unique_ptr<Pass> mlir::createInlinerPass() { 623 return std::make_unique<InlinerPass>(); 624 } 625 626 static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); 627