1 //===- FunctionSpecialization.cpp - Function Specialization ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This specialises functions with constant parameters. Constant parameters 10 // like function pointers and constant globals are propagated to the callee by 11 // specializing the function. The main benefit of this pass at the moment is 12 // that indirect calls are transformed into direct calls, which provides inline 13 // opportunities that the inliner would not have been able to achieve. That's 14 // why function specialisation is run before the inliner in the optimisation 15 // pipeline; that is by design. Otherwise, we would only benefit from constant 16 // passing, which is a valid use-case too, but hasn't been explored much in 17 // terms of performance uplifts, cost-model and compile-time impact. 18 // 19 // Current limitations: 20 // - It does not yet handle integer ranges. We do support "literal constants", 21 // but that's off by default under an option. 22 // - Only 1 argument per function is specialised, 23 // - The cost-model could be further looked into (it mainly focuses on inlining 24 // benefits), 25 // - We are not yet caching analysis results, but profiling and checking where 26 // extra compile time is spent didn't suggest this to be a problem. 27 // 28 // Ideas: 29 // - With a function specialization attribute for arguments, we could have 30 // a direct way to steer function specialization, avoiding the cost-model, 31 // and thus control compile-times / code-size. 32 // 33 // Todos: 34 // - Specializing recursive functions relies on running the transformation a 35 // number of times, which is controlled by option 36 // `func-specialization-max-iters`. Thus, increasing this value and the 37 // number of iterations, will linearly increase the number of times recursive 38 // functions get specialized, see also the discussion in 39 // https://reviews.llvm.org/D106426 for details. Perhaps there is a 40 // compile-time friendlier way to control/limit the number of specialisations 41 // for recursive functions. 42 // - Don't transform the function if function specialization does not trigger; 43 // the SCCPSolver may make IR changes. 44 // 45 // References: 46 // - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable 47 // it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q 48 // 49 //===----------------------------------------------------------------------===// 50 51 #include "llvm/ADT/Statistic.h" 52 #include "llvm/Analysis/AssumptionCache.h" 53 #include "llvm/Analysis/CodeMetrics.h" 54 #include "llvm/Analysis/DomTreeUpdater.h" 55 #include "llvm/Analysis/InlineCost.h" 56 #include "llvm/Analysis/LoopInfo.h" 57 #include "llvm/Analysis/TargetLibraryInfo.h" 58 #include "llvm/Analysis/TargetTransformInfo.h" 59 #include "llvm/Analysis/ValueLattice.h" 60 #include "llvm/Analysis/ValueLatticeUtils.h" 61 #include "llvm/IR/IntrinsicInst.h" 62 #include "llvm/Transforms/Scalar/SCCP.h" 63 #include "llvm/Transforms/Utils/Cloning.h" 64 #include "llvm/Transforms/Utils/SCCPSolver.h" 65 #include "llvm/Transforms/Utils/SizeOpts.h" 66 #include <cmath> 67 68 using namespace llvm; 69 70 #define DEBUG_TYPE "function-specialization" 71 72 STATISTIC(NumFuncSpecialized, "Number of functions specialized"); 73 74 static cl::opt<bool> ForceFunctionSpecialization( 75 "force-function-specialization", cl::init(false), cl::Hidden, 76 cl::desc("Force function specialization for every call site with a " 77 "constant argument")); 78 79 static cl::opt<unsigned> FuncSpecializationMaxIters( 80 "func-specialization-max-iters", cl::Hidden, 81 cl::desc("The maximum number of iterations function specialization is run"), 82 cl::init(1)); 83 84 static cl::opt<unsigned> MaxClonesThreshold( 85 "func-specialization-max-clones", cl::Hidden, 86 cl::desc("The maximum number of clones allowed for a single function " 87 "specialization"), 88 cl::init(3)); 89 90 static cl::opt<unsigned> SmallFunctionThreshold( 91 "func-specialization-size-threshold", cl::Hidden, 92 cl::desc("Don't specialize functions that have less than this theshold " 93 "number of instructions"), 94 cl::init(100)); 95 96 static cl::opt<unsigned> 97 AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden, 98 cl::desc("Average loop iteration count cost"), 99 cl::init(10)); 100 101 static cl::opt<bool> SpecializeOnAddresses( 102 "func-specialization-on-address", cl::init(false), cl::Hidden, 103 cl::desc("Enable function specialization on the address of global values")); 104 105 // TODO: This needs checking to see the impact on compile-times, which is why 106 // this is off by default for now. 107 static cl::opt<bool> EnableSpecializationForLiteralConstant( 108 "function-specialization-for-literal-constant", cl::init(false), cl::Hidden, 109 cl::desc("Enable specialization of functions that take a literal constant " 110 "as an argument.")); 111 112 namespace { 113 // Bookkeeping struct to pass data from the analysis and profitability phase 114 // to the actual transform helper functions. 115 struct SpecializationInfo { 116 ArgInfo Arg; // Stores the {formal,actual} argument pair. 117 InstructionCost Gain; // Profitability: Gain = Bonus - Cost. 118 119 SpecializationInfo(Argument *A, Constant *C, InstructionCost G) 120 : Arg(A, C), Gain(G){}; 121 }; 122 } // Anonymous namespace 123 124 using FuncList = SmallVectorImpl<Function *>; 125 using ConstList = SmallVector<Constant *>; 126 using SpecializationList = SmallVector<SpecializationInfo>; 127 128 // Helper to check if \p LV is either a constant or a constant 129 // range with a single element. This should cover exactly the same cases as the 130 // old ValueLatticeElement::isConstant() and is intended to be used in the 131 // transition to ValueLatticeElement. 132 static bool isConstant(const ValueLatticeElement &LV) { 133 return LV.isConstant() || 134 (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); 135 } 136 137 // Helper to check if \p LV is either overdefined or a constant int. 138 static bool isOverdefined(const ValueLatticeElement &LV) { 139 return !LV.isUnknownOrUndef() && !isConstant(LV); 140 } 141 142 static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { 143 Value *StoreValue = nullptr; 144 for (auto *User : Alloca->users()) { 145 // We can't use llvm::isAllocaPromotable() as that would fail because of 146 // the usage in the CallInst, which is what we check here. 147 if (User == Call) 148 continue; 149 if (auto *Bitcast = dyn_cast<BitCastInst>(User)) { 150 if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call) 151 return nullptr; 152 continue; 153 } 154 155 if (auto *Store = dyn_cast<StoreInst>(User)) { 156 // This is a duplicate store, bail out. 157 if (StoreValue || Store->isVolatile()) 158 return nullptr; 159 StoreValue = Store->getValueOperand(); 160 continue; 161 } 162 // Bail if there is any other unknown usage. 163 return nullptr; 164 } 165 return dyn_cast_or_null<Constant>(StoreValue); 166 } 167 168 // A constant stack value is an AllocaInst that has a single constant 169 // value stored to it. Return this constant if such an alloca stack value 170 // is a function argument. 171 static Constant *getConstantStackValue(CallInst *Call, Value *Val, 172 SCCPSolver &Solver) { 173 if (!Val) 174 return nullptr; 175 Val = Val->stripPointerCasts(); 176 if (auto *ConstVal = dyn_cast<ConstantInt>(Val)) 177 return ConstVal; 178 auto *Alloca = dyn_cast<AllocaInst>(Val); 179 if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy()) 180 return nullptr; 181 return getPromotableAlloca(Alloca, Call); 182 } 183 184 // To support specializing recursive functions, it is important to propagate 185 // constant arguments because after a first iteration of specialisation, a 186 // reduced example may look like this: 187 // 188 // define internal void @RecursiveFn(i32* arg1) { 189 // %temp = alloca i32, align 4 190 // store i32 2 i32* %temp, align 4 191 // call void @RecursiveFn.1(i32* nonnull %temp) 192 // ret void 193 // } 194 // 195 // Before a next iteration, we need to propagate the constant like so 196 // which allows further specialization in next iterations. 197 // 198 // @funcspec.arg = internal constant i32 2 199 // 200 // define internal void @someFunc(i32* arg1) { 201 // call void @otherFunc(i32* nonnull @funcspec.arg) 202 // ret void 203 // } 204 // 205 static void constantArgPropagation(FuncList &WorkList, 206 Module &M, SCCPSolver &Solver) { 207 // Iterate over the argument tracked functions see if there 208 // are any new constant values for the call instruction via 209 // stack variables. 210 for (auto *F : WorkList) { 211 // TODO: Generalize for any read only arguments. 212 if (F->arg_size() != 1) 213 continue; 214 215 auto &Arg = *F->arg_begin(); 216 if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy()) 217 continue; 218 219 for (auto *User : F->users()) { 220 auto *Call = dyn_cast<CallInst>(User); 221 if (!Call) 222 break; 223 auto *ArgOp = Call->getArgOperand(0); 224 auto *ArgOpType = ArgOp->getType(); 225 auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); 226 if (!ConstVal) 227 break; 228 229 Value *GV = new GlobalVariable(M, ConstVal->getType(), true, 230 GlobalValue::InternalLinkage, ConstVal, 231 "funcspec.arg"); 232 233 if (ArgOpType != ConstVal->getType()) 234 GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType()); 235 236 Call->setArgOperand(0, GV); 237 238 // Add the changed CallInst to Solver Worklist 239 Solver.visitCall(*Call); 240 } 241 } 242 } 243 244 // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics 245 // interfere with the constantArgPropagation optimization. 246 static void removeSSACopy(Function &F) { 247 for (BasicBlock &BB : F) { 248 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 249 auto *II = dyn_cast<IntrinsicInst>(&Inst); 250 if (!II) 251 continue; 252 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 253 continue; 254 Inst.replaceAllUsesWith(II->getOperand(0)); 255 Inst.eraseFromParent(); 256 } 257 } 258 } 259 260 static void removeSSACopy(Module &M) { 261 for (Function &F : M) 262 removeSSACopy(F); 263 } 264 265 namespace { 266 class FunctionSpecializer { 267 268 /// The IPSCCP Solver. 269 SCCPSolver &Solver; 270 271 /// Analyses used to help determine if a function should be specialized. 272 std::function<AssumptionCache &(Function &)> GetAC; 273 std::function<TargetTransformInfo &(Function &)> GetTTI; 274 std::function<TargetLibraryInfo &(Function &)> GetTLI; 275 276 SmallPtrSet<Function *, 4> SpecializedFuncs; 277 SmallPtrSet<Function *, 4> FullySpecialized; 278 SmallVector<Instruction *> ReplacedWithConstant; 279 280 public: 281 FunctionSpecializer(SCCPSolver &Solver, 282 std::function<AssumptionCache &(Function &)> GetAC, 283 std::function<TargetTransformInfo &(Function &)> GetTTI, 284 std::function<TargetLibraryInfo &(Function &)> GetTLI) 285 : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {} 286 287 ~FunctionSpecializer() { 288 // Eliminate dead code. 289 removeDeadInstructions(); 290 removeDeadFunctions(); 291 } 292 293 /// Attempt to specialize functions in the module to enable constant 294 /// propagation across function boundaries. 295 /// 296 /// \returns true if at least one function is specialized. 297 bool specializeFunctions(FuncList &Candidates, FuncList &WorkList) { 298 bool Changed = false; 299 for (auto *F : Candidates) { 300 if (!isCandidateFunction(F)) 301 continue; 302 303 auto Cost = getSpecializationCost(F); 304 if (!Cost.isValid()) { 305 LLVM_DEBUG( 306 dbgs() << "FnSpecialization: Invalid specialisation cost.\n"); 307 continue; 308 } 309 310 LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " 311 << F->getName() << " is " << Cost << "\n"); 312 313 SpecializationList Specializations; 314 calculateGains(F, Cost, Specializations); 315 if (Specializations.empty()) { 316 LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n"); 317 continue; 318 } 319 320 for (SpecializationInfo &S : Specializations) { 321 specializeFunction(F, S, WorkList); 322 Changed = true; 323 } 324 } 325 326 updateSpecializedFuncs(Candidates, WorkList); 327 NumFuncSpecialized += NbFunctionsSpecialized; 328 return Changed; 329 } 330 331 void removeDeadInstructions() { 332 for (auto *I : ReplacedWithConstant) { 333 LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction " 334 << *I << "\n"); 335 I->eraseFromParent(); 336 } 337 ReplacedWithConstant.clear(); 338 } 339 340 void removeDeadFunctions() { 341 for (auto *F : FullySpecialized) { 342 LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " 343 << F->getName() << "\n"); 344 F->eraseFromParent(); 345 } 346 FullySpecialized.clear(); 347 } 348 349 bool tryToReplaceWithConstant(Value *V) { 350 if (!V->getType()->isSingleValueType() || isa<CallBase>(V) || 351 V->user_empty()) 352 return false; 353 354 const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); 355 if (isOverdefined(IV)) 356 return false; 357 auto *Const = 358 isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); 359 360 LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V 361 << "\nFnSpecialization: with " << *Const << "\n"); 362 363 // Record uses of V to avoid visiting irrelevant uses of const later. 364 SmallVector<Instruction *> UseInsts; 365 for (auto *U : V->users()) 366 if (auto *I = dyn_cast<Instruction>(U)) 367 if (Solver.isBlockExecutable(I->getParent())) 368 UseInsts.push_back(I); 369 370 V->replaceAllUsesWith(Const); 371 372 for (auto *I : UseInsts) 373 Solver.visit(I); 374 375 // Remove the instruction from Block and Solver. 376 if (auto *I = dyn_cast<Instruction>(V)) { 377 if (I->isSafeToRemove()) { 378 ReplacedWithConstant.push_back(I); 379 Solver.removeLatticeValueFor(I); 380 } 381 } 382 return true; 383 } 384 385 private: 386 // The number of functions specialised, used for collecting statistics and 387 // also in the cost model. 388 unsigned NbFunctionsSpecialized = 0; 389 390 /// Clone the function \p F and remove the ssa_copy intrinsics added by 391 /// the SCCPSolver in the cloned version. 392 Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) { 393 Function *Clone = CloneFunction(F, Mappings); 394 removeSSACopy(*Clone); 395 return Clone; 396 } 397 398 /// This function decides whether it's worthwhile to specialize function \p F 399 /// based on the known constant values its arguments can take on, i.e. it 400 /// calculates a gain and returns a list of actual arguments that are deemed 401 /// profitable to specialize. Specialization is performed on the first 402 /// interesting argument. Specializations based on additional arguments will 403 /// be evaluated on following iterations of the main IPSCCP solve loop. 404 void calculateGains(Function *F, InstructionCost Cost, 405 SpecializationList &WorkList) { 406 // Determine if we should specialize the function based on the values the 407 // argument can take on. If specialization is not profitable, we continue 408 // on to the next argument. 409 for (Argument &FormalArg : F->args()) { 410 // Determine if this argument is interesting. If we know the argument can 411 // take on any constant values, they are collected in Constants. 412 ConstList ActualArgs; 413 if (!isArgumentInteresting(&FormalArg, ActualArgs)) { 414 LLVM_DEBUG(dbgs() << "FnSpecialization: Argument " 415 << FormalArg.getNameOrAsOperand() 416 << " is not interesting\n"); 417 continue; 418 } 419 420 for (auto *ActualArg : ActualArgs) { 421 InstructionCost Gain = 422 ForceFunctionSpecialization 423 ? 1 424 : getSpecializationBonus(&FormalArg, ActualArg) - Cost; 425 426 if (Gain <= 0) 427 continue; 428 WorkList.push_back({&FormalArg, ActualArg, Gain}); 429 } 430 431 if (WorkList.empty()) 432 continue; 433 434 // Sort the candidates in descending order. 435 llvm::stable_sort(WorkList, [](const SpecializationInfo &L, 436 const SpecializationInfo &R) { 437 return L.Gain > R.Gain; 438 }); 439 440 // Truncate the worklist to 'MaxClonesThreshold' candidates if 441 // necessary. 442 if (WorkList.size() > MaxClonesThreshold) { 443 LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed " 444 << "the maximum number of clones threshold.\n" 445 << "FnSpecialization: Truncating worklist to " 446 << MaxClonesThreshold << " candidates.\n"); 447 WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end()); 448 } 449 450 LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function " 451 << F->getName() << "\n"; 452 for (SpecializationInfo &S : WorkList) { 453 dbgs() << "FnSpecialization: FormalArg = " 454 << S.Arg.Formal->getNameOrAsOperand() 455 << ", ActualArg = " 456 << S.Arg.Actual->getNameOrAsOperand() 457 << ", Gain = " << S.Gain << "\n"; 458 }); 459 460 // FIXME: Only one argument per function. 461 break; 462 } 463 } 464 465 bool isCandidateFunction(Function *F) { 466 // Do not specialize the cloned function again. 467 if (SpecializedFuncs.contains(F)) 468 return false; 469 470 // If we're optimizing the function for size, we shouldn't specialize it. 471 if (F->hasOptSize() || 472 shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass)) 473 return false; 474 475 // Exit if the function is not executable. There's no point in specializing 476 // a dead function. 477 if (!Solver.isBlockExecutable(&F->getEntryBlock())) 478 return false; 479 480 // It wastes time to specialize a function which would get inlined finally. 481 if (F->hasFnAttribute(Attribute::AlwaysInline)) 482 return false; 483 484 LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName() 485 << "\n"); 486 return true; 487 } 488 489 void specializeFunction(Function *F, SpecializationInfo &S, 490 FuncList &WorkList) { 491 ValueToValueMapTy Mappings; 492 Function *Clone = cloneCandidateFunction(F, Mappings); 493 494 // Rewrite calls to the function so that they call the clone instead. 495 rewriteCallSites(Clone, S.Arg, Mappings); 496 497 // Initialize the lattice state of the arguments of the function clone, 498 // marking the argument on which we specialized the function constant 499 // with the given value. 500 Solver.markArgInFuncSpecialization(Clone, S.Arg); 501 502 // Mark all the specialized functions 503 WorkList.push_back(Clone); 504 NbFunctionsSpecialized++; 505 506 // If the function has been completely specialized, the original function 507 // is no longer needed. Mark it unreachable. 508 if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) { 509 if (auto *CS = dyn_cast<CallBase>(U)) 510 return CS->getFunction() == F; 511 return false; 512 })) { 513 Solver.markFunctionUnreachable(F); 514 FullySpecialized.insert(F); 515 } 516 } 517 518 /// Compute and return the cost of specializing function \p F. 519 InstructionCost getSpecializationCost(Function *F) { 520 // Compute the code metrics for the function. 521 SmallPtrSet<const Value *, 32> EphValues; 522 CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); 523 CodeMetrics Metrics; 524 for (BasicBlock &BB : *F) 525 Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); 526 527 // If the code metrics reveal that we shouldn't duplicate the function, we 528 // shouldn't specialize it. Set the specialization cost to Invalid. 529 // Or if the lines of codes implies that this function is easy to get 530 // inlined so that we shouldn't specialize it. 531 if (Metrics.notDuplicatable || 532 (!ForceFunctionSpecialization && 533 Metrics.NumInsts < SmallFunctionThreshold)) { 534 InstructionCost C{}; 535 C.setInvalid(); 536 return C; 537 } 538 539 // Otherwise, set the specialization cost to be the cost of all the 540 // instructions in the function and penalty for specializing more functions. 541 unsigned Penalty = NbFunctionsSpecialized + 1; 542 return Metrics.NumInsts * InlineConstants::InstrCost * Penalty; 543 } 544 545 InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI, 546 LoopInfo &LI) { 547 auto *I = dyn_cast_or_null<Instruction>(U); 548 // If not an instruction we do not know how to evaluate. 549 // Keep minimum possible cost for now so that it doesnt affect 550 // specialization. 551 if (!I) 552 return std::numeric_limits<unsigned>::min(); 553 554 auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency); 555 556 // Traverse recursively if there are more uses. 557 // TODO: Any other instructions to be added here? 558 if (I->mayReadFromMemory() || I->isCast()) 559 for (auto *User : I->users()) 560 Cost += getUserBonus(User, TTI, LI); 561 562 // Increase the cost if it is inside the loop. 563 auto LoopDepth = LI.getLoopDepth(I->getParent()); 564 Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth); 565 return Cost; 566 } 567 568 /// Compute a bonus for replacing argument \p A with constant \p C. 569 InstructionCost getSpecializationBonus(Argument *A, Constant *C) { 570 Function *F = A->getParent(); 571 DominatorTree DT(*F); 572 LoopInfo LI(DT); 573 auto &TTI = (GetTTI)(*F); 574 LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " 575 << C->getNameOrAsOperand() << "\n"); 576 577 InstructionCost TotalCost = 0; 578 for (auto *U : A->users()) { 579 TotalCost += getUserBonus(U, TTI, LI); 580 LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; 581 TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); 582 } 583 584 // The below heuristic is only concerned with exposing inlining 585 // opportunities via indirect call promotion. If the argument is not a 586 // function pointer, give up. 587 if (!isa<PointerType>(A->getType()) || 588 !isa<FunctionType>(A->getType()->getPointerElementType())) 589 return TotalCost; 590 591 // Since the argument is a function pointer, its incoming constant values 592 // should be functions or constant expressions. The code below attempts to 593 // look through cast expressions to find the function that will be called. 594 Value *CalledValue = C; 595 while (isa<ConstantExpr>(CalledValue) && 596 cast<ConstantExpr>(CalledValue)->isCast()) 597 CalledValue = cast<User>(CalledValue)->getOperand(0); 598 Function *CalledFunction = dyn_cast<Function>(CalledValue); 599 if (!CalledFunction) 600 return TotalCost; 601 602 // Get TTI for the called function (used for the inline cost). 603 auto &CalleeTTI = (GetTTI)(*CalledFunction); 604 605 // Look at all the call sites whose called value is the argument. 606 // Specializing the function on the argument would allow these indirect 607 // calls to be promoted to direct calls. If the indirect call promotion 608 // would likely enable the called function to be inlined, specializing is a 609 // good idea. 610 int Bonus = 0; 611 for (User *U : A->users()) { 612 if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) 613 continue; 614 auto *CS = cast<CallBase>(U); 615 if (CS->getCalledOperand() != A) 616 continue; 617 618 // Get the cost of inlining the called function at this call site. Note 619 // that this is only an estimate. The called function may eventually 620 // change in a way that leads to it not being inlined here, even though 621 // inlining looks profitable now. For example, one of its called 622 // functions may be inlined into it, making the called function too large 623 // to be inlined into this call site. 624 // 625 // We apply a boost for performing indirect call promotion by increasing 626 // the default threshold by the threshold for indirect calls. 627 auto Params = getInlineParams(); 628 Params.DefaultThreshold += InlineConstants::IndirectCallThreshold; 629 InlineCost IC = 630 getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI); 631 632 // We clamp the bonus for this call to be between zero and the default 633 // threshold. 634 if (IC.isAlways()) 635 Bonus += Params.DefaultThreshold; 636 else if (IC.isVariable() && IC.getCostDelta() > 0) 637 Bonus += IC.getCostDelta(); 638 639 LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus 640 << " for user " << *U << "\n"); 641 } 642 643 return TotalCost + Bonus; 644 } 645 646 /// Determine if we should specialize a function based on the incoming values 647 /// of the given argument. 648 /// 649 /// This function implements the goal-directed heuristic. It determines if 650 /// specializing the function based on the incoming values of argument \p A 651 /// would result in any significant optimization opportunities. If 652 /// optimization opportunities exist, the constant values of \p A on which to 653 /// specialize the function are collected in \p Constants. 654 /// 655 /// \returns true if the function should be specialized on the given 656 /// argument. 657 bool isArgumentInteresting(Argument *A, ConstList &Constants) { 658 // For now, don't attempt to specialize functions based on the values of 659 // composite types. 660 if (!A->getType()->isSingleValueType() || A->user_empty()) 661 return false; 662 663 // If the argument isn't overdefined, there's nothing to do. It should 664 // already be constant. 665 if (!Solver.getLatticeValueFor(A).isOverdefined()) { 666 LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument " 667 << A->getNameOrAsOperand() 668 << " is already constant?\n"); 669 return false; 670 } 671 672 // Collect the constant values that the argument can take on. If the 673 // argument can't take on any constant values, we aren't going to 674 // specialize the function. While it's possible to specialize the function 675 // based on non-constant arguments, there's likely not much benefit to 676 // constant propagation in doing so. 677 // 678 // TODO 1: currently it won't specialize if there are over the threshold of 679 // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it 680 // might be beneficial to take the occurrences into account in the cost 681 // model, so we would need to find the unique constants. 682 // 683 // TODO 2: this currently does not support constants, i.e. integer ranges. 684 // 685 getPossibleConstants(A, Constants); 686 687 if (Constants.empty()) 688 return false; 689 690 LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument " 691 << A->getNameOrAsOperand() << "\n"); 692 return true; 693 } 694 695 /// Collect in \p Constants all the constant values that argument \p A can 696 /// take on. 697 void getPossibleConstants(Argument *A, ConstList &Constants) { 698 Function *F = A->getParent(); 699 700 // Iterate over all the call sites of the argument's parent function. 701 for (User *U : F->users()) { 702 if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) 703 continue; 704 auto &CS = *cast<CallBase>(U); 705 // If the call site has attribute minsize set, that callsite won't be 706 // specialized. 707 if (CS.hasFnAttr(Attribute::MinSize)) 708 continue; 709 710 // If the parent of the call site will never be executed, we don't need 711 // to worry about the passed value. 712 if (!Solver.isBlockExecutable(CS.getParent())) 713 continue; 714 715 auto *V = CS.getArgOperand(A->getArgNo()); 716 if (isa<PoisonValue>(V)) 717 return; 718 719 // For now, constant expressions are fine but only if they are function 720 // calls. 721 if (auto *CE = dyn_cast<ConstantExpr>(V)) 722 if (!isa<Function>(CE->getOperand(0))) 723 return; 724 725 // TrackValueOfGlobalVariable only tracks scalar global variables. 726 if (auto *GV = dyn_cast<GlobalVariable>(V)) { 727 // Check if we want to specialize on the address of non-constant 728 // global values. 729 if (!GV->isConstant()) 730 if (!SpecializeOnAddresses) 731 return; 732 733 if (!GV->getValueType()->isSingleValueType()) 734 return; 735 } 736 737 if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() || 738 EnableSpecializationForLiteralConstant)) 739 Constants.push_back(cast<Constant>(V)); 740 } 741 } 742 743 /// Rewrite calls to function \p F to call function \p Clone instead. 744 /// 745 /// This function modifies calls to function \p F as long as the actual 746 /// argument matches the one in \p Arg. Note that for recursive calls we 747 /// need to compare against the cloned formal argument. 748 /// 749 /// Callsites that have been marked with the MinSize function attribute won't 750 /// be specialized and rewritten. 751 void rewriteCallSites(Function *Clone, const ArgInfo &Arg, 752 ValueToValueMapTy &Mappings) { 753 Function *F = Arg.Formal->getParent(); 754 unsigned ArgNo = Arg.Formal->getArgNo(); 755 SmallVector<CallBase *, 4> CallSitesToRewrite; 756 for (auto *U : F->users()) { 757 if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) 758 continue; 759 auto &CS = *cast<CallBase>(U); 760 if (!CS.getCalledFunction() || CS.getCalledFunction() != F) 761 continue; 762 CallSitesToRewrite.push_back(&CS); 763 } 764 765 LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of " 766 << F->getName() << " with " 767 << Clone->getName() << "\n"); 768 769 for (auto *CS : CallSitesToRewrite) { 770 LLVM_DEBUG(dbgs() << "FnSpecialization: " 771 << CS->getFunction()->getName() << " ->" 772 << *CS << "\n"); 773 if (/* recursive call */ 774 (CS->getFunction() == Clone && 775 CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]) || 776 /* normal call */ 777 CS->getArgOperand(ArgNo) == Arg.Actual) { 778 CS->setCalledFunction(Clone); 779 Solver.markOverdefined(CS); 780 } 781 } 782 } 783 784 void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) { 785 for (auto *F : WorkList) { 786 SpecializedFuncs.insert(F); 787 788 // Initialize the state of the newly created functions, marking them 789 // argument-tracked and executable. 790 if (F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked)) 791 Solver.addTrackedFunction(F); 792 793 Solver.addArgumentTrackedFunction(F); 794 Candidates.push_back(F); 795 Solver.markBlockExecutable(&F->front()); 796 797 // Replace the function arguments for the specialized functions. 798 for (Argument &Arg : F->args()) 799 if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg)) 800 LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: " 801 << Arg.getNameOrAsOperand() << "\n"); 802 } 803 } 804 }; 805 } // namespace 806 807 bool llvm::runFunctionSpecialization( 808 Module &M, const DataLayout &DL, 809 std::function<TargetLibraryInfo &(Function &)> GetTLI, 810 std::function<TargetTransformInfo &(Function &)> GetTTI, 811 std::function<AssumptionCache &(Function &)> GetAC, 812 function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) { 813 SCCPSolver Solver(DL, GetTLI, M.getContext()); 814 FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI); 815 bool Changed = false; 816 817 // Loop over all functions, marking arguments to those with their addresses 818 // taken or that are external as overdefined. 819 for (Function &F : M) { 820 if (F.isDeclaration()) 821 continue; 822 if (F.hasFnAttribute(Attribute::NoDuplicate)) 823 continue; 824 825 LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName() 826 << "\n"); 827 Solver.addAnalysis(F, GetAnalysis(F)); 828 829 // Determine if we can track the function's arguments. If so, add the 830 // function to the solver's set of argument-tracked functions. 831 if (canTrackArgumentsInterprocedurally(&F)) { 832 LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n"); 833 Solver.addArgumentTrackedFunction(&F); 834 continue; 835 } else { 836 LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n" 837 << "FnSpecialization: Doesn't have local linkage, or " 838 << "has its address taken\n"); 839 } 840 841 // Assume the function is called. 842 Solver.markBlockExecutable(&F.front()); 843 844 // Assume nothing about the incoming arguments. 845 for (Argument &AI : F.args()) 846 Solver.markOverdefined(&AI); 847 } 848 849 // Determine if we can track any of the module's global variables. If so, add 850 // the global variables we can track to the solver's set of tracked global 851 // variables. 852 for (GlobalVariable &G : M.globals()) { 853 G.removeDeadConstantUsers(); 854 if (canTrackGlobalVariableInterprocedurally(&G)) 855 Solver.trackValueOfGlobalVariable(&G); 856 } 857 858 auto &TrackedFuncs = Solver.getArgumentTrackedFunctions(); 859 SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(), 860 TrackedFuncs.end()); 861 862 // No tracked functions, so nothing to do: don't run the solver and remove 863 // the ssa_copy intrinsics that may have been introduced. 864 if (TrackedFuncs.empty()) { 865 removeSSACopy(M); 866 return false; 867 } 868 869 // Solve for constants. 870 auto RunSCCPSolver = [&](auto &WorkList) { 871 bool ResolvedUndefs = true; 872 873 while (ResolvedUndefs) { 874 // Not running the solver unnecessary is checked in regression test 875 // nothing-to-do.ll, so if this debug message is changed, this regression 876 // test needs updating too. 877 LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n"); 878 879 Solver.solve(); 880 LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n"); 881 ResolvedUndefs = false; 882 for (Function *F : WorkList) 883 if (Solver.resolvedUndefsIn(*F)) 884 ResolvedUndefs = true; 885 } 886 887 for (auto *F : WorkList) { 888 for (BasicBlock &BB : *F) { 889 if (!Solver.isBlockExecutable(&BB)) 890 continue; 891 // FIXME: The solver may make changes to the function here, so set 892 // Changed, even if later function specialization does not trigger. 893 for (auto &I : make_early_inc_range(BB)) 894 Changed |= FS.tryToReplaceWithConstant(&I); 895 } 896 } 897 }; 898 899 #ifndef NDEBUG 900 LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n"); 901 for (auto *F : FuncDecls) 902 LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n"); 903 #endif 904 905 // Initially resolve the constants in all the argument tracked functions. 906 RunSCCPSolver(FuncDecls); 907 908 SmallVector<Function *, 2> WorkList; 909 unsigned I = 0; 910 while (FuncSpecializationMaxIters != I++ && 911 FS.specializeFunctions(FuncDecls, WorkList)) { 912 LLVM_DEBUG(dbgs() << "FnSpecialization: Finished iteration " << I << "\n"); 913 914 // Run the solver for the specialized functions. 915 RunSCCPSolver(WorkList); 916 917 // Replace some unresolved constant arguments. 918 constantArgPropagation(FuncDecls, M, Solver); 919 920 WorkList.clear(); 921 Changed = true; 922 } 923 924 LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = " 925 << NumFuncSpecialized <<"\n"); 926 927 // Remove any ssa_copy intrinsics that may have been introduced. 928 removeSSACopy(M); 929 return Changed; 930 } 931