1 //===- FunctionSpecialization.cpp - Function Specialization ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This specialises functions with constant parameters. Constant parameters
10 // like function pointers and constant globals are propagated to the callee by
11 // specializing the function. The main benefit of this pass at the moment is
12 // that indirect calls are transformed into direct calls, which provides inline
13 // opportunities that the inliner would not have been able to achieve. That's
14 // why function specialisation is run before the inliner in the optimisation
15 // pipeline; that is by design. Otherwise, we would only benefit from constant
16 // passing, which is a valid use-case too, but hasn't been explored much in
17 // terms of performance uplifts, cost-model and compile-time impact.
18 //
19 // Current limitations:
20 // - It does not yet handle integer ranges. We do support "literal constants",
21 //   but that's off by default under an option.
22 // - Only 1 argument per function is specialised,
23 // - The cost-model could be further looked into (it mainly focuses on inlining
24 //   benefits),
25 // - We are not yet caching analysis results, but profiling and checking where
26 //   extra compile time is spent didn't suggest this to be a problem.
27 //
28 // Ideas:
29 // - With a function specialization attribute for arguments, we could have
30 //   a direct way to steer function specialization, avoiding the cost-model,
31 //   and thus control compile-times / code-size.
32 //
33 // Todos:
34 // - Specializing recursive functions relies on running the transformation a
35 //   number of times, which is controlled by option
36 //   `func-specialization-max-iters`. Thus, increasing this value and the
37 //   number of iterations, will linearly increase the number of times recursive
38 //   functions get specialized, see also the discussion in
39 //   https://reviews.llvm.org/D106426 for details. Perhaps there is a
40 //   compile-time friendlier way to control/limit the number of specialisations
41 //   for recursive functions.
42 // - Don't transform the function if function specialization does not trigger;
43 //   the SCCPSolver may make IR changes.
44 //
45 // References:
46 // - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable
47 //   it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q
48 //
49 //===----------------------------------------------------------------------===//
50 
51 #include "llvm/ADT/Statistic.h"
52 #include "llvm/Analysis/CodeMetrics.h"
53 #include "llvm/Analysis/InlineCost.h"
54 #include "llvm/Analysis/LoopInfo.h"
55 #include "llvm/Analysis/TargetTransformInfo.h"
56 #include "llvm/Analysis/ValueLattice.h"
57 #include "llvm/Analysis/ValueLatticeUtils.h"
58 #include "llvm/IR/IntrinsicInst.h"
59 #include "llvm/Transforms/Scalar/SCCP.h"
60 #include "llvm/Transforms/Utils/Cloning.h"
61 #include "llvm/Transforms/Utils/SCCPSolver.h"
62 #include "llvm/Transforms/Utils/SizeOpts.h"
63 #include <cmath>
64 
65 using namespace llvm;
66 
67 #define DEBUG_TYPE "function-specialization"
68 
69 STATISTIC(NumFuncSpecialized, "Number of functions specialized");
70 
71 static cl::opt<bool> ForceFunctionSpecialization(
72     "force-function-specialization", cl::init(false), cl::Hidden,
73     cl::desc("Force function specialization for every call site with a "
74              "constant argument"));
75 
76 static cl::opt<unsigned> FuncSpecializationMaxIters(
77     "func-specialization-max-iters", cl::Hidden,
78     cl::desc("The maximum number of iterations function specialization is run"),
79     cl::init(1));
80 
81 static cl::opt<unsigned> MaxClonesThreshold(
82     "func-specialization-max-clones", cl::Hidden,
83     cl::desc("The maximum number of clones allowed for a single function "
84              "specialization"),
85     cl::init(3));
86 
87 static cl::opt<unsigned> SmallFunctionThreshold(
88     "func-specialization-size-threshold", cl::Hidden,
89     cl::desc("Don't specialize functions that have less than this theshold "
90              "number of instructions"),
91     cl::init(100));
92 
93 static cl::opt<unsigned>
94     AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden,
95                           cl::desc("Average loop iteration count cost"),
96                           cl::init(10));
97 
98 static cl::opt<bool> SpecializeOnAddresses(
99     "func-specialization-on-address", cl::init(false), cl::Hidden,
100     cl::desc("Enable function specialization on the address of global values"));
101 
102 // TODO: This needs checking to see the impact on compile-times, which is why
103 // this is off by default for now.
104 static cl::opt<bool> EnableSpecializationForLiteralConstant(
105     "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
106     cl::desc("Enable specialization of functions that take a literal constant "
107              "as an argument."));
108 
109 namespace {
110 // Bookkeeping struct to pass data from the analysis and profitability phase
111 // to the actual transform helper functions.
112 struct SpecializationInfo {
113   ArgInfo Arg;          // Stores the {formal,actual} argument pair.
114   InstructionCost Gain; // Profitability: Gain = Bonus - Cost.
115 
116   SpecializationInfo(Argument *A, Constant *C, InstructionCost G)
117       : Arg(A, C), Gain(G){};
118 };
119 } // Anonymous namespace
120 
121 using FuncList = SmallVectorImpl<Function *>;
122 using ConstList = SmallVector<Constant *>;
123 using SpecializationList = SmallVector<SpecializationInfo>;
124 
125 // Helper to check if \p LV is either a constant or a constant
126 // range with a single element. This should cover exactly the same cases as the
127 // old ValueLatticeElement::isConstant() and is intended to be used in the
128 // transition to ValueLatticeElement.
129 static bool isConstant(const ValueLatticeElement &LV) {
130   return LV.isConstant() ||
131          (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
132 }
133 
134 // Helper to check if \p LV is either overdefined or a constant int.
135 static bool isOverdefined(const ValueLatticeElement &LV) {
136   return !LV.isUnknownOrUndef() && !isConstant(LV);
137 }
138 
139 static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
140   Value *StoreValue = nullptr;
141   for (auto *User : Alloca->users()) {
142     // We can't use llvm::isAllocaPromotable() as that would fail because of
143     // the usage in the CallInst, which is what we check here.
144     if (User == Call)
145       continue;
146     if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
147       if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
148         return nullptr;
149       continue;
150     }
151 
152     if (auto *Store = dyn_cast<StoreInst>(User)) {
153       // This is a duplicate store, bail out.
154       if (StoreValue || Store->isVolatile())
155         return nullptr;
156       StoreValue = Store->getValueOperand();
157       continue;
158     }
159     // Bail if there is any other unknown usage.
160     return nullptr;
161   }
162   return dyn_cast_or_null<Constant>(StoreValue);
163 }
164 
165 // A constant stack value is an AllocaInst that has a single constant
166 // value stored to it. Return this constant if such an alloca stack value
167 // is a function argument.
168 static Constant *getConstantStackValue(CallInst *Call, Value *Val,
169                                        SCCPSolver &Solver) {
170   if (!Val)
171     return nullptr;
172   Val = Val->stripPointerCasts();
173   if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
174     return ConstVal;
175   auto *Alloca = dyn_cast<AllocaInst>(Val);
176   if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
177     return nullptr;
178   return getPromotableAlloca(Alloca, Call);
179 }
180 
181 // To support specializing recursive functions, it is important to propagate
182 // constant arguments because after a first iteration of specialisation, a
183 // reduced example may look like this:
184 //
185 //     define internal void @RecursiveFn(i32* arg1) {
186 //       %temp = alloca i32, align 4
187 //       store i32 2 i32* %temp, align 4
188 //       call void @RecursiveFn.1(i32* nonnull %temp)
189 //       ret void
190 //     }
191 //
192 // Before a next iteration, we need to propagate the constant like so
193 // which allows further specialization in next iterations.
194 //
195 //     @funcspec.arg = internal constant i32 2
196 //
197 //     define internal void @someFunc(i32* arg1) {
198 //       call void @otherFunc(i32* nonnull @funcspec.arg)
199 //       ret void
200 //     }
201 //
202 static void constantArgPropagation(FuncList &WorkList,
203                                    Module &M, SCCPSolver &Solver) {
204   // Iterate over the argument tracked functions see if there
205   // are any new constant values for the call instruction via
206   // stack variables.
207   for (auto *F : WorkList) {
208     // TODO: Generalize for any read only arguments.
209     if (F->arg_size() != 1)
210       continue;
211 
212     auto &Arg = *F->arg_begin();
213     if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy())
214       continue;
215 
216     for (auto *User : F->users()) {
217       auto *Call = dyn_cast<CallInst>(User);
218       if (!Call)
219         break;
220       auto *ArgOp = Call->getArgOperand(0);
221       auto *ArgOpType = ArgOp->getType();
222       auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
223       if (!ConstVal)
224         break;
225 
226       Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
227                                      GlobalValue::InternalLinkage, ConstVal,
228                                      "funcspec.arg");
229 
230       if (ArgOpType != ConstVal->getType())
231         GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType());
232 
233       Call->setArgOperand(0, GV);
234 
235       // Add the changed CallInst to Solver Worklist
236       Solver.visitCall(*Call);
237     }
238   }
239 }
240 
241 // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
242 // interfere with the constantArgPropagation optimization.
243 static void removeSSACopy(Function &F) {
244   for (BasicBlock &BB : F) {
245     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
246       auto *II = dyn_cast<IntrinsicInst>(&Inst);
247       if (!II)
248         continue;
249       if (II->getIntrinsicID() != Intrinsic::ssa_copy)
250         continue;
251       Inst.replaceAllUsesWith(II->getOperand(0));
252       Inst.eraseFromParent();
253     }
254   }
255 }
256 
257 static void removeSSACopy(Module &M) {
258   for (Function &F : M)
259     removeSSACopy(F);
260 }
261 
262 namespace {
263 class FunctionSpecializer {
264 
265   /// The IPSCCP Solver.
266   SCCPSolver &Solver;
267 
268   /// Analyses used to help determine if a function should be specialized.
269   std::function<AssumptionCache &(Function &)> GetAC;
270   std::function<TargetTransformInfo &(Function &)> GetTTI;
271   std::function<TargetLibraryInfo &(Function &)> GetTLI;
272 
273   SmallPtrSet<Function *, 4> SpecializedFuncs;
274   SmallPtrSet<Function *, 4> FullySpecialized;
275   SmallVector<Instruction *> ReplacedWithConstant;
276 
277 public:
278   FunctionSpecializer(SCCPSolver &Solver,
279                       std::function<AssumptionCache &(Function &)> GetAC,
280                       std::function<TargetTransformInfo &(Function &)> GetTTI,
281                       std::function<TargetLibraryInfo &(Function &)> GetTLI)
282       : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {}
283 
284   ~FunctionSpecializer() {
285     // Eliminate dead code.
286     removeDeadInstructions();
287     removeDeadFunctions();
288   }
289 
290   /// Attempt to specialize functions in the module to enable constant
291   /// propagation across function boundaries.
292   ///
293   /// \returns true if at least one function is specialized.
294   bool specializeFunctions(FuncList &Candidates, FuncList &WorkList) {
295     bool Changed = false;
296     for (auto *F : Candidates) {
297       if (!isCandidateFunction(F))
298         continue;
299 
300       auto Cost = getSpecializationCost(F);
301       if (!Cost.isValid()) {
302         LLVM_DEBUG(
303             dbgs() << "FnSpecialization: Invalid specialisation cost.\n");
304         continue;
305       }
306 
307       LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
308                         << F->getName() << " is " << Cost << "\n");
309 
310       SpecializationList Specializations;
311       calculateGains(F, Cost, Specializations);
312       if (Specializations.empty()) {
313         LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n");
314         continue;
315       }
316 
317       for (SpecializationInfo &S : Specializations) {
318         specializeFunction(F, S, WorkList);
319         Changed = true;
320       }
321     }
322 
323     updateSpecializedFuncs(Candidates, WorkList);
324     NumFuncSpecialized += NbFunctionsSpecialized;
325     return Changed;
326   }
327 
328   void removeDeadInstructions() {
329     for (auto *I : ReplacedWithConstant) {
330       LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction "
331                         << *I << "\n");
332       I->eraseFromParent();
333     }
334     ReplacedWithConstant.clear();
335   }
336 
337   void removeDeadFunctions() {
338     for (auto *F : FullySpecialized) {
339       LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function "
340                         << F->getName() << "\n");
341       F->eraseFromParent();
342     }
343     FullySpecialized.clear();
344   }
345 
346   bool tryToReplaceWithConstant(Value *V) {
347     if (!V->getType()->isSingleValueType() || isa<CallBase>(V) ||
348         V->user_empty())
349       return false;
350 
351     const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
352     if (isOverdefined(IV))
353       return false;
354     auto *Const =
355         isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
356 
357     LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V
358                       << "\nFnSpecialization: with " << *Const << "\n");
359 
360     // Record uses of V to avoid visiting irrelevant uses of const later.
361     SmallVector<Instruction *> UseInsts;
362     for (auto *U : V->users())
363       if (auto *I = dyn_cast<Instruction>(U))
364         if (Solver.isBlockExecutable(I->getParent()))
365           UseInsts.push_back(I);
366 
367     V->replaceAllUsesWith(Const);
368 
369     for (auto *I : UseInsts)
370       Solver.visit(I);
371 
372     // Remove the instruction from Block and Solver.
373     if (auto *I = dyn_cast<Instruction>(V)) {
374       if (I->isSafeToRemove()) {
375         ReplacedWithConstant.push_back(I);
376         Solver.removeLatticeValueFor(I);
377       }
378     }
379     return true;
380   }
381 
382 private:
383   // The number of functions specialised, used for collecting statistics and
384   // also in the cost model.
385   unsigned NbFunctionsSpecialized = 0;
386 
387   /// Clone the function \p F and remove the ssa_copy intrinsics added by
388   /// the SCCPSolver in the cloned version.
389   Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) {
390     Function *Clone = CloneFunction(F, Mappings);
391     removeSSACopy(*Clone);
392     return Clone;
393   }
394 
395   /// This function decides whether it's worthwhile to specialize function \p F
396   /// based on the known constant values its arguments can take on, i.e. it
397   /// calculates a gain and returns a list of actual arguments that are deemed
398   /// profitable to specialize. Specialization is performed on the first
399   /// interesting argument. Specializations based on additional arguments will
400   /// be evaluated on following iterations of the main IPSCCP solve loop.
401   void calculateGains(Function *F, InstructionCost Cost,
402                       SpecializationList &WorkList) {
403     // Determine if we should specialize the function based on the values the
404     // argument can take on. If specialization is not profitable, we continue
405     // on to the next argument.
406     for (Argument &FormalArg : F->args()) {
407       // Determine if this argument is interesting. If we know the argument can
408       // take on any constant values, they are collected in Constants.
409       ConstList ActualArgs;
410       if (!isArgumentInteresting(&FormalArg, ActualArgs)) {
411         LLVM_DEBUG(dbgs() << "FnSpecialization: Argument "
412                           << FormalArg.getNameOrAsOperand()
413                           << " is not interesting\n");
414         continue;
415       }
416 
417       for (auto *ActualArg : ActualArgs) {
418         InstructionCost Gain =
419             ForceFunctionSpecialization
420                 ? 1
421                 : getSpecializationBonus(&FormalArg, ActualArg) - Cost;
422 
423         if (Gain <= 0)
424           continue;
425         WorkList.push_back({&FormalArg, ActualArg, Gain});
426       }
427 
428       if (WorkList.empty())
429         continue;
430 
431       // Sort the candidates in descending order.
432       llvm::stable_sort(WorkList, [](const SpecializationInfo &L,
433                                      const SpecializationInfo &R) {
434         return L.Gain > R.Gain;
435       });
436 
437       // Truncate the worklist to 'MaxClonesThreshold' candidates if
438       // necessary.
439       if (WorkList.size() > MaxClonesThreshold) {
440         LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
441                           << "the maximum number of clones threshold.\n"
442                           << "FnSpecialization: Truncating worklist to "
443                           << MaxClonesThreshold << " candidates.\n");
444         WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
445       }
446 
447       LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
448                         << F->getName() << "\n";
449                  for (SpecializationInfo &S : WorkList) {
450                    dbgs() << "FnSpecialization:   FormalArg = "
451                           << S.Arg.Formal->getNameOrAsOperand()
452                           << ", ActualArg = "
453                           << S.Arg.Actual->getNameOrAsOperand()
454                           << ", Gain = " << S.Gain << "\n";
455                  });
456 
457       // FIXME: Only one argument per function.
458       break;
459     }
460   }
461 
462   bool isCandidateFunction(Function *F) {
463     // Do not specialize the cloned function again.
464     if (SpecializedFuncs.contains(F))
465       return false;
466 
467     // If we're optimizing the function for size, we shouldn't specialize it.
468     if (F->hasOptSize() ||
469         shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass))
470       return false;
471 
472     // Exit if the function is not executable. There's no point in specializing
473     // a dead function.
474     if (!Solver.isBlockExecutable(&F->getEntryBlock()))
475       return false;
476 
477     // It wastes time to specialize a function which would get inlined finally.
478     if (F->hasFnAttribute(Attribute::AlwaysInline))
479       return false;
480 
481     LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName()
482                       << "\n");
483     return true;
484   }
485 
486   void specializeFunction(Function *F, SpecializationInfo &S,
487                           FuncList &WorkList) {
488     ValueToValueMapTy Mappings;
489     Function *Clone = cloneCandidateFunction(F, Mappings);
490 
491     // Rewrite calls to the function so that they call the clone instead.
492     rewriteCallSites(Clone, S.Arg, Mappings);
493 
494     // Initialize the lattice state of the arguments of the function clone,
495     // marking the argument on which we specialized the function constant
496     // with the given value.
497     Solver.markArgInFuncSpecialization(Clone, S.Arg);
498 
499     // Mark all the specialized functions
500     WorkList.push_back(Clone);
501     NbFunctionsSpecialized++;
502 
503     // If the function has been completely specialized, the original function
504     // is no longer needed. Mark it unreachable.
505     if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) {
506           if (auto *CS = dyn_cast<CallBase>(U))
507             return CS->getFunction() == F;
508           return false;
509         })) {
510       Solver.markFunctionUnreachable(F);
511       FullySpecialized.insert(F);
512     }
513   }
514 
515   /// Compute and return the cost of specializing function \p F.
516   InstructionCost getSpecializationCost(Function *F) {
517     // Compute the code metrics for the function.
518     SmallPtrSet<const Value *, 32> EphValues;
519     CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
520     CodeMetrics Metrics;
521     for (BasicBlock &BB : *F)
522       Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
523 
524     // If the code metrics reveal that we shouldn't duplicate the function, we
525     // shouldn't specialize it. Set the specialization cost to Invalid.
526     // Or if the lines of codes implies that this function is easy to get
527     // inlined so that we shouldn't specialize it.
528     if (Metrics.notDuplicatable ||
529         (!ForceFunctionSpecialization &&
530          Metrics.NumInsts < SmallFunctionThreshold)) {
531       InstructionCost C{};
532       C.setInvalid();
533       return C;
534     }
535 
536     // Otherwise, set the specialization cost to be the cost of all the
537     // instructions in the function and penalty for specializing more functions.
538     unsigned Penalty = NbFunctionsSpecialized + 1;
539     return Metrics.NumInsts * InlineConstants::InstrCost * Penalty;
540   }
541 
542   InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
543                                LoopInfo &LI) {
544     auto *I = dyn_cast_or_null<Instruction>(U);
545     // If not an instruction we do not know how to evaluate.
546     // Keep minimum possible cost for now so that it doesnt affect
547     // specialization.
548     if (!I)
549       return std::numeric_limits<unsigned>::min();
550 
551     auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency);
552 
553     // Traverse recursively if there are more uses.
554     // TODO: Any other instructions to be added here?
555     if (I->mayReadFromMemory() || I->isCast())
556       for (auto *User : I->users())
557         Cost += getUserBonus(User, TTI, LI);
558 
559     // Increase the cost if it is inside the loop.
560     auto LoopDepth = LI.getLoopDepth(I->getParent());
561     Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth);
562     return Cost;
563   }
564 
565   /// Compute a bonus for replacing argument \p A with constant \p C.
566   InstructionCost getSpecializationBonus(Argument *A, Constant *C) {
567     Function *F = A->getParent();
568     DominatorTree DT(*F);
569     LoopInfo LI(DT);
570     auto &TTI = (GetTTI)(*F);
571     LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
572                       << C->getNameOrAsOperand() << "\n");
573 
574     InstructionCost TotalCost = 0;
575     for (auto *U : A->users()) {
576       TotalCost += getUserBonus(U, TTI, LI);
577       LLVM_DEBUG(dbgs() << "FnSpecialization:   User cost ";
578                  TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
579     }
580 
581     // The below heuristic is only concerned with exposing inlining
582     // opportunities via indirect call promotion. If the argument is not a
583     // (potentially casted) function pointer, give up.
584     Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts());
585     if (!CalledFunction)
586       return TotalCost;
587 
588     // Get TTI for the called function (used for the inline cost).
589     auto &CalleeTTI = (GetTTI)(*CalledFunction);
590 
591     // Look at all the call sites whose called value is the argument.
592     // Specializing the function on the argument would allow these indirect
593     // calls to be promoted to direct calls. If the indirect call promotion
594     // would likely enable the called function to be inlined, specializing is a
595     // good idea.
596     int Bonus = 0;
597     for (User *U : A->users()) {
598       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
599         continue;
600       auto *CS = cast<CallBase>(U);
601       if (CS->getCalledOperand() != A)
602         continue;
603 
604       // Get the cost of inlining the called function at this call site. Note
605       // that this is only an estimate. The called function may eventually
606       // change in a way that leads to it not being inlined here, even though
607       // inlining looks profitable now. For example, one of its called
608       // functions may be inlined into it, making the called function too large
609       // to be inlined into this call site.
610       //
611       // We apply a boost for performing indirect call promotion by increasing
612       // the default threshold by the threshold for indirect calls.
613       auto Params = getInlineParams();
614       Params.DefaultThreshold += InlineConstants::IndirectCallThreshold;
615       InlineCost IC =
616           getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI);
617 
618       // We clamp the bonus for this call to be between zero and the default
619       // threshold.
620       if (IC.isAlways())
621         Bonus += Params.DefaultThreshold;
622       else if (IC.isVariable() && IC.getCostDelta() > 0)
623         Bonus += IC.getCostDelta();
624 
625       LLVM_DEBUG(dbgs() << "FnSpecialization:   Inlining bonus " << Bonus
626                         << " for user " << *U << "\n");
627     }
628 
629     return TotalCost + Bonus;
630   }
631 
632   /// Determine if we should specialize a function based on the incoming values
633   /// of the given argument.
634   ///
635   /// This function implements the goal-directed heuristic. It determines if
636   /// specializing the function based on the incoming values of argument \p A
637   /// would result in any significant optimization opportunities. If
638   /// optimization opportunities exist, the constant values of \p A on which to
639   /// specialize the function are collected in \p Constants.
640   ///
641   /// \returns true if the function should be specialized on the given
642   /// argument.
643   bool isArgumentInteresting(Argument *A, ConstList &Constants) {
644     // For now, don't attempt to specialize functions based on the values of
645     // composite types.
646     if (!A->getType()->isSingleValueType() || A->user_empty())
647       return false;
648 
649     // If the argument isn't overdefined, there's nothing to do. It should
650     // already be constant.
651     if (!Solver.getLatticeValueFor(A).isOverdefined()) {
652       LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument "
653                         << A->getNameOrAsOperand()
654                         << " is already constant?\n");
655       return false;
656     }
657 
658     // Collect the constant values that the argument can take on. If the
659     // argument can't take on any constant values, we aren't going to
660     // specialize the function. While it's possible to specialize the function
661     // based on non-constant arguments, there's likely not much benefit to
662     // constant propagation in doing so.
663     //
664     // TODO 1: currently it won't specialize if there are over the threshold of
665     // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it
666     // might be beneficial to take the occurrences into account in the cost
667     // model, so we would need to find the unique constants.
668     //
669     // TODO 2: this currently does not support constants, i.e. integer ranges.
670     //
671     getPossibleConstants(A, Constants);
672 
673     if (Constants.empty())
674       return false;
675 
676     LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument "
677                       << A->getNameOrAsOperand() << "\n");
678     return true;
679   }
680 
681   /// Collect in \p Constants all the constant values that argument \p A can
682   /// take on.
683   void getPossibleConstants(Argument *A, ConstList &Constants) {
684     Function *F = A->getParent();
685 
686     // Iterate over all the call sites of the argument's parent function.
687     for (User *U : F->users()) {
688       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
689         continue;
690       auto &CS = *cast<CallBase>(U);
691       // If the call site has attribute minsize set, that callsite won't be
692       // specialized.
693       if (CS.hasFnAttr(Attribute::MinSize))
694         continue;
695 
696       // If the parent of the call site will never be executed, we don't need
697       // to worry about the passed value.
698       if (!Solver.isBlockExecutable(CS.getParent()))
699         continue;
700 
701       auto *V = CS.getArgOperand(A->getArgNo());
702       if (isa<PoisonValue>(V))
703         return;
704 
705       // For now, constant expressions are fine but only if they are function
706       // calls.
707       if (auto *CE = dyn_cast<ConstantExpr>(V))
708         if (!isa<Function>(CE->getOperand(0)))
709           return;
710 
711       // TrackValueOfGlobalVariable only tracks scalar global variables.
712       if (auto *GV = dyn_cast<GlobalVariable>(V)) {
713         // Check if we want to specialize on the address of non-constant
714         // global values.
715         if (!GV->isConstant())
716           if (!SpecializeOnAddresses)
717             return;
718 
719         if (!GV->getValueType()->isSingleValueType())
720           return;
721       }
722 
723       if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
724                                EnableSpecializationForLiteralConstant))
725         Constants.push_back(cast<Constant>(V));
726     }
727   }
728 
729   /// Rewrite calls to function \p F to call function \p Clone instead.
730   ///
731   /// This function modifies calls to function \p F as long as the actual
732   /// argument matches the one in \p Arg. Note that for recursive calls we
733   /// need to compare against the cloned formal argument.
734   ///
735   /// Callsites that have been marked with the MinSize function attribute won't
736   /// be specialized and rewritten.
737   void rewriteCallSites(Function *Clone, const ArgInfo &Arg,
738                         ValueToValueMapTy &Mappings) {
739     Function *F = Arg.Formal->getParent();
740     unsigned ArgNo = Arg.Formal->getArgNo();
741     SmallVector<CallBase *, 4> CallSitesToRewrite;
742     for (auto *U : F->users()) {
743       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
744         continue;
745       auto &CS = *cast<CallBase>(U);
746       if (!CS.getCalledFunction() || CS.getCalledFunction() != F)
747         continue;
748       CallSitesToRewrite.push_back(&CS);
749     }
750 
751     LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of "
752                       << F->getName() << " with "
753                       << Clone->getName() << "\n");
754 
755     for (auto *CS : CallSitesToRewrite) {
756       LLVM_DEBUG(dbgs() << "FnSpecialization:   "
757                         << CS->getFunction()->getName() << " ->"
758                         << *CS << "\n");
759       if (/* recursive call */
760           (CS->getFunction() == Clone &&
761            CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]) ||
762           /* normal call */
763           CS->getArgOperand(ArgNo) == Arg.Actual) {
764         CS->setCalledFunction(Clone);
765         Solver.markOverdefined(CS);
766       }
767     }
768   }
769 
770   void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) {
771     for (auto *F : WorkList) {
772       SpecializedFuncs.insert(F);
773 
774       // Initialize the state of the newly created functions, marking them
775       // argument-tracked and executable.
776       if (F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked))
777         Solver.addTrackedFunction(F);
778 
779       Solver.addArgumentTrackedFunction(F);
780       Candidates.push_back(F);
781       Solver.markBlockExecutable(&F->front());
782 
783       // Replace the function arguments for the specialized functions.
784       for (Argument &Arg : F->args())
785         if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg))
786           LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: "
787                             << Arg.getNameOrAsOperand() << "\n");
788     }
789   }
790 };
791 } // namespace
792 
793 bool llvm::runFunctionSpecialization(
794     Module &M, const DataLayout &DL,
795     std::function<TargetLibraryInfo &(Function &)> GetTLI,
796     std::function<TargetTransformInfo &(Function &)> GetTTI,
797     std::function<AssumptionCache &(Function &)> GetAC,
798     function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) {
799   SCCPSolver Solver(DL, GetTLI, M.getContext());
800   FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI);
801   bool Changed = false;
802 
803   // Loop over all functions, marking arguments to those with their addresses
804   // taken or that are external as overdefined.
805   for (Function &F : M) {
806     if (F.isDeclaration())
807       continue;
808     if (F.hasFnAttribute(Attribute::NoDuplicate))
809       continue;
810 
811     LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName()
812                       << "\n");
813     Solver.addAnalysis(F, GetAnalysis(F));
814 
815     // Determine if we can track the function's arguments. If so, add the
816     // function to the solver's set of argument-tracked functions.
817     if (canTrackArgumentsInterprocedurally(&F)) {
818       LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n");
819       Solver.addArgumentTrackedFunction(&F);
820       continue;
821     } else {
822       LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n"
823                         << "FnSpecialization: Doesn't have local linkage, or "
824                         << "has its address taken\n");
825     }
826 
827     // Assume the function is called.
828     Solver.markBlockExecutable(&F.front());
829 
830     // Assume nothing about the incoming arguments.
831     for (Argument &AI : F.args())
832       Solver.markOverdefined(&AI);
833   }
834 
835   // Determine if we can track any of the module's global variables. If so, add
836   // the global variables we can track to the solver's set of tracked global
837   // variables.
838   for (GlobalVariable &G : M.globals()) {
839     G.removeDeadConstantUsers();
840     if (canTrackGlobalVariableInterprocedurally(&G))
841       Solver.trackValueOfGlobalVariable(&G);
842   }
843 
844   auto &TrackedFuncs = Solver.getArgumentTrackedFunctions();
845   SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(),
846                                         TrackedFuncs.end());
847 
848   // No tracked functions, so nothing to do: don't run the solver and remove
849   // the ssa_copy intrinsics that may have been introduced.
850   if (TrackedFuncs.empty()) {
851     removeSSACopy(M);
852     return false;
853   }
854 
855   // Solve for constants.
856   auto RunSCCPSolver = [&](auto &WorkList) {
857     bool ResolvedUndefs = true;
858 
859     while (ResolvedUndefs) {
860       // Not running the solver unnecessary is checked in regression test
861       // nothing-to-do.ll, so if this debug message is changed, this regression
862       // test needs updating too.
863       LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n");
864 
865       Solver.solve();
866       LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n");
867       ResolvedUndefs = false;
868       for (Function *F : WorkList)
869         if (Solver.resolvedUndefsIn(*F))
870           ResolvedUndefs = true;
871     }
872 
873     for (auto *F : WorkList) {
874       for (BasicBlock &BB : *F) {
875         if (!Solver.isBlockExecutable(&BB))
876           continue;
877         // FIXME: The solver may make changes to the function here, so set
878         // Changed, even if later function specialization does not trigger.
879         for (auto &I : make_early_inc_range(BB))
880           Changed |= FS.tryToReplaceWithConstant(&I);
881       }
882     }
883   };
884 
885 #ifndef NDEBUG
886   LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n");
887   for (auto *F : FuncDecls)
888     LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n");
889 #endif
890 
891   // Initially resolve the constants in all the argument tracked functions.
892   RunSCCPSolver(FuncDecls);
893 
894   SmallVector<Function *, 2> WorkList;
895   unsigned I = 0;
896   while (FuncSpecializationMaxIters != I++ &&
897          FS.specializeFunctions(FuncDecls, WorkList)) {
898     LLVM_DEBUG(dbgs() << "FnSpecialization: Finished iteration " << I << "\n");
899 
900     // Run the solver for the specialized functions.
901     RunSCCPSolver(WorkList);
902 
903     // Replace some unresolved constant arguments.
904     constantArgPropagation(FuncDecls, M, Solver);
905 
906     WorkList.clear();
907     Changed = true;
908   }
909 
910   LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = "
911                     << NumFuncSpecialized <<"\n");
912 
913   // Remove any ssa_copy intrinsics that may have been introduced.
914   removeSSACopy(M);
915   return Changed;
916 }
917