1 //===- FunctionSpecialization.cpp - Function Specialization ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This specialises functions with constant parameters (e.g. functions,
10 // globals). Constant parameters like function pointers and constant globals
11 // are propagated to the callee by specializing the function.
12 //
13 // Current limitations:
14 // - It does not yet handle integer ranges.
15 // - Only 1 argument per function is specialised,
16 // - The cost-model could be further looked into,
17 // - We are not yet caching analysis results.
18 //
19 // Ideas:
20 // - With a function specialization attribute for arguments, we could have
21 //   a direct way to steer function specialization, avoiding the cost-model,
22 //   and thus control compile-times / code-size.
23 //
24 // Todos:
25 // - Limit the times a recursive function get specialized when
26 // `func-specialization-max-iters`
27 //   increases linearly. See discussion in https://reviews.llvm.org/D106426 for
28 //   details.
29 // - Don't transform the function if there is no function specialization
30 // happens.
31 //
32 //===----------------------------------------------------------------------===//
33 
34 #include "llvm/ADT/Statistic.h"
35 #include "llvm/Analysis/AssumptionCache.h"
36 #include "llvm/Analysis/CodeMetrics.h"
37 #include "llvm/Analysis/DomTreeUpdater.h"
38 #include "llvm/Analysis/InlineCost.h"
39 #include "llvm/Analysis/LoopInfo.h"
40 #include "llvm/Analysis/TargetLibraryInfo.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/Transforms/Scalar/SCCP.h"
43 #include "llvm/Transforms/Utils/Cloning.h"
44 #include "llvm/Transforms/Utils/SizeOpts.h"
45 #include <cmath>
46 
47 using namespace llvm;
48 
49 #define DEBUG_TYPE "function-specialization"
50 
51 STATISTIC(NumFuncSpecialized, "Number of functions specialized");
52 
53 static cl::opt<bool> ForceFunctionSpecialization(
54     "force-function-specialization", cl::init(false), cl::Hidden,
55     cl::desc("Force function specialization for every call site with a "
56              "constant argument"));
57 
58 static cl::opt<unsigned> FuncSpecializationMaxIters(
59     "func-specialization-max-iters", cl::Hidden,
60     cl::desc("The maximum number of iterations function specialization is run"),
61     cl::init(1));
62 
63 static cl::opt<unsigned> MaxConstantsThreshold(
64     "func-specialization-max-constants", cl::Hidden,
65     cl::desc("The maximum number of clones allowed for a single function "
66              "specialization"),
67     cl::init(3));
68 
69 static cl::opt<unsigned> SmallFunctionThreshold(
70     "func-specialization-size-threshold", cl::Hidden,
71     cl::desc("Don't specialize functions that have less than this theshold "
72              "number of instructions"),
73     cl::init(100));
74 
75 static cl::opt<unsigned>
76     AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden,
77                           cl::desc("Average loop iteration count cost"),
78                           cl::init(10));
79 
80 // TODO: This needs checking to see the impact on compile-times, which is why
81 // this is off by default for now.
82 static cl::opt<bool> EnableSpecializationForLiteralConstant(
83     "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
84     cl::desc("Enable specialization of functions that take a literal constant "
85              "as an argument."));
86 
87 // Helper to check if \p LV is either a constant or a constant
88 // range with a single element. This should cover exactly the same cases as the
89 // old ValueLatticeElement::isConstant() and is intended to be used in the
90 // transition to ValueLatticeElement.
91 static bool isConstant(const ValueLatticeElement &LV) {
92   return LV.isConstant() ||
93          (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
94 }
95 
96 // Helper to check if \p LV is either overdefined or a constant int.
97 static bool isOverdefined(const ValueLatticeElement &LV) {
98   return !LV.isUnknownOrUndef() && !isConstant(LV);
99 }
100 
101 static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
102   Value *StoreValue = nullptr;
103   for (auto *User : Alloca->users()) {
104     // We can't use llvm::isAllocaPromotable() as that would fail because of
105     // the usage in the CallInst, which is what we check here.
106     if (User == Call)
107       continue;
108     if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
109       if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
110         return nullptr;
111       continue;
112     }
113 
114     if (auto *Store = dyn_cast<StoreInst>(User)) {
115       // This is a duplicate store, bail out.
116       if (StoreValue || Store->isVolatile())
117         return nullptr;
118       StoreValue = Store->getValueOperand();
119       continue;
120     }
121     // Bail if there is any other unknown usage.
122     return nullptr;
123   }
124   return dyn_cast_or_null<Constant>(StoreValue);
125 }
126 
127 // A constant stack value is an AllocaInst that has a single constant
128 // value stored to it. Return this constant if such an alloca stack value
129 // is a function argument.
130 static Constant *getConstantStackValue(CallInst *Call, Value *Val,
131                                        SCCPSolver &Solver) {
132   if (!Val)
133     return nullptr;
134   Val = Val->stripPointerCasts();
135   if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
136     return ConstVal;
137   auto *Alloca = dyn_cast<AllocaInst>(Val);
138   if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
139     return nullptr;
140   return getPromotableAlloca(Alloca, Call);
141 }
142 
143 // To support specializing recursive functions, it is important to propagate
144 // constant arguments because after a first iteration of specialisation, a
145 // reduced example may look like this:
146 //
147 //     define internal void @RecursiveFn(i32* arg1) {
148 //       %temp = alloca i32, align 4
149 //       store i32 2 i32* %temp, align 4
150 //       call void @RecursiveFn.1(i32* nonnull %temp)
151 //       ret void
152 //     }
153 //
154 // Before a next iteration, we need to propagate the constant like so
155 // which allows further specialization in next iterations.
156 //
157 //     @funcspec.arg = internal constant i32 2
158 //
159 //     define internal void @someFunc(i32* arg1) {
160 //       call void @otherFunc(i32* nonnull @funcspec.arg)
161 //       ret void
162 //     }
163 //
164 static void constantArgPropagation(SmallVectorImpl<Function *> &WorkList,
165                                    Module &M, SCCPSolver &Solver) {
166   // Iterate over the argument tracked functions see if there
167   // are any new constant values for the call instruction via
168   // stack variables.
169   for (auto *F : WorkList) {
170     // TODO: Generalize for any read only arguments.
171     if (F->arg_size() != 1)
172       continue;
173 
174     auto &Arg = *F->arg_begin();
175     if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy())
176       continue;
177 
178     for (auto *User : F->users()) {
179       auto *Call = dyn_cast<CallInst>(User);
180       if (!Call)
181         break;
182       auto *ArgOp = Call->getArgOperand(0);
183       auto *ArgOpType = ArgOp->getType();
184       auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
185       if (!ConstVal)
186         break;
187 
188       Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
189                                      GlobalValue::InternalLinkage, ConstVal,
190                                      "funcspec.arg");
191 
192       if (ArgOpType != ConstVal->getType())
193         GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType());
194 
195       Call->setArgOperand(0, GV);
196 
197       // Add the changed CallInst to Solver Worklist
198       Solver.visitCall(*Call);
199     }
200   }
201 }
202 
203 // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
204 // interfere with the constantArgPropagation optimization.
205 static void removeSSACopy(Function &F) {
206   for (BasicBlock &BB : F) {
207     for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) {
208       Instruction *Inst = &*BI++;
209       auto *II = dyn_cast<IntrinsicInst>(Inst);
210       if (!II)
211         continue;
212       if (II->getIntrinsicID() != Intrinsic::ssa_copy)
213         continue;
214       Inst->replaceAllUsesWith(II->getOperand(0));
215       Inst->eraseFromParent();
216     }
217   }
218 }
219 
220 static void removeSSACopy(Module &M) {
221   for (Function &F : M)
222     removeSSACopy(F);
223 }
224 
225 class FunctionSpecializer {
226 
227   /// The IPSCCP Solver.
228   SCCPSolver &Solver;
229 
230   /// Analyses used to help determine if a function should be specialized.
231   std::function<AssumptionCache &(Function &)> GetAC;
232   std::function<TargetTransformInfo &(Function &)> GetTTI;
233   std::function<TargetLibraryInfo &(Function &)> GetTLI;
234 
235   SmallPtrSet<Function *, 2> SpecializedFuncs;
236 
237 public:
238   FunctionSpecializer(SCCPSolver &Solver,
239                       std::function<AssumptionCache &(Function &)> GetAC,
240                       std::function<TargetTransformInfo &(Function &)> GetTTI,
241                       std::function<TargetLibraryInfo &(Function &)> GetTLI)
242       : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {}
243 
244   /// Attempt to specialize functions in the module to enable constant
245   /// propagation across function boundaries.
246   ///
247   /// \returns true if at least one function is specialized.
248   bool
249   specializeFunctions(SmallVectorImpl<Function *> &FuncDecls,
250                       SmallVectorImpl<Function *> &CurrentSpecializations) {
251 
252     // Attempt to specialize the argument-tracked functions.
253     bool Changed = false;
254     for (auto *F : FuncDecls) {
255       if (specializeFunction(F, CurrentSpecializations)) {
256         Changed = true;
257         LLVM_DEBUG(dbgs() << "FnSpecialization: Can specialize this func.\n");
258       } else {
259         LLVM_DEBUG(
260             dbgs() << "FnSpecialization: Cannot specialize this func.\n");
261       }
262     }
263 
264     for (auto *SpecializedFunc : CurrentSpecializations) {
265       SpecializedFuncs.insert(SpecializedFunc);
266 
267       // Initialize the state of the newly created functions, marking them
268       // argument-tracked and executable.
269       if (SpecializedFunc->hasExactDefinition() &&
270           !SpecializedFunc->hasFnAttribute(Attribute::Naked))
271         Solver.addTrackedFunction(SpecializedFunc);
272       Solver.addArgumentTrackedFunction(SpecializedFunc);
273       FuncDecls.push_back(SpecializedFunc);
274       Solver.markBlockExecutable(&SpecializedFunc->front());
275 
276       // Replace the function arguments for the specialized functions.
277       for (Argument &Arg : SpecializedFunc->args())
278         if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg))
279           LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: "
280                             << Arg.getName() << "\n");
281     }
282 
283     NumFuncSpecialized += NbFunctionsSpecialized;
284     return Changed;
285   }
286 
287   bool tryToReplaceWithConstant(Value *V) {
288     if (!V->getType()->isSingleValueType() || isa<CallBase>(V) ||
289         V->user_empty())
290       return false;
291 
292     const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
293     if (isOverdefined(IV))
294       return false;
295     auto *Const =
296         isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
297     V->replaceAllUsesWith(Const);
298 
299     for (auto *U : Const->users())
300       if (auto *I = dyn_cast<Instruction>(U))
301         if (Solver.isBlockExecutable(I->getParent()))
302           Solver.visit(I);
303 
304     // Remove the instruction from Block and Solver.
305     if (auto *I = dyn_cast<Instruction>(V)) {
306       if (I->isSafeToRemove()) {
307         I->eraseFromParent();
308         Solver.removeLatticeValueFor(I);
309       }
310     }
311     return true;
312   }
313 
314 private:
315   // The number of functions specialised, used for collecting statistics and
316   // also in the cost model.
317   unsigned NbFunctionsSpecialized = 0;
318 
319   /// Clone the function \p F and remove the ssa_copy intrinsics added by
320   /// the SCCPSolver in the cloned version.
321   Function *cloneCandidateFunction(Function *F) {
322     ValueToValueMapTy EmptyMap;
323     Function *Clone = CloneFunction(F, EmptyMap);
324     removeSSACopy(*Clone);
325     return Clone;
326   }
327 
328   /// This function decides whether to specialize function \p F based on the
329   /// known constant values its arguments can take on. Specialization is
330   /// performed on the first interesting argument. Specializations based on
331   /// additional arguments will be evaluated on following iterations of the
332   /// main IPSCCP solve loop. \returns true if the function is specialized and
333   /// false otherwise.
334   bool specializeFunction(Function *F,
335                           SmallVectorImpl<Function *> &Specializations) {
336 
337     // Do not specialize the cloned function again.
338     if (SpecializedFuncs.contains(F)) {
339       return false;
340     }
341 
342     // If we're optimizing the function for size, we shouldn't specialize it.
343     if (F->hasOptSize() ||
344         shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass))
345       return false;
346 
347     // Exit if the function is not executable. There's no point in specializing
348     // a dead function.
349     if (!Solver.isBlockExecutable(&F->getEntryBlock()))
350       return false;
351 
352     // It wastes time to specialize a function which would get inlined finally.
353     if (F->hasFnAttribute(Attribute::AlwaysInline))
354       return false;
355 
356     LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName()
357                       << "\n");
358 
359     // Determine if it would be profitable to create a specialization of the
360     // function where the argument takes on the given constant value. If so,
361     // add the constant to Constants.
362     auto FnSpecCost = getSpecializationCost(F);
363     if (!FnSpecCost.isValid()) {
364       LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialisation cost.\n");
365       return false;
366     }
367 
368     LLVM_DEBUG(dbgs() << "FnSpecialization: func specialisation cost: ";
369                FnSpecCost.print(dbgs()); dbgs() << "\n");
370 
371     // Determine if we should specialize the function based on the values the
372     // argument can take on. If specialization is not profitable, we continue
373     // on to the next argument.
374     for (Argument &A : F->args()) {
375       LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing arg: " << A.getName()
376                         << "\n");
377       // True if this will be a partial specialization. We will need to keep
378       // the original function around in addition to the added specializations.
379       bool IsPartial = true;
380 
381       // Determine if this argument is interesting. If we know the argument can
382       // take on any constant values, they are collected in Constants. If the
383       // argument can only ever equal a constant value in Constants, the
384       // function will be completely specialized, and the IsPartial flag will
385       // be set to false by isArgumentInteresting (that function only adds
386       // values to the Constants list that are deemed profitable).
387       SmallVector<Constant *, 4> Constants;
388       if (!isArgumentInteresting(&A, Constants, FnSpecCost, IsPartial)) {
389         LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is not interesting\n");
390         continue;
391       }
392 
393       assert(!Constants.empty() && "No constants on which to specialize");
394       LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is interesting!\n"
395                         << "FnSpecialization: Specializing '" << F->getName()
396                         << "' on argument: " << A << "\n"
397                         << "FnSpecialization: Constants are:\n\n";
398                  for (unsigned I = 0; I < Constants.size(); ++I) dbgs()
399                  << *Constants[I] << "\n";
400                  dbgs() << "FnSpecialization: End of constants\n\n");
401 
402       // Create a version of the function in which the argument is marked
403       // constant with the given value.
404       for (auto *C : Constants) {
405         // Clone the function. We leave the ValueToValueMap empty to allow
406         // IPSCCP to propagate the constant arguments.
407         Function *Clone = cloneCandidateFunction(F);
408         Argument *ClonedArg = Clone->arg_begin() + A.getArgNo();
409 
410         // Rewrite calls to the function so that they call the clone instead.
411         rewriteCallSites(F, Clone, *ClonedArg, C);
412 
413         // Initialize the lattice state of the arguments of the function clone,
414         // marking the argument on which we specialized the function constant
415         // with the given value.
416         Solver.markArgInFuncSpecialization(F, ClonedArg, C);
417 
418         // Mark all the specialized functions
419         Specializations.push_back(Clone);
420         NbFunctionsSpecialized++;
421       }
422 
423       // If the function has been completely specialized, the original function
424       // is no longer needed. Mark it unreachable.
425       if (!IsPartial)
426         Solver.markFunctionUnreachable(F);
427 
428       // FIXME: Only one argument per function.
429       return true;
430     }
431 
432     return false;
433   }
434 
435   /// Compute the cost of specializing function \p F.
436   InstructionCost getSpecializationCost(Function *F) {
437     // Compute the code metrics for the function.
438     SmallPtrSet<const Value *, 32> EphValues;
439     CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
440     CodeMetrics Metrics;
441     for (BasicBlock &BB : *F)
442       Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
443 
444     // If the code metrics reveal that we shouldn't duplicate the function, we
445     // shouldn't specialize it. Set the specialization cost to Invalid.
446     // Or if the lines of codes implies that this function is easy to get
447     // inlined so that we shouldn't specialize it.
448     if (Metrics.notDuplicatable ||
449         (!ForceFunctionSpecialization &&
450          Metrics.NumInsts < SmallFunctionThreshold)) {
451       InstructionCost C{};
452       C.setInvalid();
453       return C;
454     }
455 
456     // Otherwise, set the specialization cost to be the cost of all the
457     // instructions in the function and penalty for specializing more functions.
458     unsigned Penalty = NbFunctionsSpecialized + 1;
459     return Metrics.NumInsts * InlineConstants::InstrCost * Penalty;
460   }
461 
462   InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
463                                LoopInfo &LI) {
464     auto *I = dyn_cast_or_null<Instruction>(U);
465     // If not an instruction we do not know how to evaluate.
466     // Keep minimum possible cost for now so that it doesnt affect
467     // specialization.
468     if (!I)
469       return std::numeric_limits<unsigned>::min();
470 
471     auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency);
472 
473     // Traverse recursively if there are more uses.
474     // TODO: Any other instructions to be added here?
475     if (I->mayReadFromMemory() || I->isCast())
476       for (auto *User : I->users())
477         Cost += getUserBonus(User, TTI, LI);
478 
479     // Increase the cost if it is inside the loop.
480     auto LoopDepth = LI.getLoopDepth(I->getParent());
481     Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth);
482     return Cost;
483   }
484 
485   /// Compute a bonus for replacing argument \p A with constant \p C.
486   InstructionCost getSpecializationBonus(Argument *A, Constant *C) {
487     Function *F = A->getParent();
488     DominatorTree DT(*F);
489     LoopInfo LI(DT);
490     auto &TTI = (GetTTI)(*F);
491     LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for: " << *A
492                       << "\n");
493 
494     InstructionCost TotalCost = 0;
495     for (auto *U : A->users()) {
496       TotalCost += getUserBonus(U, TTI, LI);
497       LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
498                  TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
499     }
500 
501     // The below heuristic is only concerned with exposing inlining
502     // opportunities via indirect call promotion. If the argument is not a
503     // function pointer, give up.
504     if (!isa<PointerType>(A->getType()) ||
505         !isa<FunctionType>(A->getType()->getPointerElementType()))
506       return TotalCost;
507 
508     // Since the argument is a function pointer, its incoming constant values
509     // should be functions or constant expressions. The code below attempts to
510     // look through cast expressions to find the function that will be called.
511     Value *CalledValue = C;
512     while (isa<ConstantExpr>(CalledValue) &&
513            cast<ConstantExpr>(CalledValue)->isCast())
514       CalledValue = cast<User>(CalledValue)->getOperand(0);
515     Function *CalledFunction = dyn_cast<Function>(CalledValue);
516     if (!CalledFunction)
517       return TotalCost;
518 
519     // Get TTI for the called function (used for the inline cost).
520     auto &CalleeTTI = (GetTTI)(*CalledFunction);
521 
522     // Look at all the call sites whose called value is the argument.
523     // Specializing the function on the argument would allow these indirect
524     // calls to be promoted to direct calls. If the indirect call promotion
525     // would likely enable the called function to be inlined, specializing is a
526     // good idea.
527     int Bonus = 0;
528     for (User *U : A->users()) {
529       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
530         continue;
531       auto *CS = cast<CallBase>(U);
532       if (CS->getCalledOperand() != A)
533         continue;
534 
535       // Get the cost of inlining the called function at this call site. Note
536       // that this is only an estimate. The called function may eventually
537       // change in a way that leads to it not being inlined here, even though
538       // inlining looks profitable now. For example, one of its called
539       // functions may be inlined into it, making the called function too large
540       // to be inlined into this call site.
541       //
542       // We apply a boost for performing indirect call promotion by increasing
543       // the default threshold by the threshold for indirect calls.
544       auto Params = getInlineParams();
545       Params.DefaultThreshold += InlineConstants::IndirectCallThreshold;
546       InlineCost IC =
547           getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI);
548 
549       // We clamp the bonus for this call to be between zero and the default
550       // threshold.
551       if (IC.isAlways())
552         Bonus += Params.DefaultThreshold;
553       else if (IC.isVariable() && IC.getCostDelta() > 0)
554         Bonus += IC.getCostDelta();
555     }
556 
557     return TotalCost + Bonus;
558   }
559 
560   /// Determine if we should specialize a function based on the incoming values
561   /// of the given argument.
562   ///
563   /// This function implements the goal-directed heuristic. It determines if
564   /// specializing the function based on the incoming values of argument \p A
565   /// would result in any significant optimization opportunities. If
566   /// optimization opportunities exist, the constant values of \p A on which to
567   /// specialize the function are collected in \p Constants. If the values in
568   /// \p Constants represent the complete set of values that \p A can take on,
569   /// the function will be completely specialized, and the \p IsPartial flag is
570   /// set to false.
571   ///
572   /// \returns true if the function should be specialized on the given
573   /// argument.
574   bool isArgumentInteresting(Argument *A,
575                              SmallVectorImpl<Constant *> &Constants,
576                              const InstructionCost &FnSpecCost,
577                              bool &IsPartial) {
578     // For now, don't attempt to specialize functions based on the values of
579     // composite types.
580     if (!A->getType()->isSingleValueType() || A->user_empty())
581       return false;
582 
583     // If the argument isn't overdefined, there's nothing to do. It should
584     // already be constant.
585     if (!Solver.getLatticeValueFor(A).isOverdefined()) {
586       LLVM_DEBUG(dbgs() << "FnSpecialization: nothing to do, arg is already "
587                         << "constant?\n");
588       return false;
589     }
590 
591     // Collect the constant values that the argument can take on. If the
592     // argument can't take on any constant values, we aren't going to
593     // specialize the function. While it's possible to specialize the function
594     // based on non-constant arguments, there's likely not much benefit to
595     // constant propagation in doing so.
596     //
597     // TODO 1: currently it won't specialize if there are over the threshold of
598     // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it
599     // might be beneficial to take the occurrences into account in the cost
600     // model, so we would need to find the unique constants.
601     //
602     // TODO 2: this currently does not support constants, i.e. integer ranges.
603     //
604     SmallVector<Constant *, 4> PossibleConstants;
605     bool AllConstant = getPossibleConstants(A, PossibleConstants);
606     if (PossibleConstants.empty()) {
607       LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n");
608       return false;
609     }
610     if (PossibleConstants.size() > MaxConstantsThreshold) {
611       LLVM_DEBUG(dbgs() << "FnSpecialization: number of constants found exceed "
612                         << "the maximum number of constants threshold.\n");
613       return false;
614     }
615 
616     for (auto *C : PossibleConstants) {
617       LLVM_DEBUG(dbgs() << "FnSpecialization: Constant: " << *C << "\n");
618       if (ForceFunctionSpecialization) {
619         LLVM_DEBUG(dbgs() << "FnSpecialization: Forced!\n");
620         Constants.push_back(C);
621         continue;
622       }
623       if (getSpecializationBonus(A, C) > FnSpecCost) {
624         LLVM_DEBUG(dbgs() << "FnSpecialization: profitable!\n");
625         Constants.push_back(C);
626       } else {
627         LLVM_DEBUG(dbgs() << "FnSpecialization: not profitable\n");
628       }
629     }
630 
631     // None of the constant values the argument can take on were deemed good
632     // candidates on which to specialize the function.
633     if (Constants.empty())
634       return false;
635 
636     // This will be a partial specialization if some of the constants were
637     // rejected due to their profitability.
638     IsPartial = !AllConstant || PossibleConstants.size() != Constants.size();
639 
640     return true;
641   }
642 
643   /// Collect in \p Constants all the constant values that argument \p A can
644   /// take on.
645   ///
646   /// \returns true if all of the values the argument can take on are constant
647   /// (e.g., the argument's parent function cannot be called with an
648   /// overdefined value).
649   bool getPossibleConstants(Argument *A,
650                             SmallVectorImpl<Constant *> &Constants) {
651     Function *F = A->getParent();
652     bool AllConstant = true;
653 
654     // Iterate over all the call sites of the argument's parent function.
655     for (User *U : F->users()) {
656       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
657         continue;
658       auto &CS = *cast<CallBase>(U);
659 
660       // If the parent of the call site will never be executed, we don't need
661       // to worry about the passed value.
662       if (!Solver.isBlockExecutable(CS.getParent()))
663         continue;
664 
665       auto *V = CS.getArgOperand(A->getArgNo());
666       // TrackValueOfGlobalVariable only tracks scalar global variables.
667       if (auto *GV = dyn_cast<GlobalVariable>(V)) {
668         if (!GV->getValueType()->isSingleValueType()) {
669           return false;
670         }
671       }
672 
673       if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
674                                EnableSpecializationForLiteralConstant))
675         Constants.push_back(cast<Constant>(V));
676       else
677         AllConstant = false;
678     }
679 
680     // If the argument can only take on constant values, AllConstant will be
681     // true.
682     return AllConstant;
683   }
684 
685   /// Rewrite calls to function \p F to call function \p Clone instead.
686   ///
687   /// This function modifies calls to function \p F whose argument at index \p
688   /// ArgNo is equal to constant \p C. The calls are rewritten to call function
689   /// \p Clone instead.
690   void rewriteCallSites(Function *F, Function *Clone, Argument &Arg,
691                         Constant *C) {
692     unsigned ArgNo = Arg.getArgNo();
693     SmallVector<CallBase *, 4> CallSitesToRewrite;
694     for (auto *U : F->users()) {
695       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
696         continue;
697       auto &CS = *cast<CallBase>(U);
698       if (!CS.getCalledFunction() || CS.getCalledFunction() != F)
699         continue;
700       CallSitesToRewrite.push_back(&CS);
701     }
702     for (auto *CS : CallSitesToRewrite) {
703       if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) ||
704           CS->getArgOperand(ArgNo) == C) {
705         CS->setCalledFunction(Clone);
706         Solver.markOverdefined(CS);
707       }
708     }
709   }
710 };
711 
712 bool llvm::runFunctionSpecialization(
713     Module &M, const DataLayout &DL,
714     std::function<TargetLibraryInfo &(Function &)> GetTLI,
715     std::function<TargetTransformInfo &(Function &)> GetTTI,
716     std::function<AssumptionCache &(Function &)> GetAC,
717     function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) {
718   SCCPSolver Solver(DL, GetTLI, M.getContext());
719   FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI);
720   bool Changed = false;
721 
722   // Loop over all functions, marking arguments to those with their addresses
723   // taken or that are external as overdefined.
724   for (Function &F : M) {
725     if (F.isDeclaration())
726       continue;
727     if (F.hasFnAttribute(Attribute::NoDuplicate))
728       continue;
729 
730     LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName()
731                       << "\n");
732     Solver.addAnalysis(F, GetAnalysis(F));
733 
734     // Determine if we can track the function's arguments. If so, add the
735     // function to the solver's set of argument-tracked functions.
736     if (canTrackArgumentsInterprocedurally(&F)) {
737       LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n");
738       Solver.addArgumentTrackedFunction(&F);
739       continue;
740     } else {
741       LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n"
742                         << "FnSpecialization: Doesn't have local linkage, or "
743                         << "has its address taken\n");
744     }
745 
746     // Assume the function is called.
747     Solver.markBlockExecutable(&F.front());
748 
749     // Assume nothing about the incoming arguments.
750     for (Argument &AI : F.args())
751       Solver.markOverdefined(&AI);
752   }
753 
754   // Determine if we can track any of the module's global variables. If so, add
755   // the global variables we can track to the solver's set of tracked global
756   // variables.
757   for (GlobalVariable &G : M.globals()) {
758     G.removeDeadConstantUsers();
759     if (canTrackGlobalVariableInterprocedurally(&G))
760       Solver.trackValueOfGlobalVariable(&G);
761   }
762 
763   // Solve for constants.
764   auto RunSCCPSolver = [&](auto &WorkList) {
765     bool ResolvedUndefs = true;
766 
767     while (ResolvedUndefs) {
768       LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n");
769       Solver.solve();
770       LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n");
771       ResolvedUndefs = false;
772       for (Function *F : WorkList)
773         if (Solver.resolvedUndefsIn(*F))
774           ResolvedUndefs = true;
775     }
776 
777     for (auto *F : WorkList) {
778       for (BasicBlock &BB : *F) {
779         if (!Solver.isBlockExecutable(&BB))
780           continue;
781         for (auto &I : make_early_inc_range(BB))
782           // FIXME: The solver may make changes to the function here, so set Changed, even if later
783           // function specialization does not trigger.
784           Changed |= FS.tryToReplaceWithConstant(&I);
785       }
786     }
787   };
788 
789   auto &TrackedFuncs = Solver.getArgumentTrackedFunctions();
790   SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(),
791                                         TrackedFuncs.end());
792 #ifndef NDEBUG
793   LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n");
794   for (auto *F : FuncDecls)
795     LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n");
796 #endif
797 
798   // Initially resolve the constants in all the argument tracked functions.
799   RunSCCPSolver(FuncDecls);
800 
801   SmallVector<Function *, 2> CurrentSpecializations;
802   unsigned I = 0;
803   while (FuncSpecializationMaxIters != I++ &&
804          FS.specializeFunctions(FuncDecls, CurrentSpecializations)) {
805 
806     // Run the solver for the specialized functions.
807     RunSCCPSolver(CurrentSpecializations);
808 
809     // Replace some unresolved constant arguments
810     constantArgPropagation(FuncDecls, M, Solver);
811 
812     CurrentSpecializations.clear();
813     Changed = true;
814   }
815 
816   // Clean up the IR by removing ssa_copy intrinsics.
817   removeSSACopy(M);
818   return Changed;
819 }
820