1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/Threading.h" 19 #include "mlir/Interfaces/CallInterfaces.h" 20 #include "mlir/Interfaces/SideEffectInterfaces.h" 21 #include "mlir/Pass/PassManager.h" 22 #include "mlir/Support/DebugStringHelper.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "mlir/Transforms/Passes.h" 25 #include "llvm/ADT/SCCIterator.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "inlining" 29 30 using namespace mlir; 31 32 /// This function implements the default inliner optimization pipeline. 33 static void defaultInlinerOptPipeline(OpPassManager &pm) { 34 pm.addPass(createCanonicalizerPass()); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // Symbol Use Tracking 39 //===----------------------------------------------------------------------===// 40 41 /// Walk all of the used symbol callgraph nodes referenced with the given op. 42 static void walkReferencedSymbolNodes( 43 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 44 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 45 function_ref<void(CallGraphNode *, Operation *)> callback) { 46 auto symbolUses = SymbolTable::getSymbolUses(op); 47 assert(symbolUses && "expected uses to be valid"); 48 49 Operation *symbolTableOp = op->getParentOp(); 50 for (const SymbolTable::SymbolUse &use : *symbolUses) { 51 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 52 CallGraphNode *&node = refIt.first->second; 53 54 // If this is the first instance of this reference, try to resolve a 55 // callgraph node for it. 56 if (refIt.second) { 57 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 58 use.getSymbolRef()); 59 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 60 if (!callableOp) 61 continue; 62 node = cg.lookupNode(callableOp.getCallableRegion()); 63 } 64 if (node) 65 callback(node, use.getUser()); 66 } 67 } 68 69 //===----------------------------------------------------------------------===// 70 // CGUseList 71 72 namespace { 73 /// This struct tracks the uses of callgraph nodes that can be dropped when 74 /// use_empty. It directly tracks and manages a use-list for all of the 75 /// call-graph nodes. This is necessary because many callgraph nodes are 76 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 77 /// class. 78 struct CGUseList { 79 /// This struct tracks the uses of callgraph nodes within a specific 80 /// operation. 81 struct CGUser { 82 /// Any nodes referenced in the top-level attribute list of this user. We 83 /// use a set here because the number of references does not matter. 84 DenseSet<CallGraphNode *> topLevelUses; 85 86 /// Uses of nodes referenced by nested operations. 87 DenseMap<CallGraphNode *, int> innerUses; 88 }; 89 90 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 91 92 /// Drop uses of nodes referred to by the given call operation that resides 93 /// within 'userNode'. 94 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 95 96 /// Remove the given node from the use list. 97 void eraseNode(CallGraphNode *node); 98 99 /// Returns true if the given callgraph node has no uses and can be pruned. 100 bool isDead(CallGraphNode *node) const; 101 102 /// Returns true if the given callgraph node has a single use and can be 103 /// discarded. 104 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 105 106 /// Recompute the uses held by the given callgraph node. 107 void recomputeUses(CallGraphNode *node, CallGraph &cg); 108 109 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 110 /// of 'lhs' into 'rhs'. 111 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 112 113 private: 114 /// Decrement the uses of discardable nodes referenced by the given user. 115 void decrementDiscardableUses(CGUser &uses); 116 117 /// A mapping between a discardable callgraph node (that is a symbol) and the 118 /// number of uses for this node. 119 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 120 121 /// A mapping between a callgraph node and the symbol callgraph nodes that it 122 /// uses. 123 DenseMap<CallGraphNode *, CGUser> nodeUses; 124 125 /// A symbol table to use when resolving call lookups. 126 SymbolTableCollection &symbolTable; 127 }; 128 } // namespace 129 130 CGUseList::CGUseList(Operation *op, CallGraph &cg, 131 SymbolTableCollection &symbolTable) 132 : symbolTable(symbolTable) { 133 /// A set of callgraph nodes that are always known to be live during inlining. 134 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 135 136 // Walk each of the symbol tables looking for discardable callgraph nodes. 137 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 138 for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 139 // If this is a callgraph operation, check to see if it is discardable. 140 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 141 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 142 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 143 if (symbol && (allUsesVisible || symbol.isPrivate()) && 144 symbol.canDiscardOnUseEmpty()) { 145 discardableSymNodeUses.try_emplace(node, 0); 146 } 147 continue; 148 } 149 } 150 // Otherwise, check for any referenced nodes. These will be always-live. 151 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 152 [](CallGraphNode *, Operation *) {}); 153 } 154 }; 155 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 156 walkFn); 157 158 // Drop the use information for any discardable nodes that are always live. 159 for (auto &it : alwaysLiveNodes) 160 discardableSymNodeUses.erase(it.second); 161 162 // Compute the uses for each of the callable nodes in the graph. 163 for (CallGraphNode *node : cg) 164 recomputeUses(node, cg); 165 } 166 167 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 168 CallGraph &cg) { 169 auto &userRefs = nodeUses[userNode].innerUses; 170 auto walkFn = [&](CallGraphNode *node, Operation *user) { 171 auto parentIt = userRefs.find(node); 172 if (parentIt == userRefs.end()) 173 return; 174 --parentIt->second; 175 --discardableSymNodeUses[node]; 176 }; 177 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 178 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 179 } 180 181 void CGUseList::eraseNode(CallGraphNode *node) { 182 // Drop all child nodes. 183 for (auto &edge : *node) 184 if (edge.isChild()) 185 eraseNode(edge.getTarget()); 186 187 // Drop the uses held by this node and erase it. 188 auto useIt = nodeUses.find(node); 189 assert(useIt != nodeUses.end() && "expected node to be valid"); 190 decrementDiscardableUses(useIt->getSecond()); 191 nodeUses.erase(useIt); 192 discardableSymNodeUses.erase(node); 193 } 194 195 bool CGUseList::isDead(CallGraphNode *node) const { 196 // If the parent operation isn't a symbol, simply check normal SSA deadness. 197 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 198 if (!isa<SymbolOpInterface>(nodeOp)) 199 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 200 201 // Otherwise, check the number of symbol uses. 202 auto symbolIt = discardableSymNodeUses.find(node); 203 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 204 } 205 206 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 207 // If this isn't a symbol node, check for side-effects and SSA use count. 208 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 209 if (!isa<SymbolOpInterface>(nodeOp)) 210 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 211 212 // Otherwise, check the number of symbol uses. 213 auto symbolIt = discardableSymNodeUses.find(node); 214 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 215 } 216 217 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 218 Operation *parentOp = node->getCallableRegion()->getParentOp(); 219 CGUser &uses = nodeUses[node]; 220 decrementDiscardableUses(uses); 221 222 // Collect the new discardable uses within this node. 223 uses = CGUser(); 224 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 225 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 226 auto discardSymIt = discardableSymNodeUses.find(refNode); 227 if (discardSymIt == discardableSymNodeUses.end()) 228 return; 229 230 if (user != parentOp) 231 ++uses.innerUses[refNode]; 232 else if (!uses.topLevelUses.insert(refNode).second) 233 return; 234 ++discardSymIt->second; 235 }; 236 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 237 } 238 239 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 240 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 241 for (auto &useIt : lhsUses.innerUses) { 242 rhsUses.innerUses[useIt.first] += useIt.second; 243 discardableSymNodeUses[useIt.first] += useIt.second; 244 } 245 } 246 247 void CGUseList::decrementDiscardableUses(CGUser &uses) { 248 for (CallGraphNode *node : uses.topLevelUses) 249 --discardableSymNodeUses[node]; 250 for (auto &it : uses.innerUses) 251 discardableSymNodeUses[it.first] -= it.second; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // CallGraph traversal 256 //===----------------------------------------------------------------------===// 257 258 namespace { 259 /// This class represents a specific callgraph SCC. 260 class CallGraphSCC { 261 public: 262 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 263 : parentIterator(parentIterator) {} 264 /// Return a range over the nodes within this SCC. 265 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 266 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 267 268 /// Reset the nodes of this SCC with those provided. 269 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 270 271 /// Remove the given node from this SCC. 272 void remove(CallGraphNode *node) { 273 auto it = llvm::find(nodes, node); 274 if (it != nodes.end()) { 275 nodes.erase(it); 276 parentIterator.ReplaceNode(node, nullptr); 277 } 278 } 279 280 private: 281 std::vector<CallGraphNode *> nodes; 282 llvm::scc_iterator<const CallGraph *> &parentIterator; 283 }; 284 } // namespace 285 286 /// Run a given transformation over the SCCs of the callgraph in a bottom up 287 /// traversal. 288 static LogicalResult runTransformOnCGSCCs( 289 const CallGraph &cg, 290 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { 291 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 292 CallGraphSCC currentSCC(cgi); 293 while (!cgi.isAtEnd()) { 294 // Copy the current SCC and increment so that the transformer can modify the 295 // SCC without invalidating our iterator. 296 currentSCC.reset(*cgi); 297 ++cgi; 298 if (failed(sccTransformer(currentSCC))) 299 return failure(); 300 } 301 return success(); 302 } 303 304 namespace { 305 /// This struct represents a resolved call to a given callgraph node. Given that 306 /// the call does not actually contain a direct reference to the 307 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 308 /// explicitly. 309 struct ResolvedCall { 310 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 311 CallGraphNode *targetNode) 312 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 313 CallOpInterface call; 314 CallGraphNode *sourceNode, *targetNode; 315 }; 316 } // namespace 317 318 /// Collect all of the callable operations within the given range of blocks. If 319 /// `traverseNestedCGNodes` is true, this will also collect call operations 320 /// inside of nested callgraph nodes. 321 static void collectCallOps(iterator_range<Region::iterator> blocks, 322 CallGraphNode *sourceNode, CallGraph &cg, 323 SymbolTableCollection &symbolTable, 324 SmallVectorImpl<ResolvedCall> &calls, 325 bool traverseNestedCGNodes) { 326 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 327 auto addToWorklist = [&](CallGraphNode *node, 328 iterator_range<Region::iterator> blocks) { 329 for (Block &block : blocks) 330 worklist.emplace_back(&block, node); 331 }; 332 333 addToWorklist(sourceNode, blocks); 334 while (!worklist.empty()) { 335 Block *block; 336 std::tie(block, sourceNode) = worklist.pop_back_val(); 337 338 for (Operation &op : *block) { 339 if (auto call = dyn_cast<CallOpInterface>(op)) { 340 // TODO: Support inlining nested call references. 341 CallInterfaceCallable callable = call.getCallableForCallee(); 342 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 343 if (!symRef.isa<FlatSymbolRefAttr>()) 344 continue; 345 } 346 347 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 348 if (!targetNode->isExternal()) 349 calls.emplace_back(call, sourceNode, targetNode); 350 continue; 351 } 352 353 // If this is not a call, traverse the nested regions. If 354 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 355 // regions. 356 for (auto &nestedRegion : op.getRegions()) { 357 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 358 if (traverseNestedCGNodes || !nestedNode) 359 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 360 } 361 } 362 } 363 } 364 365 //===----------------------------------------------------------------------===// 366 // Inliner 367 //===----------------------------------------------------------------------===// 368 369 #ifndef NDEBUG 370 static std::string getNodeName(CallOpInterface op) { 371 if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>()) 372 return debugString(op); 373 return "_unnamed_callee_"; 374 } 375 #endif 376 377 /// Return true if the specified `inlineHistoryID` indicates an inline history 378 /// that already includes `node`. 379 static bool inlineHistoryIncludes( 380 CallGraphNode *node, Optional<size_t> inlineHistoryID, 381 MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>> 382 inlineHistory) { 383 while (inlineHistoryID.has_value()) { 384 assert(inlineHistoryID.value() < inlineHistory.size() && 385 "Invalid inline history ID"); 386 if (inlineHistory[inlineHistoryID.value()].first == node) 387 return true; 388 inlineHistoryID = inlineHistory[inlineHistoryID.value()].second; 389 } 390 return false; 391 } 392 393 namespace { 394 /// This class provides a specialization of the main inlining interface. 395 struct Inliner : public InlinerInterface { 396 Inliner(MLIRContext *context, CallGraph &cg, 397 SymbolTableCollection &symbolTable) 398 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 399 400 /// Process a set of blocks that have been inlined. This callback is invoked 401 /// *before* inlined terminator operations have been processed. 402 void 403 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 404 // Find the closest callgraph node from the first block. 405 CallGraphNode *node; 406 Region *region = inlinedBlocks.begin()->getParent(); 407 while (!(node = cg.lookupNode(region))) { 408 region = region->getParentRegion(); 409 assert(region && "expected valid parent node"); 410 } 411 412 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 413 /*traverseNestedCGNodes=*/true); 414 } 415 416 /// Mark the given callgraph node for deletion. 417 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 418 419 /// This method properly disposes of callables that became dead during 420 /// inlining. This should not be called while iterating over the SCCs. 421 void eraseDeadCallables() { 422 for (CallGraphNode *node : deadNodes) 423 node->getCallableRegion()->getParentOp()->erase(); 424 } 425 426 /// The set of callables known to be dead. 427 SmallPtrSet<CallGraphNode *, 8> deadNodes; 428 429 /// The current set of call instructions to consider for inlining. 430 SmallVector<ResolvedCall, 8> calls; 431 432 /// The callgraph being operated on. 433 CallGraph &cg; 434 435 /// A symbol table to use when resolving call lookups. 436 SymbolTableCollection &symbolTable; 437 }; 438 } // namespace 439 440 /// Returns true if the given call should be inlined. 441 static bool shouldInline(ResolvedCall &resolvedCall) { 442 // Don't allow inlining terminator calls. We currently don't support this 443 // case. 444 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) 445 return false; 446 447 // Don't allow inlining if the target is an ancestor of the call. This 448 // prevents inlining recursively. 449 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 450 resolvedCall.call->getParentRegion())) 451 return false; 452 453 // Otherwise, inline. 454 return true; 455 } 456 457 /// Attempt to inline calls within the given scc. This function returns 458 /// success if any calls were inlined, failure otherwise. 459 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 460 CallGraphSCC ¤tSCC) { 461 CallGraph &cg = inliner.cg; 462 auto &calls = inliner.calls; 463 464 // A set of dead nodes to remove after inlining. 465 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes; 466 467 // Collect all of the direct calls within the nodes of the current SCC. We 468 // don't traverse nested callgraph nodes, because they are handled separately 469 // likely within a different SCC. 470 for (CallGraphNode *node : currentSCC) { 471 if (node->isExternal()) 472 continue; 473 474 // Don't collect calls if the node is already dead. 475 if (useList.isDead(node)) { 476 deadNodes.insert(node); 477 } else { 478 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable, 479 calls, /*traverseNestedCGNodes=*/false); 480 } 481 } 482 483 // When inlining a callee produces new call sites, we want to keep track of 484 // the fact that they were inlined from the callee. This allows us to avoid 485 // infinite inlining. 486 using InlineHistoryT = Optional<size_t>; 487 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory; 488 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); 489 490 LLVM_DEBUG({ 491 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; 492 for (unsigned i = 0, e = calls.size(); i < e; ++i) 493 llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; 494 llvm::dbgs() << "}\n"; 495 }); 496 497 // Try to inline each of the call operations. Don't cache the end iterator 498 // here as more calls may be added during inlining. 499 bool inlinedAnyCalls = false; 500 for (unsigned i = 0; i < calls.size(); ++i) { 501 if (deadNodes.contains(calls[i].sourceNode)) 502 continue; 503 ResolvedCall it = calls[i]; 504 505 InlineHistoryT inlineHistoryID = callHistory[i]; 506 bool inHistory = 507 inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory); 508 bool doInline = !inHistory && shouldInline(it); 509 CallOpInterface call = it.call; 510 LLVM_DEBUG({ 511 if (doInline) 512 llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; 513 else 514 llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; 515 }); 516 if (!doInline) 517 continue; 518 519 unsigned prevSize = calls.size(); 520 Region *targetRegion = it.targetNode->getCallableRegion(); 521 522 // If this is the last call to the target node and the node is discardable, 523 // then inline it in-place and delete the node if successful. 524 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 525 526 LogicalResult inlineResult = inlineCall( 527 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 528 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 529 if (failed(inlineResult)) { 530 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 531 continue; 532 } 533 inlinedAnyCalls = true; 534 535 // Create a inline history entry for this inlined call, so that we remember 536 // that new callsites came about due to inlining Callee. 537 InlineHistoryT newInlineHistoryID{inlineHistory.size()}; 538 inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID)); 539 540 auto historyToString = [](InlineHistoryT h) { 541 return h.has_value() ? std::to_string(h.value()) : "root"; 542 }; 543 (void)historyToString; 544 LLVM_DEBUG(llvm::dbgs() 545 << "* new inlineHistory entry: " << newInlineHistoryID << ". [" 546 << getNodeName(call) << ", " << historyToString(inlineHistoryID) 547 << "]\n"); 548 549 for (unsigned k = prevSize; k != calls.size(); ++k) { 550 callHistory.push_back(newInlineHistoryID); 551 LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call 552 << "}\n with historyID = " << newInlineHistoryID 553 << ", added due to inlining of\n call {" << call 554 << "}\n with historyID = " 555 << historyToString(inlineHistoryID) << "\n"); 556 } 557 558 // If the inlining was successful, Merge the new uses into the source node. 559 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 560 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 561 562 // then erase the call. 563 call.erase(); 564 565 // If we inlined in place, mark the node for deletion. 566 if (inlineInPlace) { 567 useList.eraseNode(it.targetNode); 568 deadNodes.insert(it.targetNode); 569 } 570 } 571 572 for (CallGraphNode *node : deadNodes) { 573 currentSCC.remove(node); 574 inliner.markForDeletion(node); 575 } 576 calls.clear(); 577 return success(inlinedAnyCalls); 578 } 579 580 //===----------------------------------------------------------------------===// 581 // InlinerPass 582 //===----------------------------------------------------------------------===// 583 584 namespace { 585 class InlinerPass : public InlinerBase<InlinerPass> { 586 public: 587 InlinerPass(); 588 InlinerPass(const InlinerPass &) = default; 589 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline); 590 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline, 591 llvm::StringMap<OpPassManager> opPipelines); 592 void runOnOperation() override; 593 594 private: 595 /// Attempt to inline calls within the given scc, and run simplifications, 596 /// until a fixed point is reached. This allows for the inlining of newly 597 /// devirtualized calls. Returns failure if there was a fatal error during 598 /// inlining. 599 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList, 600 CallGraphSCC ¤tSCC, MLIRContext *context); 601 602 /// Optimize the nodes within the given SCC with one of the held optimization 603 /// pass pipelines. Returns failure if an error occurred during the 604 /// optimization of the SCC, success otherwise. 605 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, 606 CallGraphSCC ¤tSCC, MLIRContext *context); 607 608 /// Optimize the nodes within the given SCC in parallel. Returns failure if an 609 /// error occurred during the optimization of the SCC, success otherwise. 610 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 611 MLIRContext *context); 612 613 /// Optimize the given callable node with one of the pass managers provided 614 /// with `pipelines`, or the default pipeline. Returns failure if an error 615 /// occurred during the optimization of the callable, success otherwise. 616 LogicalResult optimizeCallable(CallGraphNode *node, 617 llvm::StringMap<OpPassManager> &pipelines); 618 619 /// Attempt to initialize the options of this pass from the given string. 620 /// Derived classes may override this method to hook into the point at which 621 /// options are initialized, but should generally always invoke this base 622 /// class variant. 623 LogicalResult initializeOptions(StringRef options) override; 624 625 /// An optional function that constructs a default optimization pipeline for 626 /// a given operation. 627 std::function<void(OpPassManager &)> defaultPipeline; 628 /// A map of operation names to pass pipelines to use when optimizing 629 /// callable operations of these types. This provides a specialized pipeline 630 /// instead of the default. The vector size is the number of threads used 631 /// during optimization. 632 SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines; 633 }; 634 } // namespace 635 636 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {} 637 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline) 638 : defaultPipeline(std::move(defaultPipeline)) { 639 opPipelines.push_back({}); 640 641 // Initialize the pass options with the provided arguments. 642 if (defaultPipeline) { 643 OpPassManager fakePM("__mlir_fake_pm_op"); 644 defaultPipeline(fakePM); 645 llvm::raw_string_ostream strStream(defaultPipelineStr); 646 fakePM.printAsTextualPipeline(strStream); 647 } 648 } 649 650 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline, 651 llvm::StringMap<OpPassManager> opPipelines) 652 : InlinerPass(std::move(defaultPipeline)) { 653 if (opPipelines.empty()) 654 return; 655 656 // Update the option for the op specific optimization pipelines. 657 for (auto &it : opPipelines) 658 opPipelineList.addValue(it.second); 659 this->opPipelines.emplace_back(std::move(opPipelines)); 660 } 661 662 void InlinerPass::runOnOperation() { 663 CallGraph &cg = getAnalysis<CallGraph>(); 664 auto *context = &getContext(); 665 666 // The inliner should only be run on operations that define a symbol table, 667 // as the callgraph will need to resolve references. 668 Operation *op = getOperation(); 669 if (!op->hasTrait<OpTrait::SymbolTable>()) { 670 op->emitOpError() << " was scheduled to run under the inliner, but does " 671 "not define a symbol table"; 672 return signalPassFailure(); 673 } 674 675 // Run the inline transform in post-order over the SCCs in the callgraph. 676 SymbolTableCollection symbolTable; 677 Inliner inliner(context, cg, symbolTable); 678 CGUseList useList(getOperation(), cg, symbolTable); 679 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 680 return inlineSCC(inliner, useList, scc, context); 681 }); 682 if (failed(result)) 683 return signalPassFailure(); 684 685 // After inlining, make sure to erase any callables proven to be dead. 686 inliner.eraseDeadCallables(); 687 } 688 689 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 690 CallGraphSCC ¤tSCC, 691 MLIRContext *context) { 692 // Continuously simplify and inline until we either reach a fixed point, or 693 // hit the maximum iteration count. Simplifying early helps to refine the cost 694 // model, and in future iterations may devirtualize new calls. 695 unsigned iterationCount = 0; 696 do { 697 if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context))) 698 return failure(); 699 if (failed(inlineCallsInSCC(inliner, useList, currentSCC))) 700 break; 701 } while (++iterationCount < maxInliningIterations); 702 return success(); 703 } 704 705 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList, 706 CallGraphSCC ¤tSCC, 707 MLIRContext *context) { 708 // Collect the sets of nodes to simplify. 709 SmallVector<CallGraphNode *, 4> nodesToVisit; 710 for (auto *node : currentSCC) { 711 if (node->isExternal()) 712 continue; 713 714 // Don't simplify nodes with children. Nodes with children require special 715 // handling as we may remove the node during simplification. In the future, 716 // we should be able to handle this case with proper node deletion tracking. 717 if (node->hasChildren()) 718 continue; 719 720 // We also won't apply simplifications to nodes that can't have passes 721 // scheduled on them. 722 auto *region = node->getCallableRegion(); 723 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 724 continue; 725 nodesToVisit.push_back(node); 726 } 727 if (nodesToVisit.empty()) 728 return success(); 729 730 // Optimize each of the nodes within the SCC in parallel. 731 if (failed(optimizeSCCAsync(nodesToVisit, context))) 732 return failure(); 733 734 // Recompute the uses held by each of the nodes. 735 for (CallGraphNode *node : nodesToVisit) 736 useList.recomputeUses(node, cg); 737 return success(); 738 } 739 740 LogicalResult 741 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 742 MLIRContext *ctx) { 743 // We must maintain a fixed pool of pass managers which is at least as large 744 // as the maximum parallelism of the failableParallelForEach below. 745 // Note: The number of pass managers here needs to remain constant 746 // to prevent issues with pass instrumentations that rely on having the same 747 // pass manager for the main thread. 748 size_t numThreads = ctx->getNumThreads(); 749 if (opPipelines.size() < numThreads) { 750 // Reserve before resizing so that we can use a reference to the first 751 // element. 752 opPipelines.reserve(numThreads); 753 opPipelines.resize(numThreads, opPipelines.front()); 754 } 755 756 // Ensure an analysis manager has been constructed for each of the nodes. 757 // This prevents thread races when running the nested pipelines. 758 for (CallGraphNode *node : nodesToVisit) 759 getAnalysisManager().nest(node->getCallableRegion()->getParentOp()); 760 761 // An atomic failure variable for the async executors. 762 std::vector<std::atomic<bool>> activePMs(opPipelines.size()); 763 std::fill(activePMs.begin(), activePMs.end(), false); 764 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) { 765 // Find a pass manager for this operation. 766 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) { 767 bool expectedInactive = false; 768 return isActive.compare_exchange_strong(expectedInactive, true); 769 }); 770 assert(it != activePMs.end() && 771 "could not find inactive pass manager for thread"); 772 unsigned pmIndex = it - activePMs.begin(); 773 774 // Optimize this callable node. 775 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]); 776 777 // Reset the active bit for this pass manager. 778 activePMs[pmIndex].store(false); 779 return result; 780 }); 781 } 782 783 LogicalResult 784 InlinerPass::optimizeCallable(CallGraphNode *node, 785 llvm::StringMap<OpPassManager> &pipelines) { 786 Operation *callable = node->getCallableRegion()->getParentOp(); 787 StringRef opName = callable->getName().getStringRef(); 788 auto pipelineIt = pipelines.find(opName); 789 if (pipelineIt == pipelines.end()) { 790 // If a pipeline didn't exist, use the default if possible. 791 if (!defaultPipeline) 792 return success(); 793 794 OpPassManager defaultPM(opName); 795 defaultPipeline(defaultPM); 796 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first; 797 } 798 return runPipeline(pipelineIt->second, callable); 799 } 800 801 LogicalResult InlinerPass::initializeOptions(StringRef options) { 802 if (failed(Pass::initializeOptions(options))) 803 return failure(); 804 805 // Initialize the default pipeline builder to use the option string. 806 // TODO: Use a generic pass manager for default pipelines, and remove this. 807 if (!defaultPipelineStr.empty()) { 808 std::string defaultPipelineCopy = defaultPipelineStr; 809 defaultPipeline = [=](OpPassManager &pm) { 810 (void)parsePassPipeline(defaultPipelineCopy, pm); 811 }; 812 } else if (defaultPipelineStr.getNumOccurrences()) { 813 defaultPipeline = nullptr; 814 } 815 816 // Initialize the op specific pass pipelines. 817 llvm::StringMap<OpPassManager> pipelines; 818 for (OpPassManager pipeline : opPipelineList) 819 if (!pipeline.empty()) 820 pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline); 821 opPipelines.assign({std::move(pipelines)}); 822 823 return success(); 824 } 825 826 std::unique_ptr<Pass> mlir::createInlinerPass() { 827 return std::make_unique<InlinerPass>(); 828 } 829 std::unique_ptr<Pass> 830 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) { 831 return std::make_unique<InlinerPass>(defaultInlinerOptPipeline, 832 std::move(opPipelines)); 833 } 834 std::unique_ptr<Pass> mlir::createInlinerPass( 835 llvm::StringMap<OpPassManager> opPipelines, 836 std::function<void(OpPassManager &)> defaultPipelineBuilder) { 837 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder), 838 std::move(opPipelines)); 839 } 840