1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "PassDetail.h" 17 #include "mlir/Analysis/CallGraph.h" 18 #include "mlir/IR/Threading.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/InliningUtils.h" 22 #include "mlir/Transforms/Passes.h" 23 #include "llvm/ADT/SCCIterator.h" 24 #include "llvm/Support/Debug.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 /// This function implements the default inliner optimization pipeline. 31 static void defaultInlinerOptPipeline(OpPassManager &pm) { 32 pm.addPass(createCanonicalizerPass()); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // Symbol Use Tracking 37 //===----------------------------------------------------------------------===// 38 39 /// Walk all of the used symbol callgraph nodes referenced with the given op. 40 static void walkReferencedSymbolNodes( 41 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, 42 DenseMap<Attribute, CallGraphNode *> &resolvedRefs, 43 function_ref<void(CallGraphNode *, Operation *)> callback) { 44 auto symbolUses = SymbolTable::getSymbolUses(op); 45 assert(symbolUses && "expected uses to be valid"); 46 47 Operation *symbolTableOp = op->getParentOp(); 48 for (const SymbolTable::SymbolUse &use : *symbolUses) { 49 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); 50 CallGraphNode *&node = refIt.first->second; 51 52 // If this is the first instance of this reference, try to resolve a 53 // callgraph node for it. 54 if (refIt.second) { 55 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, 56 use.getSymbolRef()); 57 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); 58 if (!callableOp) 59 continue; 60 node = cg.lookupNode(callableOp.getCallableRegion()); 61 } 62 if (node) 63 callback(node, use.getUser()); 64 } 65 } 66 67 //===----------------------------------------------------------------------===// 68 // CGUseList 69 70 namespace { 71 /// This struct tracks the uses of callgraph nodes that can be dropped when 72 /// use_empty. It directly tracks and manages a use-list for all of the 73 /// call-graph nodes. This is necessary because many callgraph nodes are 74 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` 75 /// class. 76 struct CGUseList { 77 /// This struct tracks the uses of callgraph nodes within a specific 78 /// operation. 79 struct CGUser { 80 /// Any nodes referenced in the top-level attribute list of this user. We 81 /// use a set here because the number of references does not matter. 82 DenseSet<CallGraphNode *> topLevelUses; 83 84 /// Uses of nodes referenced by nested operations. 85 DenseMap<CallGraphNode *, int> innerUses; 86 }; 87 88 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); 89 90 /// Drop uses of nodes referred to by the given call operation that resides 91 /// within 'userNode'. 92 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); 93 94 /// Remove the given node from the use list. 95 void eraseNode(CallGraphNode *node); 96 97 /// Returns true if the given callgraph node has no uses and can be pruned. 98 bool isDead(CallGraphNode *node) const; 99 100 /// Returns true if the given callgraph node has a single use and can be 101 /// discarded. 102 bool hasOneUseAndDiscardable(CallGraphNode *node) const; 103 104 /// Recompute the uses held by the given callgraph node. 105 void recomputeUses(CallGraphNode *node, CallGraph &cg); 106 107 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy 108 /// of 'lhs' into 'rhs'. 109 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); 110 111 private: 112 /// Decrement the uses of discardable nodes referenced by the given user. 113 void decrementDiscardableUses(CGUser &uses); 114 115 /// A mapping between a discardable callgraph node (that is a symbol) and the 116 /// number of uses for this node. 117 DenseMap<CallGraphNode *, int> discardableSymNodeUses; 118 119 /// A mapping between a callgraph node and the symbol callgraph nodes that it 120 /// uses. 121 DenseMap<CallGraphNode *, CGUser> nodeUses; 122 123 /// A symbol table to use when resolving call lookups. 124 SymbolTableCollection &symbolTable; 125 }; 126 } // end anonymous namespace 127 128 CGUseList::CGUseList(Operation *op, CallGraph &cg, 129 SymbolTableCollection &symbolTable) 130 : symbolTable(symbolTable) { 131 /// A set of callgraph nodes that are always known to be live during inlining. 132 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; 133 134 // Walk each of the symbol tables looking for discardable callgraph nodes. 135 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { 136 for (Operation &op : symbolTableOp->getRegion(0).getOps()) { 137 // If this is a callgraph operation, check to see if it is discardable. 138 if (auto callable = dyn_cast<CallableOpInterface>(&op)) { 139 if (auto *node = cg.lookupNode(callable.getCallableRegion())) { 140 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); 141 if (symbol && (allUsesVisible || symbol.isPrivate()) && 142 symbol.canDiscardOnUseEmpty()) { 143 discardableSymNodeUses.try_emplace(node, 0); 144 } 145 continue; 146 } 147 } 148 // Otherwise, check for any referenced nodes. These will be always-live. 149 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, 150 [](CallGraphNode *, Operation *) {}); 151 } 152 }; 153 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 154 walkFn); 155 156 // Drop the use information for any discardable nodes that are always live. 157 for (auto &it : alwaysLiveNodes) 158 discardableSymNodeUses.erase(it.second); 159 160 // Compute the uses for each of the callable nodes in the graph. 161 for (CallGraphNode *node : cg) 162 recomputeUses(node, cg); 163 } 164 165 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, 166 CallGraph &cg) { 167 auto &userRefs = nodeUses[userNode].innerUses; 168 auto walkFn = [&](CallGraphNode *node, Operation *user) { 169 auto parentIt = userRefs.find(node); 170 if (parentIt == userRefs.end()) 171 return; 172 --parentIt->second; 173 --discardableSymNodeUses[node]; 174 }; 175 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 176 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); 177 } 178 179 void CGUseList::eraseNode(CallGraphNode *node) { 180 // Drop all child nodes. 181 for (auto &edge : *node) 182 if (edge.isChild()) 183 eraseNode(edge.getTarget()); 184 185 // Drop the uses held by this node and erase it. 186 auto useIt = nodeUses.find(node); 187 assert(useIt != nodeUses.end() && "expected node to be valid"); 188 decrementDiscardableUses(useIt->getSecond()); 189 nodeUses.erase(useIt); 190 discardableSymNodeUses.erase(node); 191 } 192 193 bool CGUseList::isDead(CallGraphNode *node) const { 194 // If the parent operation isn't a symbol, simply check normal SSA deadness. 195 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 196 if (!isa<SymbolOpInterface>(nodeOp)) 197 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); 198 199 // Otherwise, check the number of symbol uses. 200 auto symbolIt = discardableSymNodeUses.find(node); 201 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; 202 } 203 204 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { 205 // If this isn't a symbol node, check for side-effects and SSA use count. 206 Operation *nodeOp = node->getCallableRegion()->getParentOp(); 207 if (!isa<SymbolOpInterface>(nodeOp)) 208 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); 209 210 // Otherwise, check the number of symbol uses. 211 auto symbolIt = discardableSymNodeUses.find(node); 212 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; 213 } 214 215 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { 216 Operation *parentOp = node->getCallableRegion()->getParentOp(); 217 CGUser &uses = nodeUses[node]; 218 decrementDiscardableUses(uses); 219 220 // Collect the new discardable uses within this node. 221 uses = CGUser(); 222 DenseMap<Attribute, CallGraphNode *> resolvedRefs; 223 auto walkFn = [&](CallGraphNode *refNode, Operation *user) { 224 auto discardSymIt = discardableSymNodeUses.find(refNode); 225 if (discardSymIt == discardableSymNodeUses.end()) 226 return; 227 228 if (user != parentOp) 229 ++uses.innerUses[refNode]; 230 else if (!uses.topLevelUses.insert(refNode).second) 231 return; 232 ++discardSymIt->second; 233 }; 234 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); 235 } 236 237 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { 238 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; 239 for (auto &useIt : lhsUses.innerUses) { 240 rhsUses.innerUses[useIt.first] += useIt.second; 241 discardableSymNodeUses[useIt.first] += useIt.second; 242 } 243 } 244 245 void CGUseList::decrementDiscardableUses(CGUser &uses) { 246 for (CallGraphNode *node : uses.topLevelUses) 247 --discardableSymNodeUses[node]; 248 for (auto &it : uses.innerUses) 249 discardableSymNodeUses[it.first] -= it.second; 250 } 251 252 //===----------------------------------------------------------------------===// 253 // CallGraph traversal 254 //===----------------------------------------------------------------------===// 255 256 namespace { 257 /// This class represents a specific callgraph SCC. 258 class CallGraphSCC { 259 public: 260 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) 261 : parentIterator(parentIterator) {} 262 /// Return a range over the nodes within this SCC. 263 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } 264 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } 265 266 /// Reset the nodes of this SCC with those provided. 267 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } 268 269 /// Remove the given node from this SCC. 270 void remove(CallGraphNode *node) { 271 auto it = llvm::find(nodes, node); 272 if (it != nodes.end()) { 273 nodes.erase(it); 274 parentIterator.ReplaceNode(node, nullptr); 275 } 276 } 277 278 private: 279 std::vector<CallGraphNode *> nodes; 280 llvm::scc_iterator<const CallGraph *> &parentIterator; 281 }; 282 } // end anonymous namespace 283 284 /// Run a given transformation over the SCCs of the callgraph in a bottom up 285 /// traversal. 286 static LogicalResult runTransformOnCGSCCs( 287 const CallGraph &cg, 288 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { 289 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg); 290 CallGraphSCC currentSCC(cgi); 291 while (!cgi.isAtEnd()) { 292 // Copy the current SCC and increment so that the transformer can modify the 293 // SCC without invalidating our iterator. 294 currentSCC.reset(*cgi); 295 ++cgi; 296 if (failed(sccTransformer(currentSCC))) 297 return failure(); 298 } 299 return success(); 300 } 301 302 namespace { 303 /// This struct represents a resolved call to a given callgraph node. Given that 304 /// the call does not actually contain a direct reference to the 305 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 306 /// explicitly. 307 struct ResolvedCall { 308 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode, 309 CallGraphNode *targetNode) 310 : call(call), sourceNode(sourceNode), targetNode(targetNode) {} 311 CallOpInterface call; 312 CallGraphNode *sourceNode, *targetNode; 313 }; 314 } // end anonymous namespace 315 316 /// Collect all of the callable operations within the given range of blocks. If 317 /// `traverseNestedCGNodes` is true, this will also collect call operations 318 /// inside of nested callgraph nodes. 319 static void collectCallOps(iterator_range<Region::iterator> blocks, 320 CallGraphNode *sourceNode, CallGraph &cg, 321 SymbolTableCollection &symbolTable, 322 SmallVectorImpl<ResolvedCall> &calls, 323 bool traverseNestedCGNodes) { 324 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; 325 auto addToWorklist = [&](CallGraphNode *node, 326 iterator_range<Region::iterator> blocks) { 327 for (Block &block : blocks) 328 worklist.emplace_back(&block, node); 329 }; 330 331 addToWorklist(sourceNode, blocks); 332 while (!worklist.empty()) { 333 Block *block; 334 std::tie(block, sourceNode) = worklist.pop_back_val(); 335 336 for (Operation &op : *block) { 337 if (auto call = dyn_cast<CallOpInterface>(op)) { 338 // TODO: Support inlining nested call references. 339 CallInterfaceCallable callable = call.getCallableForCallee(); 340 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 341 if (!symRef.isa<FlatSymbolRefAttr>()) 342 continue; 343 } 344 345 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); 346 if (!targetNode->isExternal()) 347 calls.emplace_back(call, sourceNode, targetNode); 348 continue; 349 } 350 351 // If this is not a call, traverse the nested regions. If 352 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 353 // regions. 354 for (auto &nestedRegion : op.getRegions()) { 355 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion); 356 if (traverseNestedCGNodes || !nestedNode) 357 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); 358 } 359 } 360 } 361 } 362 363 //===----------------------------------------------------------------------===// 364 // Inliner 365 //===----------------------------------------------------------------------===// 366 namespace { 367 /// This class provides a specialization of the main inlining interface. 368 struct Inliner : public InlinerInterface { 369 Inliner(MLIRContext *context, CallGraph &cg, 370 SymbolTableCollection &symbolTable) 371 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} 372 373 /// Process a set of blocks that have been inlined. This callback is invoked 374 /// *before* inlined terminator operations have been processed. 375 void 376 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 377 // Find the closest callgraph node from the first block. 378 CallGraphNode *node; 379 Region *region = inlinedBlocks.begin()->getParent(); 380 while (!(node = cg.lookupNode(region))) { 381 region = region->getParentRegion(); 382 assert(region && "expected valid parent node"); 383 } 384 385 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, 386 /*traverseNestedCGNodes=*/true); 387 } 388 389 /// Mark the given callgraph node for deletion. 390 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } 391 392 /// This method properly disposes of callables that became dead during 393 /// inlining. This should not be called while iterating over the SCCs. 394 void eraseDeadCallables() { 395 for (CallGraphNode *node : deadNodes) 396 node->getCallableRegion()->getParentOp()->erase(); 397 } 398 399 /// The set of callables known to be dead. 400 SmallPtrSet<CallGraphNode *, 8> deadNodes; 401 402 /// The current set of call instructions to consider for inlining. 403 SmallVector<ResolvedCall, 8> calls; 404 405 /// The callgraph being operated on. 406 CallGraph &cg; 407 408 /// A symbol table to use when resolving call lookups. 409 SymbolTableCollection &symbolTable; 410 }; 411 } // namespace 412 413 /// Returns true if the given call should be inlined. 414 static bool shouldInline(ResolvedCall &resolvedCall) { 415 // Don't allow inlining terminator calls. We currently don't support this 416 // case. 417 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) 418 return false; 419 420 // Don't allow inlining if the target is an ancestor of the call. This 421 // prevents inlining recursively. 422 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 423 resolvedCall.call->getParentRegion())) 424 return false; 425 426 // Otherwise, inline. 427 return true; 428 } 429 430 /// Attempt to inline calls within the given scc. This function returns 431 /// success if any calls were inlined, failure otherwise. 432 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, 433 CallGraphSCC ¤tSCC) { 434 CallGraph &cg = inliner.cg; 435 auto &calls = inliner.calls; 436 437 // A set of dead nodes to remove after inlining. 438 SmallVector<CallGraphNode *, 1> deadNodes; 439 440 // Collect all of the direct calls within the nodes of the current SCC. We 441 // don't traverse nested callgraph nodes, because they are handled separately 442 // likely within a different SCC. 443 for (CallGraphNode *node : currentSCC) { 444 if (node->isExternal()) 445 continue; 446 447 // Don't collect calls if the node is already dead. 448 if (useList.isDead(node)) { 449 deadNodes.push_back(node); 450 } else { 451 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable, 452 calls, /*traverseNestedCGNodes=*/false); 453 } 454 } 455 456 // Try to inline each of the call operations. Don't cache the end iterator 457 // here as more calls may be added during inlining. 458 bool inlinedAnyCalls = false; 459 for (unsigned i = 0; i != calls.size(); ++i) { 460 ResolvedCall it = calls[i]; 461 bool doInline = shouldInline(it); 462 CallOpInterface call = it.call; 463 LLVM_DEBUG({ 464 if (doInline) 465 llvm::dbgs() << "* Inlining call: " << call << "\n"; 466 else 467 llvm::dbgs() << "* Not inlining call: " << call << "\n"; 468 }); 469 if (!doInline) 470 continue; 471 Region *targetRegion = it.targetNode->getCallableRegion(); 472 473 // If this is the last call to the target node and the node is discardable, 474 // then inline it in-place and delete the node if successful. 475 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode); 476 477 LogicalResult inlineResult = inlineCall( 478 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 479 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); 480 if (failed(inlineResult)) { 481 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); 482 continue; 483 } 484 inlinedAnyCalls = true; 485 486 // If the inlining was successful, Merge the new uses into the source node. 487 useList.dropCallUses(it.sourceNode, call.getOperation(), cg); 488 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); 489 490 // then erase the call. 491 call.erase(); 492 493 // If we inlined in place, mark the node for deletion. 494 if (inlineInPlace) { 495 useList.eraseNode(it.targetNode); 496 deadNodes.push_back(it.targetNode); 497 } 498 } 499 500 for (CallGraphNode *node : deadNodes) { 501 currentSCC.remove(node); 502 inliner.markForDeletion(node); 503 } 504 calls.clear(); 505 return success(inlinedAnyCalls); 506 } 507 508 //===----------------------------------------------------------------------===// 509 // InlinerPass 510 //===----------------------------------------------------------------------===// 511 512 namespace { 513 class InlinerPass : public InlinerBase<InlinerPass> { 514 public: 515 InlinerPass(); 516 InlinerPass(const InlinerPass &) = default; 517 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline); 518 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline, 519 llvm::StringMap<OpPassManager> opPipelines); 520 void runOnOperation() override; 521 522 private: 523 /// Attempt to inline calls within the given scc, and run simplifications, 524 /// until a fixed point is reached. This allows for the inlining of newly 525 /// devirtualized calls. Returns failure if there was a fatal error during 526 /// inlining. 527 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList, 528 CallGraphSCC ¤tSCC, MLIRContext *context); 529 530 /// Optimize the nodes within the given SCC with one of the held optimization 531 /// pass pipelines. Returns failure if an error occurred during the 532 /// optimization of the SCC, success otherwise. 533 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, 534 CallGraphSCC ¤tSCC, MLIRContext *context); 535 536 /// Optimize the nodes within the given SCC in parallel. Returns failure if an 537 /// error occurred during the optimization of the SCC, success otherwise. 538 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 539 MLIRContext *context); 540 541 /// Optimize the given callable node with one of the pass managers provided 542 /// with `pipelines`, or the default pipeline. Returns failure if an error 543 /// occurred during the optimization of the callable, success otherwise. 544 LogicalResult optimizeCallable(CallGraphNode *node, 545 llvm::StringMap<OpPassManager> &pipelines); 546 547 /// Attempt to initialize the options of this pass from the given string. 548 /// Derived classes may override this method to hook into the point at which 549 /// options are initialized, but should generally always invoke this base 550 /// class variant. 551 LogicalResult initializeOptions(StringRef options) override; 552 553 /// An optional function that constructs a default optimization pipeline for 554 /// a given operation. 555 std::function<void(OpPassManager &)> defaultPipeline; 556 /// A map of operation names to pass pipelines to use when optimizing 557 /// callable operations of these types. This provides a specialized pipeline 558 /// instead of the default. The vector size is the number of threads used 559 /// during optimization. 560 SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines; 561 }; 562 } // end anonymous namespace 563 564 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {} 565 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline) 566 : defaultPipeline(defaultPipeline) { 567 opPipelines.push_back({}); 568 569 // Initialize the pass options with the provided arguments. 570 if (defaultPipeline) { 571 OpPassManager fakePM("__mlir_fake_pm_op"); 572 defaultPipeline(fakePM); 573 llvm::raw_string_ostream strStream(defaultPipelineStr); 574 fakePM.printAsTextualPipeline(strStream); 575 } 576 } 577 578 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline, 579 llvm::StringMap<OpPassManager> opPipelines) 580 : InlinerPass(std::move(defaultPipeline)) { 581 if (opPipelines.empty()) 582 return; 583 584 // Update the option for the op specific optimization pipelines. 585 for (auto &it : opPipelines) { 586 std::string pipeline; 587 llvm::raw_string_ostream pipelineOS(pipeline); 588 pipelineOS << it.getKey() << "("; 589 it.second.printAsTextualPipeline(pipelineOS); 590 pipelineOS << ")"; 591 opPipelineStrs.addValue(pipeline); 592 } 593 this->opPipelines.emplace_back(std::move(opPipelines)); 594 } 595 596 void InlinerPass::runOnOperation() { 597 CallGraph &cg = getAnalysis<CallGraph>(); 598 auto *context = &getContext(); 599 600 // The inliner should only be run on operations that define a symbol table, 601 // as the callgraph will need to resolve references. 602 Operation *op = getOperation(); 603 if (!op->hasTrait<OpTrait::SymbolTable>()) { 604 op->emitOpError() << " was scheduled to run under the inliner, but does " 605 "not define a symbol table"; 606 return signalPassFailure(); 607 } 608 609 // Run the inline transform in post-order over the SCCs in the callgraph. 610 SymbolTableCollection symbolTable; 611 Inliner inliner(context, cg, symbolTable); 612 CGUseList useList(getOperation(), cg, symbolTable); 613 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { 614 return inlineSCC(inliner, useList, scc, context); 615 }); 616 if (failed(result)) 617 return signalPassFailure(); 618 619 // After inlining, make sure to erase any callables proven to be dead. 620 inliner.eraseDeadCallables(); 621 } 622 623 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, 624 CallGraphSCC ¤tSCC, 625 MLIRContext *context) { 626 // Continuously simplify and inline until we either reach a fixed point, or 627 // hit the maximum iteration count. Simplifying early helps to refine the cost 628 // model, and in future iterations may devirtualize new calls. 629 unsigned iterationCount = 0; 630 do { 631 if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context))) 632 return failure(); 633 if (failed(inlineCallsInSCC(inliner, useList, currentSCC))) 634 break; 635 } while (++iterationCount < maxInliningIterations); 636 return success(); 637 } 638 639 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList, 640 CallGraphSCC ¤tSCC, 641 MLIRContext *context) { 642 // Collect the sets of nodes to simplify. 643 SmallVector<CallGraphNode *, 4> nodesToVisit; 644 for (auto *node : currentSCC) { 645 if (node->isExternal()) 646 continue; 647 648 // Don't simplify nodes with children. Nodes with children require special 649 // handling as we may remove the node during simplification. In the future, 650 // we should be able to handle this case with proper node deletion tracking. 651 if (node->hasChildren()) 652 continue; 653 654 // We also won't apply simplifications to nodes that can't have passes 655 // scheduled on them. 656 auto *region = node->getCallableRegion(); 657 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 658 continue; 659 nodesToVisit.push_back(node); 660 } 661 if (nodesToVisit.empty()) 662 return success(); 663 664 // Optimize each of the nodes within the SCC in parallel. 665 if (failed(optimizeSCCAsync(nodesToVisit, context))) 666 return failure(); 667 668 // Recompute the uses held by each of the nodes. 669 for (CallGraphNode *node : nodesToVisit) 670 useList.recomputeUses(node, cg); 671 return success(); 672 } 673 674 LogicalResult 675 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, 676 MLIRContext *ctx) { 677 // Ensure that there are enough pipeline maps for the optimizer to run in 678 // parallel. Note: The number of pass managers here needs to remain constant 679 // to prevent issues with pass instrumentations that rely on having the same 680 // pass manager for the main thread. 681 size_t numThreads = llvm::hardware_concurrency().compute_thread_count(); 682 if (opPipelines.size() < numThreads) { 683 // Reserve before resizing so that we can use a reference to the first 684 // element. 685 opPipelines.reserve(numThreads); 686 opPipelines.resize(numThreads, opPipelines.front()); 687 } 688 689 // Ensure an analysis manager has been constructed for each of the nodes. 690 // This prevents thread races when running the nested pipelines. 691 for (CallGraphNode *node : nodesToVisit) 692 getAnalysisManager().nest(node->getCallableRegion()->getParentOp()); 693 694 // An atomic failure variable for the async executors. 695 std::vector<std::atomic<bool>> activePMs(opPipelines.size()); 696 std::fill(activePMs.begin(), activePMs.end(), false); 697 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) { 698 // Find a pass manager for this operation. 699 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) { 700 bool expectedInactive = false; 701 return isActive.compare_exchange_strong(expectedInactive, true); 702 }); 703 unsigned pmIndex = it - activePMs.begin(); 704 705 // Optimize this callable node. 706 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]); 707 708 // Reset the active bit for this pass manager. 709 activePMs[pmIndex].store(false); 710 return result; 711 }); 712 } 713 714 LogicalResult 715 InlinerPass::optimizeCallable(CallGraphNode *node, 716 llvm::StringMap<OpPassManager> &pipelines) { 717 Operation *callable = node->getCallableRegion()->getParentOp(); 718 StringRef opName = callable->getName().getStringRef(); 719 auto pipelineIt = pipelines.find(opName); 720 if (pipelineIt == pipelines.end()) { 721 // If a pipeline didn't exist, use the default if possible. 722 if (!defaultPipeline) 723 return success(); 724 725 OpPassManager defaultPM(opName); 726 defaultPipeline(defaultPM); 727 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first; 728 } 729 return runPipeline(pipelineIt->second, callable); 730 } 731 732 LogicalResult InlinerPass::initializeOptions(StringRef options) { 733 if (failed(Pass::initializeOptions(options))) 734 return failure(); 735 736 // Initialize the default pipeline builder to use the option string. 737 if (!defaultPipelineStr.empty()) { 738 std::string defaultPipelineCopy = defaultPipelineStr; 739 defaultPipeline = [=](OpPassManager &pm) { 740 (void)parsePassPipeline(defaultPipelineCopy, pm); 741 }; 742 } else if (defaultPipelineStr.getNumOccurrences()) { 743 defaultPipeline = nullptr; 744 } 745 746 // Initialize the op specific pass pipelines. 747 llvm::StringMap<OpPassManager> pipelines; 748 for (StringRef pipeline : opPipelineStrs) { 749 // Skip empty pipelines. 750 if (pipeline.empty()) 751 continue; 752 753 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. 754 size_t pipelineStart = pipeline.find_first_of('('); 755 if (pipelineStart == StringRef::npos || !pipeline.consume_back(")")) 756 return failure(); 757 StringRef opName = pipeline.take_front(pipelineStart); 758 OpPassManager pm(opName); 759 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) 760 return failure(); 761 pipelines.try_emplace(opName, std::move(pm)); 762 } 763 opPipelines.assign({std::move(pipelines)}); 764 765 return success(); 766 } 767 768 std::unique_ptr<Pass> mlir::createInlinerPass() { 769 return std::make_unique<InlinerPass>(); 770 } 771 std::unique_ptr<Pass> 772 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) { 773 return std::make_unique<InlinerPass>(defaultInlinerOptPipeline, 774 std::move(opPipelines)); 775 } 776 std::unique_ptr<Pass> mlir::createInlinerPass( 777 llvm::StringMap<OpPassManager> opPipelines, 778 std::function<void(OpPassManager &)> defaultPipelineBuilder) { 779 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder), 780 std::move(opPipelines)); 781 } 782