1 //===- bolt/Passes/IndirectCallPromotion.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the IndirectCallPromotion class.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "bolt/Passes/IndirectCallPromotion.h"
14 #include "bolt/Passes/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/Inliner.h"
17 #include "llvm/Support/CommandLine.h"
18
19 #define DEBUG_TYPE "ICP"
20 #define DEBUG_VERBOSE(Level, X) \
21 if (opts::Verbosity >= (Level)) { \
22 X; \
23 }
24
25 using namespace llvm;
26 using namespace bolt;
27
28 namespace opts {
29
30 extern cl::OptionCategory BoltOptCategory;
31
32 extern cl::opt<IndirectCallPromotionType> ICP;
33 extern cl::opt<unsigned> Verbosity;
34 extern cl::opt<unsigned> ExecutionCountThreshold;
35
36 static cl::opt<unsigned> ICPJTRemainingPercentThreshold(
37 "icp-jt-remaining-percent-threshold",
38 cl::desc("The percentage threshold against remaining unpromoted indirect "
39 "call count for the promotion for jump tables"),
40 cl::init(30), cl::ZeroOrMore, cl::Hidden, cl::cat(BoltOptCategory));
41
42 static cl::opt<unsigned> ICPJTTotalPercentThreshold(
43 "icp-jt-total-percent-threshold",
44 cl::desc(
45 "The percentage threshold against total count for the promotion for "
46 "jump tables"),
47 cl::init(5), cl::Hidden, cl::cat(BoltOptCategory));
48
49 static cl::opt<unsigned> ICPCallsRemainingPercentThreshold(
50 "icp-calls-remaining-percent-threshold",
51 cl::desc("The percentage threshold against remaining unpromoted indirect "
52 "call count for the promotion for calls"),
53 cl::init(50), cl::Hidden, cl::cat(BoltOptCategory));
54
55 static cl::opt<unsigned> ICPCallsTotalPercentThreshold(
56 "icp-calls-total-percent-threshold",
57 cl::desc(
58 "The percentage threshold against total count for the promotion for "
59 "calls"),
60 cl::init(30), cl::Hidden, cl::cat(BoltOptCategory));
61
62 static cl::opt<unsigned> ICPMispredictThreshold(
63 "indirect-call-promotion-mispredict-threshold",
64 cl::desc("misprediction threshold for skipping ICP on an "
65 "indirect call"),
66 cl::init(0), cl::cat(BoltOptCategory));
67
68 static cl::alias ICPMispredictThresholdAlias(
69 "icp-mp-threshold",
70 cl::desc("alias for --indirect-call-promotion-mispredict-threshold"),
71 cl::aliasopt(ICPMispredictThreshold));
72
73 static cl::opt<bool> ICPUseMispredicts(
74 "indirect-call-promotion-use-mispredicts",
75 cl::desc("use misprediction frequency for determining whether or not ICP "
76 "should be applied at a callsite. The "
77 "-indirect-call-promotion-mispredict-threshold value will be used "
78 "by this heuristic"),
79 cl::cat(BoltOptCategory));
80
81 static cl::alias ICPUseMispredictsAlias(
82 "icp-use-mp",
83 cl::desc("alias for --indirect-call-promotion-use-mispredicts"),
84 cl::aliasopt(ICPUseMispredicts));
85
86 static cl::opt<unsigned>
87 ICPTopN("indirect-call-promotion-topn",
88 cl::desc("limit number of targets to consider when doing indirect "
89 "call promotion. 0 = no limit"),
90 cl::init(3), cl::cat(BoltOptCategory));
91
92 static cl::alias
93 ICPTopNAlias("icp-topn",
94 cl::desc("alias for --indirect-call-promotion-topn"),
95 cl::aliasopt(ICPTopN));
96
97 static cl::opt<unsigned> ICPCallsTopN(
98 "indirect-call-promotion-calls-topn",
99 cl::desc("limit number of targets to consider when doing indirect "
100 "call promotion on calls. 0 = no limit"),
101 cl::init(0), cl::cat(BoltOptCategory));
102
103 static cl::alias ICPCallsTopNAlias(
104 "icp-calls-topn",
105 cl::desc("alias for --indirect-call-promotion-calls-topn"),
106 cl::aliasopt(ICPCallsTopN));
107
108 static cl::opt<unsigned> ICPJumpTablesTopN(
109 "indirect-call-promotion-jump-tables-topn",
110 cl::desc("limit number of targets to consider when doing indirect "
111 "call promotion on jump tables. 0 = no limit"),
112 cl::init(0), cl::cat(BoltOptCategory));
113
114 static cl::alias ICPJumpTablesTopNAlias(
115 "icp-jt-topn",
116 cl::desc("alias for --indirect-call-promotion-jump-tables-topn"),
117 cl::aliasopt(ICPJumpTablesTopN));
118
119 static cl::opt<bool> EliminateLoads(
120 "icp-eliminate-loads",
121 cl::desc("enable load elimination using memory profiling data when "
122 "performing ICP"),
123 cl::init(true), cl::cat(BoltOptCategory));
124
125 static cl::opt<unsigned> ICPTopCallsites(
126 "icp-top-callsites",
127 cl::desc("optimize hottest calls until at least this percentage of all "
128 "indirect calls frequency is covered. 0 = all callsites"),
129 cl::init(99), cl::Hidden, cl::cat(BoltOptCategory));
130
131 static cl::list<std::string>
132 ICPFuncsList("icp-funcs", cl::CommaSeparated,
133 cl::desc("list of functions to enable ICP for"),
134 cl::value_desc("func1,func2,func3,..."), cl::Hidden,
135 cl::cat(BoltOptCategory));
136
137 static cl::opt<bool>
138 ICPOldCodeSequence("icp-old-code-sequence",
139 cl::desc("use old code sequence for promoted calls"),
140 cl::Hidden, cl::cat(BoltOptCategory));
141
142 static cl::opt<bool> ICPJumpTablesByTarget(
143 "icp-jump-tables-targets",
144 cl::desc(
145 "for jump tables, optimize indirect jmp targets instead of indices"),
146 cl::Hidden, cl::cat(BoltOptCategory));
147
148 static cl::alias
149 ICPJumpTablesByTargetAlias("icp-jt-targets",
150 cl::desc("alias for --icp-jump-tables-targets"),
151 cl::aliasopt(ICPJumpTablesByTarget));
152
153 static cl::opt<bool> ICPPeelForInline(
154 "icp-inline", cl::desc("only promote call targets eligible for inlining"),
155 cl::Hidden, cl::cat(BoltOptCategory));
156
157 } // namespace opts
158
verifyProfile(std::map<uint64_t,BinaryFunction> & BFs)159 static bool verifyProfile(std::map<uint64_t, BinaryFunction> &BFs) {
160 bool IsValid = true;
161 for (auto &BFI : BFs) {
162 BinaryFunction &BF = BFI.second;
163 if (!BF.isSimple())
164 continue;
165 for (const BinaryBasicBlock &BB : BF) {
166 auto BI = BB.branch_info_begin();
167 for (BinaryBasicBlock *SuccBB : BB.successors()) {
168 if (BI->Count != BinaryBasicBlock::COUNT_NO_PROFILE && BI->Count > 0) {
169 if (BB.getKnownExecutionCount() == 0 ||
170 SuccBB->getKnownExecutionCount() == 0) {
171 errs() << "BOLT-WARNING: profile verification failed after ICP for "
172 "function "
173 << BF << '\n';
174 IsValid = false;
175 }
176 }
177 ++BI;
178 }
179 }
180 }
181 return IsValid;
182 }
183
184 namespace llvm {
185 namespace bolt {
186
Callsite(BinaryFunction & BF,const IndirectCallProfile & ICP)187 IndirectCallPromotion::Callsite::Callsite(BinaryFunction &BF,
188 const IndirectCallProfile &ICP)
189 : From(BF.getSymbol()), To(ICP.Offset), Mispreds(ICP.Mispreds),
190 Branches(ICP.Count) {
191 if (ICP.Symbol) {
192 To.Sym = ICP.Symbol;
193 To.Addr = 0;
194 }
195 }
196
printDecision(llvm::raw_ostream & OS,std::vector<IndirectCallPromotion::Callsite> & Targets,unsigned N) const197 void IndirectCallPromotion::printDecision(
198 llvm::raw_ostream &OS,
199 std::vector<IndirectCallPromotion::Callsite> &Targets, unsigned N) const {
200 uint64_t TotalCount = 0;
201 uint64_t TotalMispreds = 0;
202 for (const Callsite &S : Targets) {
203 TotalCount += S.Branches;
204 TotalMispreds += S.Mispreds;
205 }
206 if (!TotalCount)
207 TotalCount = 1;
208 if (!TotalMispreds)
209 TotalMispreds = 1;
210
211 OS << "BOLT-INFO: ICP decision for call site with " << Targets.size()
212 << " targets, Count = " << TotalCount << ", Mispreds = " << TotalMispreds
213 << "\n";
214
215 size_t I = 0;
216 for (const Callsite &S : Targets) {
217 OS << "Count = " << S.Branches << ", "
218 << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
219 << "Mispreds = " << S.Mispreds << ", "
220 << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds);
221 if (I < N)
222 OS << " * to be optimized *";
223 if (!S.JTIndices.empty()) {
224 OS << " Indices:";
225 for (const uint64_t Idx : S.JTIndices)
226 OS << " " << Idx;
227 }
228 OS << "\n";
229 I += S.JTIndices.empty() ? 1 : S.JTIndices.size();
230 }
231 }
232
233 // Get list of targets for a given call sorted by most frequently
234 // called first.
235 std::vector<IndirectCallPromotion::Callsite>
getCallTargets(BinaryBasicBlock & BB,const MCInst & Inst) const236 IndirectCallPromotion::getCallTargets(BinaryBasicBlock &BB,
237 const MCInst &Inst) const {
238 BinaryFunction &BF = *BB.getFunction();
239 const BinaryContext &BC = BF.getBinaryContext();
240 std::vector<Callsite> Targets;
241
242 if (const JumpTable *JT = BF.getJumpTable(Inst)) {
243 // Don't support PIC jump tables for now
244 if (!opts::ICPJumpTablesByTarget && JT->Type == JumpTable::JTT_PIC)
245 return Targets;
246 const Location From(BF.getSymbol());
247 const std::pair<size_t, size_t> Range =
248 JT->getEntriesForAddress(BC.MIB->getJumpTable(Inst));
249 assert(JT->Counts.empty() || JT->Counts.size() >= Range.second);
250 JumpTable::JumpInfo DefaultJI;
251 const JumpTable::JumpInfo *JI =
252 JT->Counts.empty() ? &DefaultJI : &JT->Counts[Range.first];
253 const size_t JIAdj = JT->Counts.empty() ? 0 : 1;
254 assert(JT->Type == JumpTable::JTT_PIC ||
255 JT->EntrySize == BC.AsmInfo->getCodePointerSize());
256 for (size_t I = Range.first; I < Range.second; ++I, JI += JIAdj) {
257 MCSymbol *Entry = JT->Entries[I];
258 assert(BF.getBasicBlockForLabel(Entry) ||
259 Entry == BF.getFunctionEndLabel() ||
260 Entry == BF.getFunctionColdEndLabel());
261 if (Entry == BF.getFunctionEndLabel() ||
262 Entry == BF.getFunctionColdEndLabel())
263 continue;
264 const Location To(Entry);
265 const BinaryBasicBlock::BinaryBranchInfo &BI = BB.getBranchInfo(Entry);
266 Targets.emplace_back(From, To, BI.MispredictedCount, BI.Count,
267 I - Range.first);
268 }
269
270 // Sort by symbol then addr.
271 llvm::sort(Targets, [](const Callsite &A, const Callsite &B) {
272 if (A.To.Sym && B.To.Sym)
273 return A.To.Sym < B.To.Sym;
274 else if (A.To.Sym && !B.To.Sym)
275 return true;
276 else if (!A.To.Sym && B.To.Sym)
277 return false;
278 else
279 return A.To.Addr < B.To.Addr;
280 });
281
282 // Targets may contain multiple entries to the same target, but using
283 // different indices. Their profile will report the same number of branches
284 // for different indices if the target is the same. That's because we don't
285 // profile the index value, but only the target via LBR.
286 auto First = Targets.begin();
287 auto Last = Targets.end();
288 auto Result = First;
289 while (++First != Last) {
290 Callsite &A = *Result;
291 const Callsite &B = *First;
292 if (A.To.Sym && B.To.Sym && A.To.Sym == B.To.Sym)
293 A.JTIndices.insert(A.JTIndices.end(), B.JTIndices.begin(),
294 B.JTIndices.end());
295 else
296 *(++Result) = *First;
297 }
298 ++Result;
299
300 LLVM_DEBUG(if (Targets.end() - Result > 0) {
301 dbgs() << "BOLT-INFO: ICP: " << (Targets.end() - Result)
302 << " duplicate targets removed\n";
303 });
304
305 Targets.erase(Result, Targets.end());
306 } else {
307 // Don't try to optimize PC relative indirect calls.
308 if (Inst.getOperand(0).isReg() &&
309 Inst.getOperand(0).getReg() == BC.MRI->getProgramCounter())
310 return Targets;
311
312 const auto ICSP = BC.MIB->tryGetAnnotationAs<IndirectCallSiteProfile>(
313 Inst, "CallProfile");
314 if (ICSP) {
315 for (const IndirectCallProfile &CSP : ICSP.get()) {
316 Callsite Site(BF, CSP);
317 if (Site.isValid())
318 Targets.emplace_back(std::move(Site));
319 }
320 }
321 }
322
323 // Sort by target count, number of indices in case of jump table, and
324 // mispredicts. We prioritize targets with high count, small number of indices
325 // and high mispredicts. Break ties by selecting targets with lower addresses.
326 llvm::stable_sort(Targets, [](const Callsite &A, const Callsite &B) {
327 if (A.Branches != B.Branches)
328 return A.Branches > B.Branches;
329 if (A.JTIndices.size() != B.JTIndices.size())
330 return A.JTIndices.size() < B.JTIndices.size();
331 if (A.Mispreds != B.Mispreds)
332 return A.Mispreds > B.Mispreds;
333 return A.To.Addr < B.To.Addr;
334 });
335
336 // Remove non-symbol targets
337 llvm::erase_if(Targets, [](const Callsite &CS) { return !CS.To.Sym; });
338
339 LLVM_DEBUG(if (BF.getJumpTable(Inst)) {
340 uint64_t TotalCount = 0;
341 uint64_t TotalMispreds = 0;
342 for (const Callsite &S : Targets) {
343 TotalCount += S.Branches;
344 TotalMispreds += S.Mispreds;
345 }
346 if (!TotalCount)
347 TotalCount = 1;
348 if (!TotalMispreds)
349 TotalMispreds = 1;
350
351 dbgs() << "BOLT-INFO: ICP: jump table size = " << Targets.size()
352 << ", Count = " << TotalCount << ", Mispreds = " << TotalMispreds
353 << "\n";
354
355 size_t I = 0;
356 for (const Callsite &S : Targets) {
357 dbgs() << "Count[" << I << "] = " << S.Branches << ", "
358 << format("%.1f", (100.0 * S.Branches) / TotalCount) << ", "
359 << "Mispreds[" << I << "] = " << S.Mispreds << ", "
360 << format("%.1f", (100.0 * S.Mispreds) / TotalMispreds) << "\n";
361 ++I;
362 }
363 });
364
365 return Targets;
366 }
367
368 IndirectCallPromotion::JumpTableInfoType
maybeGetHotJumpTableTargets(BinaryBasicBlock & BB,MCInst & CallInst,MCInst * & TargetFetchInst,const JumpTable * JT) const369 IndirectCallPromotion::maybeGetHotJumpTableTargets(BinaryBasicBlock &BB,
370 MCInst &CallInst,
371 MCInst *&TargetFetchInst,
372 const JumpTable *JT) const {
373 assert(JT && "Can't get jump table addrs for non-jump tables.");
374
375 BinaryFunction &Function = *BB.getFunction();
376 BinaryContext &BC = Function.getBinaryContext();
377
378 if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
379 return JumpTableInfoType();
380
381 JumpTableInfoType HotTargets;
382 MCInst *MemLocInstr;
383 MCInst *PCRelBaseOut;
384 unsigned BaseReg, IndexReg;
385 int64_t DispValue;
386 const MCExpr *DispExpr;
387 MutableArrayRef<MCInst> Insts(&BB.front(), &CallInst);
388 const IndirectBranchType Type = BC.MIB->analyzeIndirectBranch(
389 CallInst, Insts.begin(), Insts.end(), BC.AsmInfo->getCodePointerSize(),
390 MemLocInstr, BaseReg, IndexReg, DispValue, DispExpr, PCRelBaseOut);
391
392 assert(MemLocInstr && "There should always be a load for jump tables");
393 if (!MemLocInstr)
394 return JumpTableInfoType();
395
396 LLVM_DEBUG({
397 dbgs() << "BOLT-INFO: ICP attempting to find memory profiling data for "
398 << "jump table in " << Function << " at @ "
399 << (&CallInst - &BB.front()) << "\n"
400 << "BOLT-INFO: ICP target fetch instructions:\n";
401 BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
402 if (MemLocInstr != &CallInst)
403 BC.printInstruction(dbgs(), CallInst, 0, &Function);
404 });
405
406 DEBUG_VERBOSE(1, {
407 dbgs() << "Jmp info: Type = " << (unsigned)Type << ", "
408 << "BaseReg = " << BC.MRI->getName(BaseReg) << ", "
409 << "IndexReg = " << BC.MRI->getName(IndexReg) << ", "
410 << "DispValue = " << Twine::utohexstr(DispValue) << ", "
411 << "DispExpr = " << DispExpr << ", "
412 << "MemLocInstr = ";
413 BC.printInstruction(dbgs(), *MemLocInstr, 0, &Function);
414 dbgs() << "\n";
415 });
416
417 ++TotalIndexBasedCandidates;
418
419 auto ErrorOrMemAccesssProfile =
420 BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MemLocInstr,
421 "MemoryAccessProfile");
422 if (!ErrorOrMemAccesssProfile) {
423 DEBUG_VERBOSE(1, dbgs()
424 << "BOLT-INFO: ICP no memory profiling data found\n");
425 return JumpTableInfoType();
426 }
427 MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
428
429 uint64_t ArrayStart;
430 if (DispExpr) {
431 ErrorOr<uint64_t> DispValueOrError =
432 BC.getSymbolValue(*BC.MIB->getTargetSymbol(DispExpr));
433 assert(DispValueOrError && "global symbol needs a value");
434 ArrayStart = *DispValueOrError;
435 } else {
436 ArrayStart = static_cast<uint64_t>(DispValue);
437 }
438
439 if (BaseReg == BC.MRI->getProgramCounter())
440 ArrayStart += Function.getAddress() + MemAccessProfile.NextInstrOffset;
441
442 // This is a map of [symbol] -> [count, index] and is used to combine indices
443 // into the jump table since there may be multiple addresses that all have the
444 // same entry.
445 std::map<MCSymbol *, std::pair<uint64_t, uint64_t>> HotTargetMap;
446 const std::pair<size_t, size_t> Range = JT->getEntriesForAddress(ArrayStart);
447
448 for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
449 size_t Index;
450 // Mem data occasionally includes nullprs, ignore them.
451 if (!AccessInfo.MemoryObject && !AccessInfo.Offset)
452 continue;
453
454 if (AccessInfo.Offset % JT->EntrySize != 0) // ignore bogus data
455 return JumpTableInfoType();
456
457 if (AccessInfo.MemoryObject) {
458 // Deal with bad/stale data
459 if (!AccessInfo.MemoryObject->getName().startswith(
460 "JUMP_TABLE/" + Function.getOneName().str()))
461 return JumpTableInfoType();
462 Index =
463 (AccessInfo.Offset - (ArrayStart - JT->getAddress())) / JT->EntrySize;
464 } else {
465 Index = (AccessInfo.Offset - ArrayStart) / JT->EntrySize;
466 }
467
468 // If Index is out of range it probably means the memory profiling data is
469 // wrong for this instruction, bail out.
470 if (Index >= Range.second) {
471 LLVM_DEBUG(dbgs() << "BOLT-INFO: Index out of range of " << Range.first
472 << ", " << Range.second << "\n");
473 return JumpTableInfoType();
474 }
475
476 // Make sure the hot index points at a legal label corresponding to a BB,
477 // e.g. not the end of function (unreachable) label.
478 if (!Function.getBasicBlockForLabel(JT->Entries[Index + Range.first])) {
479 LLVM_DEBUG({
480 dbgs() << "BOLT-INFO: hot index " << Index << " pointing at bogus "
481 << "label " << JT->Entries[Index + Range.first]->getName()
482 << " in jump table:\n";
483 JT->print(dbgs());
484 dbgs() << "HotTargetMap:\n";
485 for (std::pair<MCSymbol *const, std::pair<uint64_t, uint64_t>> &HT :
486 HotTargetMap)
487 dbgs() << "BOLT-INFO: " << HT.first->getName()
488 << " = (count=" << HT.second.first
489 << ", index=" << HT.second.second << ")\n";
490 });
491 return JumpTableInfoType();
492 }
493
494 std::pair<uint64_t, uint64_t> &HotTarget =
495 HotTargetMap[JT->Entries[Index + Range.first]];
496 HotTarget.first += AccessInfo.Count;
497 HotTarget.second = Index;
498 }
499
500 llvm::transform(
501 HotTargetMap, std::back_inserter(HotTargets),
502 [](const std::pair<MCSymbol *, std::pair<uint64_t, uint64_t>> &A) {
503 return A.second;
504 });
505
506 // Sort with highest counts first.
507 llvm::sort(reverse(HotTargets));
508
509 LLVM_DEBUG({
510 dbgs() << "BOLT-INFO: ICP jump table hot targets:\n";
511 for (const std::pair<uint64_t, uint64_t> &Target : HotTargets)
512 dbgs() << "BOLT-INFO: Idx = " << Target.second << ", "
513 << "Count = " << Target.first << "\n";
514 });
515
516 BC.MIB->getOrCreateAnnotationAs<uint16_t>(CallInst, "JTIndexReg") = IndexReg;
517
518 TargetFetchInst = MemLocInstr;
519
520 return HotTargets;
521 }
522
523 IndirectCallPromotion::SymTargetsType
findCallTargetSymbols(std::vector<Callsite> & Targets,size_t & N,BinaryBasicBlock & BB,MCInst & CallInst,MCInst * & TargetFetchInst) const524 IndirectCallPromotion::findCallTargetSymbols(std::vector<Callsite> &Targets,
525 size_t &N, BinaryBasicBlock &BB,
526 MCInst &CallInst,
527 MCInst *&TargetFetchInst) const {
528 const JumpTable *JT = BB.getFunction()->getJumpTable(CallInst);
529 SymTargetsType SymTargets;
530
531 if (!JT) {
532 for (size_t I = 0; I < N; ++I) {
533 assert(Targets[I].To.Sym && "All ICP targets must be to known symbols");
534 assert(Targets[I].JTIndices.empty() &&
535 "Can't have jump table indices for non-jump tables");
536 SymTargets.emplace_back(Targets[I].To.Sym, 0);
537 }
538 return SymTargets;
539 }
540
541 // Use memory profile to select hot targets.
542 JumpTableInfoType HotTargets =
543 maybeGetHotJumpTableTargets(BB, CallInst, TargetFetchInst, JT);
544
545 auto findTargetsIndex = [&](uint64_t JTIndex) {
546 for (size_t I = 0; I < Targets.size(); ++I)
547 if (llvm::is_contained(Targets[I].JTIndices, JTIndex))
548 return I;
549 LLVM_DEBUG(dbgs() << "BOLT-ERROR: Unable to find target index for hot jump "
550 << " table entry in " << *BB.getFunction() << "\n");
551 llvm_unreachable("Hot indices must be referred to by at least one "
552 "callsite");
553 };
554
555 if (!HotTargets.empty()) {
556 if (opts::Verbosity >= 1)
557 for (size_t I = 0; I < HotTargets.size(); ++I)
558 outs() << "BOLT-INFO: HotTarget[" << I << "] = (" << HotTargets[I].first
559 << ", " << HotTargets[I].second << ")\n";
560
561 // Recompute hottest targets, now discriminating which index is hot
562 // NOTE: This is a tradeoff. On one hand, we get index information. On the
563 // other hand, info coming from the memory profile is much less accurate
564 // than LBRs. So we may actually end up working with more coarse
565 // profile granularity in exchange for information about indices.
566 std::vector<Callsite> NewTargets;
567 std::map<const MCSymbol *, uint32_t> IndicesPerTarget;
568 uint64_t TotalMemAccesses = 0;
569 for (size_t I = 0; I < HotTargets.size(); ++I) {
570 const uint64_t TargetIndex = findTargetsIndex(HotTargets[I].second);
571 ++IndicesPerTarget[Targets[TargetIndex].To.Sym];
572 TotalMemAccesses += HotTargets[I].first;
573 }
574 uint64_t RemainingMemAccesses = TotalMemAccesses;
575 const size_t TopN =
576 opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : opts::ICPTopN;
577 size_t I = 0;
578 for (; I < HotTargets.size(); ++I) {
579 const uint64_t MemAccesses = HotTargets[I].first;
580 if (100 * MemAccesses <
581 TotalMemAccesses * opts::ICPJTTotalPercentThreshold)
582 break;
583 if (100 * MemAccesses <
584 RemainingMemAccesses * opts::ICPJTRemainingPercentThreshold)
585 break;
586 if (TopN && I >= TopN)
587 break;
588 RemainingMemAccesses -= MemAccesses;
589
590 const uint64_t JTIndex = HotTargets[I].second;
591 Callsite &Target = Targets[findTargetsIndex(JTIndex)];
592
593 NewTargets.push_back(Target);
594 std::vector<uint64_t>({JTIndex}).swap(NewTargets.back().JTIndices);
595 llvm::erase_value(Target.JTIndices, JTIndex);
596
597 // Keep fixCFG counts sane if more indices use this same target later
598 assert(IndicesPerTarget[Target.To.Sym] > 0 && "wrong map");
599 NewTargets.back().Branches =
600 Target.Branches / IndicesPerTarget[Target.To.Sym];
601 NewTargets.back().Mispreds =
602 Target.Mispreds / IndicesPerTarget[Target.To.Sym];
603 assert(Target.Branches >= NewTargets.back().Branches);
604 assert(Target.Mispreds >= NewTargets.back().Mispreds);
605 Target.Branches -= NewTargets.back().Branches;
606 Target.Mispreds -= NewTargets.back().Mispreds;
607 }
608 llvm::copy(Targets, std::back_inserter(NewTargets));
609 std::swap(NewTargets, Targets);
610 N = I;
611
612 if (N == 0 && opts::Verbosity >= 1) {
613 outs() << "BOLT-INFO: ICP failed in " << *BB.getFunction() << " in "
614 << BB.getName() << ": failed to meet thresholds after memory "
615 << "profile data was loaded.\n";
616 return SymTargets;
617 }
618 }
619
620 for (size_t I = 0, TgtIdx = 0; I < N; ++TgtIdx) {
621 Callsite &Target = Targets[TgtIdx];
622 assert(Target.To.Sym && "All ICP targets must be to known symbols");
623 assert(!Target.JTIndices.empty() && "Jump tables must have indices");
624 for (uint64_t Idx : Target.JTIndices) {
625 SymTargets.emplace_back(Target.To.Sym, Idx);
626 ++I;
627 }
628 }
629
630 return SymTargets;
631 }
632
maybeGetVtableSyms(BinaryBasicBlock & BB,MCInst & Inst,const SymTargetsType & SymTargets) const633 IndirectCallPromotion::MethodInfoType IndirectCallPromotion::maybeGetVtableSyms(
634 BinaryBasicBlock &BB, MCInst &Inst,
635 const SymTargetsType &SymTargets) const {
636 BinaryFunction &Function = *BB.getFunction();
637 BinaryContext &BC = Function.getBinaryContext();
638 std::vector<std::pair<MCSymbol *, uint64_t>> VtableSyms;
639 std::vector<MCInst *> MethodFetchInsns;
640 unsigned VtableReg, MethodReg;
641 uint64_t MethodOffset;
642
643 assert(!Function.getJumpTable(Inst) &&
644 "Can't get vtable addrs for jump tables.");
645
646 if (!Function.hasMemoryProfile() || !opts::EliminateLoads)
647 return MethodInfoType();
648
649 MutableArrayRef<MCInst> Insts(&BB.front(), &Inst + 1);
650 if (!BC.MIB->analyzeVirtualMethodCall(Insts.begin(), Insts.end(),
651 MethodFetchInsns, VtableReg, MethodReg,
652 MethodOffset)) {
653 DEBUG_VERBOSE(
654 1, dbgs() << "BOLT-INFO: ICP unable to analyze method call in "
655 << Function << " at @ " << (&Inst - &BB.front()) << "\n");
656 return MethodInfoType();
657 }
658
659 ++TotalMethodLoadEliminationCandidates;
660
661 DEBUG_VERBOSE(1, {
662 dbgs() << "BOLT-INFO: ICP found virtual method call in " << Function
663 << " at @ " << (&Inst - &BB.front()) << "\n";
664 dbgs() << "BOLT-INFO: ICP method fetch instructions:\n";
665 for (MCInst *Inst : MethodFetchInsns)
666 BC.printInstruction(dbgs(), *Inst, 0, &Function);
667
668 if (MethodFetchInsns.back() != &Inst)
669 BC.printInstruction(dbgs(), Inst, 0, &Function);
670 });
671
672 // Try to get value profiling data for the method load instruction.
673 auto ErrorOrMemAccesssProfile =
674 BC.MIB->tryGetAnnotationAs<MemoryAccessProfile>(*MethodFetchInsns.back(),
675 "MemoryAccessProfile");
676 if (!ErrorOrMemAccesssProfile) {
677 DEBUG_VERBOSE(1, dbgs()
678 << "BOLT-INFO: ICP no memory profiling data found\n");
679 return MethodInfoType();
680 }
681 MemoryAccessProfile &MemAccessProfile = ErrorOrMemAccesssProfile.get();
682
683 // Find the vtable that each method belongs to.
684 std::map<const MCSymbol *, uint64_t> MethodToVtable;
685
686 for (const AddressAccess &AccessInfo : MemAccessProfile.AddressAccessInfo) {
687 uint64_t Address = AccessInfo.Offset;
688 if (AccessInfo.MemoryObject)
689 Address += AccessInfo.MemoryObject->getAddress();
690
691 // Ignore bogus data.
692 if (!Address)
693 continue;
694
695 const uint64_t VtableBase = Address - MethodOffset;
696
697 DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP vtable = "
698 << Twine::utohexstr(VtableBase) << "+"
699 << MethodOffset << "/" << AccessInfo.Count << "\n");
700
701 if (ErrorOr<uint64_t> MethodAddr = BC.getPointerAtAddress(Address)) {
702 BinaryData *MethodBD = BC.getBinaryDataAtAddress(MethodAddr.get());
703 if (!MethodBD) // skip unknown methods
704 continue;
705 MCSymbol *MethodSym = MethodBD->getSymbol();
706 MethodToVtable[MethodSym] = VtableBase;
707 DEBUG_VERBOSE(1, {
708 const BinaryFunction *Method = BC.getFunctionForSymbol(MethodSym);
709 dbgs() << "BOLT-INFO: ICP found method = "
710 << Twine::utohexstr(MethodAddr.get()) << "/"
711 << (Method ? Method->getPrintName() : "") << "\n";
712 });
713 }
714 }
715
716 // Find the vtable for each target symbol.
717 for (size_t I = 0; I < SymTargets.size(); ++I) {
718 auto Itr = MethodToVtable.find(SymTargets[I].first);
719 if (Itr != MethodToVtable.end()) {
720 if (BinaryData *BD = BC.getBinaryDataContainingAddress(Itr->second)) {
721 const uint64_t Addend = Itr->second - BD->getAddress();
722 VtableSyms.emplace_back(BD->getSymbol(), Addend);
723 continue;
724 }
725 }
726 // Give up if we can't find the vtable for a method.
727 DEBUG_VERBOSE(1, dbgs() << "BOLT-INFO: ICP can't find vtable for "
728 << SymTargets[I].first->getName() << "\n");
729 return MethodInfoType();
730 }
731
732 // Make sure the vtable reg is not clobbered by the argument passing code
733 if (VtableReg != MethodReg) {
734 for (MCInst *CurInst = MethodFetchInsns.front(); CurInst < &Inst;
735 ++CurInst) {
736 const MCInstrDesc &InstrInfo = BC.MII->get(CurInst->getOpcode());
737 if (InstrInfo.hasDefOfPhysReg(*CurInst, VtableReg, *BC.MRI))
738 return MethodInfoType();
739 }
740 }
741
742 return MethodInfoType(VtableSyms, MethodFetchInsns);
743 }
744
745 std::vector<std::unique_ptr<BinaryBasicBlock>>
rewriteCall(BinaryBasicBlock & IndCallBlock,const MCInst & CallInst,MCPlusBuilder::BlocksVectorTy && ICPcode,const std::vector<MCInst * > & MethodFetchInsns) const746 IndirectCallPromotion::rewriteCall(
747 BinaryBasicBlock &IndCallBlock, const MCInst &CallInst,
748 MCPlusBuilder::BlocksVectorTy &&ICPcode,
749 const std::vector<MCInst *> &MethodFetchInsns) const {
750 BinaryFunction &Function = *IndCallBlock.getFunction();
751 MCPlusBuilder *MIB = Function.getBinaryContext().MIB.get();
752
753 // Create new basic blocks with correct code in each one first.
754 std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
755 const bool IsTailCallOrJT =
756 (MIB->isTailCall(CallInst) || Function.getJumpTable(CallInst));
757
758 // Move instructions from the tail of the original call block
759 // to the merge block.
760
761 // Remember any pseudo instructions following a tail call. These
762 // must be preserved and moved to the original block.
763 InstructionListType TailInsts;
764 const MCInst *TailInst = &CallInst;
765 if (IsTailCallOrJT)
766 while (TailInst + 1 < &(*IndCallBlock.end()) &&
767 MIB->isPseudo(*(TailInst + 1)))
768 TailInsts.push_back(*++TailInst);
769
770 InstructionListType MovedInst = IndCallBlock.splitInstructions(&CallInst);
771 // Link new BBs to the original input offset of the BB where the indirect
772 // call site is, so we can map samples recorded in new BBs back to the
773 // original BB seen in the input binary (if using BAT)
774 const uint32_t OrigOffset = IndCallBlock.getInputOffset();
775
776 IndCallBlock.eraseInstructions(MethodFetchInsns.begin(),
777 MethodFetchInsns.end());
778 if (IndCallBlock.empty() ||
779 (!MethodFetchInsns.empty() && MethodFetchInsns.back() == &CallInst))
780 IndCallBlock.addInstructions(ICPcode.front().second.begin(),
781 ICPcode.front().second.end());
782 else
783 IndCallBlock.replaceInstruction(std::prev(IndCallBlock.end()),
784 ICPcode.front().second);
785 IndCallBlock.addInstructions(TailInsts.begin(), TailInsts.end());
786
787 for (auto Itr = ICPcode.begin() + 1; Itr != ICPcode.end(); ++Itr) {
788 MCSymbol *&Sym = Itr->first;
789 InstructionListType &Insts = Itr->second;
790 assert(Sym);
791 std::unique_ptr<BinaryBasicBlock> TBB = Function.createBasicBlock(Sym);
792 TBB->setOffset(OrigOffset);
793 for (MCInst &Inst : Insts) // sanitize new instructions.
794 if (MIB->isCall(Inst))
795 MIB->removeAnnotation(Inst, "CallProfile");
796 TBB->addInstructions(Insts.begin(), Insts.end());
797 NewBBs.emplace_back(std::move(TBB));
798 }
799
800 // Move tail of instructions from after the original call to
801 // the merge block.
802 if (!IsTailCallOrJT)
803 NewBBs.back()->addInstructions(MovedInst.begin(), MovedInst.end());
804
805 return NewBBs;
806 }
807
808 BinaryBasicBlock *
fixCFG(BinaryBasicBlock & IndCallBlock,const bool IsTailCall,const bool IsJumpTable,IndirectCallPromotion::BasicBlocksVector && NewBBs,const std::vector<Callsite> & Targets) const809 IndirectCallPromotion::fixCFG(BinaryBasicBlock &IndCallBlock,
810 const bool IsTailCall, const bool IsJumpTable,
811 IndirectCallPromotion::BasicBlocksVector &&NewBBs,
812 const std::vector<Callsite> &Targets) const {
813 BinaryFunction &Function = *IndCallBlock.getFunction();
814 using BinaryBranchInfo = BinaryBasicBlock::BinaryBranchInfo;
815 BinaryBasicBlock *MergeBlock = nullptr;
816
817 // Scale indirect call counts to the execution count of the original
818 // basic block containing the indirect call.
819 uint64_t TotalCount = IndCallBlock.getKnownExecutionCount();
820 uint64_t TotalIndirectBranches = 0;
821 for (const Callsite &Target : Targets)
822 TotalIndirectBranches += Target.Branches;
823 if (TotalIndirectBranches == 0)
824 TotalIndirectBranches = 1;
825 BinaryBasicBlock::BranchInfoType BBI;
826 BinaryBasicBlock::BranchInfoType ScaledBBI;
827 for (const Callsite &Target : Targets) {
828 const size_t NumEntries =
829 std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
830 for (size_t I = 0; I < NumEntries; ++I) {
831 BBI.push_back(
832 BinaryBranchInfo{(Target.Branches + NumEntries - 1) / NumEntries,
833 (Target.Mispreds + NumEntries - 1) / NumEntries});
834 ScaledBBI.push_back(
835 BinaryBranchInfo{uint64_t(TotalCount * Target.Branches /
836 (NumEntries * TotalIndirectBranches)),
837 uint64_t(TotalCount * Target.Mispreds /
838 (NumEntries * TotalIndirectBranches))});
839 }
840 }
841
842 if (IsJumpTable) {
843 BinaryBasicBlock *NewIndCallBlock = NewBBs.back().get();
844 IndCallBlock.moveAllSuccessorsTo(NewIndCallBlock);
845
846 std::vector<MCSymbol *> SymTargets;
847 for (const Callsite &Target : Targets) {
848 const size_t NumEntries =
849 std::max(static_cast<std::size_t>(1UL), Target.JTIndices.size());
850 for (size_t I = 0; I < NumEntries; ++I)
851 SymTargets.push_back(Target.To.Sym);
852 }
853 assert(SymTargets.size() > NewBBs.size() - 1 &&
854 "There must be a target symbol associated with each new BB.");
855
856 for (uint64_t I = 0; I < NewBBs.size(); ++I) {
857 BinaryBasicBlock *SourceBB = I ? NewBBs[I - 1].get() : &IndCallBlock;
858 SourceBB->setExecutionCount(TotalCount);
859
860 BinaryBasicBlock *TargetBB =
861 Function.getBasicBlockForLabel(SymTargets[I]);
862 SourceBB->addSuccessor(TargetBB, ScaledBBI[I]); // taken
863
864 TotalCount -= ScaledBBI[I].Count;
865 SourceBB->addSuccessor(NewBBs[I].get(), TotalCount); // fall-through
866
867 // Update branch info for the indirect jump.
868 BinaryBasicBlock::BinaryBranchInfo &BranchInfo =
869 NewIndCallBlock->getBranchInfo(*TargetBB);
870 if (BranchInfo.Count > BBI[I].Count)
871 BranchInfo.Count -= BBI[I].Count;
872 else
873 BranchInfo.Count = 0;
874
875 if (BranchInfo.MispredictedCount > BBI[I].MispredictedCount)
876 BranchInfo.MispredictedCount -= BBI[I].MispredictedCount;
877 else
878 BranchInfo.MispredictedCount = 0;
879 }
880 } else {
881 assert(NewBBs.size() >= 2);
882 assert(NewBBs.size() % 2 == 1 || IndCallBlock.succ_empty());
883 assert(NewBBs.size() % 2 == 1 || IsTailCall);
884
885 auto ScaledBI = ScaledBBI.begin();
886 auto updateCurrentBranchInfo = [&] {
887 assert(ScaledBI != ScaledBBI.end());
888 TotalCount -= ScaledBI->Count;
889 ++ScaledBI;
890 };
891
892 if (!IsTailCall) {
893 MergeBlock = NewBBs.back().get();
894 IndCallBlock.moveAllSuccessorsTo(MergeBlock);
895 }
896
897 // Fix up successors and execution counts.
898 updateCurrentBranchInfo();
899 IndCallBlock.addSuccessor(NewBBs[1].get(), TotalCount);
900 IndCallBlock.addSuccessor(NewBBs[0].get(), ScaledBBI[0]);
901
902 const size_t Adj = IsTailCall ? 1 : 2;
903 for (size_t I = 0; I < NewBBs.size() - Adj; ++I) {
904 assert(TotalCount <= IndCallBlock.getExecutionCount() ||
905 TotalCount <= uint64_t(TotalIndirectBranches));
906 uint64_t ExecCount = ScaledBBI[(I + 1) / 2].Count;
907 if (I % 2 == 0) {
908 if (MergeBlock)
909 NewBBs[I]->addSuccessor(MergeBlock, ScaledBBI[(I + 1) / 2].Count);
910 } else {
911 assert(I + 2 < NewBBs.size());
912 updateCurrentBranchInfo();
913 NewBBs[I]->addSuccessor(NewBBs[I + 2].get(), TotalCount);
914 NewBBs[I]->addSuccessor(NewBBs[I + 1].get(), ScaledBBI[(I + 1) / 2]);
915 ExecCount += TotalCount;
916 }
917 NewBBs[I]->setExecutionCount(ExecCount);
918 }
919
920 if (MergeBlock) {
921 // Arrange for the MergeBlock to be the fallthrough for the first
922 // promoted call block.
923 std::unique_ptr<BinaryBasicBlock> MBPtr;
924 std::swap(MBPtr, NewBBs.back());
925 NewBBs.pop_back();
926 NewBBs.emplace(NewBBs.begin() + 1, std::move(MBPtr));
927 // TODO: is COUNT_FALLTHROUGH_EDGE the right thing here?
928 NewBBs.back()->addSuccessor(MergeBlock, TotalCount); // uncond branch
929 }
930 }
931
932 // Update the execution count.
933 NewBBs.back()->setExecutionCount(TotalCount);
934
935 // Update BB and BB layout.
936 Function.insertBasicBlocks(&IndCallBlock, std::move(NewBBs));
937 assert(Function.validateCFG());
938
939 return MergeBlock;
940 }
941
canPromoteCallsite(const BinaryBasicBlock & BB,const MCInst & Inst,const std::vector<Callsite> & Targets,uint64_t NumCalls)942 size_t IndirectCallPromotion::canPromoteCallsite(
943 const BinaryBasicBlock &BB, const MCInst &Inst,
944 const std::vector<Callsite> &Targets, uint64_t NumCalls) {
945 BinaryFunction *BF = BB.getFunction();
946 const BinaryContext &BC = BF->getBinaryContext();
947
948 if (BB.getKnownExecutionCount() < opts::ExecutionCountThreshold)
949 return 0;
950
951 const bool IsJumpTable = BF->getJumpTable(Inst);
952
953 auto computeStats = [&](size_t N) {
954 for (size_t I = 0; I < N; ++I)
955 if (IsJumpTable)
956 TotalNumFrequentJmps += Targets[I].Branches;
957 else
958 TotalNumFrequentCalls += Targets[I].Branches;
959 };
960
961 // If we have no targets (or no calls), skip this callsite.
962 if (Targets.empty() || !NumCalls) {
963 if (opts::Verbosity >= 1) {
964 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
965 outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx << " in "
966 << BB.getName() << ", calls = " << NumCalls
967 << ", targets empty or NumCalls == 0.\n";
968 }
969 return 0;
970 }
971
972 size_t TopN = opts::ICPTopN;
973 if (IsJumpTable)
974 TopN = opts::ICPJumpTablesTopN ? opts::ICPJumpTablesTopN : TopN;
975 else
976 TopN = opts::ICPCallsTopN ? opts::ICPCallsTopN : TopN;
977
978 const size_t TrialN = TopN ? std::min(TopN, Targets.size()) : Targets.size();
979
980 if (opts::ICPTopCallsites && !BC.MIB->hasAnnotation(Inst, "DoICP"))
981 return 0;
982
983 // Pick the top N targets.
984 uint64_t TotalMispredictsTopN = 0;
985 size_t N = 0;
986
987 if (opts::ICPUseMispredicts &&
988 (!IsJumpTable || opts::ICPJumpTablesByTarget)) {
989 // Count total number of mispredictions for (at most) the top N targets.
990 // We may choose a smaller N (TrialN vs. N) if the frequency threshold
991 // is exceeded by fewer targets.
992 double Threshold = double(opts::ICPMispredictThreshold);
993 for (size_t I = 0; I < TrialN && Threshold > 0; ++I, ++N) {
994 Threshold -= (100.0 * Targets[I].Mispreds) / NumCalls;
995 TotalMispredictsTopN += Targets[I].Mispreds;
996 }
997 computeStats(N);
998
999 // Compute the misprediction frequency of the top N call targets. If this
1000 // frequency is greater than the threshold, we should try ICP on this
1001 // callsite.
1002 const double TopNFrequency = (100.0 * TotalMispredictsTopN) / NumCalls;
1003 if (TopNFrequency == 0 || TopNFrequency < opts::ICPMispredictThreshold) {
1004 if (opts::Verbosity >= 1) {
1005 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1006 outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1007 << " in " << BB.getName() << ", calls = " << NumCalls
1008 << ", top N mis. frequency " << format("%.1f", TopNFrequency)
1009 << "% < " << opts::ICPMispredictThreshold << "%\n";
1010 }
1011 return 0;
1012 }
1013 } else {
1014 size_t MaxTargets = 0;
1015
1016 // Count total number of calls for (at most) the top N targets.
1017 // We may choose a smaller N (TrialN vs. N) if the frequency threshold
1018 // is exceeded by fewer targets.
1019 const unsigned TotalThreshold = IsJumpTable
1020 ? opts::ICPJTTotalPercentThreshold
1021 : opts::ICPCallsTotalPercentThreshold;
1022 const unsigned RemainingThreshold =
1023 IsJumpTable ? opts::ICPJTRemainingPercentThreshold
1024 : opts::ICPCallsRemainingPercentThreshold;
1025 uint64_t NumRemainingCalls = NumCalls;
1026 for (size_t I = 0; I < TrialN; ++I, ++MaxTargets) {
1027 if (100 * Targets[I].Branches < NumCalls * TotalThreshold)
1028 break;
1029 if (100 * Targets[I].Branches < NumRemainingCalls * RemainingThreshold)
1030 break;
1031 if (N + (Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size()) >
1032 TrialN)
1033 break;
1034 TotalMispredictsTopN += Targets[I].Mispreds;
1035 NumRemainingCalls -= Targets[I].Branches;
1036 N += Targets[I].JTIndices.empty() ? 1 : Targets[I].JTIndices.size();
1037 }
1038 computeStats(MaxTargets);
1039
1040 // Don't check misprediction frequency for jump tables -- we don't really
1041 // care as long as we are saving loads from the jump table.
1042 if (!IsJumpTable || opts::ICPJumpTablesByTarget) {
1043 // Compute the misprediction frequency of the top N call targets. If
1044 // this frequency is less than the threshold, we should skip ICP at
1045 // this callsite.
1046 const double TopNMispredictFrequency =
1047 (100.0 * TotalMispredictsTopN) / NumCalls;
1048
1049 if (TopNMispredictFrequency < opts::ICPMispredictThreshold) {
1050 if (opts::Verbosity >= 1) {
1051 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1052 outs() << "BOLT-INFO: ICP failed in " << *BF << " @ " << InstIdx
1053 << " in " << BB.getName() << ", calls = " << NumCalls
1054 << ", top N mispredict frequency "
1055 << format("%.1f", TopNMispredictFrequency) << "% < "
1056 << opts::ICPMispredictThreshold << "%\n";
1057 }
1058 return 0;
1059 }
1060 }
1061 }
1062
1063 // Filter by inline-ability of target functions, stop at first target that
1064 // can't be inlined.
1065 if (!IsJumpTable && opts::ICPPeelForInline) {
1066 for (size_t I = 0; I < N; ++I) {
1067 const MCSymbol *TargetSym = Targets[I].To.Sym;
1068 const BinaryFunction *TargetBF = BC.getFunctionForSymbol(TargetSym);
1069 if (!TargetBF || !BinaryFunctionPass::shouldOptimize(*TargetBF) ||
1070 getInliningInfo(*TargetBF).Type == InliningType::INL_NONE) {
1071 N = I;
1072 break;
1073 }
1074 }
1075 }
1076
1077 // Filter functions that can have ICP applied (for debugging)
1078 if (!opts::ICPFuncsList.empty()) {
1079 for (std::string &Name : opts::ICPFuncsList)
1080 if (BF->hasName(Name))
1081 return N;
1082 return 0;
1083 }
1084
1085 return N;
1086 }
1087
printCallsiteInfo(const BinaryBasicBlock & BB,const MCInst & Inst,const std::vector<Callsite> & Targets,const size_t N,uint64_t NumCalls) const1088 void IndirectCallPromotion::printCallsiteInfo(
1089 const BinaryBasicBlock &BB, const MCInst &Inst,
1090 const std::vector<Callsite> &Targets, const size_t N,
1091 uint64_t NumCalls) const {
1092 BinaryContext &BC = BB.getFunction()->getBinaryContext();
1093 const bool IsTailCall = BC.MIB->isTailCall(Inst);
1094 const bool IsJumpTable = BB.getFunction()->getJumpTable(Inst);
1095 const ptrdiff_t InstIdx = &Inst - &(*BB.begin());
1096
1097 outs() << "BOLT-INFO: ICP candidate branch info: " << *BB.getFunction()
1098 << " @ " << InstIdx << " in " << BB.getName()
1099 << " -> calls = " << NumCalls
1100 << (IsTailCall ? " (tail)" : (IsJumpTable ? " (jump table)" : ""))
1101 << "\n";
1102 for (size_t I = 0; I < N; I++) {
1103 const double Frequency = 100.0 * Targets[I].Branches / NumCalls;
1104 const double MisFrequency = 100.0 * Targets[I].Mispreds / NumCalls;
1105 outs() << "BOLT-INFO: ";
1106 if (Targets[I].To.Sym)
1107 outs() << Targets[I].To.Sym->getName();
1108 else
1109 outs() << Targets[I].To.Addr;
1110 outs() << ", calls = " << Targets[I].Branches
1111 << ", mispreds = " << Targets[I].Mispreds
1112 << ", taken freq = " << format("%.1f", Frequency) << "%"
1113 << ", mis. freq = " << format("%.1f", MisFrequency) << "%";
1114 bool First = true;
1115 for (uint64_t JTIndex : Targets[I].JTIndices) {
1116 outs() << (First ? ", indices = " : ", ") << JTIndex;
1117 First = false;
1118 }
1119 outs() << "\n";
1120 }
1121
1122 LLVM_DEBUG({
1123 dbgs() << "BOLT-INFO: ICP original call instruction:";
1124 BC.printInstruction(dbgs(), Inst, Targets[0].From.Addr, nullptr, true);
1125 });
1126 }
1127
runOnFunctions(BinaryContext & BC)1128 void IndirectCallPromotion::runOnFunctions(BinaryContext &BC) {
1129 if (opts::ICP == ICP_NONE)
1130 return;
1131
1132 auto &BFs = BC.getBinaryFunctions();
1133
1134 const bool OptimizeCalls = (opts::ICP == ICP_CALLS || opts::ICP == ICP_ALL);
1135 const bool OptimizeJumpTables =
1136 (opts::ICP == ICP_JUMP_TABLES || opts::ICP == ICP_ALL);
1137
1138 std::unique_ptr<RegAnalysis> RA;
1139 std::unique_ptr<BinaryFunctionCallGraph> CG;
1140 if (OptimizeJumpTables) {
1141 CG.reset(new BinaryFunctionCallGraph(buildCallGraph(BC)));
1142 RA.reset(new RegAnalysis(BC, &BFs, &*CG));
1143 }
1144
1145 // If icp-top-callsites is enabled, compute the total number of indirect
1146 // calls and then optimize the hottest callsites that contribute to that
1147 // total.
1148 SetVector<BinaryFunction *> Functions;
1149 if (opts::ICPTopCallsites == 0) {
1150 for (auto &KV : BFs)
1151 Functions.insert(&KV.second);
1152 } else {
1153 using IndirectCallsite = std::tuple<uint64_t, MCInst *, BinaryFunction *>;
1154 std::vector<IndirectCallsite> IndirectCalls;
1155 size_t TotalIndirectCalls = 0;
1156
1157 // Find all the indirect callsites.
1158 for (auto &BFIt : BFs) {
1159 BinaryFunction &Function = BFIt.second;
1160
1161 if (!Function.isSimple() || Function.isIgnored() ||
1162 !Function.hasProfile())
1163 continue;
1164
1165 const bool HasLayout = !Function.getLayout().block_empty();
1166
1167 for (BinaryBasicBlock &BB : Function) {
1168 if (HasLayout && Function.isSplit() && BB.isCold())
1169 continue;
1170
1171 for (MCInst &Inst : BB) {
1172 const bool IsJumpTable = Function.getJumpTable(Inst);
1173 const bool HasIndirectCallProfile =
1174 BC.MIB->hasAnnotation(Inst, "CallProfile");
1175 const bool IsDirectCall =
1176 (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0));
1177
1178 if (!IsDirectCall &&
1179 ((HasIndirectCallProfile && !IsJumpTable && OptimizeCalls) ||
1180 (IsJumpTable && OptimizeJumpTables))) {
1181 uint64_t NumCalls = 0;
1182 for (const Callsite &BInfo : getCallTargets(BB, Inst))
1183 NumCalls += BInfo.Branches;
1184 IndirectCalls.push_back(
1185 std::make_tuple(NumCalls, &Inst, &Function));
1186 TotalIndirectCalls += NumCalls;
1187 }
1188 }
1189 }
1190 }
1191
1192 // Sort callsites by execution count.
1193 llvm::sort(reverse(IndirectCalls));
1194
1195 // Find callsites that contribute to the top "opts::ICPTopCallsites"%
1196 // number of calls.
1197 const float TopPerc = opts::ICPTopCallsites / 100.0f;
1198 int64_t MaxCalls = TotalIndirectCalls * TopPerc;
1199 uint64_t LastFreq = std::numeric_limits<uint64_t>::max();
1200 size_t Num = 0;
1201 for (const IndirectCallsite &IC : IndirectCalls) {
1202 const uint64_t CurFreq = std::get<0>(IC);
1203 // Once we decide to stop, include at least all branches that share the
1204 // same frequency of the last one to avoid non-deterministic behavior
1205 // (e.g. turning on/off ICP depending on the order of functions)
1206 if (MaxCalls <= 0 && CurFreq != LastFreq)
1207 break;
1208 MaxCalls -= CurFreq;
1209 LastFreq = CurFreq;
1210 BC.MIB->addAnnotation(*std::get<1>(IC), "DoICP", true);
1211 Functions.insert(std::get<2>(IC));
1212 ++Num;
1213 }
1214 outs() << "BOLT-INFO: ICP Total indirect calls = " << TotalIndirectCalls
1215 << ", " << Num << " callsites cover " << opts::ICPTopCallsites
1216 << "% of all indirect calls\n";
1217 }
1218
1219 for (BinaryFunction *FuncPtr : Functions) {
1220 BinaryFunction &Function = *FuncPtr;
1221
1222 if (!Function.isSimple() || Function.isIgnored() || !Function.hasProfile())
1223 continue;
1224
1225 const bool HasLayout = !Function.getLayout().block_empty();
1226
1227 // Total number of indirect calls issued from the current Function.
1228 // (a fraction of TotalIndirectCalls)
1229 uint64_t FuncTotalIndirectCalls = 0;
1230 uint64_t FuncTotalIndirectJmps = 0;
1231
1232 std::vector<BinaryBasicBlock *> BBs;
1233 for (BinaryBasicBlock &BB : Function) {
1234 // Skip indirect calls in cold blocks.
1235 if (!HasLayout || !Function.isSplit() || !BB.isCold())
1236 BBs.push_back(&BB);
1237 }
1238 if (BBs.empty())
1239 continue;
1240
1241 DataflowInfoManager Info(Function, RA.get(), nullptr);
1242 while (!BBs.empty()) {
1243 BinaryBasicBlock *BB = BBs.back();
1244 BBs.pop_back();
1245
1246 for (unsigned Idx = 0; Idx < BB->size(); ++Idx) {
1247 MCInst &Inst = BB->getInstructionAtIndex(Idx);
1248 const ptrdiff_t InstIdx = &Inst - &(*BB->begin());
1249 const bool IsTailCall = BC.MIB->isTailCall(Inst);
1250 const bool HasIndirectCallProfile =
1251 BC.MIB->hasAnnotation(Inst, "CallProfile");
1252 const bool IsJumpTable = Function.getJumpTable(Inst);
1253
1254 if (BC.MIB->isCall(Inst))
1255 TotalCalls += BB->getKnownExecutionCount();
1256
1257 if (IsJumpTable && !OptimizeJumpTables)
1258 continue;
1259
1260 if (!IsJumpTable && (!HasIndirectCallProfile || !OptimizeCalls))
1261 continue;
1262
1263 // Ignore direct calls.
1264 if (BC.MIB->isCall(Inst) && BC.MIB->getTargetSymbol(Inst, 0))
1265 continue;
1266
1267 assert((BC.MIB->isCall(Inst) || BC.MIB->isIndirectBranch(Inst)) &&
1268 "expected a call or an indirect jump instruction");
1269
1270 if (IsJumpTable)
1271 ++TotalJumpTableCallsites;
1272 else
1273 ++TotalIndirectCallsites;
1274
1275 std::vector<Callsite> Targets = getCallTargets(*BB, Inst);
1276
1277 // Compute the total number of calls from this particular callsite.
1278 uint64_t NumCalls = 0;
1279 for (const Callsite &BInfo : Targets)
1280 NumCalls += BInfo.Branches;
1281 if (!IsJumpTable)
1282 FuncTotalIndirectCalls += NumCalls;
1283 else
1284 FuncTotalIndirectJmps += NumCalls;
1285
1286 // If FLAGS regs is alive after this jmp site, do not try
1287 // promoting because we will clobber FLAGS.
1288 if (IsJumpTable) {
1289 ErrorOr<const BitVector &> State =
1290 Info.getLivenessAnalysis().getStateBefore(Inst);
1291 if (!State || (State && (*State)[BC.MIB->getFlagsReg()])) {
1292 if (opts::Verbosity >= 1)
1293 outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1294 << InstIdx << " in " << BB->getName()
1295 << ", calls = " << NumCalls
1296 << (State ? ", cannot clobber flags reg.\n"
1297 : ", no liveness data available.\n");
1298 continue;
1299 }
1300 }
1301
1302 // Should this callsite be optimized? Return the number of targets
1303 // to use when promoting this call. A value of zero means to skip
1304 // this callsite.
1305 size_t N = canPromoteCallsite(*BB, Inst, Targets, NumCalls);
1306
1307 // If it is a jump table and it failed to meet our initial threshold,
1308 // proceed to findCallTargetSymbols -- it may reevaluate N if
1309 // memory profile is present
1310 if (!N && !IsJumpTable)
1311 continue;
1312
1313 if (opts::Verbosity >= 1)
1314 printCallsiteInfo(*BB, Inst, Targets, N, NumCalls);
1315
1316 // Find MCSymbols or absolute addresses for each call target.
1317 MCInst *TargetFetchInst = nullptr;
1318 const SymTargetsType SymTargets =
1319 findCallTargetSymbols(Targets, N, *BB, Inst, TargetFetchInst);
1320
1321 // findCallTargetSymbols may have changed N if mem profile is available
1322 // for jump tables
1323 if (!N)
1324 continue;
1325
1326 LLVM_DEBUG(printDecision(dbgs(), Targets, N));
1327
1328 // If we can't resolve any of the target symbols, punt on this callsite.
1329 // TODO: can this ever happen?
1330 if (SymTargets.size() < N) {
1331 const size_t LastTarget = SymTargets.size();
1332 if (opts::Verbosity >= 1)
1333 outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1334 << InstIdx << " in " << BB->getName()
1335 << ", calls = " << NumCalls
1336 << ", ICP failed to find target symbol for "
1337 << Targets[LastTarget].To.Sym->getName() << "\n";
1338 continue;
1339 }
1340
1341 MethodInfoType MethodInfo;
1342
1343 if (!IsJumpTable) {
1344 MethodInfo = maybeGetVtableSyms(*BB, Inst, SymTargets);
1345 TotalMethodLoadsEliminated += MethodInfo.first.empty() ? 0 : 1;
1346 LLVM_DEBUG(dbgs()
1347 << "BOLT-INFO: ICP "
1348 << (!MethodInfo.first.empty() ? "found" : "did not find")
1349 << " vtables for all methods.\n");
1350 } else if (TargetFetchInst) {
1351 ++TotalIndexBasedJumps;
1352 MethodInfo.second.push_back(TargetFetchInst);
1353 }
1354
1355 // Generate new promoted call code for this callsite.
1356 MCPlusBuilder::BlocksVectorTy ICPcode =
1357 (IsJumpTable && !opts::ICPJumpTablesByTarget)
1358 ? BC.MIB->jumpTablePromotion(Inst, SymTargets,
1359 MethodInfo.second, BC.Ctx.get())
1360 : BC.MIB->indirectCallPromotion(
1361 Inst, SymTargets, MethodInfo.first, MethodInfo.second,
1362 opts::ICPOldCodeSequence, BC.Ctx.get());
1363
1364 if (ICPcode.empty()) {
1365 if (opts::Verbosity >= 1)
1366 outs() << "BOLT-INFO: ICP failed in " << Function << " @ "
1367 << InstIdx << " in " << BB->getName()
1368 << ", calls = " << NumCalls
1369 << ", unable to generate promoted call code.\n";
1370 continue;
1371 }
1372
1373 LLVM_DEBUG({
1374 uint64_t Offset = Targets[0].From.Addr;
1375 dbgs() << "BOLT-INFO: ICP indirect call code:\n";
1376 for (const auto &entry : ICPcode) {
1377 const MCSymbol *const &Sym = entry.first;
1378 const InstructionListType &Insts = entry.second;
1379 if (Sym)
1380 dbgs() << Sym->getName() << ":\n";
1381 Offset = BC.printInstructions(dbgs(), Insts.begin(), Insts.end(),
1382 Offset);
1383 }
1384 dbgs() << "---------------------------------------------------\n";
1385 });
1386
1387 // Rewrite the CFG with the newly generated ICP code.
1388 std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs =
1389 rewriteCall(*BB, Inst, std::move(ICPcode), MethodInfo.second);
1390
1391 // Fix the CFG after inserting the new basic blocks.
1392 BinaryBasicBlock *MergeBlock =
1393 fixCFG(*BB, IsTailCall, IsJumpTable, std::move(NewBBs), Targets);
1394
1395 // Since the tail of the original block was split off and it may contain
1396 // additional indirect calls, we must add the merge block to the set of
1397 // blocks to process.
1398 if (MergeBlock)
1399 BBs.push_back(MergeBlock);
1400
1401 if (opts::Verbosity >= 1)
1402 outs() << "BOLT-INFO: ICP succeeded in " << Function << " @ "
1403 << InstIdx << " in " << BB->getName()
1404 << " -> calls = " << NumCalls << "\n";
1405
1406 if (IsJumpTable)
1407 ++TotalOptimizedJumpTableCallsites;
1408 else
1409 ++TotalOptimizedIndirectCallsites;
1410
1411 Modified.insert(&Function);
1412 }
1413 }
1414 TotalIndirectCalls += FuncTotalIndirectCalls;
1415 TotalIndirectJmps += FuncTotalIndirectJmps;
1416 }
1417
1418 outs() << "BOLT-INFO: ICP total indirect callsites with profile = "
1419 << TotalIndirectCallsites << "\n"
1420 << "BOLT-INFO: ICP total jump table callsites = "
1421 << TotalJumpTableCallsites << "\n"
1422 << "BOLT-INFO: ICP total number of calls = " << TotalCalls << "\n"
1423 << "BOLT-INFO: ICP percentage of calls that are indirect = "
1424 << format("%.1f", (100.0 * TotalIndirectCalls) / TotalCalls) << "%\n"
1425 << "BOLT-INFO: ICP percentage of indirect calls that can be "
1426 "optimized = "
1427 << format("%.1f", (100.0 * TotalNumFrequentCalls) /
1428 std::max<size_t>(TotalIndirectCalls, 1))
1429 << "%\n"
1430 << "BOLT-INFO: ICP percentage of indirect callsites that are "
1431 "optimized = "
1432 << format("%.1f", (100.0 * TotalOptimizedIndirectCallsites) /
1433 std::max<uint64_t>(TotalIndirectCallsites, 1))
1434 << "%\n"
1435 << "BOLT-INFO: ICP number of method load elimination candidates = "
1436 << TotalMethodLoadEliminationCandidates << "\n"
1437 << "BOLT-INFO: ICP percentage of method calls candidates that have "
1438 "loads eliminated = "
1439 << format("%.1f", (100.0 * TotalMethodLoadsEliminated) /
1440 std::max<uint64_t>(
1441 TotalMethodLoadEliminationCandidates, 1))
1442 << "%\n"
1443 << "BOLT-INFO: ICP percentage of indirect branches that are "
1444 "optimized = "
1445 << format("%.1f", (100.0 * TotalNumFrequentJmps) /
1446 std::max<uint64_t>(TotalIndirectJmps, 1))
1447 << "%\n"
1448 << "BOLT-INFO: ICP percentage of jump table callsites that are "
1449 << "optimized = "
1450 << format("%.1f", (100.0 * TotalOptimizedJumpTableCallsites) /
1451 std::max<uint64_t>(TotalJumpTableCallsites, 1))
1452 << "%\n"
1453 << "BOLT-INFO: ICP number of jump table callsites that can use hot "
1454 << "indices = " << TotalIndexBasedCandidates << "\n"
1455 << "BOLT-INFO: ICP percentage of jump table callsites that use hot "
1456 "indices = "
1457 << format("%.1f", (100.0 * TotalIndexBasedJumps) /
1458 std::max<uint64_t>(TotalIndexBasedCandidates, 1))
1459 << "%\n";
1460
1461 (void)verifyProfile;
1462 #ifndef NDEBUG
1463 verifyProfile(BFs);
1464 #endif
1465 }
1466
1467 } // namespace bolt
1468 } // namespace llvm
1469