1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/Threading.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/InliningUtils.h" 22 #include "mlir/Transforms/Passes.h" 23 #include "llvm/ADT/SCCIterator.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/Parallel.h" 26 27 #define DEBUG_TYPE "inlining" 28 29 using namespace mlir; 30 31 /// This function implements the default inliner optimization pipeline. 32 static void defaultInlinerOptPipeline(OpPassManager &pm) { 33 pm.addPass(createCanonicalizerPass()); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Symbol Use Tracking 38 //===----------------------------------------------------------------------===// 39 40 /// Walk all of the used symbol callgraph nodes referenced with the given op. 41 static void walkReferencedSymbolNodes( 42 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 43 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 44 function_ref<void(CallGraphNode *, Operation *)> callback) { 45 auto symbolUses = SymbolTable::getSymbolUses(op); 46 assert(symbolUses && "expected uses to be valid"); 47 48 Operation *symbolTableOp = op->getParentOp(); 49 for (const SymbolTable::SymbolUse &use : *symbolUses) { 50 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 51 CallGraphNode *&node = refIt.first->second; 52 53 // If this is the first instance of this reference, try to resolve a 54 // callgraph node for it. 55 if (refIt.second) { 56 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 57 use.getSymbolRef()); 58 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 59 if (!callableOp) 60 continue; 61 node = cg.lookupNode(callableOp.getCallableRegion()); 62 } 63 if (node) 64 callback(node, use.getUser()); 65 } 66 } 67 68 //===----------------------------------------------------------------------===// 69 // CGUseList 70 71 namespace { 72 /// This struct tracks the uses of callgraph nodes that can be dropped when 73 /// use_empty. It directly tracks and manages a use-list for all of the 74 /// call-graph nodes. This is necessary because many callgraph nodes are 75 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 76 /// class. 77 struct CGUseList { 78 /// This struct tracks the uses of callgraph nodes within a specific 79 /// operation. 80 struct CGUser { 81 /// Any nodes referenced in the top-level attribute list of this user. We 82 /// use a set here because the number of references does not matter. 83 DenseSet<CallGraphNode *> topLevelUses; 84 85 /// Uses of nodes referenced by nested operations. 86 DenseMap<CallGraphNode *, int> innerUses; 87 }; 88 89 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 90 91 /// Drop uses of nodes referred to by the given call operation that resides 92 /// within 'userNode'. 93 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 94 95 /// Remove the given node from the use list. 96 void eraseNode(CallGraphNode *node); 97 98 /// Returns true if the given callgraph node has no uses and can be pruned. 99 bool isDead(CallGraphNode *node) const; 100 101 /// Returns true if the given callgraph node has a single use and can be 102 /// discarded. 103 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 104 105 /// Recompute the uses held by the given callgraph node. 106 void recomputeUses(CallGraphNode *node, CallGraph &cg); 107 108 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 109 /// of 'lhs' into 'rhs'. 110 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 111 112 private: 113 /// Decrement the uses of discardable nodes referenced by the given user. 114 void decrementDiscardableUses(CGUser &uses); 115 116 /// A mapping between a discardable callgraph node (that is a symbol) and the 117 /// number of uses for this node. 118 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 119 120 /// A mapping between a callgraph node and the symbol callgraph nodes that it 121 /// uses. 122 DenseMap<CallGraphNode *, CGUser> nodeUses; 123 124 /// A symbol table to use when resolving call lookups. 125 SymbolTableCollection &symbolTable; 126 }; 127 } // end anonymous namespace 128 129 CGUseList::CGUseList(Operation *op, CallGraph &cg, 130 SymbolTableCollection &symbolTable) 131 : symbolTable(symbolTable) { 132 /// A set of callgraph nodes that are always known to be live during inlining. 133 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 134 135 // Walk each of the symbol tables looking for discardable callgraph nodes. 136 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 137 for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 138 // If this is a callgraph operation, check to see if it is discardable. 139 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 140 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 141 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 142 if (symbol && (allUsesVisible || symbol.isPrivate()) && 143 symbol.canDiscardOnUseEmpty()) { 144 discardableSymNodeUses.try_emplace(node, 0); 145 } 146 continue; 147 } 148 } 149 // Otherwise, check for any referenced nodes. These will be always-live. 150 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 151 [](CallGraphNode *, Operation *) {}); 152 } 153 }; 154 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 155 walkFn); 156 157 // Drop the use information for any discardable nodes that are always live. 158 for (auto &it : alwaysLiveNodes) 159 discardableSymNodeUses.erase(it.second); 160 161 // Compute the uses for each of the callable nodes in the graph. 162 for (CallGraphNode *node : cg) 163 recomputeUses(node, cg); 164 } 165 166 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 167 CallGraph &cg) { 168 auto &userRefs = nodeUses[userNode].innerUses; 169 auto walkFn = [&](CallGraphNode *node, Operation *user) { 170 auto parentIt = userRefs.find(node); 171 if (parentIt == userRefs.end()) 172 return; 173 --parentIt->second; 174 --discardableSymNodeUses[node]; 175 }; 176 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 177 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 178 } 179 180 void CGUseList::eraseNode(CallGraphNode *node) { 181 // Drop all child nodes. 182 for (auto &edge : *node) 183 if (edge.isChild()) 184 eraseNode(edge.getTarget()); 185 186 // Drop the uses held by this node and erase it. 187 auto useIt = nodeUses.find(node); 188 assert(useIt != nodeUses.end() && "expected node to be valid"); 189 decrementDiscardableUses(useIt->getSecond()); 190 nodeUses.erase(useIt); 191 discardableSymNodeUses.erase(node); 192 } 193 194 bool CGUseList::isDead(CallGraphNode *node) const { 195 // If the parent operation isn't a symbol, simply check normal SSA deadness. 196 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 197 if (!isa<SymbolOpInterface>(nodeOp)) 198 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 199 200 // Otherwise, check the number of symbol uses. 201 auto symbolIt = discardableSymNodeUses.find(node); 202 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 203 } 204 205 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 206 // If this isn't a symbol node, check for side-effects and SSA use count. 207 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 208 if (!isa<SymbolOpInterface>(nodeOp)) 209 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 210 211 // Otherwise, check the number of symbol uses. 212 auto symbolIt = discardableSymNodeUses.find(node); 213 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 214 } 215 216 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 217 Operation *parentOp = node->getCallableRegion()->getParentOp(); 218 CGUser &uses = nodeUses[node]; 219 decrementDiscardableUses(uses); 220 221 // Collect the new discardable uses within this node. 222 uses = CGUser(); 223 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 224 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 225 auto discardSymIt = discardableSymNodeUses.find(refNode); 226 if (discardSymIt == discardableSymNodeUses.end()) 227 return; 228 229 if (user != parentOp) 230 ++uses.innerUses[refNode]; 231 else if (!uses.topLevelUses.insert(refNode).second) 232 return; 233 ++discardSymIt->second; 234 }; 235 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 236 } 237 238 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 239 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 240 for (auto &useIt : lhsUses.innerUses) { 241 rhsUses.innerUses[useIt.first] += useIt.second; 242 discardableSymNodeUses[useIt.first] += useIt.second; 243 } 244 } 245 246 void CGUseList::decrementDiscardableUses(CGUser &uses) { 247 for (CallGraphNode *node : uses.topLevelUses) 248 --discardableSymNodeUses[node]; 249 for (auto &it : uses.innerUses) 250 discardableSymNodeUses[it.first] -= it.second; 251 } 252 253 //===----------------------------------------------------------------------===// 254 // CallGraph traversal 255 //===----------------------------------------------------------------------===// 256 257 namespace { 258 /// This class represents a specific callgraph SCC. 259 class CallGraphSCC { 260 public: 261 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 262 : parentIterator(parentIterator) {} 263 /// Return a range over the nodes within this SCC. 264 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 265 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 266 267 /// Reset the nodes of this SCC with those provided. 268 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 269 270 /// Remove the given node from this SCC. 271 void remove(CallGraphNode *node) { 272 auto it = llvm::find(nodes, node); 273 if (it != nodes.end()) { 274 nodes.erase(it); 275 parentIterator.ReplaceNode(node, nullptr); 276 } 277 } 278 279 private: 280 std::vector<CallGraphNode *> nodes; 281 llvm::scc_iterator<const CallGraph *> &parentIterator; 282 }; 283 } // end anonymous namespace 284 285 /// Run a given transformation over the SCCs of the callgraph in a bottom up 286 /// traversal. 287 static LogicalResult runTransformOnCGSCCs( 288 const CallGraph &cg, 289 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { 290 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 291 CallGraphSCC currentSCC(cgi); 292 while (!cgi.isAtEnd()) { 293 // Copy the current SCC and increment so that the transformer can modify the 294 // SCC without invalidating our iterator. 295 currentSCC.reset(*cgi); 296 ++cgi; 297 if (failed(sccTransformer(currentSCC))) 298 return failure(); 299 } 300 return success(); 301 } 302 303 namespace { 304 /// This struct represents a resolved call to a given callgraph node. Given that 305 /// the call does not actually contain a direct reference to the 306 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 307 /// explicitly. 308 struct ResolvedCall { 309 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 310 CallGraphNode *targetNode) 311 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 312 CallOpInterface call; 313 CallGraphNode *sourceNode, *targetNode; 314 }; 315 } // end anonymous namespace 316 317 /// Collect all of the callable operations within the given range of blocks. If 318 /// `traverseNestedCGNodes` is true, this will also collect call operations 319 /// inside of nested callgraph nodes. 320 static void collectCallOps(iterator_range<Region::iterator> blocks, 321 CallGraphNode *sourceNode, CallGraph &cg, 322 SymbolTableCollection &symbolTable, 323 SmallVectorImpl<ResolvedCall> &calls, 324 bool traverseNestedCGNodes) { 325 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 326 auto addToWorklist = [&](CallGraphNode *node, 327 iterator_range<Region::iterator> blocks) { 328 for (Block &block : blocks) 329 worklist.emplace_back(&block, node); 330 }; 331 332 addToWorklist(sourceNode, blocks); 333 while (!worklist.empty()) { 334 Block *block; 335 std::tie(block, sourceNode) = worklist.pop_back_val(); 336 337 for (Operation &op : *block) { 338 if (auto call = dyn_cast<CallOpInterface>(op)) { 339 // TODO: Support inlining nested call references. 340 CallInterfaceCallable callable = call.getCallableForCallee(); 341 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 342 if (!symRef.isa<FlatSymbolRefAttr>()) 343 continue; 344 } 345 346 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 347 if (!targetNode->isExternal()) 348 calls.emplace_back(call, sourceNode, targetNode); 349 continue; 350 } 351 352 // If this is not a call, traverse the nested regions. If 353 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 354 // regions. 355 for (auto &nestedRegion : op.getRegions()) { 356 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 357 if (traverseNestedCGNodes || !nestedNode) 358 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 359 } 360 } 361 } 362 } 363 364 //===----------------------------------------------------------------------===// 365 // Inliner 366 //===----------------------------------------------------------------------===// 367 namespace { 368 /// This class provides a specialization of the main inlining interface. 369 struct Inliner : public InlinerInterface { 370 Inliner(MLIRContext *context, CallGraph &cg, 371 SymbolTableCollection &symbolTable) 372 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 373 374 /// Process a set of blocks that have been inlined. This callback is invoked 375 /// *before* inlined terminator operations have been processed. 376 void 377 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 378 // Find the closest callgraph node from the first block. 379 CallGraphNode *node; 380 Region *region = inlinedBlocks.begin()->getParent(); 381 while (!(node = cg.lookupNode(region))) { 382 region = region->getParentRegion(); 383 assert(region && "expected valid parent node"); 384 } 385 386 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 387 /*traverseNestedCGNodes=*/true); 388 } 389 390 /// Mark the given callgraph node for deletion. 391 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 392 393 /// This method properly disposes of callables that became dead during 394 /// inlining. This should not be called while iterating over the SCCs. 395 void eraseDeadCallables() { 396 for (CallGraphNode *node : deadNodes) 397 node->getCallableRegion()->getParentOp()->erase(); 398 } 399 400 /// The set of callables known to be dead. 401 SmallPtrSet<CallGraphNode *, 8> deadNodes; 402 403 /// The current set of call instructions to consider for inlining. 404 SmallVector<ResolvedCall, 8> calls; 405 406 /// The callgraph being operated on. 407 CallGraph &cg; 408 409 /// A symbol table to use when resolving call lookups. 410 SymbolTableCollection &symbolTable; 411 }; 412 } // namespace 413 414 /// Returns true if the given call should be inlined. 415 static bool shouldInline(ResolvedCall &resolvedCall) { 416 // Don't allow inlining terminator calls. We currently don't support this 417 // case. 418 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) 419 return false; 420 421 // Don't allow inlining if the target is an ancestor of the call. This 422 // prevents inlining recursively. 423 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 424 resolvedCall.call->getParentRegion())) 425 return false; 426 427 // Otherwise, inline. 428 return true; 429 } 430 431 /// Attempt to inline calls within the given scc. This function returns 432 /// success if any calls were inlined, failure otherwise. 433 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 434 CallGraphSCC ¤tSCC) { 435 CallGraph &cg = inliner.cg; 436 auto &calls = inliner.calls; 437 438 // A set of dead nodes to remove after inlining. 439 SmallVector<CallGraphNode *, 1> deadNodes; 440 441 // Collect all of the direct calls within the nodes of the current SCC. We 442 // don't traverse nested callgraph nodes, because they are handled separately 443 // likely within a different SCC. 444 for (CallGraphNode *node : currentSCC) { 445 if (node->isExternal()) 446 continue; 447 448 // Don't collect calls if the node is already dead. 449 if (useList.isDead(node)) { 450 deadNodes.push_back(node); 451 } else { 452 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable, 453 calls, /*traverseNestedCGNodes=*/false); 454 } 455 } 456 457 // Try to inline each of the call operations. Don't cache the end iterator 458 // here as more calls may be added during inlining. 459 bool inlinedAnyCalls = false; 460 for (unsigned i = 0; i != calls.size(); ++i) { 461 ResolvedCall it = calls[i]; 462 bool doInline = shouldInline(it); 463 CallOpInterface call = it.call; 464 LLVM_DEBUG({ 465 if (doInline) 466 llvm::dbgs() << "* Inlining call: " << call << "\n"; 467 else 468 llvm::dbgs() << "* Not inlining call: " << call << "\n"; 469 }); 470 if (!doInline) 471 continue; 472 Region *targetRegion = it.targetNode->getCallableRegion(); 473 474 // If this is the last call to the target node and the node is discardable, 475 // then inline it in-place and delete the node if successful. 476 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 477 478 LogicalResult inlineResult = inlineCall( 479 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 480 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 481 if (failed(inlineResult)) { 482 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 483 continue; 484 } 485 inlinedAnyCalls = true; 486 487 // If the inlining was successful, Merge the new uses into the source node. 488 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 489 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 490 491 // then erase the call. 492 call.erase(); 493 494 // If we inlined in place, mark the node for deletion. 495 if (inlineInPlace) { 496 useList.eraseNode(it.targetNode); 497 deadNodes.push_back(it.targetNode); 498 } 499 } 500 501 for (CallGraphNode *node : deadNodes) { 502 currentSCC.remove(node); 503 inliner.markForDeletion(node); 504 } 505 calls.clear(); 506 return success(inlinedAnyCalls); 507 } 508 509 //===----------------------------------------------------------------------===// 510 // InlinerPass 511 //===----------------------------------------------------------------------===// 512 513 namespace { 514 class InlinerPass : public InlinerBase<InlinerPass> { 515 public: 516 InlinerPass(); 517 InlinerPass(const InlinerPass &) = default; 518 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline); 519 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline, 520 llvm::StringMap<OpPassManager> opPipelines); 521 void runOnOperation() override; 522 523 private: 524 /// Attempt to inline calls within the given scc, and run simplifications, 525 /// until a fixed point is reached. This allows for the inlining of newly 526 /// devirtualized calls. Returns failure if there was a fatal error during 527 /// inlining. 528 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList, 529 CallGraphSCC ¤tSCC, MLIRContext *context); 530 531 /// Optimize the nodes within the given SCC with one of the held optimization 532 /// pass pipelines. Returns failure if an error occurred during the 533 /// optimization of the SCC, success otherwise. 534 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, 535 CallGraphSCC ¤tSCC, MLIRContext *context); 536 537 /// Optimize the nodes within the given SCC in parallel. Returns failure if an 538 /// error occurred during the optimization of the SCC, success otherwise. 539 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 540 MLIRContext *context); 541 542 /// Optimize the given callable node with one of the pass managers provided 543 /// with `pipelines`, or the default pipeline. Returns failure if an error 544 /// occurred during the optimization of the callable, success otherwise. 545 LogicalResult optimizeCallable(CallGraphNode *node, 546 llvm::StringMap<OpPassManager> &pipelines); 547 548 /// Attempt to initialize the options of this pass from the given string. 549 /// Derived classes may override this method to hook into the point at which 550 /// options are initialized, but should generally always invoke this base 551 /// class variant. 552 LogicalResult initializeOptions(StringRef options) override; 553 554 /// An optional function that constructs a default optimization pipeline for 555 /// a given operation. 556 std::function<void(OpPassManager &)> defaultPipeline; 557 /// A map of operation names to pass pipelines to use when optimizing 558 /// callable operations of these types. This provides a specialized pipeline 559 /// instead of the default. The vector size is the number of threads used 560 /// during optimization. 561 SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines; 562 }; 563 } // end anonymous namespace 564 565 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {} 566 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline) 567 : defaultPipeline(defaultPipeline) { 568 opPipelines.push_back({}); 569 570 // Initialize the pass options with the provided arguments. 571 if (defaultPipeline) { 572 OpPassManager fakePM("__mlir_fake_pm_op"); 573 defaultPipeline(fakePM); 574 llvm::raw_string_ostream strStream(defaultPipelineStr); 575 fakePM.printAsTextualPipeline(strStream); 576 } 577 } 578 579 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline, 580 llvm::StringMap<OpPassManager> opPipelines) 581 : InlinerPass(std::move(defaultPipeline)) { 582 if (opPipelines.empty()) 583 return; 584 585 // Update the option for the op specific optimization pipelines. 586 for (auto &it : opPipelines) { 587 std::string pipeline; 588 llvm::raw_string_ostream pipelineOS(pipeline); 589 pipelineOS << it.getKey() << "("; 590 it.second.printAsTextualPipeline(pipelineOS); 591 pipelineOS << ")"; 592 opPipelineStrs.addValue(pipeline); 593 } 594 this->opPipelines.emplace_back(std::move(opPipelines)); 595 } 596 597 void InlinerPass::runOnOperation() { 598 CallGraph &cg = getAnalysis<CallGraph>(); 599 auto *context = &getContext(); 600 601 // The inliner should only be run on operations that define a symbol table, 602 // as the callgraph will need to resolve references. 603 Operation *op = getOperation(); 604 if (!op->hasTrait<OpTrait::SymbolTable>()) { 605 op->emitOpError() << " was scheduled to run under the inliner, but does " 606 "not define a symbol table"; 607 return signalPassFailure(); 608 } 609 610 // Run the inline transform in post-order over the SCCs in the callgraph. 611 SymbolTableCollection symbolTable; 612 Inliner inliner(context, cg, symbolTable); 613 CGUseList useList(getOperation(), cg, symbolTable); 614 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 615 return inlineSCC(inliner, useList, scc, context); 616 }); 617 if (failed(result)) 618 return signalPassFailure(); 619 620 // After inlining, make sure to erase any callables proven to be dead. 621 inliner.eraseDeadCallables(); 622 } 623 624 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 625 CallGraphSCC ¤tSCC, 626 MLIRContext *context) { 627 // Continuously simplify and inline until we either reach a fixed point, or 628 // hit the maximum iteration count. Simplifying early helps to refine the cost 629 // model, and in future iterations may devirtualize new calls. 630 unsigned iterationCount = 0; 631 do { 632 if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context))) 633 return failure(); 634 if (failed(inlineCallsInSCC(inliner, useList, currentSCC))) 635 break; 636 } while (++iterationCount < maxInliningIterations); 637 return success(); 638 } 639 640 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList, 641 CallGraphSCC ¤tSCC, 642 MLIRContext *context) { 643 // Collect the sets of nodes to simplify. 644 SmallVector<CallGraphNode *, 4> nodesToVisit; 645 for (auto *node : currentSCC) { 646 if (node->isExternal()) 647 continue; 648 649 // Don't simplify nodes with children. Nodes with children require special 650 // handling as we may remove the node during simplification. In the future, 651 // we should be able to handle this case with proper node deletion tracking. 652 if (node->hasChildren()) 653 continue; 654 655 // We also won't apply simplifications to nodes that can't have passes 656 // scheduled on them. 657 auto *region = node->getCallableRegion(); 658 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 659 continue; 660 nodesToVisit.push_back(node); 661 } 662 if (nodesToVisit.empty()) 663 return success(); 664 665 // Optimize each of the nodes within the SCC in parallel. 666 if (failed(optimizeSCCAsync(nodesToVisit, context))) 667 return failure(); 668 669 // Recompute the uses held by each of the nodes. 670 for (CallGraphNode *node : nodesToVisit) 671 useList.recomputeUses(node, cg); 672 return success(); 673 } 674 675 LogicalResult 676 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 677 MLIRContext *ctx) { 678 // Ensure that there are enough pipeline maps for the optimizer to run in 679 // parallel. Note: The number of pass managers here needs to remain constant 680 // to prevent issues with pass instrumentations that rely on having the same 681 // pass manager for the main thread. 682 size_t numThreads = llvm::hardware_concurrency().compute_thread_count(); 683 if (opPipelines.size() < numThreads) { 684 // Reserve before resizing so that we can use a reference to the first 685 // element. 686 opPipelines.reserve(numThreads); 687 opPipelines.resize(numThreads, opPipelines.front()); 688 } 689 690 // Ensure an analysis manager has been constructed for each of the nodes. 691 // This prevents thread races when running the nested pipelines. 692 for (CallGraphNode *node : nodesToVisit) 693 getAnalysisManager().nest(node->getCallableRegion()->getParentOp()); 694 695 // An atomic failure variable for the async executors. 696 std::vector<std::atomic<bool>> activePMs(opPipelines.size()); 697 std::fill(activePMs.begin(), activePMs.end(), false); 698 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) { 699 // Find a pass manager for this operation. 700 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) { 701 bool expectedInactive = false; 702 return isActive.compare_exchange_strong(expectedInactive, true); 703 }); 704 unsigned pmIndex = it - activePMs.begin(); 705 706 // Optimize this callable node. 707 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]); 708 709 // Reset the active bit for this pass manager. 710 activePMs[pmIndex].store(false); 711 return result; 712 }); 713 } 714 715 LogicalResult 716 InlinerPass::optimizeCallable(CallGraphNode *node, 717 llvm::StringMap<OpPassManager> &pipelines) { 718 Operation *callable = node->getCallableRegion()->getParentOp(); 719 StringRef opName = callable->getName().getStringRef(); 720 auto pipelineIt = pipelines.find(opName); 721 if (pipelineIt == pipelines.end()) { 722 // If a pipeline didn't exist, use the default if possible. 723 if (!defaultPipeline) 724 return success(); 725 726 OpPassManager defaultPM(opName); 727 defaultPipeline(defaultPM); 728 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first; 729 } 730 return runPipeline(pipelineIt->second, callable); 731 } 732 733 LogicalResult InlinerPass::initializeOptions(StringRef options) { 734 if (failed(Pass::initializeOptions(options))) 735 return failure(); 736 737 // Initialize the default pipeline builder to use the option string. 738 if (!defaultPipelineStr.empty()) { 739 std::string defaultPipelineCopy = defaultPipelineStr; 740 defaultPipeline = [=](OpPassManager &pm) { 741 (void)parsePassPipeline(defaultPipelineCopy, pm); 742 }; 743 } else if (defaultPipelineStr.getNumOccurrences()) { 744 defaultPipeline = nullptr; 745 } 746 747 // Initialize the op specific pass pipelines. 748 llvm::StringMap<OpPassManager> pipelines; 749 for (StringRef pipeline : opPipelineStrs) { 750 // Skip empty pipelines. 751 if (pipeline.empty()) 752 continue; 753 754 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. 755 size_t pipelineStart = pipeline.find_first_of('('); 756 if (pipelineStart == StringRef::npos || !pipeline.consume_back(")")) 757 return failure(); 758 StringRef opName = pipeline.take_front(pipelineStart); 759 OpPassManager pm(opName); 760 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) 761 return failure(); 762 pipelines.try_emplace(opName, std::move(pm)); 763 } 764 opPipelines.assign({std::move(pipelines)}); 765 766 return success(); 767 } 768 769 std::unique_ptr<Pass> mlir::createInlinerPass() { 770 return std::make_unique<InlinerPass>(); 771 } 772 std::unique_ptr<Pass> 773 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) { 774 return std::make_unique<InlinerPass>(defaultInlinerOptPipeline, 775 std::move(opPipelines)); 776 } 777 std::unique_ptr<Pass> 778 createInlinerPass(llvm::StringMap<OpPassManager> opPipelines, 779 std::function<void(OpPassManager &)> defaultPipelineBuilder) { 780 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder), 781 std::move(opPipelines)); 782 } 783