1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Interfaces/SideEffects.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Symbol Use Tracking 32 //===----------------------------------------------------------------------===// 33 34 /// Returns true if this operation can be discarded if it is a symbol and has no 35 /// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden 36 /// from above. 37 static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) { 38 if (!SymbolTable::isSymbol(op)) 39 return false; 40 41 // TODO: This is essentially the same logic from SymbolDCE. Remove this when 42 // we have a 'Symbol' interface. 43 // Private symbols are always initially considered dead. 44 SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op); 45 if (visibility == mlir::SymbolTable::Visibility::Private) 46 return true; 47 // We only include nested visibility here if all uses are visible. 48 if (allUsesVisible && visibility == SymbolTable::Visibility::Nested) 49 return true; 50 // Otherwise, public symbols are never removable. 51 return false; 52 } 53 54 /// Walk all of the symbol table operations nested with 'op' along with a 55 /// boolean signifying if the symbols within can be treated as if all uses are 56 /// visible. The provided callback is invoked with the symbol table operation, 57 /// and a boolean signaling if all of the uses within the symbol table are 58 /// visible. 59 static void walkSymbolTables(Operation *op, bool allSymUsesVisible, 60 function_ref<void(Operation *, bool)> callback) { 61 if (op->hasTrait<OpTrait::SymbolTable>()) { 62 allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) || 63 SymbolTable::getSymbolVisibility(op) == 64 SymbolTable::Visibility::Private; 65 callback(op, allSymUsesVisible); 66 } else { 67 // Otherwise if 'op' is not a symbol table, any nested symbols are 68 // guaranteed to be hidden. 69 allSymUsesVisible = true; 70 } 71 72 for (Region ®ion : op->getRegions()) 73 for (Block &block : region) 74 for (Operation &nested : block) 75 walkSymbolTables(&nested, allSymUsesVisible, callback); 76 } 77 78 /// Walk all of the used symbol callgraph nodes referenced with the given op. 79 static void walkReferencedSymbolNodes( 80 Operation *op, CallGraph &cg, 81 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 82 function_ref<void(CallGraphNode *, Operation *)> callback) { 83 auto symbolUses = SymbolTable::getSymbolUses(op); 84 assert(symbolUses && "expected uses to be valid"); 85 86 Operation *symbolTableOp = op->getParentOp(); 87 for (const SymbolTable::SymbolUse &use : *symbolUses) { 88 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 89 CallGraphNode *&node = refIt.first->second; 90 91 // If this is the first instance of this reference, try to resolve a 92 // callgraph node for it. 93 if (refIt.second) { 94 auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp, 95 use.getSymbolRef()); 96 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 97 if (!callableOp) 98 continue; 99 node = cg.lookupNode(callableOp.getCallableRegion()); 100 } 101 if (node) 102 callback(node, use.getUser()); 103 } 104 } 105 106 //===----------------------------------------------------------------------===// 107 // CGUseList 108 109 namespace { 110 /// This struct tracks the uses of callgraph nodes that can be dropped when 111 /// use_empty. It directly tracks and manages a use-list for all of the 112 /// call-graph nodes. This is necessary because many callgraph nodes are 113 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 114 /// class. 115 struct CGUseList { 116 /// This struct tracks the uses of callgraph nodes within a specific 117 /// operation. 118 struct CGUser { 119 /// Any nodes referenced in the top-level attribute list of this user. We 120 /// use a set here because the number of references does not matter. 121 DenseSet<CallGraphNode *> topLevelUses; 122 123 /// Uses of nodes referenced by nested operations. 124 DenseMap<CallGraphNode *, int> innerUses; 125 }; 126 127 CGUseList(Operation *op, CallGraph &cg); 128 129 /// Drop uses of nodes referred to by the given call operation that resides 130 /// within 'userNode'. 131 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 132 133 /// Remove the given node from the use list. 134 void eraseNode(CallGraphNode *node); 135 136 /// Returns true if the given callgraph node has no uses and can be pruned. 137 bool isDead(CallGraphNode *node) const; 138 139 /// Returns true if the given callgraph node has a single use and can be 140 /// discarded. 141 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 142 143 /// Recompute the uses held by the given callgraph node. 144 void recomputeUses(CallGraphNode *node, CallGraph &cg); 145 146 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 147 /// of 'lhs' into 'rhs'. 148 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 149 150 private: 151 /// Decrement the uses of discardable nodes referenced by the given user. 152 void decrementDiscardableUses(CGUser &uses); 153 154 /// A mapping between a discardable callgraph node (that is a symbol) and the 155 /// number of uses for this node. 156 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 157 /// A mapping between a callgraph node and the symbol callgraph nodes that it 158 /// uses. 159 DenseMap<CallGraphNode *, CGUser> nodeUses; 160 }; 161 } // end anonymous namespace 162 163 CGUseList::CGUseList(Operation *op, CallGraph &cg) { 164 /// A set of callgraph nodes that are always known to be live during inlining. 165 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 166 167 // Walk each of the symbol tables looking for discardable callgraph nodes. 168 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 169 for (Block &block : symbolTableOp->getRegion(0)) { 170 for (Operation &op : block) { 171 // If this is a callgraph operation, check to see if it is discardable. 172 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 173 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 174 if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible)) 175 discardableSymNodeUses.try_emplace(node, 0); 176 continue; 177 } 178 } 179 // Otherwise, check for any referenced nodes. These will be always-live. 180 walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, 181 [](CallGraphNode *, Operation *) {}); 182 } 183 } 184 }; 185 walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), walkFn); 186 187 // Drop the use information for any discardable nodes that are always live. 188 for (auto &it : alwaysLiveNodes) 189 discardableSymNodeUses.erase(it.second); 190 191 // Compute the uses for each of the callable nodes in the graph. 192 for (CallGraphNode *node : cg) 193 recomputeUses(node, cg); 194 } 195 196 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 197 CallGraph &cg) { 198 auto &userRefs = nodeUses[userNode].innerUses; 199 auto walkFn = [&](CallGraphNode *node, Operation *user) { 200 auto parentIt = userRefs.find(node); 201 if (parentIt == userRefs.end()) 202 return; 203 --parentIt->second; 204 --discardableSymNodeUses[node]; 205 }; 206 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 207 walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn); 208 } 209 210 void CGUseList::eraseNode(CallGraphNode *node) { 211 // Drop all child nodes. 212 for (auto &edge : *node) 213 if (edge.isChild()) 214 eraseNode(edge.getTarget()); 215 216 // Drop the uses held by this node and erase it. 217 auto useIt = nodeUses.find(node); 218 assert(useIt != nodeUses.end() && "expected node to be valid"); 219 decrementDiscardableUses(useIt->getSecond()); 220 nodeUses.erase(useIt); 221 discardableSymNodeUses.erase(node); 222 } 223 224 bool CGUseList::isDead(CallGraphNode *node) const { 225 // If the parent operation isn't a symbol, simply check normal SSA deadness. 226 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 227 if (!SymbolTable::isSymbol(nodeOp)) 228 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 229 230 // Otherwise, check the number of symbol uses. 231 auto symbolIt = discardableSymNodeUses.find(node); 232 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 233 } 234 235 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 236 // If this isn't a symbol node, check for side-effects and SSA use count. 237 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 238 if (!SymbolTable::isSymbol(nodeOp)) 239 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 240 241 // Otherwise, check the number of symbol uses. 242 auto symbolIt = discardableSymNodeUses.find(node); 243 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 244 } 245 246 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 247 Operation *parentOp = node->getCallableRegion()->getParentOp(); 248 CGUser &uses = nodeUses[node]; 249 decrementDiscardableUses(uses); 250 251 // Collect the new discardable uses within this node. 252 uses = CGUser(); 253 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 254 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 255 auto discardSymIt = discardableSymNodeUses.find(refNode); 256 if (discardSymIt == discardableSymNodeUses.end()) 257 return; 258 259 if (user != parentOp) 260 ++uses.innerUses[refNode]; 261 else if (!uses.topLevelUses.insert(refNode).second) 262 return; 263 ++discardSymIt->second; 264 }; 265 walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn); 266 } 267 268 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 269 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 270 for (auto &useIt : lhsUses.innerUses) { 271 rhsUses.innerUses[useIt.first] += useIt.second; 272 discardableSymNodeUses[useIt.first] += useIt.second; 273 } 274 } 275 276 void CGUseList::decrementDiscardableUses(CGUser &uses) { 277 for (CallGraphNode *node : uses.topLevelUses) 278 --discardableSymNodeUses[node]; 279 for (auto &it : uses.innerUses) 280 discardableSymNodeUses[it.first] -= it.second; 281 } 282 283 //===----------------------------------------------------------------------===// 284 // CallGraph traversal 285 //===----------------------------------------------------------------------===// 286 287 /// Run a given transformation over the SCCs of the callgraph in a bottom up 288 /// traversal. 289 static void runTransformOnCGSCCs( 290 const CallGraph &cg, 291 function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) { 292 std::vector<CallGraphNode *> currentSCCVec; 293 auto cgi = llvm::scc_begin(&cg); 294 while (!cgi.isAtEnd()) { 295 // Copy the current SCC and increment so that the transformer can modify the 296 // SCC without invalidating our iterator. 297 currentSCCVec = *cgi; 298 ++cgi; 299 sccTransformer(currentSCCVec); 300 } 301 } 302 303 namespace { 304 /// This struct represents a resolved call to a given callgraph node. Given that 305 /// the call does not actually contain a direct reference to the 306 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 307 /// explicitly. 308 struct ResolvedCall { 309 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 310 CallGraphNode *targetNode) 311 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 312 CallOpInterface call; 313 CallGraphNode *sourceNode, *targetNode; 314 }; 315 } // end anonymous namespace 316 317 /// Collect all of the callable operations within the given range of blocks. If 318 /// `traverseNestedCGNodes` is true, this will also collect call operations 319 /// inside of nested callgraph nodes. 320 static void collectCallOps(iterator_range<Region::iterator> blocks, 321 CallGraphNode *sourceNode, CallGraph &cg, 322 SmallVectorImpl<ResolvedCall> &calls, 323 bool traverseNestedCGNodes) { 324 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 325 auto addToWorklist = [&](CallGraphNode *node, 326 iterator_range<Region::iterator> blocks) { 327 for (Block &block : blocks) 328 worklist.emplace_back(&block, node); 329 }; 330 331 addToWorklist(sourceNode, blocks); 332 while (!worklist.empty()) { 333 Block *block; 334 std::tie(block, sourceNode) = worklist.pop_back_val(); 335 336 for (Operation &op : *block) { 337 if (auto call = dyn_cast<CallOpInterface>(op)) { 338 // TODO(riverriddle) Support inlining nested call references. 339 CallInterfaceCallable callable = call.getCallableForCallee(); 340 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 341 if (!symRef.isa<FlatSymbolRefAttr>()) 342 continue; 343 } 344 345 CallGraphNode *targetNode = cg.resolveCallable(call); 346 if (!targetNode->isExternal()) 347 calls.emplace_back(call, sourceNode, targetNode); 348 continue; 349 } 350 351 // If this is not a call, traverse the nested regions. If 352 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 353 // regions. 354 for (auto &nestedRegion : op.getRegions()) { 355 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 356 if (traverseNestedCGNodes || !nestedNode) 357 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 358 } 359 } 360 } 361 } 362 363 //===----------------------------------------------------------------------===// 364 // Inliner 365 //===----------------------------------------------------------------------===// 366 namespace { 367 /// This class provides a specialization of the main inlining interface. 368 struct Inliner : public InlinerInterface { 369 Inliner(MLIRContext *context, CallGraph &cg) 370 : InlinerInterface(context), cg(cg) {} 371 372 /// Process a set of blocks that have been inlined. This callback is invoked 373 /// *before* inlined terminator operations have been processed. 374 void 375 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 376 // Find the closest callgraph node from the first block. 377 CallGraphNode *node; 378 Region *region = inlinedBlocks.begin()->getParent(); 379 while (!(node = cg.lookupNode(region))) { 380 region = region->getParentRegion(); 381 assert(region && "expected valid parent node"); 382 } 383 384 collectCallOps(inlinedBlocks, node, cg, calls, 385 /*traverseNestedCGNodes=*/true); 386 } 387 388 /// The current set of call instructions to consider for inlining. 389 SmallVector<ResolvedCall, 8> calls; 390 391 /// The callgraph being operated on. 392 CallGraph &cg; 393 }; 394 } // namespace 395 396 /// Returns true if the given call should be inlined. 397 static bool shouldInline(ResolvedCall &resolvedCall) { 398 // Don't allow inlining terminator calls. We currently don't support this 399 // case. 400 if (resolvedCall.call.getOperation()->isKnownTerminator()) 401 return false; 402 403 // Don't allow inlining if the target is an ancestor of the call. This 404 // prevents inlining recursively. 405 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 406 resolvedCall.call.getParentRegion())) 407 return false; 408 409 // Otherwise, inline. 410 return true; 411 } 412 413 /// Delete the given node and remove it from the current scc and the callgraph. 414 static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg, 415 MutableArrayRef<CallGraphNode *> currentSCC) { 416 // Erase the parent operation and remove it from the various lists. 417 node->getCallableRegion()->getParentOp()->erase(); 418 cg.eraseNode(node); 419 420 // Replace this node in the currentSCC with the external node. 421 auto it = llvm::find(currentSCC, node); 422 if (it != currentSCC.end()) 423 *it = cg.getExternalNode(); 424 } 425 426 /// Attempt to inline calls within the given scc. This function returns 427 /// success if any calls were inlined, failure otherwise. 428 static LogicalResult 429 inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 430 MutableArrayRef<CallGraphNode *> currentSCC) { 431 CallGraph &cg = inliner.cg; 432 auto &calls = inliner.calls; 433 434 // Collect all of the direct calls within the nodes of the current SCC. We 435 // don't traverse nested callgraph nodes, because they are handled separately 436 // likely within a different SCC. 437 for (CallGraphNode *node : currentSCC) { 438 if (node->isExternal()) 439 continue; 440 441 // If this node is dead, just delete it now. 442 if (useList.isDead(node)) 443 deleteNode(node, useList, cg, currentSCC); 444 else 445 collectCallOps(*node->getCallableRegion(), node, cg, calls, 446 /*traverseNestedCGNodes=*/false); 447 } 448 if (calls.empty()) 449 return failure(); 450 451 // A set of dead nodes to remove after inlining. 452 SmallVector<CallGraphNode *, 1> deadNodes; 453 454 // Try to inline each of the call operations. Don't cache the end iterator 455 // here as more calls may be added during inlining. 456 bool inlinedAnyCalls = false; 457 for (unsigned i = 0; i != calls.size(); ++i) { 458 ResolvedCall &it = calls[i]; 459 LLVM_DEBUG({ 460 llvm::dbgs() << "* Considering inlining call: "; 461 it.call.dump(); 462 }); 463 if (!shouldInline(it)) 464 continue; 465 CallOpInterface call = it.call; 466 Region *targetRegion = it.targetNode->getCallableRegion(); 467 468 // If this is the last call to the target node and the node is discardable, 469 // then inline it in-place and delete the node if successful. 470 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 471 472 LogicalResult inlineResult = inlineCall( 473 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 474 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 475 if (failed(inlineResult)) 476 continue; 477 inlinedAnyCalls = true; 478 479 // If the inlining was successful, Merge the new uses into the source node. 480 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 481 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 482 483 // then erase the call. 484 call.erase(); 485 486 // If we inlined in place, mark the node for deletion. 487 if (inlineInPlace) { 488 useList.eraseNode(it.targetNode); 489 deadNodes.push_back(it.targetNode); 490 } 491 } 492 493 for (CallGraphNode *node : deadNodes) 494 deleteNode(node, useList, cg, currentSCC); 495 calls.clear(); 496 return success(inlinedAnyCalls); 497 } 498 499 /// Canonicalize the nodes within the given SCC with the given set of 500 /// canonicalization patterns. 501 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, 502 MutableArrayRef<CallGraphNode *> currentSCC, 503 MLIRContext *context, 504 const OwningRewritePatternList &canonPatterns) { 505 // Collect the sets of nodes to canonicalize. 506 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 507 for (auto *node : currentSCC) { 508 // Don't canonicalize the external node, it has no valid callable region. 509 if (node->isExternal()) 510 continue; 511 512 // Don't canonicalize nodes with children. Nodes with children 513 // require special handling as we may remove the node during 514 // canonicalization. In the future, we should be able to handle this 515 // case with proper node deletion tracking. 516 if (node->hasChildren()) 517 continue; 518 519 // We also won't apply canonicalizations for nodes that are not 520 // isolated. This avoids potentially mutating the regions of nodes defined 521 // above, this is also a stipulation of the 'applyPatternsGreedily' driver. 522 auto *region = node->getCallableRegion(); 523 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 524 continue; 525 nodesToCanonicalize.push_back(node); 526 } 527 if (nodesToCanonicalize.empty()) 528 return; 529 530 // Canonicalize each of the nodes within the SCC in parallel. 531 // NOTE: This is simple now, because we don't enable canonicalizing nodes 532 // within children. When we remove this restriction, this logic will need to 533 // be reworked. 534 ParallelDiagnosticHandler canonicalizationHandler(context); 535 llvm::parallel::for_each_n( 536 llvm::parallel::par, /*Begin=*/size_t(0), 537 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 538 // Set the order for this thread so that diagnostics will be properly 539 // ordered. 540 canonicalizationHandler.setOrderIDForThread(index); 541 542 // Apply the canonicalization patterns to this region. 543 auto *node = nodesToCanonicalize[index]; 544 applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); 545 546 // Make sure to reset the order ID for the diagnostic handler, as this 547 // thread may be used in a different context. 548 canonicalizationHandler.eraseOrderIDForThread(); 549 }); 550 551 // Recompute the uses held by each of the nodes. 552 for (CallGraphNode *node : nodesToCanonicalize) 553 useList.recomputeUses(node, cg); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // InlinerPass 558 //===----------------------------------------------------------------------===// 559 560 namespace { 561 struct InlinerPass : public InlinerBase<InlinerPass> { 562 void runOnOperation() override; 563 564 /// Attempt to inline calls within the given scc, and run canonicalizations 565 /// with the given patterns, until a fixed point is reached. This allows for 566 /// the inlining of newly devirtualized calls. 567 void inlineSCC(Inliner &inliner, CGUseList &useList, 568 MutableArrayRef<CallGraphNode *> currentSCC, 569 MLIRContext *context, 570 const OwningRewritePatternList &canonPatterns); 571 }; 572 } // end anonymous namespace 573 574 void InlinerPass::runOnOperation() { 575 CallGraph &cg = getAnalysis<CallGraph>(); 576 auto *context = &getContext(); 577 578 // The inliner should only be run on operations that define a symbol table, 579 // as the callgraph will need to resolve references. 580 Operation *op = getOperation(); 581 if (!op->hasTrait<OpTrait::SymbolTable>()) { 582 op->emitOpError() << " was scheduled to run under the inliner, but does " 583 "not define a symbol table"; 584 return signalPassFailure(); 585 } 586 587 // Collect a set of canonicalization patterns to use when simplifying 588 // callable regions within an SCC. 589 OwningRewritePatternList canonPatterns; 590 for (auto *op : context->getRegisteredOperations()) 591 op->getCanonicalizationPatterns(canonPatterns, context); 592 593 // Run the inline transform in post-order over the SCCs in the callgraph. 594 Inliner inliner(context, cg); 595 CGUseList useList(getOperation(), cg); 596 runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) { 597 inlineSCC(inliner, useList, scc, context, canonPatterns); 598 }); 599 } 600 601 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 602 MutableArrayRef<CallGraphNode *> currentSCC, 603 MLIRContext *context, 604 const OwningRewritePatternList &canonPatterns) { 605 // If we successfully inlined any calls, run some simplifications on the 606 // nodes of the scc. Continue attempting to inline until we reach a fixed 607 // point, or a maximum iteration count. We canonicalize here as it may 608 // devirtualize new calls, as well as give us a better cost model. 609 unsigned iterationCount = 0; 610 while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) { 611 // If we aren't allowing simplifications or the max iteration count was 612 // reached, then bail out early. 613 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 614 break; 615 canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns); 616 } 617 } 618 619 std::unique_ptr<Pass> mlir::createInlinerPass() { 620 return std::make_unique<InlinerPass>(); 621 } 622