1 //===- FunctionSpecialization.cpp - Function Specialization ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This specialises functions with constant parameters. Constant parameters
10 // like function pointers and constant globals are propagated to the callee by
11 // specializing the function. The main benefit of this pass at the moment is
12 // that indirect calls are transformed into direct calls, which provides inline
13 // opportunities that the inliner would not have been able to achieve. That's
14 // why function specialisation is run before the inliner in the optimisation
15 // pipeline; that is by design. Otherwise, we would only benefit from constant
16 // passing, which is a valid use-case too, but hasn't been explored much in
17 // terms of performance uplifts, cost-model and compile-time impact.
18 //
19 // Current limitations:
20 // - It does not yet handle integer ranges. We do support "literal constants",
21 // but that's off by default under an option.
22 // - The cost-model could be further looked into (it mainly focuses on inlining
23 // benefits),
24 //
25 // Ideas:
26 // - With a function specialization attribute for arguments, we could have
27 // a direct way to steer function specialization, avoiding the cost-model,
28 // and thus control compile-times / code-size.
29 //
30 // Todos:
31 // - Specializing recursive functions relies on running the transformation a
32 // number of times, which is controlled by option
33 // `func-specialization-max-iters`. Thus, increasing this value and the
34 // number of iterations, will linearly increase the number of times recursive
35 // functions get specialized, see also the discussion in
36 // https://reviews.llvm.org/D106426 for details. Perhaps there is a
37 // compile-time friendlier way to control/limit the number of specialisations
38 // for recursive functions.
39 // - Don't transform the function if function specialization does not trigger;
40 // the SCCPSolver may make IR changes.
41 //
42 // References:
43 // - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable
44 // it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q
45 //
46 //===----------------------------------------------------------------------===//
47
48 #include "llvm/ADT/Statistic.h"
49 #include "llvm/Analysis/CodeMetrics.h"
50 #include "llvm/Analysis/InlineCost.h"
51 #include "llvm/Analysis/LoopInfo.h"
52 #include "llvm/Analysis/TargetTransformInfo.h"
53 #include "llvm/Analysis/ValueLattice.h"
54 #include "llvm/Analysis/ValueLatticeUtils.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/Transforms/Scalar/SCCP.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/SCCPSolver.h"
59 #include "llvm/Transforms/Utils/SizeOpts.h"
60 #include <cmath>
61
62 using namespace llvm;
63
64 #define DEBUG_TYPE "function-specialization"
65
66 STATISTIC(NumFuncSpecialized, "Number of functions specialized");
67
68 static cl::opt<bool> ForceFunctionSpecialization(
69 "force-function-specialization", cl::init(false), cl::Hidden,
70 cl::desc("Force function specialization for every call site with a "
71 "constant argument"));
72
73 static cl::opt<unsigned> FuncSpecializationMaxIters(
74 "func-specialization-max-iters", cl::Hidden,
75 cl::desc("The maximum number of iterations function specialization is run"),
76 cl::init(1));
77
78 static cl::opt<unsigned> MaxClonesThreshold(
79 "func-specialization-max-clones", cl::Hidden,
80 cl::desc("The maximum number of clones allowed for a single function "
81 "specialization"),
82 cl::init(3));
83
84 static cl::opt<unsigned> SmallFunctionThreshold(
85 "func-specialization-size-threshold", cl::Hidden,
86 cl::desc("Don't specialize functions that have less than this theshold "
87 "number of instructions"),
88 cl::init(100));
89
90 static cl::opt<unsigned>
91 AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden,
92 cl::desc("Average loop iteration count cost"),
93 cl::init(10));
94
95 static cl::opt<bool> SpecializeOnAddresses(
96 "func-specialization-on-address", cl::init(false), cl::Hidden,
97 cl::desc("Enable function specialization on the address of global values"));
98
99 // Disabled by default as it can significantly increase compilation times.
100 // Running nikic's compile time tracker on x86 with instruction count as the
101 // metric shows 3-4% regression for SPASS while being neutral for all other
102 // benchmarks of the llvm test suite.
103 //
104 // https://llvm-compile-time-tracker.com
105 // https://github.com/nikic/llvm-compile-time-tracker
106 static cl::opt<bool> EnableSpecializationForLiteralConstant(
107 "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
108 cl::desc("Enable specialization of functions that take a literal constant "
109 "as an argument."));
110
111 namespace {
112 // Bookkeeping struct to pass data from the analysis and profitability phase
113 // to the actual transform helper functions.
114 struct SpecializationInfo {
115 SmallVector<ArgInfo, 8> Args; // Stores the {formal,actual} argument pairs.
116 InstructionCost Gain; // Profitability: Gain = Bonus - Cost.
117 };
118 } // Anonymous namespace
119
120 using FuncList = SmallVectorImpl<Function *>;
121 using CallArgBinding = std::pair<CallBase *, Constant *>;
122 using CallSpecBinding = std::pair<CallBase *, SpecializationInfo>;
123 // We are using MapVector because it guarantees deterministic iteration
124 // order across executions.
125 using SpecializationMap = SmallMapVector<CallBase *, SpecializationInfo, 8>;
126
127 // Helper to check if \p LV is either a constant or a constant
128 // range with a single element. This should cover exactly the same cases as the
129 // old ValueLatticeElement::isConstant() and is intended to be used in the
130 // transition to ValueLatticeElement.
isConstant(const ValueLatticeElement & LV)131 static bool isConstant(const ValueLatticeElement &LV) {
132 return LV.isConstant() ||
133 (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
134 }
135
136 // Helper to check if \p LV is either overdefined or a constant int.
isOverdefined(const ValueLatticeElement & LV)137 static bool isOverdefined(const ValueLatticeElement &LV) {
138 return !LV.isUnknownOrUndef() && !isConstant(LV);
139 }
140
getPromotableAlloca(AllocaInst * Alloca,CallInst * Call)141 static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
142 Value *StoreValue = nullptr;
143 for (auto *User : Alloca->users()) {
144 // We can't use llvm::isAllocaPromotable() as that would fail because of
145 // the usage in the CallInst, which is what we check here.
146 if (User == Call)
147 continue;
148 if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
149 if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
150 return nullptr;
151 continue;
152 }
153
154 if (auto *Store = dyn_cast<StoreInst>(User)) {
155 // This is a duplicate store, bail out.
156 if (StoreValue || Store->isVolatile())
157 return nullptr;
158 StoreValue = Store->getValueOperand();
159 continue;
160 }
161 // Bail if there is any other unknown usage.
162 return nullptr;
163 }
164 return dyn_cast_or_null<Constant>(StoreValue);
165 }
166
167 // A constant stack value is an AllocaInst that has a single constant
168 // value stored to it. Return this constant if such an alloca stack value
169 // is a function argument.
getConstantStackValue(CallInst * Call,Value * Val,SCCPSolver & Solver)170 static Constant *getConstantStackValue(CallInst *Call, Value *Val,
171 SCCPSolver &Solver) {
172 if (!Val)
173 return nullptr;
174 Val = Val->stripPointerCasts();
175 if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
176 return ConstVal;
177 auto *Alloca = dyn_cast<AllocaInst>(Val);
178 if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
179 return nullptr;
180 return getPromotableAlloca(Alloca, Call);
181 }
182
183 // To support specializing recursive functions, it is important to propagate
184 // constant arguments because after a first iteration of specialisation, a
185 // reduced example may look like this:
186 //
187 // define internal void @RecursiveFn(i32* arg1) {
188 // %temp = alloca i32, align 4
189 // store i32 2 i32* %temp, align 4
190 // call void @RecursiveFn.1(i32* nonnull %temp)
191 // ret void
192 // }
193 //
194 // Before a next iteration, we need to propagate the constant like so
195 // which allows further specialization in next iterations.
196 //
197 // @funcspec.arg = internal constant i32 2
198 //
199 // define internal void @someFunc(i32* arg1) {
200 // call void @otherFunc(i32* nonnull @funcspec.arg)
201 // ret void
202 // }
203 //
constantArgPropagation(FuncList & WorkList,Module & M,SCCPSolver & Solver)204 static void constantArgPropagation(FuncList &WorkList, Module &M,
205 SCCPSolver &Solver) {
206 // Iterate over the argument tracked functions see if there
207 // are any new constant values for the call instruction via
208 // stack variables.
209 for (auto *F : WorkList) {
210
211 for (auto *User : F->users()) {
212
213 auto *Call = dyn_cast<CallInst>(User);
214 if (!Call)
215 continue;
216
217 bool Changed = false;
218 for (const Use &U : Call->args()) {
219 unsigned Idx = Call->getArgOperandNo(&U);
220 Value *ArgOp = Call->getArgOperand(Idx);
221 Type *ArgOpType = ArgOp->getType();
222
223 if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy())
224 continue;
225
226 auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
227 if (!ConstVal)
228 continue;
229
230 Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
231 GlobalValue::InternalLinkage, ConstVal,
232 "funcspec.arg");
233 if (ArgOpType != ConstVal->getType())
234 GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType);
235
236 Call->setArgOperand(Idx, GV);
237 Changed = true;
238 }
239
240 // Add the changed CallInst to Solver Worklist
241 if (Changed)
242 Solver.visitCall(*Call);
243 }
244 }
245 }
246
247 // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
248 // interfere with the constantArgPropagation optimization.
removeSSACopy(Function & F)249 static void removeSSACopy(Function &F) {
250 for (BasicBlock &BB : F) {
251 for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
252 auto *II = dyn_cast<IntrinsicInst>(&Inst);
253 if (!II)
254 continue;
255 if (II->getIntrinsicID() != Intrinsic::ssa_copy)
256 continue;
257 Inst.replaceAllUsesWith(II->getOperand(0));
258 Inst.eraseFromParent();
259 }
260 }
261 }
262
removeSSACopy(Module & M)263 static void removeSSACopy(Module &M) {
264 for (Function &F : M)
265 removeSSACopy(F);
266 }
267
268 namespace {
269 class FunctionSpecializer {
270
271 /// The IPSCCP Solver.
272 SCCPSolver &Solver;
273
274 /// Analyses used to help determine if a function should be specialized.
275 std::function<AssumptionCache &(Function &)> GetAC;
276 std::function<TargetTransformInfo &(Function &)> GetTTI;
277 std::function<TargetLibraryInfo &(Function &)> GetTLI;
278
279 SmallPtrSet<Function *, 4> SpecializedFuncs;
280 SmallPtrSet<Function *, 4> FullySpecialized;
281 SmallVector<Instruction *> ReplacedWithConstant;
282 DenseMap<Function *, CodeMetrics> FunctionMetrics;
283
284 public:
FunctionSpecializer(SCCPSolver & Solver,std::function<AssumptionCache & (Function &)> GetAC,std::function<TargetTransformInfo & (Function &)> GetTTI,std::function<TargetLibraryInfo & (Function &)> GetTLI)285 FunctionSpecializer(SCCPSolver &Solver,
286 std::function<AssumptionCache &(Function &)> GetAC,
287 std::function<TargetTransformInfo &(Function &)> GetTTI,
288 std::function<TargetLibraryInfo &(Function &)> GetTLI)
289 : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {}
290
~FunctionSpecializer()291 ~FunctionSpecializer() {
292 // Eliminate dead code.
293 removeDeadInstructions();
294 removeDeadFunctions();
295 }
296
297 /// Attempt to specialize functions in the module to enable constant
298 /// propagation across function boundaries.
299 ///
300 /// \returns true if at least one function is specialized.
specializeFunctions(FuncList & Candidates,FuncList & WorkList)301 bool specializeFunctions(FuncList &Candidates, FuncList &WorkList) {
302 bool Changed = false;
303 for (auto *F : Candidates) {
304 if (!isCandidateFunction(F))
305 continue;
306
307 auto Cost = getSpecializationCost(F);
308 if (!Cost.isValid()) {
309 LLVM_DEBUG(
310 dbgs() << "FnSpecialization: Invalid specialization cost.\n");
311 continue;
312 }
313
314 LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
315 << F->getName() << " is " << Cost << "\n");
316
317 SmallVector<CallSpecBinding, 8> Specializations;
318 if (!calculateGains(F, Cost, Specializations)) {
319 LLVM_DEBUG(dbgs() << "FnSpecialization: No possible constants found\n");
320 continue;
321 }
322
323 Changed = true;
324 for (auto &Entry : Specializations)
325 specializeFunction(F, Entry.second, WorkList);
326 }
327
328 updateSpecializedFuncs(Candidates, WorkList);
329 NumFuncSpecialized += NbFunctionsSpecialized;
330 return Changed;
331 }
332
removeDeadInstructions()333 void removeDeadInstructions() {
334 for (auto *I : ReplacedWithConstant) {
335 LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction " << *I
336 << "\n");
337 I->eraseFromParent();
338 }
339 ReplacedWithConstant.clear();
340 }
341
removeDeadFunctions()342 void removeDeadFunctions() {
343 for (auto *F : FullySpecialized) {
344 LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function "
345 << F->getName() << "\n");
346 F->eraseFromParent();
347 }
348 FullySpecialized.clear();
349 }
350
tryToReplaceWithConstant(Value * V)351 bool tryToReplaceWithConstant(Value *V) {
352 if (!V->getType()->isSingleValueType() || isa<CallBase>(V) ||
353 V->user_empty())
354 return false;
355
356 const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
357 if (isOverdefined(IV))
358 return false;
359 auto *Const =
360 isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
361
362 LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V
363 << "\nFnSpecialization: with " << *Const << "\n");
364
365 // Record uses of V to avoid visiting irrelevant uses of const later.
366 SmallVector<Instruction *> UseInsts;
367 for (auto *U : V->users())
368 if (auto *I = dyn_cast<Instruction>(U))
369 if (Solver.isBlockExecutable(I->getParent()))
370 UseInsts.push_back(I);
371
372 V->replaceAllUsesWith(Const);
373
374 for (auto *I : UseInsts)
375 Solver.visit(I);
376
377 // Remove the instruction from Block and Solver.
378 if (auto *I = dyn_cast<Instruction>(V)) {
379 if (I->isSafeToRemove()) {
380 ReplacedWithConstant.push_back(I);
381 Solver.removeLatticeValueFor(I);
382 }
383 }
384 return true;
385 }
386
387 private:
388 // The number of functions specialised, used for collecting statistics and
389 // also in the cost model.
390 unsigned NbFunctionsSpecialized = 0;
391
392 // Compute the code metrics for function \p F.
analyzeFunction(Function * F)393 CodeMetrics &analyzeFunction(Function *F) {
394 auto I = FunctionMetrics.insert({F, CodeMetrics()});
395 CodeMetrics &Metrics = I.first->second;
396 if (I.second) {
397 // The code metrics were not cached.
398 SmallPtrSet<const Value *, 32> EphValues;
399 CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
400 for (BasicBlock &BB : *F)
401 Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
402
403 LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
404 << F->getName() << " is " << Metrics.NumInsts
405 << " instructions\n");
406 }
407 return Metrics;
408 }
409
410 /// Clone the function \p F and remove the ssa_copy intrinsics added by
411 /// the SCCPSolver in the cloned version.
cloneCandidateFunction(Function * F,ValueToValueMapTy & Mappings)412 Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) {
413 Function *Clone = CloneFunction(F, Mappings);
414 removeSSACopy(*Clone);
415 return Clone;
416 }
417
418 /// This function decides whether it's worthwhile to specialize function
419 /// \p F based on the known constant values its arguments can take on. It
420 /// only discovers potential specialization opportunities without actually
421 /// applying them.
422 ///
423 /// \returns true if any specializations have been found.
calculateGains(Function * F,InstructionCost Cost,SmallVectorImpl<CallSpecBinding> & WorkList)424 bool calculateGains(Function *F, InstructionCost Cost,
425 SmallVectorImpl<CallSpecBinding> &WorkList) {
426 SpecializationMap Specializations;
427 // Determine if we should specialize the function based on the values the
428 // argument can take on. If specialization is not profitable, we continue
429 // on to the next argument.
430 for (Argument &FormalArg : F->args()) {
431 // Determine if this argument is interesting. If we know the argument can
432 // take on any constant values, they are collected in Constants.
433 SmallVector<CallArgBinding, 8> ActualArgs;
434 if (!isArgumentInteresting(&FormalArg, ActualArgs)) {
435 LLVM_DEBUG(dbgs() << "FnSpecialization: Argument "
436 << FormalArg.getNameOrAsOperand()
437 << " is not interesting\n");
438 continue;
439 }
440
441 for (const auto &Entry : ActualArgs) {
442 CallBase *Call = Entry.first;
443 Constant *ActualArg = Entry.second;
444
445 auto I = Specializations.insert({Call, SpecializationInfo()});
446 SpecializationInfo &S = I.first->second;
447
448 if (I.second)
449 S.Gain = ForceFunctionSpecialization ? 1 : 0 - Cost;
450 if (!ForceFunctionSpecialization)
451 S.Gain += getSpecializationBonus(&FormalArg, ActualArg);
452 S.Args.push_back({&FormalArg, ActualArg});
453 }
454 }
455
456 // Remove unprofitable specializations.
457 Specializations.remove_if(
458 [](const auto &Entry) { return Entry.second.Gain <= 0; });
459
460 // Clear the MapVector and return the underlying vector.
461 WorkList = Specializations.takeVector();
462
463 // Sort the candidates in descending order.
464 llvm::stable_sort(WorkList, [](const auto &L, const auto &R) {
465 return L.second.Gain > R.second.Gain;
466 });
467
468 // Truncate the worklist to 'MaxClonesThreshold' candidates if necessary.
469 if (WorkList.size() > MaxClonesThreshold) {
470 LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
471 << "the maximum number of clones threshold.\n"
472 << "FnSpecialization: Truncating worklist to "
473 << MaxClonesThreshold << " candidates.\n");
474 WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
475 }
476
477 LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
478 << F->getName() << "\n";
479 for (const auto &Entry
480 : WorkList) {
481 dbgs() << "FnSpecialization: Gain = " << Entry.second.Gain
482 << "\n";
483 for (const ArgInfo &Arg : Entry.second.Args)
484 dbgs() << "FnSpecialization: FormalArg = "
485 << Arg.Formal->getNameOrAsOperand()
486 << ", ActualArg = "
487 << Arg.Actual->getNameOrAsOperand() << "\n";
488 });
489
490 return !WorkList.empty();
491 }
492
isCandidateFunction(Function * F)493 bool isCandidateFunction(Function *F) {
494 // Do not specialize the cloned function again.
495 if (SpecializedFuncs.contains(F))
496 return false;
497
498 // If we're optimizing the function for size, we shouldn't specialize it.
499 if (F->hasOptSize() ||
500 shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass))
501 return false;
502
503 // Exit if the function is not executable. There's no point in specializing
504 // a dead function.
505 if (!Solver.isBlockExecutable(&F->getEntryBlock()))
506 return false;
507
508 // It wastes time to specialize a function which would get inlined finally.
509 if (F->hasFnAttribute(Attribute::AlwaysInline))
510 return false;
511
512 LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName()
513 << "\n");
514 return true;
515 }
516
specializeFunction(Function * F,SpecializationInfo & S,FuncList & WorkList)517 void specializeFunction(Function *F, SpecializationInfo &S,
518 FuncList &WorkList) {
519 ValueToValueMapTy Mappings;
520 Function *Clone = cloneCandidateFunction(F, Mappings);
521
522 // Rewrite calls to the function so that they call the clone instead.
523 rewriteCallSites(Clone, S.Args, Mappings);
524
525 // Initialize the lattice state of the arguments of the function clone,
526 // marking the argument on which we specialized the function constant
527 // with the given value.
528 Solver.markArgInFuncSpecialization(Clone, S.Args);
529
530 // Mark all the specialized functions
531 WorkList.push_back(Clone);
532 NbFunctionsSpecialized++;
533
534 // If the function has been completely specialized, the original function
535 // is no longer needed. Mark it unreachable.
536 if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) {
537 if (auto *CS = dyn_cast<CallBase>(U))
538 return CS->getFunction() == F;
539 return false;
540 })) {
541 Solver.markFunctionUnreachable(F);
542 FullySpecialized.insert(F);
543 }
544 }
545
546 /// Compute and return the cost of specializing function \p F.
getSpecializationCost(Function * F)547 InstructionCost getSpecializationCost(Function *F) {
548 CodeMetrics &Metrics = analyzeFunction(F);
549 // If the code metrics reveal that we shouldn't duplicate the function, we
550 // shouldn't specialize it. Set the specialization cost to Invalid.
551 // Or if the lines of codes implies that this function is easy to get
552 // inlined so that we shouldn't specialize it.
553 if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() ||
554 (!ForceFunctionSpecialization &&
555 *Metrics.NumInsts.getValue() < SmallFunctionThreshold)) {
556 InstructionCost C{};
557 C.setInvalid();
558 return C;
559 }
560
561 // Otherwise, set the specialization cost to be the cost of all the
562 // instructions in the function and penalty for specializing more functions.
563 unsigned Penalty = NbFunctionsSpecialized + 1;
564 return Metrics.NumInsts * InlineConstants::InstrCost * Penalty;
565 }
566
getUserBonus(User * U,llvm::TargetTransformInfo & TTI,LoopInfo & LI)567 InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
568 LoopInfo &LI) {
569 auto *I = dyn_cast_or_null<Instruction>(U);
570 // If not an instruction we do not know how to evaluate.
571 // Keep minimum possible cost for now so that it doesnt affect
572 // specialization.
573 if (!I)
574 return std::numeric_limits<unsigned>::min();
575
576 auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency);
577
578 // Traverse recursively if there are more uses.
579 // TODO: Any other instructions to be added here?
580 if (I->mayReadFromMemory() || I->isCast())
581 for (auto *User : I->users())
582 Cost += getUserBonus(User, TTI, LI);
583
584 // Increase the cost if it is inside the loop.
585 auto LoopDepth = LI.getLoopDepth(I->getParent());
586 Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth);
587 return Cost;
588 }
589
590 /// Compute a bonus for replacing argument \p A with constant \p C.
getSpecializationBonus(Argument * A,Constant * C)591 InstructionCost getSpecializationBonus(Argument *A, Constant *C) {
592 Function *F = A->getParent();
593 DominatorTree DT(*F);
594 LoopInfo LI(DT);
595 auto &TTI = (GetTTI)(*F);
596 LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
597 << C->getNameOrAsOperand() << "\n");
598
599 InstructionCost TotalCost = 0;
600 for (auto *U : A->users()) {
601 TotalCost += getUserBonus(U, TTI, LI);
602 LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
603 TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
604 }
605
606 // The below heuristic is only concerned with exposing inlining
607 // opportunities via indirect call promotion. If the argument is not a
608 // (potentially casted) function pointer, give up.
609 Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts());
610 if (!CalledFunction)
611 return TotalCost;
612
613 // Get TTI for the called function (used for the inline cost).
614 auto &CalleeTTI = (GetTTI)(*CalledFunction);
615
616 // Look at all the call sites whose called value is the argument.
617 // Specializing the function on the argument would allow these indirect
618 // calls to be promoted to direct calls. If the indirect call promotion
619 // would likely enable the called function to be inlined, specializing is a
620 // good idea.
621 int Bonus = 0;
622 for (User *U : A->users()) {
623 if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
624 continue;
625 auto *CS = cast<CallBase>(U);
626 if (CS->getCalledOperand() != A)
627 continue;
628
629 // Get the cost of inlining the called function at this call site. Note
630 // that this is only an estimate. The called function may eventually
631 // change in a way that leads to it not being inlined here, even though
632 // inlining looks profitable now. For example, one of its called
633 // functions may be inlined into it, making the called function too large
634 // to be inlined into this call site.
635 //
636 // We apply a boost for performing indirect call promotion by increasing
637 // the default threshold by the threshold for indirect calls.
638 auto Params = getInlineParams();
639 Params.DefaultThreshold += InlineConstants::IndirectCallThreshold;
640 InlineCost IC =
641 getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI);
642
643 // We clamp the bonus for this call to be between zero and the default
644 // threshold.
645 if (IC.isAlways())
646 Bonus += Params.DefaultThreshold;
647 else if (IC.isVariable() && IC.getCostDelta() > 0)
648 Bonus += IC.getCostDelta();
649
650 LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus
651 << " for user " << *U << "\n");
652 }
653
654 return TotalCost + Bonus;
655 }
656
657 /// Determine if we should specialize a function based on the incoming values
658 /// of the given argument.
659 ///
660 /// This function implements the goal-directed heuristic. It determines if
661 /// specializing the function based on the incoming values of argument \p A
662 /// would result in any significant optimization opportunities. If
663 /// optimization opportunities exist, the constant values of \p A on which to
664 /// specialize the function are collected in \p Constants.
665 ///
666 /// \returns true if the function should be specialized on the given
667 /// argument.
isArgumentInteresting(Argument * A,SmallVectorImpl<CallArgBinding> & Constants)668 bool isArgumentInteresting(Argument *A,
669 SmallVectorImpl<CallArgBinding> &Constants) {
670 // For now, don't attempt to specialize functions based on the values of
671 // composite types.
672 if (!A->getType()->isSingleValueType() || A->user_empty())
673 return false;
674
675 // If the argument isn't overdefined, there's nothing to do. It should
676 // already be constant.
677 if (!Solver.getLatticeValueFor(A).isOverdefined()) {
678 LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument "
679 << A->getNameOrAsOperand()
680 << " is already constant?\n");
681 return false;
682 }
683
684 // Collect the constant values that the argument can take on. If the
685 // argument can't take on any constant values, we aren't going to
686 // specialize the function. While it's possible to specialize the function
687 // based on non-constant arguments, there's likely not much benefit to
688 // constant propagation in doing so.
689 //
690 // TODO 1: currently it won't specialize if there are over the threshold of
691 // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it
692 // might be beneficial to take the occurrences into account in the cost
693 // model, so we would need to find the unique constants.
694 //
695 // TODO 2: this currently does not support constants, i.e. integer ranges.
696 //
697 getPossibleConstants(A, Constants);
698
699 if (Constants.empty())
700 return false;
701
702 LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument "
703 << A->getNameOrAsOperand() << "\n");
704 return true;
705 }
706
707 /// Collect in \p Constants all the constant values that argument \p A can
708 /// take on.
getPossibleConstants(Argument * A,SmallVectorImpl<CallArgBinding> & Constants)709 void getPossibleConstants(Argument *A,
710 SmallVectorImpl<CallArgBinding> &Constants) {
711 Function *F = A->getParent();
712
713 // SCCP solver does not record an argument that will be constructed on
714 // stack.
715 if (A->hasByValAttr() && !F->onlyReadsMemory())
716 return;
717
718 // Iterate over all the call sites of the argument's parent function.
719 for (User *U : F->users()) {
720 if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
721 continue;
722 auto &CS = *cast<CallBase>(U);
723 // If the call site has attribute minsize set, that callsite won't be
724 // specialized.
725 if (CS.hasFnAttr(Attribute::MinSize))
726 continue;
727
728 // If the parent of the call site will never be executed, we don't need
729 // to worry about the passed value.
730 if (!Solver.isBlockExecutable(CS.getParent()))
731 continue;
732
733 auto *V = CS.getArgOperand(A->getArgNo());
734 if (isa<PoisonValue>(V))
735 return;
736
737 // TrackValueOfGlobalVariable only tracks scalar global variables.
738 if (auto *GV = dyn_cast<GlobalVariable>(V)) {
739 // Check if we want to specialize on the address of non-constant
740 // global values.
741 if (!GV->isConstant())
742 if (!SpecializeOnAddresses)
743 return;
744
745 if (!GV->getValueType()->isSingleValueType())
746 return;
747 }
748
749 if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
750 EnableSpecializationForLiteralConstant))
751 Constants.push_back({&CS, cast<Constant>(V)});
752 }
753 }
754
755 /// Rewrite calls to function \p F to call function \p Clone instead.
756 ///
757 /// This function modifies calls to function \p F as long as the actual
758 /// arguments match those in \p Args. Note that for recursive calls we
759 /// need to compare against the cloned formal arguments.
760 ///
761 /// Callsites that have been marked with the MinSize function attribute won't
762 /// be specialized and rewritten.
rewriteCallSites(Function * Clone,const SmallVectorImpl<ArgInfo> & Args,ValueToValueMapTy & Mappings)763 void rewriteCallSites(Function *Clone, const SmallVectorImpl<ArgInfo> &Args,
764 ValueToValueMapTy &Mappings) {
765 assert(!Args.empty() && "Specialization without arguments");
766 Function *F = Args[0].Formal->getParent();
767
768 SmallVector<CallBase *, 8> CallSitesToRewrite;
769 for (auto *U : F->users()) {
770 if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
771 continue;
772 auto &CS = *cast<CallBase>(U);
773 if (!CS.getCalledFunction() || CS.getCalledFunction() != F)
774 continue;
775 CallSitesToRewrite.push_back(&CS);
776 }
777
778 LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of "
779 << F->getName() << " with " << Clone->getName() << "\n");
780
781 for (auto *CS : CallSitesToRewrite) {
782 LLVM_DEBUG(dbgs() << "FnSpecialization: "
783 << CS->getFunction()->getName() << " ->" << *CS
784 << "\n");
785 if (/* recursive call */
786 (CS->getFunction() == Clone &&
787 all_of(Args,
788 [CS, &Mappings](const ArgInfo &Arg) {
789 unsigned ArgNo = Arg.Formal->getArgNo();
790 return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal];
791 })) ||
792 /* normal call */
793 all_of(Args, [CS](const ArgInfo &Arg) {
794 unsigned ArgNo = Arg.Formal->getArgNo();
795 return CS->getArgOperand(ArgNo) == Arg.Actual;
796 })) {
797 CS->setCalledFunction(Clone);
798 Solver.markOverdefined(CS);
799 }
800 }
801 }
802
updateSpecializedFuncs(FuncList & Candidates,FuncList & WorkList)803 void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) {
804 for (auto *F : WorkList) {
805 SpecializedFuncs.insert(F);
806
807 // Initialize the state of the newly created functions, marking them
808 // argument-tracked and executable.
809 if (F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked))
810 Solver.addTrackedFunction(F);
811
812 Solver.addArgumentTrackedFunction(F);
813 Candidates.push_back(F);
814 Solver.markBlockExecutable(&F->front());
815
816 // Replace the function arguments for the specialized functions.
817 for (Argument &Arg : F->args())
818 if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg))
819 LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: "
820 << Arg.getNameOrAsOperand() << "\n");
821 }
822 }
823 };
824 } // namespace
825
runFunctionSpecialization(Module & M,const DataLayout & DL,std::function<TargetLibraryInfo & (Function &)> GetTLI,std::function<TargetTransformInfo & (Function &)> GetTTI,std::function<AssumptionCache & (Function &)> GetAC,function_ref<AnalysisResultsForFn (Function &)> GetAnalysis)826 bool llvm::runFunctionSpecialization(
827 Module &M, const DataLayout &DL,
828 std::function<TargetLibraryInfo &(Function &)> GetTLI,
829 std::function<TargetTransformInfo &(Function &)> GetTTI,
830 std::function<AssumptionCache &(Function &)> GetAC,
831 function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) {
832 SCCPSolver Solver(DL, GetTLI, M.getContext());
833 FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI);
834 bool Changed = false;
835
836 // Loop over all functions, marking arguments to those with their addresses
837 // taken or that are external as overdefined.
838 for (Function &F : M) {
839 if (F.isDeclaration())
840 continue;
841 if (F.hasFnAttribute(Attribute::NoDuplicate))
842 continue;
843
844 LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName()
845 << "\n");
846 Solver.addAnalysis(F, GetAnalysis(F));
847
848 // Determine if we can track the function's arguments. If so, add the
849 // function to the solver's set of argument-tracked functions.
850 if (canTrackArgumentsInterprocedurally(&F)) {
851 LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n");
852 Solver.addArgumentTrackedFunction(&F);
853 continue;
854 } else {
855 LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n"
856 << "FnSpecialization: Doesn't have local linkage, or "
857 << "has its address taken\n");
858 }
859
860 // Assume the function is called.
861 Solver.markBlockExecutable(&F.front());
862
863 // Assume nothing about the incoming arguments.
864 for (Argument &AI : F.args())
865 Solver.markOverdefined(&AI);
866 }
867
868 // Determine if we can track any of the module's global variables. If so, add
869 // the global variables we can track to the solver's set of tracked global
870 // variables.
871 for (GlobalVariable &G : M.globals()) {
872 G.removeDeadConstantUsers();
873 if (canTrackGlobalVariableInterprocedurally(&G))
874 Solver.trackValueOfGlobalVariable(&G);
875 }
876
877 auto &TrackedFuncs = Solver.getArgumentTrackedFunctions();
878 SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(),
879 TrackedFuncs.end());
880
881 // No tracked functions, so nothing to do: don't run the solver and remove
882 // the ssa_copy intrinsics that may have been introduced.
883 if (TrackedFuncs.empty()) {
884 removeSSACopy(M);
885 return false;
886 }
887
888 // Solve for constants.
889 auto RunSCCPSolver = [&](auto &WorkList) {
890 bool ResolvedUndefs = true;
891
892 while (ResolvedUndefs) {
893 // Not running the solver unnecessary is checked in regression test
894 // nothing-to-do.ll, so if this debug message is changed, this regression
895 // test needs updating too.
896 LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n");
897
898 Solver.solve();
899 LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n");
900 ResolvedUndefs = false;
901 for (Function *F : WorkList)
902 if (Solver.resolvedUndefsIn(*F))
903 ResolvedUndefs = true;
904 }
905
906 for (auto *F : WorkList) {
907 for (BasicBlock &BB : *F) {
908 if (!Solver.isBlockExecutable(&BB))
909 continue;
910 // FIXME: The solver may make changes to the function here, so set
911 // Changed, even if later function specialization does not trigger.
912 for (auto &I : make_early_inc_range(BB))
913 Changed |= FS.tryToReplaceWithConstant(&I);
914 }
915 }
916 };
917
918 #ifndef NDEBUG
919 LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n");
920 for (auto *F : FuncDecls)
921 LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n");
922 #endif
923
924 // Initially resolve the constants in all the argument tracked functions.
925 RunSCCPSolver(FuncDecls);
926
927 SmallVector<Function *, 8> WorkList;
928 unsigned I = 0;
929 while (FuncSpecializationMaxIters != I++ &&
930 FS.specializeFunctions(FuncDecls, WorkList)) {
931 LLVM_DEBUG(dbgs() << "FnSpecialization: Finished iteration " << I << "\n");
932
933 // Run the solver for the specialized functions.
934 RunSCCPSolver(WorkList);
935
936 // Replace some unresolved constant arguments.
937 constantArgPropagation(FuncDecls, M, Solver);
938
939 WorkList.clear();
940 Changed = true;
941 }
942
943 LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = "
944 << NumFuncSpecialized << "\n");
945
946 // Remove any ssa_copy intrinsics that may have been introduced.
947 removeSSACopy(M);
948 return Changed;
949 }
950