1 //===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the IndirectCallPromotion class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/IndirectCallPromotion.h"
14 #include "bolt/Passes/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/Inliner.h"
17 #include "llvm/Support/CommandLine.h"
18 
19 #define DEBUG_TYPE "ICP"
20 #define DEBUG_VERBOSE(Level, X)                                                \
21   if (opts::Verbosity >= (Level)) {                                            \
22     X;                                                                         \
23   }
24 
25 using namespace llvm;
26 using namespace bolt;
27 
28 namespace opts {
29 
30 extern cl::OptionCategory BoltOptCategory;
31 
32 extern cl::opt<IndirectCallPromotionType> ICP;
33 extern cl::opt<unsigned> Verbosity;
34 extern cl::opt<unsigned> ExecutionCountThreshold;
35 
36 static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
37     "icp-jt-remaining-percent-threshold",
38     cl::desc("The percentage threshold against remaining unpromoted indirect "
39              "call count for the promotion for jump tables"),
40     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
41 
42 static cl::opt<unsigned> ICPJTTotalPercentThreshold(
43     "icp-jt-total-percent-threshold",
44     cl::desc(
45         "The percentage threshold against total count for the promotion for "
46         "jump tables"),
47     cl::init(5), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
48 
49 static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
50     "icp-calls-remaining-percent-threshold",
51     cl::desc("The percentage threshold against remaining unpromoted indirect "
52              "call count for the promotion for calls"),
53     cl::init(50), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
54 
55 static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
56     "icp-calls-total-percent-threshold",
57     cl::desc(
58         "The percentage threshold against total count for the promotion for "
59         "calls"),
60     cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
61 
62 static cl::opt<unsigned> ICPMispredictThreshold(
63     "indirect-call-promotion-mispredict-threshold",
64     cl::desc("misprediction threshold for skipping ICP on an "
65              "indirect call"),
66     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
67 
68 static cl::opt<bool> ICPUseMispredicts(
69     "indirect-call-promotion-use-mispredicts",
70     cl::desc("use misprediction frequency for determining whether or not ICP "
71              "should be applied at a callsite.  The "
72              "-indirect-call-promotion-mispredict-threshold value will be used "
73              "by this heuristic"),
74     cl::ZeroOrMore, cl::cat(BoltOptCategory));
75 
76 static cl::opt<unsigned>
77     ICPTopN("indirect-call-promotion-topn",
78             cl::desc("limit number of targets to consider when doing indirect "
79                      "call promotion. 0 = no limit"),
80             cl::init(3), cl::ZeroOrMore, cl::cat(BoltOptCategory));
81 
82 static cl::opt<unsigned> ICPCallsTopN(
83     "indirect-call-promotion-calls-topn",
84     cl::desc("limit number of targets to consider when doing indirect "
85              "call promotion on calls. 0 = no limit"),
86     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
87 
88 static cl::opt<unsigned> ICPJumpTablesTopN(
89     "indirect-call-promotion-jump-tables-topn",
90     cl::desc("limit number of targets to consider when doing indirect "
91              "call promotion on jump tables. 0 = no limit"),
92     cl::init(0), cl::ZeroOrMore, cl::cat(BoltOptCategory));
93 
94 static cl::opt<bool> EliminateLoads(
95     "icp-eliminate-loads",
96     cl::desc("enable load elimination using memory profiling data when "
97              "performing ICP"),
98     cl::init(true), cl::ZeroOrMore, cl::cat(BoltOptCategory));
99 
100 static cl::opt<unsigned> ICPTopCallsites(
101     "icp-top-callsites",
102     cl::desc("optimize hottest calls until at least this percentage of all "
103              "indirect calls frequency is covered. 0 = all callsites"),
104     cl::init(99), cl::Hidden, cl::ZeroOrMore, cl::cat(BoltOptCategory));
105 
106 static cl::list<std::string>
107     ICPFuncsList("icp-funcs", cl::CommaSeparated,
108                  cl::desc("list of functions to enable ICP for"),
109                  cl::value_desc("func1,func2,func3,..."), cl::Hidden,
110                  cl::cat(BoltOptCategory));
111 
112 static cl::opt<bool>
113     ICPOldCodeSequence("icp-old-code-sequence",
114                        cl::desc("use old code sequence for promoted calls"),
115                        cl::init(false), cl::ZeroOrMore, cl::Hidden,
116                        cl::cat(BoltOptCategory));
117 
118 static cl::opt<bool> ICPJumpTablesByTarget(
119     "icp-jump-tables-targets",
120     cl::desc(
121         "for jump tables, optimize indirect jmp targets instead of indices"),
122     cl::init(false), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
123 
124 static cl::opt<bool> ICPPeelForInline(
125     "icp-inline", cl::desc("only promote call targets eligible for inlining"),
126     cl::init(false), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
127 
128 } // namespace opts
129 
130 namespace llvm {
131 namespace bolt {
132 
133 namespace {
134 
135 bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
136   bool IsValid = true;
137   for (auto &BFI : BFs) {
138     BinaryFunction &BF = BFI.second;
139     if (!BF.isSimple())
140       continue;
141     for (BinaryBasicBlock *BB : BF.layout()) {
142       auto BI = BB->branch_info_begin();
143       for (BinaryBasicBlock *SuccBB : BB->successors()) {
144         if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
145           if (BB->getKnownExecutionCount() == 0 ||
146               SuccBB->getKnownExecutionCount() == 0) {
147             errs() << "BOLT-WARNING: profile verification failed after ICP for "
148                       "function "
149                    << BF << '\n';
150             IsValid = false;
151           }
152         }
153         ++BI;
154       }
155     }
156   }
157   return IsValid;
158 }
159 
160 } // namespace
161 
162 IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
163                                           const IndirectCallProfile &ICP)
164     : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
165       Branches(ICP.Count) {
166   if (ICP.Symbol) {
167     To.Sym = ICP.Symbol;
168     To.Addr = 0;
169   }
170 }
171 
172 void IndirectCallPromotion::printDecision(
173     llvm::raw_ostream &OS,
174     std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
175   uint64_t TotalCount = 0;
176   uint64_t TotalMispreds = 0;
177   for (const Callsite &S : Targets) {
178     TotalCount += S.Branches;
179     TotalMispreds += S.Mispreds;
180   }
181   if (!TotalCount)
182     TotalCount = 1;
183   if (!TotalMispreds)
184     TotalMispreds = 1;
185 
186   OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
187      << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
188      << "\n";
189 
190   size_t I = 0;
191   for (const Callsite &S : Targets) {
192     OS << "Count = " << S.Branches << ", "
193        << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
194        << "Mispreds = " << S.Mispreds << ", "
195        << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds);
196     if (I < N)
197       OS << " * to be optimized *";
198     if (!S.JTIndices.empty()) {
199       OS << " Indices:";
200       for (const uint64_t Idx : S.JTIndices)
201         OS << " " << Idx;
202     }
203     OS << "\n";
204     I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
205   }
206 }
207 
208 // Get list of targets for a given call sorted by most frequently
209 // called first.
210 std::vector<IndirectCallPromotion::Callsite>
211 IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
212                                       const MCInst &Inst) const {
213   BinaryFunction &BF = *BB.getFunction();
214   const BinaryContext &BC = BF.getBinaryContext();
215   std::vector<Callsite> Targets;
216 
217   if (const JumpTable *JT = BF.getJumpTable(Inst)) {
218     // Don't support PIC jump tables for now
219     if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
220       return Targets;
221     const Location From(BF.getSymbol());
222     const std::pair<size_t, size_t> Range =
223         JT->getEntriesForAddress(BC.MIB->getJumpTable(Inst));
224     assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
225     JumpTable::JumpInfo DefaultJI;
226     const JumpTable::JumpInfo *JI =
227         JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
228     const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
229     assert(JT->Type == JumpTable::JTT_PIC ||
230            JT->EntrySize == BC.AsmInfo->getCodePointerSize());
231     for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
232       MCSymbol *Entry = JT->Entries[I];
233       assert(BF.getBasicBlockForLabel(Entry) ||
234              Entry == BF.getFunctionEndLabel() ||
235              Entry == BF.getFunctionColdEndLabel());
236       if (Entry == BF.getFunctionEndLabel() ||
237           Entry == BF.getFunctionColdEndLabel())
238         continue;
239       const Location To(Entry);
240       const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(Entry);
241       Targets.emplace_back(From, To, BI.MispredictedCount, BI.Count,
242                            I - Range.first);
243     }
244 
245     // Sort by symbol then addr.
246     std::sort(Targets.begin(), Targets.end(),
247               [](const Callsite &A, const Callsite &B) {
248                 if (A.To.Sym && B.To.Sym)
249                   return A.To.Sym < B.To.Sym;
250                 else if (A.To.Sym && !B.To.Sym)
251                   return true;
252                 else if (!A.To.Sym && B.To.Sym)
253                   return false;
254                 else
255                   return A.To.Addr < B.To.Addr;
256               });
257 
258     // Targets may contain multiple entries to the same target, but using
259     // different indices. Their profile will report the same number of branches
260     // for different indices if the target is the same. That's because we don't
261     // profile the index value, but only the target via LBR.
262     auto First = Targets.begin();
263     auto Last = Targets.end();
264     auto Result = First;
265     while (++First != Last) {
266       Callsite &A = *Result;
267       const Callsite &B = *First;
268       if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
269         A.JTIndices.insert(A.JTIndices.end(), B.JTIndices.begin(),
270                            B.JTIndices.end());
271       else
272         *(++Result) = *First;
273     }
274     ++Result;
275 
276     LLVM_DEBUG(if (Targets.end() - Result > 0) {
277       dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
278              << " duplicate targets removed\n";
279     });
280 
281     Targets.erase(Result, Targets.end());
282   } else {
283     // Don't try to optimize PC relative indirect calls.
284     if (Inst.getOperand(0).isReg() &&
285         Inst.getOperand(0).getReg() == BC.MRI->getProgramCounter())
286       return Targets;
287 
288     const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
289         Inst, "CallProfile");
290     if (ICSP) {
291       for (const IndirectCallProfile &CSP : ICSP.get()) {
292         Callsite Site(BF, CSP);
293         if (Site.isValid())
294           Targets.emplace_back(std::move(Site));
295       }
296     }
297   }
298 
299   // Sort by target count, number of indices in case of jump table, and
300   // mispredicts. We prioritize targets with high count, small number of indices
301   // and high mispredicts. Break ties by selecting targets with lower addresses.
302   std::stable_sort(Targets.begin(), Targets.end(),
303                    [](const Callsite &A, const Callsite &B) {
304                      if (A.Branches != B.Branches)
305                        return A.Branches > B.Branches;
306                      if (A.JTIndices.size() != B.JTIndices.size())
307                        return A.JTIndices.size() < B.JTIndices.size();
308                      if (A.Mispreds != B.Mispreds)
309                        return A.Mispreds > B.Mispreds;
310                      return A.To.Addr < B.To.Addr;
311                    });
312 
313   // Remove non-symbol targets
314   auto Last = std::remove_if(Targets.begin(), Targets.end(),
315                              [](const Callsite &CS) { return !CS.To.Sym; });
316   Targets.erase(Last, Targets.end());
317 
318   LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
319     uint64_t TotalCount = 0;
320     uint64_t TotalMispreds = 0;
321     for (const Callsite &S : Targets) {
322       TotalCount += S.Branches;
323       TotalMispreds += S.Mispreds;
324     }
325     if (!TotalCount)
326       TotalCount = 1;
327     if (!TotalMispreds)
328       TotalMispreds = 1;
329 
330     dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
331            << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
332            << "\n";
333 
334     size_t I = 0;
335     for (const Callsite &S : Targets) {
336       dbgs() << "Count[" << I << "] = " << S.Branches << ", "
337              << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
338              << "Mispreds[" << I << "] = " << S.Mispreds << ", "
339              << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
340       ++I;
341     }
342   });
343 
344   return Targets;
345 }
346 
347 IndirectCallPromotion::JumpTableInfoType
348 IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
349                                                    MCInst &CallInst,
350                                                    MCInst *&TargetFetchInst,
351                                                    const JumpTable *JT) const {
352   assert(JT && "Can't get jump table addrs for non-jump tables.");
353 
354   BinaryFunction &Function = *BB.getFunction();
355   BinaryContext &BC = Function.getBinaryContext();
356 
357   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
358     return JumpTableInfoType();
359 
360   JumpTableInfoType HotTargets;
361   MCInst *MemLocInstr;
362   MCInst *PCRelBaseOut;
363   unsigned BaseReg, IndexReg;
364   int64_t DispValue;
365   const MCExpr *DispExpr;
366   MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
367   const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
368       CallInst, Insts.begin(), Insts.end(), BC.AsmInfo->getCodePointerSize(),
369       MemLocInstr, BaseReg, IndexReg, DispValue, DispExpr, PCRelBaseOut);
370 
371   assert(MemLocInstr && "There should always be a load for jump tables");
372   if (!MemLocInstr)
373     return JumpTableInfoType();
374 
375   LLVM_DEBUG({
376     dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
377            << "jump table in " << Function << " at @ "
378            << (&CallInst - &BB.front()) << "\n"
379            << "BOLT-INFO: ICP target fetch instructions:\n";
380     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
381     if (MemLocInstr != &CallInst)
382       BC.printInstruction(dbgs(), CallInst, 0, &Function);
383   });
384 
385   DEBUG_VERBOSE(1, {
386     dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
387            << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
388            << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
389            << "DispValue = " << Twine::utohexstr(DispValue) << ", "
390            << "DispExpr = " << DispExpr << ", "
391            << "MemLocInstr = ";
392     BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
393     dbgs() << "\n";
394   });
395 
396   ++TotalIndexBasedCandidates;
397 
398   auto ErrorOrMemAccesssProfile =
399       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MemLocInstr,
400                                                       "MemoryAccessProfile");
401   if (!ErrorOrMemAccesssProfile) {
402     DEBUG_VERBOSE(1, dbgs()
403                          << "BOLT-INFO: ICP no memory profiling data found\n");
404     return JumpTableInfoType();
405   }
406   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
407 
408   uint64_t ArrayStart;
409   if (DispExpr) {
410     ErrorOr<uint64_t> DispValueOrError =
411         BC.getSymbolValue(*BC.MIB->getTargetSymbol(DispExpr));
412     assert(DispValueOrError && "global symbol needs a value");
413     ArrayStart = *DispValueOrError;
414   } else {
415     ArrayStart = static_cast<uint64_t>(DispValue);
416   }
417 
418   if (BaseReg == BC.MRI->getProgramCounter())
419     ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
420 
421   // This is a map of [symbol] -> [count, index] and is used to combine indices
422   // into the jump table since there may be multiple addresses that all have the
423   // same entry.
424   std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
425   const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(ArrayStart);
426 
427   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
428     size_t Index;
429     // Mem data occasionally includes nullprs, ignore them.
430     if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
431       continue;
432 
433     if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
434       return JumpTableInfoType();
435 
436     if (AccessInfo.MemoryObject) {
437       // Deal with bad/stale data
438       if (!AccessInfo.MemoryObject->getName().startswith(
439               "JUMP_TABLE/" + Function.getOneName().str()))
440         return JumpTableInfoType();
441       Index =
442           (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
443     } else {
444       Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
445     }
446 
447     // If Index is out of range it probably means the memory profiling data is
448     // wrong for this instruction, bail out.
449     if (Index >= Range.second) {
450       LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
451                         << ", " << Range.second << "\n");
452       return JumpTableInfoType();
453     }
454 
455     // Make sure the hot index points at a legal label corresponding to a BB,
456     // e.g. not the end of function (unreachable) label.
457     if (!Function.getBasicBlockForLabel(JT->Entries[Index + Range.first])) {
458       LLVM_DEBUG({
459         dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
460                << "label " << JT->Entries[Index + Range.first]->getName()
461                << " in jump table:\n";
462         JT->print(dbgs());
463         dbgs() << "HotTargetMap:\n";
464         for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
465              HotTargetMap)
466           dbgs() << "BOLT-INFO: " << HT.first->getName()
467                  << " = (count=" << HT.second.first
468                  << ", index=" << HT.second.second << ")\n";
469       });
470       return JumpTableInfoType();
471     }
472 
473     std::pair<uint64_t, uint64_t> &HotTarget =
474         HotTargetMap[JT->Entries[Index + Range.first]];
475     HotTarget.first += AccessInfo.Count;
476     HotTarget.second = Index;
477   }
478 
479   std::transform(
480       HotTargetMap.begin(), HotTargetMap.end(), std::back_inserter(HotTargets),
481       [](const std::pair<MCSymbol *, std::pair<uint64_t, uint64_t>> &A) {
482         return A.second;
483       });
484 
485   // Sort with highest counts first.
486   std::sort(HotTargets.rbegin(), HotTargets.rend());
487 
488   LLVM_DEBUG({
489     dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
490     for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
491       dbgs() << "BOLT-INFO:  Idx = " << Target.second << ", "
492              << "Count = " << Target.first << "\n";
493   });
494 
495   BC.MIB->getOrCreateAnnotationAs<uint16_t>(CallInst, "JTIndexReg") = IndexReg;
496 
497   TargetFetchInst = MemLocInstr;
498 
499   return HotTargets;
500 }
501 
502 IndirectCallPromotion::SymTargetsType
503 IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
504                                              size_t &N, BinaryBasicBlock &BB,
505                                              MCInst &CallInst,
506                                              MCInst *&TargetFetchInst) const {
507   const JumpTable *JT = BB.getFunction()->getJumpTable(CallInst);
508   SymTargetsType SymTargets;
509 
510   if (!JT) {
511     for (size_t I = 0; I < N; ++I) {
512       assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
513       assert(Targets[I].JTIndices.empty() &&
514              "Can't have jump table indices for non-jump tables");
515       SymTargets.emplace_back(Targets[I].To.Sym, 0);
516     }
517     return SymTargets;
518   }
519 
520   // Use memory profile to select hot targets.
521   JumpTableInfoType HotTargets =
522       maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
523 
524   auto findTargetsIndex = [&](uint64_t JTIndex) {
525     for (size_t I = 0; I < Targets.size(); ++I)
526       if (llvm::is_contained(Targets[I].JTIndices, JTIndex))
527         return I;
528     LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
529                       << " table entry in " << *BB.getFunction() << "\n");
530     llvm_unreachable("Hot indices must be referred to by at least one "
531                      "callsite");
532   };
533 
534   if (!HotTargets.empty()) {
535     if (opts::Verbosity >= 1)
536       for (size_t I = 0; I < HotTargets.size(); ++I)
537         outs() << "BOLT-INFO: HotTarget[" << I << "] = (" << HotTargets[I].first
538                << ", " << HotTargets[I].second << ")\n";
539 
540     // Recompute hottest targets, now discriminating which index is hot
541     // NOTE: This is a tradeoff. On one hand, we get index information. On the
542     // other hand, info coming from the memory profile is much less accurate
543     // than LBRs. So we may actually end up working with more coarse
544     // profile granularity in exchange for information about indices.
545     std::vector<Callsite> NewTargets;
546     std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
547     uint64_t TotalMemAccesses = 0;
548     for (size_t I = 0; I < HotTargets.size(); ++I) {
549       const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
550       ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
551       TotalMemAccesses += HotTargets[I].first;
552     }
553     uint64_t RemainingMemAccesses = TotalMemAccesses;
554     const size_t TopN =
555         opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
556     size_t I = 0;
557     for (; I < HotTargets.size(); ++I) {
558       const uint64_t MemAccesses = HotTargets[I].first;
559       if (100 * MemAccesses <
560           TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
561         break;
562       if (100 * MemAccesses <
563           RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
564         break;
565       if (TopN && I >= TopN)
566         break;
567       RemainingMemAccesses -= MemAccesses;
568 
569       const uint64_t JTIndex = HotTargets[I].second;
570       Callsite &Target = Targets[findTargetsIndex(JTIndex)];
571 
572       NewTargets.push_back(Target);
573       std::vector<uint64_t>({JTIndex}).swap(NewTargets.back().JTIndices);
574       Target.JTIndices.erase(std::remove(Target.JTIndices.begin(),
575                                          Target.JTIndices.end(), JTIndex),
576                              Target.JTIndices.end());
577 
578       // Keep fixCFG counts sane if more indices use this same target later
579       assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
580       NewTargets.back().Branches =
581           Target.Branches / IndicesPerTarget[Target.To.Sym];
582       NewTargets.back().Mispreds =
583           Target.Mispreds / IndicesPerTarget[Target.To.Sym];
584       assert(Target.Branches >= NewTargets.back().Branches);
585       assert(Target.Mispreds >= NewTargets.back().Mispreds);
586       Target.Branches -= NewTargets.back().Branches;
587       Target.Mispreds -= NewTargets.back().Mispreds;
588     }
589     std::copy(Targets.begin(), Targets.end(), std::back_inserter(NewTargets));
590     std::swap(NewTargets, Targets);
591     N = I;
592 
593     if (N == 0 && opts::Verbosity >= 1) {
594       outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
595              << BB.getName() << ": failed to meet thresholds after memory "
596              << "profile data was loaded.\n";
597       return SymTargets;
598     }
599   }
600 
601   for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
602     Callsite &Target = Targets[TgtIdx];
603     assert(Target.To.Sym && "All ICP targets must be to known symbols");
604     assert(!Target.JTIndices.empty() && "Jump tables must have indices");
605     for (uint64_t Idx : Target.JTIndices) {
606       SymTargets.emplace_back(Target.To.Sym, Idx);
607       ++I;
608     }
609   }
610 
611   return SymTargets;
612 }
613 
614 IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
615     BinaryBasicBlock &BB, MCInst &Inst,
616     const SymTargetsType &SymTargets) const {
617   BinaryFunction &Function = *BB.getFunction();
618   BinaryContext &BC = Function.getBinaryContext();
619   std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
620   std::vector<MCInst *> MethodFetchInsns;
621   unsigned VtableReg, MethodReg;
622   uint64_t MethodOffset;
623 
624   assert(!Function.getJumpTable(Inst) &&
625          "Can't get vtable addrs for jump tables.");
626 
627   if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
628     return MethodInfoType();
629 
630   MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
631   if (!BC.MIB->analyzeVirtualMethodCall(Insts.begin(), Insts.end(),
632                                         MethodFetchInsns, VtableReg, MethodReg,
633                                         MethodOffset)) {
634     DEBUG_VERBOSE(
635         1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
636                   << Function << " at @ " << (&Inst - &BB.front()) << "\n");
637     return MethodInfoType();
638   }
639 
640   ++TotalMethodLoadEliminationCandidates;
641 
642   DEBUG_VERBOSE(1, {
643     dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
644            << " at @ " << (&Inst - &BB.front()) << "\n";
645     dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
646     for (MCInst *Inst : MethodFetchInsns)
647       BC.printInstruction(dbgs(), *Inst, 0, &Function);
648 
649     if (MethodFetchInsns.back() != &Inst)
650       BC.printInstruction(dbgs(), Inst, 0, &Function);
651   });
652 
653   // Try to get value profiling data for the method load instruction.
654   auto ErrorOrMemAccesssProfile =
655       BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MethodFetchInsns.back(),
656                                                       "MemoryAccessProfile");
657   if (!ErrorOrMemAccesssProfile) {
658     DEBUG_VERBOSE(1, dbgs()
659                          << "BOLT-INFO: ICP no memory profiling data found\n");
660     return MethodInfoType();
661   }
662   MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
663 
664   // Find the vtable that each method belongs to.
665   std::map<const MCSymbol *, uint64_t> MethodToVtable;
666 
667   for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
668     uint64_t Address = AccessInfo.Offset;
669     if (AccessInfo.MemoryObject)
670       Address += AccessInfo.MemoryObject->getAddress();
671 
672     // Ignore bogus data.
673     if (!Address)
674       continue;
675 
676     const uint64_t VtableBase = Address - MethodOffset;
677 
678     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
679                             << Twine::utohexstr(VtableBase) << "+"
680                             << MethodOffset << "/" << AccessInfo.Count << "\n");
681 
682     if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
683       BinaryData *MethodBD = BC.getBinaryDataAtAddress(MethodAddr.get());
684       if (!MethodBD) // skip unknown methods
685         continue;
686       MCSymbol *MethodSym = MethodBD->getSymbol();
687       MethodToVtable[MethodSym] = VtableBase;
688       DEBUG_VERBOSE(1, {
689         const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
690         dbgs() << "BOLT-INFO: ICP found method = "
691                << Twine::utohexstr(MethodAddr.get()) << "/"
692                << (Method ? Method->getPrintName() : "") << "\n";
693       });
694     }
695   }
696 
697   // Find the vtable for each target symbol.
698   for (size_t I = 0; I < SymTargets.size(); ++I) {
699     auto Itr = MethodToVtable.find(SymTargets[I].first);
700     if (Itr != MethodToVtable.end()) {
701       if (BinaryData *BD = BC.getBinaryDataContainingAddress(Itr->second)) {
702         const uint64_t Addend = Itr->second - BD->getAddress();
703         VtableSyms.emplace_back(BD->getSymbol(), Addend);
704         continue;
705       }
706     }
707     // Give up if we can't find the vtable for a method.
708     DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
709                             << SymTargets[I].first->getName() << "\n");
710     return MethodInfoType();
711   }
712 
713   // Make sure the vtable reg is not clobbered by the argument passing code
714   if (VtableReg != MethodReg) {
715     for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
716          ++CurInst) {
717       const MCInstrDesc &InstrInfo = BC.MII->get(CurInst->getOpcode());
718       if (InstrInfo.hasDefOfPhysReg(*CurInst, VtableReg, *BC.MRI))
719         return MethodInfoType();
720     }
721   }
722 
723   return MethodInfoType(VtableSyms, MethodFetchInsns);
724 }
725 
726 std::vector<std::unique_ptr<BinaryBasicBlock>>
727 IndirectCallPromotion::rewriteCall(
728     BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
729     MCPlusBuilder::BlocksVectorTy &&ICPcode,
730     const std::vector<MCInst *> &MethodFetchInsns) const {
731   BinaryFunction &Function = *IndCallBlock.getFunction();
732   MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
733 
734   // Create new basic blocks with correct code in each one first.
735   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
736   const bool IsTailCallOrJT =
737       (MIB->isTailCall(CallInst) || Function.getJumpTable(CallInst));
738 
739   // Move instructions from the tail of the original call block
740   // to the merge block.
741 
742   // Remember any pseudo instructions following a tail call.  These
743   // must be preserved and moved to the original block.
744   InstructionListType TailInsts;
745   const MCInst *TailInst = &CallInst;
746   if (IsTailCallOrJT)
747     while (TailInst + 1 < &(*IndCallBlock.end()) &&
748            MIB->isPseudo(*(TailInst + 1)))
749       TailInsts.push_back(*++TailInst);
750 
751   InstructionListType MovedInst = IndCallBlock.splitInstructions(&CallInst);
752   // Link new BBs to the original input offset of the BB where the indirect
753   // call site is, so we can map samples recorded in new BBs back to the
754   // original BB seen in the input binary (if using BAT)
755   const uint32_t OrigOffset = IndCallBlock.getInputOffset();
756 
757   IndCallBlock.eraseInstructions(MethodFetchInsns.begin(),
758                                  MethodFetchInsns.end());
759   if (IndCallBlock.empty() ||
760       (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
761     IndCallBlock.addInstructions(ICPcode.front().second.begin(),
762                                  ICPcode.front().second.end());
763   else
764     IndCallBlock.replaceInstruction(std::prev(IndCallBlock.end()),
765                                     ICPcode.front().second);
766   IndCallBlock.addInstructions(TailInsts.begin(), TailInsts.end());
767 
768   for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
769     MCSymbol *&Sym = Itr->first;
770     InstructionListType &Insts = Itr->second;
771     assert(Sym);
772     std::unique_ptr<BinaryBasicBlock> TBB =
773         Function.createBasicBlock(OrigOffset, Sym);
774     for (MCInst &Inst : Insts) // sanitize new instructions.
775       if (MIB->isCall(Inst))
776         MIB->removeAnnotation(Inst, "CallProfile");
777     TBB->addInstructions(Insts.begin(), Insts.end());
778     NewBBs.emplace_back(std::move(TBB));
779   }
780 
781   // Move tail of instructions from after the original call to
782   // the merge block.
783   if (!IsTailCallOrJT)
784     NewBBs.back()->addInstructions(MovedInst.begin(), MovedInst.end());
785 
786   return NewBBs;
787 }
788 
789 BinaryBasicBlock *
790 IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
791                               const bool IsTailCall, const bool IsJumpTable,
792                               IndirectCallPromotion::BasicBlocksVector &&NewBBs,
793                               const std::vector<Callsite> &Targets) const {
794   BinaryFunction &Function = *IndCallBlock.getFunction();
795   using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
796   BinaryBasicBlock *MergeBlock = nullptr;
797 
798   // Scale indirect call counts to the execution count of the original
799   // basic block containing the indirect call.
800   uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
801   uint64_t TotalIndirectBranches = 0;
802   for (const Callsite &Target : Targets)
803     TotalIndirectBranches += Target.Branches;
804   if (TotalIndirectBranches == 0)
805     TotalIndirectBranches = 1;
806   BinaryBasicBlock::BranchInfoType BBI;
807   BinaryBasicBlock::BranchInfoType ScaledBBI;
808   for (const Callsite &Target : Targets) {
809     const size_t NumEntries =
810         std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
811     for (size_t I = 0; I < NumEntries; ++I) {
812       BBI.push_back(
813           BinaryBranchInfo{(Target.Branches + NumEntries - 1) / NumEntries,
814                            (Target.Mispreds + NumEntries - 1) / NumEntries});
815       ScaledBBI.push_back(
816           BinaryBranchInfo{uint64_t(TotalCount * Target.Branches /
817                                     (NumEntries * TotalIndirectBranches)),
818                            uint64_t(TotalCount * Target.Mispreds /
819                                     (NumEntries * TotalIndirectBranches))});
820     }
821   }
822 
823   if (IsJumpTable) {
824     BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
825     IndCallBlock.moveAllSuccessorsTo(NewIndCallBlock);
826 
827     std::vector<MCSymbol *> SymTargets;
828     for (const Callsite &Target : Targets) {
829       const size_t NumEntries =
830           std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
831       for (size_t I = 0; I < NumEntries; ++I)
832         SymTargets.push_back(Target.To.Sym);
833     }
834     assert(SymTargets.size() > NewBBs.size() - 1 &&
835            "There must be a target symbol associated with each new BB.");
836 
837     for (uint64_t I = 0; I < NewBBs.size(); ++I) {
838       BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
839       SourceBB->setExecutionCount(TotalCount);
840 
841       BinaryBasicBlock *TargetBB =
842           Function.getBasicBlockForLabel(SymTargets[I]);
843       SourceBB->addSuccessor(TargetBB, ScaledBBI[I]); // taken
844 
845       TotalCount -= ScaledBBI[I].Count;
846       SourceBB->addSuccessor(NewBBs[I].get(), TotalCount); // fall-through
847 
848       // Update branch info for the indirect jump.
849       BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
850           NewIndCallBlock->getBranchInfo(*TargetBB);
851       if (BranchInfo.Count > BBI[I].Count)
852         BranchInfo.Count -= BBI[I].Count;
853       else
854         BranchInfo.Count = 0;
855 
856       if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
857         BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
858       else
859         BranchInfo.MispredictedCount = 0;
860     }
861   } else {
862     assert(NewBBs.size() >= 2);
863     assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
864     assert(NewBBs.size() % 2 == 1 || IsTailCall);
865 
866     auto ScaledBI = ScaledBBI.begin();
867     auto updateCurrentBranchInfo = [&] {
868       assert(ScaledBI != ScaledBBI.end());
869       TotalCount -= ScaledBI->Count;
870       ++ScaledBI;
871     };
872 
873     if (!IsTailCall) {
874       MergeBlock = NewBBs.back().get();
875       IndCallBlock.moveAllSuccessorsTo(MergeBlock);
876     }
877 
878     // Fix up successors and execution counts.
879     updateCurrentBranchInfo();
880     IndCallBlock.addSuccessor(NewBBs[1].get(), TotalCount);
881     IndCallBlock.addSuccessor(NewBBs[0].get(), ScaledBBI[0]);
882 
883     const size_t Adj = IsTailCall ? 1 : 2;
884     for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
885       assert(TotalCount <= IndCallBlock.getExecutionCount() ||
886              TotalCount <= uint64_t(TotalIndirectBranches));
887       uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
888       if (I % 2 == 0) {
889         if (MergeBlock)
890           NewBBs[I]->addSuccessor(MergeBlock, ScaledBBI[(I + 1) / 2].Count);
891       } else {
892         assert(I + 2 < NewBBs.size());
893         updateCurrentBranchInfo();
894         NewBBs[I]->addSuccessor(NewBBs[I + 2].get(), TotalCount);
895         NewBBs[I]->addSuccessor(NewBBs[I + 1].get(), ScaledBBI[(I + 1) / 2]);
896         ExecCount += TotalCount;
897       }
898       NewBBs[I]->setExecutionCount(ExecCount);
899     }
900 
901     if (MergeBlock) {
902       // Arrange for the MergeBlock to be the fallthrough for the first
903       // promoted call block.
904       std::unique_ptr<BinaryBasicBlock> MBPtr;
905       std::swap(MBPtr, NewBBs.back());
906       NewBBs.pop_back();
907       NewBBs.emplace(NewBBs.begin() + 1, std::move(MBPtr));
908       // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
909       NewBBs.back()->addSuccessor(MergeBlock, TotalCount); // uncond branch
910     }
911   }
912 
913   // Update the execution count.
914   NewBBs.back()->setExecutionCount(TotalCount);
915 
916   // Update BB and BB layout.
917   Function.insertBasicBlocks(&IndCallBlock, std::move(NewBBs));
918   assert(Function.validateCFG());
919 
920   return MergeBlock;
921 }
922 
923 size_t IndirectCallPromotion::canPromoteCallsite(
924     const BinaryBasicBlock &BB, const MCInst &Inst,
925     const std::vector<Callsite> &Targets, uint64_t NumCalls) {
926   BinaryFunction *BF = BB.getFunction();
927   const BinaryContext &BC = BF->getBinaryContext();
928 
929   if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
930     return 0;
931 
932   const bool IsJumpTable = BF->getJumpTable(Inst);
933 
934   auto computeStats = [&](size_t N) {
935     for (size_t I = 0; I < N; ++I)
936       if (IsJumpTable)
937         TotalNumFrequentJmps += Targets[I].Branches;
938       else
939         TotalNumFrequentCalls += Targets[I].Branches;
940   };
941 
942   // If we have no targets (or no calls), skip this callsite.
943   if (Targets.empty() || !NumCalls) {
944     if (opts::Verbosity >= 1) {
945       const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
946       outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx << " in "
947              << BB.getName() << ", calls = " << NumCalls
948              << ", targets empty or NumCalls == 0.\n";
949     }
950     return 0;
951   }
952 
953   size_t TopN = opts::ICPTopN;
954   if (IsJumpTable)
955     TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
956   else
957     TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
958 
959   const size_t TrialN = TopN ? std::min(TopN, Targets.size()) : Targets.size();
960 
961   if (opts::ICPTopCallsites > 0) {
962     if (!BC.MIB->hasAnnotation(Inst, "DoICP"))
963       return 0;
964   }
965 
966   // Pick the top N targets.
967   uint64_t TotalMispredictsTopN = 0;
968   size_t N = 0;
969 
970   if (opts::ICPUseMispredicts &&
971       (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
972     // Count total number of mispredictions for (at most) the top N targets.
973     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
974     // is exceeded by fewer targets.
975     double Threshold = double(opts::ICPMispredictThreshold);
976     for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
977       Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
978       TotalMispredictsTopN += Targets[I].Mispreds;
979     }
980     computeStats(N);
981 
982     // Compute the misprediction frequency of the top N call targets.  If this
983     // frequency is greater than the threshold, we should try ICP on this
984     // callsite.
985     const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
986     if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
987       if (opts::Verbosity >= 1) {
988         const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
989         outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
990                << " in " << BB.getName() << ", calls = " << NumCalls
991                << ", top N mis. frequency " << format("%.1f", TopNFrequency)
992                << "% < " << opts::ICPMispredictThreshold << "%\n";
993       }
994       return 0;
995     }
996   } else {
997     size_t MaxTargets = 0;
998 
999     // Count total number of calls for (at most) the top N targets.
1000     // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1001     // is exceeded by fewer targets.
1002     const unsigned TotalThreshold = IsJumpTable
1003                                         ? opts::ICPJTTotalPercentThreshold
1004                                         : opts::ICPCallsTotalPercentThreshold;
1005     const unsigned RemainingThreshold =
1006         IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1007                     : opts::ICPCallsRemainingPercentThreshold;
1008     uint64_t NumRemainingCalls = NumCalls;
1009     for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1010       if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1011         break;
1012       if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1013         break;
1014       if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1015           TrialN)
1016         break;
1017       TotalMispredictsTopN += Targets[I].Mispreds;
1018       NumRemainingCalls -= Targets[I].Branches;
1019       N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1020     }
1021     computeStats(MaxTargets);
1022 
1023     // Don't check misprediction frequency for jump tables -- we don't really
1024     // care as long as we are saving loads from the jump table.
1025     if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1026       // Compute the misprediction frequency of the top N call targets.  If
1027       // this frequency is less than the threshold, we should skip ICP at
1028       // this callsite.
1029       const double TopNMispredictFrequency =
1030           (100.0 * TotalMispredictsTopN) / NumCalls;
1031 
1032       if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1033         if (opts::Verbosity >= 1) {
1034           const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1035           outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1036                  << " in " << BB.getName() << ", calls = " << NumCalls
1037                  << ", top N mispredict frequency "
1038                  << format("%.1f", TopNMispredictFrequency) << "% < "
1039                  << opts::ICPMispredictThreshold << "%\n";
1040         }
1041         return 0;
1042       }
1043     }
1044   }
1045 
1046   // Filter by inline-ability of target functions, stop at first target that
1047   // can't be inlined.
1048   if (opts::ICPPeelForInline) {
1049     for (size_t I = 0; I < N; ++I) {
1050       const MCSymbol *TargetSym = Targets[I].To.Sym;
1051       const BinaryFunction *TargetBF = BC.getFunctionForSymbol(TargetSym);
1052       if (!BinaryFunctionPass::shouldOptimize(*TargetBF) ||
1053           getInliningInfo(*TargetBF).Type == InliningType::INL_NONE) {
1054         N = I;
1055         break;
1056       }
1057     }
1058   }
1059 
1060   // Filter functions that can have ICP applied (for debugging)
1061   if (!opts::ICPFuncsList.empty()) {
1062     for (std::string &Name : opts::ICPFuncsList)
1063       if (BF->hasName(Name))
1064         return N;
1065     return 0;
1066   }
1067 
1068   return N;
1069 }
1070 
1071 void IndirectCallPromotion::printCallsiteInfo(
1072     const BinaryBasicBlock &BB, const MCInst &Inst,
1073     const std::vector<Callsite> &Targets, const size_t N,
1074     uint64_t NumCalls) const {
1075   BinaryContext &BC = BB.getFunction()->getBinaryContext();
1076   const bool IsTailCall = BC.MIB->isTailCall(Inst);
1077   const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1078   const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1079 
1080   outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1081          << " @ " << InstIdx << " in " << BB.getName()
1082          << " -> calls = " << NumCalls
1083          << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1084          << "\n";
1085   for (size_t I = 0; I < N; I++) {
1086     const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1087     const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1088     outs() << "BOLT-INFO:   ";
1089     if (Targets[I].To.Sym)
1090       outs() << Targets[I].To.Sym->getName();
1091     else
1092       outs() << Targets[I].To.Addr;
1093     outs() << ", calls = " << Targets[I].Branches
1094            << ", mispreds = " << Targets[I].Mispreds
1095            << ", taken freq = " << format("%.1f", Frequency) << "%"
1096            << ", mis. freq = " << format("%.1f", MisFrequency) << "%";
1097     bool First = true;
1098     for (uint64_t JTIndex : Targets[I].JTIndices) {
1099       outs() << (First ? ", indices = " : ", ") << JTIndex;
1100       First = false;
1101     }
1102     outs() << "\n";
1103   }
1104 
1105   LLVM_DEBUG({
1106     dbgs() << "BOLT-INFO: ICP original call instruction:";
1107     BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1108   });
1109 }
1110 
1111 void IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1112   if (opts::ICP == ICP_NONE)
1113     return;
1114 
1115   auto &BFs = BC.getBinaryFunctions();
1116 
1117   const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1118   const bool OptimizeJumpTables =
1119       (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1120 
1121   std::unique_ptr<RegAnalysis> RA;
1122   std::unique_ptr<BinaryFunctionCallGraph> CG;
1123   if (OptimizeJumpTables) {
1124     CG.reset(new BinaryFunctionCallGraph(buildCallGraph(BC)));
1125     RA.reset(new RegAnalysis(BC, &BFs, &*CG));
1126   }
1127 
1128   // If icp-top-callsites is enabled, compute the total number of indirect
1129   // calls and then optimize the hottest callsites that contribute to that
1130   // total.
1131   SetVector<BinaryFunction *> Functions;
1132   if (opts::ICPTopCallsites == 0) {
1133     for (auto &KV : BFs)
1134       Functions.insert(&KV.second);
1135   } else {
1136     using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1137     std::vector<IndirectCallsite> IndirectCalls;
1138     size_t TotalIndirectCalls = 0;
1139 
1140     // Find all the indirect callsites.
1141     for (auto &BFIt : BFs) {
1142       BinaryFunction &Function = BFIt.second;
1143 
1144       if (!Function.isSimple() || Function.isIgnored() ||
1145           !Function.hasProfile())
1146         continue;
1147 
1148       const bool HasLayout = !Function.layout_empty();
1149 
1150       for (BinaryBasicBlock &BB : Function) {
1151         if (HasLayout && Function.isSplit() && BB.isCold())
1152           continue;
1153 
1154         for (MCInst &Inst : BB) {
1155           const bool IsJumpTable = Function.getJumpTable(Inst);
1156           const bool HasIndirectCallProfile =
1157               BC.MIB->hasAnnotation(Inst, "CallProfile");
1158           const bool IsDirectCall =
1159               (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0));
1160 
1161           if (!IsDirectCall &&
1162               ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1163                (IsJumpTable && OptimizeJumpTables))) {
1164             uint64_t NumCalls = 0;
1165             for (const Callsite &BInfo : getCallTargets(BB, Inst))
1166               NumCalls += BInfo.Branches;
1167             IndirectCalls.push_back(
1168                 std::make_tuple(NumCalls, &Inst, &Function));
1169             TotalIndirectCalls += NumCalls;
1170           }
1171         }
1172       }
1173     }
1174 
1175     // Sort callsites by execution count.
1176     std::sort(IndirectCalls.rbegin(), IndirectCalls.rend());
1177 
1178     // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1179     // number of calls.
1180     const float TopPerc = opts::ICPTopCallsites / 100.0f;
1181     int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1182     uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1183     size_t Num = 0;
1184     for (const IndirectCallsite &IC : IndirectCalls) {
1185       const uint64_t CurFreq = std::get<0>(IC);
1186       // Once we decide to stop, include at least all branches that share the
1187       // same frequency of the last one to avoid non-deterministic behavior
1188       // (e.g. turning on/off ICP depending on the order of functions)
1189       if (MaxCalls <= 0 && CurFreq != LastFreq)
1190         break;
1191       MaxCalls -= CurFreq;
1192       LastFreq = CurFreq;
1193       BC.MIB->addAnnotation(*std::get<1>(IC), "DoICP", true);
1194       Functions.insert(std::get<2>(IC));
1195       ++Num;
1196     }
1197     outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1198            << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1199            << "% of all indirect calls\n";
1200   }
1201 
1202   for (BinaryFunction *FuncPtr : Functions) {
1203     BinaryFunction &Function = *FuncPtr;
1204 
1205     if (!Function.isSimple() || Function.isIgnored() || !Function.hasProfile())
1206       continue;
1207 
1208     const bool HasLayout = !Function.layout_empty();
1209 
1210     // Total number of indirect calls issued from the current Function.
1211     // (a fraction of TotalIndirectCalls)
1212     uint64_t FuncTotalIndirectCalls = 0;
1213     uint64_t FuncTotalIndirectJmps = 0;
1214 
1215     std::vector<BinaryBasicBlock *> BBs;
1216     for (BinaryBasicBlock &BB : Function) {
1217       // Skip indirect calls in cold blocks.
1218       if (!HasLayout || !Function.isSplit() || !BB.isCold())
1219         BBs.push_back(&BB);
1220     }
1221     if (BBs.empty())
1222       continue;
1223 
1224     DataflowInfoManager Info(Function, RA.get(), nullptr);
1225     while (!BBs.empty()) {
1226       BinaryBasicBlock *BB = BBs.back();
1227       BBs.pop_back();
1228 
1229       for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1230         MCInst &Inst = BB->getInstructionAtIndex(Idx);
1231         const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1232         const bool IsTailCall = BC.MIB->isTailCall(Inst);
1233         const bool HasIndirectCallProfile =
1234             BC.MIB->hasAnnotation(Inst, "CallProfile");
1235         const bool IsJumpTable = Function.getJumpTable(Inst);
1236 
1237         if (BC.MIB->isCall(Inst))
1238           TotalCalls += BB->getKnownExecutionCount();
1239 
1240         if (IsJumpTable && !OptimizeJumpTables)
1241           continue;
1242 
1243         if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1244           continue;
1245 
1246         // Ignore direct calls.
1247         if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0))
1248           continue;
1249 
1250         assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1251                "expected a call or an indirect jump instruction");
1252 
1253         if (IsJumpTable)
1254           ++TotalJumpTableCallsites;
1255         else
1256           ++TotalIndirectCallsites;
1257 
1258         std::vector<Callsite> Targets = getCallTargets(*BB, Inst);
1259 
1260         // Compute the total number of calls from this particular callsite.
1261         uint64_t NumCalls = 0;
1262         for (const Callsite &BInfo : Targets)
1263           NumCalls += BInfo.Branches;
1264         if (!IsJumpTable)
1265           FuncTotalIndirectCalls += NumCalls;
1266         else
1267           FuncTotalIndirectJmps += NumCalls;
1268 
1269         // If FLAGS regs is alive after this jmp site, do not try
1270         // promoting because we will clobber FLAGS.
1271         if (IsJumpTable) {
1272           ErrorOr<const BitVector &> State =
1273               Info.getLivenessAnalysis().getStateBefore(Inst);
1274           if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1275             if (opts::Verbosity >= 1)
1276               outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1277                      << InstIdx << " in " << BB->getName()
1278                      << ", calls = " << NumCalls
1279                      << (State ? ", cannot clobber flags reg.\n"
1280                                : ", no liveness data available.\n");
1281             continue;
1282           }
1283         }
1284 
1285         // Should this callsite be optimized?  Return the number of targets
1286         // to use when promoting this call.  A value of zero means to skip
1287         // this callsite.
1288         size_t N = canPromoteCallsite(*BB, Inst, Targets, NumCalls);
1289 
1290         // If it is a jump table and it failed to meet our initial threshold,
1291         // proceed to findCallTargetSymbols -- it may reevaluate N if
1292         // memory profile is present
1293         if (!N && !IsJumpTable)
1294           continue;
1295 
1296         if (opts::Verbosity >= 1)
1297           printCallsiteInfo(*BB, Inst, Targets, N, NumCalls);
1298 
1299         // Find MCSymbols or absolute addresses for each call target.
1300         MCInst *TargetFetchInst = nullptr;
1301         const SymTargetsType SymTargets =
1302             findCallTargetSymbols(Targets, N, *BB, Inst, TargetFetchInst);
1303 
1304         // findCallTargetSymbols may have changed N if mem profile is available
1305         // for jump tables
1306         if (!N)
1307           continue;
1308 
1309         LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1310 
1311         // If we can't resolve any of the target symbols, punt on this callsite.
1312         // TODO: can this ever happen?
1313         if (SymTargets.size() < N) {
1314           const size_t LastTarget = SymTargets.size();
1315           if (opts::Verbosity >= 1)
1316             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1317                    << InstIdx << " in " << BB->getName()
1318                    << ", calls = " << NumCalls
1319                    << ", ICP failed to find target symbol for "
1320                    << Targets[LastTarget].To.Sym->getName() << "\n";
1321           continue;
1322         }
1323 
1324         MethodInfoType MethodInfo;
1325 
1326         if (!IsJumpTable) {
1327           MethodInfo = maybeGetVtableSyms(*BB, Inst, SymTargets);
1328           TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1329           LLVM_DEBUG(dbgs()
1330                      << "BOLT-INFO: ICP "
1331                      << (!MethodInfo.first.empty() ? "found" : "did not find")
1332                      << " vtables for all methods.\n");
1333         } else if (TargetFetchInst) {
1334           ++TotalIndexBasedJumps;
1335           MethodInfo.second.push_back(TargetFetchInst);
1336         }
1337 
1338         // Generate new promoted call code for this callsite.
1339         MCPlusBuilder::BlocksVectorTy ICPcode =
1340             (IsJumpTable && !opts::ICPJumpTablesByTarget)
1341                 ? BC.MIB->jumpTablePromotion(Inst, SymTargets,
1342                                              MethodInfo.second, BC.Ctx.get())
1343                 : BC.MIB->indirectCallPromotion(
1344                       Inst, SymTargets, MethodInfo.first, MethodInfo.second,
1345                       opts::ICPOldCodeSequence, BC.Ctx.get());
1346 
1347         if (ICPcode.empty()) {
1348           if (opts::Verbosity >= 1)
1349             outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1350                    << InstIdx << " in " << BB->getName()
1351                    << ", calls = " << NumCalls
1352                    << ", unable to generate promoted call code.\n";
1353           continue;
1354         }
1355 
1356         LLVM_DEBUG({
1357           uint64_t Offset = Targets[0].From.Addr;
1358           dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1359           for (const auto &entry : ICPcode) {
1360             const MCSymbol *const &Sym = entry.first;
1361             const InstructionListType &Insts = entry.second;
1362             if (Sym)
1363               dbgs() << Sym->getName() << ":\n";
1364             Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1365                                           Offset);
1366           }
1367           dbgs() << "---------------------------------------------------\n";
1368         });
1369 
1370         // Rewrite the CFG with the newly generated ICP code.
1371         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1372             rewriteCall(*BB, Inst, std::move(ICPcode), MethodInfo.second);
1373 
1374         // Fix the CFG after inserting the new basic blocks.
1375         BinaryBasicBlock *MergeBlock =
1376             fixCFG(*BB, IsTailCall, IsJumpTable, std::move(NewBBs), Targets);
1377 
1378         // Since the tail of the original block was split off and it may contain
1379         // additional indirect calls, we must add the merge block to the set of
1380         // blocks to process.
1381         if (MergeBlock)
1382           BBs.push_back(MergeBlock);
1383 
1384         if (opts::Verbosity >= 1)
1385           outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1386                  << InstIdx << " in " << BB->getName()
1387                  << " -> calls = " << NumCalls << "\n";
1388 
1389         if (IsJumpTable)
1390           ++TotalOptimizedJumpTableCallsites;
1391         else
1392           ++TotalOptimizedIndirectCallsites;
1393 
1394         Modified.insert(&Function);
1395       }
1396     }
1397     TotalIndirectCalls += FuncTotalIndirectCalls;
1398     TotalIndirectJmps += FuncTotalIndirectJmps;
1399   }
1400 
1401   outs() << "BOLT-INFO: ICP total indirect callsites with profile = "
1402          << TotalIndirectCallsites << "\n"
1403          << "BOLT-INFO: ICP total jump table callsites = "
1404          << TotalJumpTableCallsites << "\n"
1405          << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1406          << "BOLT-INFO: ICP percentage of calls that are indirect = "
1407          << format("%.1f", (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1408          << "BOLT-INFO: ICP percentage of indirect calls that can be "
1409             "optimized = "
1410          << format("%.1f", (100.0 * TotalNumFrequentCalls) /
1411                                std::max<size_t>(TotalIndirectCalls, 1))
1412          << "%\n"
1413          << "BOLT-INFO: ICP percentage of indirect callsites that are "
1414             "optimized = "
1415          << format("%.1f", (100.0 * TotalOptimizedIndirectCallsites) /
1416                                std::max<uint64_t>(TotalIndirectCallsites, 1))
1417          << "%\n"
1418          << "BOLT-INFO: ICP number of method load elimination candidates = "
1419          << TotalMethodLoadEliminationCandidates << "\n"
1420          << "BOLT-INFO: ICP percentage of method calls candidates that have "
1421             "loads eliminated = "
1422          << format("%.1f", (100.0 * TotalMethodLoadsEliminated) /
1423                                std::max<uint64_t>(
1424                                    TotalMethodLoadEliminationCandidates, 1))
1425          << "%\n"
1426          << "BOLT-INFO: ICP percentage of indirect branches that are "
1427             "optimized = "
1428          << format("%.1f", (100.0 * TotalNumFrequentJmps) /
1429                                std::max<uint64_t>(TotalIndirectJmps, 1))
1430          << "%\n"
1431          << "BOLT-INFO: ICP percentage of jump table callsites that are "
1432          << "optimized = "
1433          << format("%.1f", (100.0 * TotalOptimizedJumpTableCallsites) /
1434                                std::max<uint64_t>(TotalJumpTableCallsites, 1))
1435          << "%\n"
1436          << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1437          << "indices = " << TotalIndexBasedCandidates << "\n"
1438          << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1439             "indices = "
1440          << format("%.1f", (100.0 * TotalIndexBasedJumps) /
1441                                std::max<uint64_t>(TotalIndexBasedCandidates, 1))
1442          << "%\n";
1443 
1444   (void)verifyProfile;
1445 #ifndef NDEBUG
1446   verifyProfile(BFs);
1447 #endif
1448 }
1449 
1450 } // namespace bolt
1451 } // namespace llvm
1452