1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24
25 static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
29
30 using namespace clang;
31 using namespace CodeGen;
32
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)33 void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44
setFuncName(llvm::Function * Fn)45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53 PGO_HASH_V1,
54 PGO_HASH_V2,
55
56 // Keep this set to the latest hash version.
57 PGO_HASH_LATEST = PGO_HASH_V2
58 };
59
60 namespace {
61 /// Stable hasher for PGO region counters.
62 ///
63 /// PGOHash produces a stable hash of a given function's control flow.
64 ///
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
67 ///
68 /// \note When this hash does eventually change (years?), we still need to
69 /// support old hashes. We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
71 class PGOHash {
72 uint64_t Working;
73 unsigned Count;
74 PGOHashVersion HashVersion;
75 llvm::MD5 MD5;
76
77 static const int NumBitsPerType = 6;
78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79 static const unsigned TooBig = 1u << NumBitsPerType;
80
81 public:
82 /// Hash values for AST nodes.
83 ///
84 /// Distinct values for AST nodes that have region counters attached.
85 ///
86 /// These values must be stable. All new members must be added at the end,
87 /// and no members should be removed. Changing the enumeration value for an
88 /// AST node will affect the hash of every function that contains that node.
89 enum HashType : unsigned char {
90 None = 0,
91 LabelStmt = 1,
92 WhileStmt,
93 DoStmt,
94 ForStmt,
95 CXXForRangeStmt,
96 ObjCForCollectionStmt,
97 SwitchStmt,
98 CaseStmt,
99 DefaultStmt,
100 IfStmt,
101 CXXTryStmt,
102 CXXCatchStmt,
103 ConditionalOperator,
104 BinaryOperatorLAnd,
105 BinaryOperatorLOr,
106 BinaryConditionalOperator,
107 // The preceding values are available with PGO_HASH_V1.
108
109 EndOfScope,
110 IfThenBranch,
111 IfElseBranch,
112 GotoStmt,
113 IndirectGotoStmt,
114 BreakStmt,
115 ContinueStmt,
116 ReturnStmt,
117 ThrowExpr,
118 UnaryOperatorLNot,
119 BinaryOperatorLT,
120 BinaryOperatorGT,
121 BinaryOperatorLE,
122 BinaryOperatorGE,
123 BinaryOperatorEQ,
124 BinaryOperatorNE,
125 // The preceding values are available with PGO_HASH_V2.
126
127 // Keep this last. It's for the static assert that follows.
128 LastHashType
129 };
130 static_assert(LastHashType <= TooBig, "Too many types in HashType");
131
PGOHash(PGOHashVersion HashVersion)132 PGOHash(PGOHashVersion HashVersion)
133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134 void combine(HashType Type);
135 uint64_t finalize();
getHashVersion() const136 PGOHashVersion getHashVersion() const { return HashVersion; }
137 };
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
141
142 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144 CodeGenModule &CGM) {
145 if (PGOReader->getVersion() <= 4)
146 return PGO_HASH_V1;
147 return PGO_HASH_V2;
148 }
149
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152 using Base = RecursiveASTVisitor<MapRegionCounters>;
153
154 /// The next counter value to assign.
155 unsigned NextCounter;
156 /// The function hash.
157 PGOHash Hash;
158 /// The map of statements to counters.
159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160
MapRegionCounters__anon67ae45280111::MapRegionCounters161 MapRegionCounters(PGOHashVersion HashVersion,
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164
165 // Blocks and lambdas are handled as separate functions, so we need not
166 // traverse them in the parent context.
TraverseBlockExpr__anon67ae45280111::MapRegionCounters167 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon67ae45280111::MapRegionCounters168 bool TraverseLambdaExpr(LambdaExpr *LE) {
169 // Traverse the captures, but not the body.
170 for (const auto &C : zip(LE->captures(), LE->capture_inits()))
171 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
172 return true;
173 }
TraverseCapturedStmt__anon67ae45280111::MapRegionCounters174 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
175
VisitDecl__anon67ae45280111::MapRegionCounters176 bool VisitDecl(const Decl *D) {
177 switch (D->getKind()) {
178 default:
179 break;
180 case Decl::Function:
181 case Decl::CXXMethod:
182 case Decl::CXXConstructor:
183 case Decl::CXXDestructor:
184 case Decl::CXXConversion:
185 case Decl::ObjCMethod:
186 case Decl::Block:
187 case Decl::Captured:
188 CounterMap[D->getBody()] = NextCounter++;
189 break;
190 }
191 return true;
192 }
193
194 /// If \p S gets a fresh counter, update the counter mappings. Return the
195 /// V1 hash of \p S.
updateCounterMappings__anon67ae45280111::MapRegionCounters196 PGOHash::HashType updateCounterMappings(Stmt *S) {
197 auto Type = getHashType(PGO_HASH_V1, S);
198 if (Type != PGOHash::None)
199 CounterMap[S] = NextCounter++;
200 return Type;
201 }
202
203 /// Include \p S in the function hash.
VisitStmt__anon67ae45280111::MapRegionCounters204 bool VisitStmt(Stmt *S) {
205 auto Type = updateCounterMappings(S);
206 if (Hash.getHashVersion() != PGO_HASH_V1)
207 Type = getHashType(Hash.getHashVersion(), S);
208 if (Type != PGOHash::None)
209 Hash.combine(Type);
210 return true;
211 }
212
TraverseIfStmt__anon67ae45280111::MapRegionCounters213 bool TraverseIfStmt(IfStmt *If) {
214 // If we used the V1 hash, use the default traversal.
215 if (Hash.getHashVersion() == PGO_HASH_V1)
216 return Base::TraverseIfStmt(If);
217
218 // Otherwise, keep track of which branch we're in while traversing.
219 VisitStmt(If);
220 for (Stmt *CS : If->children()) {
221 if (!CS)
222 continue;
223 if (CS == If->getThen())
224 Hash.combine(PGOHash::IfThenBranch);
225 else if (CS == If->getElse())
226 Hash.combine(PGOHash::IfElseBranch);
227 TraverseStmt(CS);
228 }
229 Hash.combine(PGOHash::EndOfScope);
230 return true;
231 }
232
233 // If the statement type \p N is nestable, and its nesting impacts profile
234 // stability, define a custom traversal which tracks the end of the statement
235 // in the hash (provided we're not using the V1 hash).
236 #define DEFINE_NESTABLE_TRAVERSAL(N) \
237 bool Traverse##N(N *S) { \
238 Base::Traverse##N(S); \
239 if (Hash.getHashVersion() != PGO_HASH_V1) \
240 Hash.combine(PGOHash::EndOfScope); \
241 return true; \
242 }
243
244 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon67ae45280111::MapRegionCounters245 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
246 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
247 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
248 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
249 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
250 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
251
252 /// Get version \p HashVersion of the PGO hash for \p S.
253 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
254 switch (S->getStmtClass()) {
255 default:
256 break;
257 case Stmt::LabelStmtClass:
258 return PGOHash::LabelStmt;
259 case Stmt::WhileStmtClass:
260 return PGOHash::WhileStmt;
261 case Stmt::DoStmtClass:
262 return PGOHash::DoStmt;
263 case Stmt::ForStmtClass:
264 return PGOHash::ForStmt;
265 case Stmt::CXXForRangeStmtClass:
266 return PGOHash::CXXForRangeStmt;
267 case Stmt::ObjCForCollectionStmtClass:
268 return PGOHash::ObjCForCollectionStmt;
269 case Stmt::SwitchStmtClass:
270 return PGOHash::SwitchStmt;
271 case Stmt::CaseStmtClass:
272 return PGOHash::CaseStmt;
273 case Stmt::DefaultStmtClass:
274 return PGOHash::DefaultStmt;
275 case Stmt::IfStmtClass:
276 return PGOHash::IfStmt;
277 case Stmt::CXXTryStmtClass:
278 return PGOHash::CXXTryStmt;
279 case Stmt::CXXCatchStmtClass:
280 return PGOHash::CXXCatchStmt;
281 case Stmt::ConditionalOperatorClass:
282 return PGOHash::ConditionalOperator;
283 case Stmt::BinaryConditionalOperatorClass:
284 return PGOHash::BinaryConditionalOperator;
285 case Stmt::BinaryOperatorClass: {
286 const BinaryOperator *BO = cast<BinaryOperator>(S);
287 if (BO->getOpcode() == BO_LAnd)
288 return PGOHash::BinaryOperatorLAnd;
289 if (BO->getOpcode() == BO_LOr)
290 return PGOHash::BinaryOperatorLOr;
291 if (HashVersion == PGO_HASH_V2) {
292 switch (BO->getOpcode()) {
293 default:
294 break;
295 case BO_LT:
296 return PGOHash::BinaryOperatorLT;
297 case BO_GT:
298 return PGOHash::BinaryOperatorGT;
299 case BO_LE:
300 return PGOHash::BinaryOperatorLE;
301 case BO_GE:
302 return PGOHash::BinaryOperatorGE;
303 case BO_EQ:
304 return PGOHash::BinaryOperatorEQ;
305 case BO_NE:
306 return PGOHash::BinaryOperatorNE;
307 }
308 }
309 break;
310 }
311 }
312
313 if (HashVersion == PGO_HASH_V2) {
314 switch (S->getStmtClass()) {
315 default:
316 break;
317 case Stmt::GotoStmtClass:
318 return PGOHash::GotoStmt;
319 case Stmt::IndirectGotoStmtClass:
320 return PGOHash::IndirectGotoStmt;
321 case Stmt::BreakStmtClass:
322 return PGOHash::BreakStmt;
323 case Stmt::ContinueStmtClass:
324 return PGOHash::ContinueStmt;
325 case Stmt::ReturnStmtClass:
326 return PGOHash::ReturnStmt;
327 case Stmt::CXXThrowExprClass:
328 return PGOHash::ThrowExpr;
329 case Stmt::UnaryOperatorClass: {
330 const UnaryOperator *UO = cast<UnaryOperator>(S);
331 if (UO->getOpcode() == UO_LNot)
332 return PGOHash::UnaryOperatorLNot;
333 break;
334 }
335 }
336 }
337
338 return PGOHash::None;
339 }
340 };
341
342 /// A StmtVisitor that propagates the raw counts through the AST and
343 /// records the count at statements where the value may change.
344 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
345 /// PGO state.
346 CodeGenPGO &PGO;
347
348 /// A flag that is set when the current count should be recorded on the
349 /// next statement, such as at the exit of a loop.
350 bool RecordNextStmtCount;
351
352 /// The count at the current location in the traversal.
353 uint64_t CurrentCount;
354
355 /// The map of statements to count values.
356 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
357
358 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
359 struct BreakContinue {
360 uint64_t BreakCount;
361 uint64_t ContinueCount;
BreakContinue__anon67ae45280111::ComputeRegionCounts::BreakContinue362 BreakContinue() : BreakCount(0), ContinueCount(0) {}
363 };
364 SmallVector<BreakContinue, 8> BreakContinueStack;
365
ComputeRegionCounts__anon67ae45280111::ComputeRegionCounts366 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
367 CodeGenPGO &PGO)
368 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
369
RecordStmtCount__anon67ae45280111::ComputeRegionCounts370 void RecordStmtCount(const Stmt *S) {
371 if (RecordNextStmtCount) {
372 CountMap[S] = CurrentCount;
373 RecordNextStmtCount = false;
374 }
375 }
376
377 /// Set and return the current count.
setCount__anon67ae45280111::ComputeRegionCounts378 uint64_t setCount(uint64_t Count) {
379 CurrentCount = Count;
380 return Count;
381 }
382
VisitStmt__anon67ae45280111::ComputeRegionCounts383 void VisitStmt(const Stmt *S) {
384 RecordStmtCount(S);
385 for (const Stmt *Child : S->children())
386 if (Child)
387 this->Visit(Child);
388 }
389
VisitFunctionDecl__anon67ae45280111::ComputeRegionCounts390 void VisitFunctionDecl(const FunctionDecl *D) {
391 // Counter tracks entry to the function body.
392 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
393 CountMap[D->getBody()] = BodyCount;
394 Visit(D->getBody());
395 }
396
397 // Skip lambda expressions. We visit these as FunctionDecls when we're
398 // generating them and aren't interested in the body when generating a
399 // parent context.
VisitLambdaExpr__anon67ae45280111::ComputeRegionCounts400 void VisitLambdaExpr(const LambdaExpr *LE) {}
401
VisitCapturedDecl__anon67ae45280111::ComputeRegionCounts402 void VisitCapturedDecl(const CapturedDecl *D) {
403 // Counter tracks entry to the capture body.
404 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
405 CountMap[D->getBody()] = BodyCount;
406 Visit(D->getBody());
407 }
408
VisitObjCMethodDecl__anon67ae45280111::ComputeRegionCounts409 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410 // Counter tracks entry to the method body.
411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412 CountMap[D->getBody()] = BodyCount;
413 Visit(D->getBody());
414 }
415
VisitBlockDecl__anon67ae45280111::ComputeRegionCounts416 void VisitBlockDecl(const BlockDecl *D) {
417 // Counter tracks entry to the block body.
418 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
419 CountMap[D->getBody()] = BodyCount;
420 Visit(D->getBody());
421 }
422
VisitReturnStmt__anon67ae45280111::ComputeRegionCounts423 void VisitReturnStmt(const ReturnStmt *S) {
424 RecordStmtCount(S);
425 if (S->getRetValue())
426 Visit(S->getRetValue());
427 CurrentCount = 0;
428 RecordNextStmtCount = true;
429 }
430
VisitCXXThrowExpr__anon67ae45280111::ComputeRegionCounts431 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
432 RecordStmtCount(E);
433 if (E->getSubExpr())
434 Visit(E->getSubExpr());
435 CurrentCount = 0;
436 RecordNextStmtCount = true;
437 }
438
VisitGotoStmt__anon67ae45280111::ComputeRegionCounts439 void VisitGotoStmt(const GotoStmt *S) {
440 RecordStmtCount(S);
441 CurrentCount = 0;
442 RecordNextStmtCount = true;
443 }
444
VisitLabelStmt__anon67ae45280111::ComputeRegionCounts445 void VisitLabelStmt(const LabelStmt *S) {
446 RecordNextStmtCount = false;
447 // Counter tracks the block following the label.
448 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
449 CountMap[S] = BlockCount;
450 Visit(S->getSubStmt());
451 }
452
VisitBreakStmt__anon67ae45280111::ComputeRegionCounts453 void VisitBreakStmt(const BreakStmt *S) {
454 RecordStmtCount(S);
455 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
456 BreakContinueStack.back().BreakCount += CurrentCount;
457 CurrentCount = 0;
458 RecordNextStmtCount = true;
459 }
460
VisitContinueStmt__anon67ae45280111::ComputeRegionCounts461 void VisitContinueStmt(const ContinueStmt *S) {
462 RecordStmtCount(S);
463 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
464 BreakContinueStack.back().ContinueCount += CurrentCount;
465 CurrentCount = 0;
466 RecordNextStmtCount = true;
467 }
468
VisitWhileStmt__anon67ae45280111::ComputeRegionCounts469 void VisitWhileStmt(const WhileStmt *S) {
470 RecordStmtCount(S);
471 uint64_t ParentCount = CurrentCount;
472
473 BreakContinueStack.push_back(BreakContinue());
474 // Visit the body region first so the break/continue adjustments can be
475 // included when visiting the condition.
476 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
477 CountMap[S->getBody()] = CurrentCount;
478 Visit(S->getBody());
479 uint64_t BackedgeCount = CurrentCount;
480
481 // ...then go back and propagate counts through the condition. The count
482 // at the start of the condition is the sum of the incoming edges,
483 // the backedge from the end of the loop body, and the edges from
484 // continue statements.
485 BreakContinue BC = BreakContinueStack.pop_back_val();
486 uint64_t CondCount =
487 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
488 CountMap[S->getCond()] = CondCount;
489 Visit(S->getCond());
490 setCount(BC.BreakCount + CondCount - BodyCount);
491 RecordNextStmtCount = true;
492 }
493
VisitDoStmt__anon67ae45280111::ComputeRegionCounts494 void VisitDoStmt(const DoStmt *S) {
495 RecordStmtCount(S);
496 uint64_t LoopCount = PGO.getRegionCount(S);
497
498 BreakContinueStack.push_back(BreakContinue());
499 // The count doesn't include the fallthrough from the parent scope. Add it.
500 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
501 CountMap[S->getBody()] = BodyCount;
502 Visit(S->getBody());
503 uint64_t BackedgeCount = CurrentCount;
504
505 BreakContinue BC = BreakContinueStack.pop_back_val();
506 // The count at the start of the condition is equal to the count at the
507 // end of the body, plus any continues.
508 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
509 CountMap[S->getCond()] = CondCount;
510 Visit(S->getCond());
511 setCount(BC.BreakCount + CondCount - LoopCount);
512 RecordNextStmtCount = true;
513 }
514
VisitForStmt__anon67ae45280111::ComputeRegionCounts515 void VisitForStmt(const ForStmt *S) {
516 RecordStmtCount(S);
517 if (S->getInit())
518 Visit(S->getInit());
519
520 uint64_t ParentCount = CurrentCount;
521
522 BreakContinueStack.push_back(BreakContinue());
523 // Visit the body region first. (This is basically the same as a while
524 // loop; see further comments in VisitWhileStmt.)
525 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
526 CountMap[S->getBody()] = BodyCount;
527 Visit(S->getBody());
528 uint64_t BackedgeCount = CurrentCount;
529 BreakContinue BC = BreakContinueStack.pop_back_val();
530
531 // The increment is essentially part of the body but it needs to include
532 // the count for all the continue statements.
533 if (S->getInc()) {
534 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
535 CountMap[S->getInc()] = IncCount;
536 Visit(S->getInc());
537 }
538
539 // ...then go back and propagate counts through the condition.
540 uint64_t CondCount =
541 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
542 if (S->getCond()) {
543 CountMap[S->getCond()] = CondCount;
544 Visit(S->getCond());
545 }
546 setCount(BC.BreakCount + CondCount - BodyCount);
547 RecordNextStmtCount = true;
548 }
549
VisitCXXForRangeStmt__anon67ae45280111::ComputeRegionCounts550 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
551 RecordStmtCount(S);
552 if (S->getInit())
553 Visit(S->getInit());
554 Visit(S->getLoopVarStmt());
555 Visit(S->getRangeStmt());
556 Visit(S->getBeginStmt());
557 Visit(S->getEndStmt());
558
559 uint64_t ParentCount = CurrentCount;
560 BreakContinueStack.push_back(BreakContinue());
561 // Visit the body region first. (This is basically the same as a while
562 // loop; see further comments in VisitWhileStmt.)
563 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
564 CountMap[S->getBody()] = BodyCount;
565 Visit(S->getBody());
566 uint64_t BackedgeCount = CurrentCount;
567 BreakContinue BC = BreakContinueStack.pop_back_val();
568
569 // The increment is essentially part of the body but it needs to include
570 // the count for all the continue statements.
571 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
572 CountMap[S->getInc()] = IncCount;
573 Visit(S->getInc());
574
575 // ...then go back and propagate counts through the condition.
576 uint64_t CondCount =
577 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
578 CountMap[S->getCond()] = CondCount;
579 Visit(S->getCond());
580 setCount(BC.BreakCount + CondCount - BodyCount);
581 RecordNextStmtCount = true;
582 }
583
VisitObjCForCollectionStmt__anon67ae45280111::ComputeRegionCounts584 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
585 RecordStmtCount(S);
586 Visit(S->getElement());
587 uint64_t ParentCount = CurrentCount;
588 BreakContinueStack.push_back(BreakContinue());
589 // Counter tracks the body of the loop.
590 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
591 CountMap[S->getBody()] = BodyCount;
592 Visit(S->getBody());
593 uint64_t BackedgeCount = CurrentCount;
594 BreakContinue BC = BreakContinueStack.pop_back_val();
595
596 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
597 BodyCount);
598 RecordNextStmtCount = true;
599 }
600
VisitSwitchStmt__anon67ae45280111::ComputeRegionCounts601 void VisitSwitchStmt(const SwitchStmt *S) {
602 RecordStmtCount(S);
603 if (S->getInit())
604 Visit(S->getInit());
605 Visit(S->getCond());
606 CurrentCount = 0;
607 BreakContinueStack.push_back(BreakContinue());
608 Visit(S->getBody());
609 // If the switch is inside a loop, add the continue counts.
610 BreakContinue BC = BreakContinueStack.pop_back_val();
611 if (!BreakContinueStack.empty())
612 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
613 // Counter tracks the exit block of the switch.
614 setCount(PGO.getRegionCount(S));
615 RecordNextStmtCount = true;
616 }
617
VisitSwitchCase__anon67ae45280111::ComputeRegionCounts618 void VisitSwitchCase(const SwitchCase *S) {
619 RecordNextStmtCount = false;
620 // Counter for this particular case. This counts only jumps from the
621 // switch header and does not include fallthrough from the case before
622 // this one.
623 uint64_t CaseCount = PGO.getRegionCount(S);
624 setCount(CurrentCount + CaseCount);
625 // We need the count without fallthrough in the mapping, so it's more useful
626 // for branch probabilities.
627 CountMap[S] = CaseCount;
628 RecordNextStmtCount = true;
629 Visit(S->getSubStmt());
630 }
631
VisitIfStmt__anon67ae45280111::ComputeRegionCounts632 void VisitIfStmt(const IfStmt *S) {
633 RecordStmtCount(S);
634 uint64_t ParentCount = CurrentCount;
635 if (S->getInit())
636 Visit(S->getInit());
637 Visit(S->getCond());
638
639 // Counter tracks the "then" part of an if statement. The count for
640 // the "else" part, if it exists, will be calculated from this counter.
641 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
642 CountMap[S->getThen()] = ThenCount;
643 Visit(S->getThen());
644 uint64_t OutCount = CurrentCount;
645
646 uint64_t ElseCount = ParentCount - ThenCount;
647 if (S->getElse()) {
648 setCount(ElseCount);
649 CountMap[S->getElse()] = ElseCount;
650 Visit(S->getElse());
651 OutCount += CurrentCount;
652 } else
653 OutCount += ElseCount;
654 setCount(OutCount);
655 RecordNextStmtCount = true;
656 }
657
VisitCXXTryStmt__anon67ae45280111::ComputeRegionCounts658 void VisitCXXTryStmt(const CXXTryStmt *S) {
659 RecordStmtCount(S);
660 Visit(S->getTryBlock());
661 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
662 Visit(S->getHandler(I));
663 // Counter tracks the continuation block of the try statement.
664 setCount(PGO.getRegionCount(S));
665 RecordNextStmtCount = true;
666 }
667
VisitCXXCatchStmt__anon67ae45280111::ComputeRegionCounts668 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
669 RecordNextStmtCount = false;
670 // Counter tracks the catch statement's handler block.
671 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
672 CountMap[S] = CatchCount;
673 Visit(S->getHandlerBlock());
674 }
675
VisitAbstractConditionalOperator__anon67ae45280111::ComputeRegionCounts676 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
677 RecordStmtCount(E);
678 uint64_t ParentCount = CurrentCount;
679 Visit(E->getCond());
680
681 // Counter tracks the "true" part of a conditional operator. The
682 // count in the "false" part will be calculated from this counter.
683 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
684 CountMap[E->getTrueExpr()] = TrueCount;
685 Visit(E->getTrueExpr());
686 uint64_t OutCount = CurrentCount;
687
688 uint64_t FalseCount = setCount(ParentCount - TrueCount);
689 CountMap[E->getFalseExpr()] = FalseCount;
690 Visit(E->getFalseExpr());
691 OutCount += CurrentCount;
692
693 setCount(OutCount);
694 RecordNextStmtCount = true;
695 }
696
VisitBinLAnd__anon67ae45280111::ComputeRegionCounts697 void VisitBinLAnd(const BinaryOperator *E) {
698 RecordStmtCount(E);
699 uint64_t ParentCount = CurrentCount;
700 Visit(E->getLHS());
701 // Counter tracks the right hand side of a logical and operator.
702 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
703 CountMap[E->getRHS()] = RHSCount;
704 Visit(E->getRHS());
705 setCount(ParentCount + RHSCount - CurrentCount);
706 RecordNextStmtCount = true;
707 }
708
VisitBinLOr__anon67ae45280111::ComputeRegionCounts709 void VisitBinLOr(const BinaryOperator *E) {
710 RecordStmtCount(E);
711 uint64_t ParentCount = CurrentCount;
712 Visit(E->getLHS());
713 // Counter tracks the right hand side of a logical or operator.
714 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
715 CountMap[E->getRHS()] = RHSCount;
716 Visit(E->getRHS());
717 setCount(ParentCount + RHSCount - CurrentCount);
718 RecordNextStmtCount = true;
719 }
720 };
721 } // end anonymous namespace
722
combine(HashType Type)723 void PGOHash::combine(HashType Type) {
724 // Check that we never combine 0 and only have six bits.
725 assert(Type && "Hash is invalid: unexpected type 0");
726 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
727
728 // Pass through MD5 if enough work has built up.
729 if (Count && Count % NumTypesPerWord == 0) {
730 using namespace llvm::support;
731 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
732 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
733 Working = 0;
734 }
735
736 // Accumulate the current type.
737 ++Count;
738 Working = Working << NumBitsPerType | Type;
739 }
740
finalize()741 uint64_t PGOHash::finalize() {
742 // Use Working as the hash directly if we never used MD5.
743 if (Count <= NumTypesPerWord)
744 // No need to byte swap here, since none of the math was endian-dependent.
745 // This number will be byte-swapped as required on endianness transitions,
746 // so we will see the same value on the other side.
747 return Working;
748
749 // Check for remaining work in Working.
750 if (Working)
751 MD5.update(Working);
752
753 // Finalize the MD5 and return the hash.
754 llvm::MD5::MD5Result Result;
755 MD5.final(Result);
756 using namespace llvm::support;
757 return Result.low();
758 }
759
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)760 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
761 const Decl *D = GD.getDecl();
762 if (!D->hasBody())
763 return;
764
765 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
766 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
767 if (!InstrumentRegions && !PGOReader)
768 return;
769 if (D->isImplicit())
770 return;
771 // Constructors and destructors may be represented by several functions in IR.
772 // If so, instrument only base variant, others are implemented by delegation
773 // to the base one, it would be counted twice otherwise.
774 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
775 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
776 return;
777
778 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
779 if (GD.getCtorType() != Ctor_Base &&
780 CodeGenFunction::IsConstructorDelegationValid(CCD))
781 return;
782 }
783 CGM.ClearUnusedCoverageMapping(D);
784 setFuncName(Fn);
785
786 mapRegionCounters(D);
787 if (CGM.getCodeGenOpts().CoverageMapping)
788 emitCounterRegionMapping(D);
789 if (PGOReader) {
790 SourceManager &SM = CGM.getContext().getSourceManager();
791 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
792 computeRegionCounts(D);
793 applyFunctionAttributes(PGOReader, Fn);
794 }
795 }
796
mapRegionCounters(const Decl * D)797 void CodeGenPGO::mapRegionCounters(const Decl *D) {
798 // Use the latest hash version when inserting instrumentation, but use the
799 // version in the indexed profile if we're reading PGO data.
800 PGOHashVersion HashVersion = PGO_HASH_LATEST;
801 if (auto *PGOReader = CGM.getPGOReader())
802 HashVersion = getPGOHashVersion(PGOReader, CGM);
803
804 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
805 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
806 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
807 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
808 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
809 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
810 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
811 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
812 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
813 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
814 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
815 NumRegionCounters = Walker.NextCounter;
816 FunctionHash = Walker.Hash.finalize();
817 }
818
skipRegionMappingForDecl(const Decl * D)819 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
820 if (!D->getBody())
821 return true;
822
823 // Don't map the functions in system headers.
824 const auto &SM = CGM.getContext().getSourceManager();
825 auto Loc = D->getBody()->getBeginLoc();
826 return SM.isInSystemHeader(Loc);
827 }
828
emitCounterRegionMapping(const Decl * D)829 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
830 if (skipRegionMappingForDecl(D))
831 return;
832
833 std::string CoverageMapping;
834 llvm::raw_string_ostream OS(CoverageMapping);
835 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
836 CGM.getContext().getSourceManager(),
837 CGM.getLangOpts(), RegionCounterMap.get());
838 MappingGen.emitCounterMapping(D, OS);
839 OS.flush();
840
841 if (CoverageMapping.empty())
842 return;
843
844 CGM.getCoverageMapping()->addFunctionMappingRecord(
845 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
846 }
847
848 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)849 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
850 llvm::GlobalValue::LinkageTypes Linkage) {
851 if (skipRegionMappingForDecl(D))
852 return;
853
854 std::string CoverageMapping;
855 llvm::raw_string_ostream OS(CoverageMapping);
856 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
857 CGM.getContext().getSourceManager(),
858 CGM.getLangOpts());
859 MappingGen.emitEmptyMapping(D, OS);
860 OS.flush();
861
862 if (CoverageMapping.empty())
863 return;
864
865 setFuncName(Name, Linkage);
866 CGM.getCoverageMapping()->addFunctionMappingRecord(
867 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
868 }
869
computeRegionCounts(const Decl * D)870 void CodeGenPGO::computeRegionCounts(const Decl *D) {
871 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
872 ComputeRegionCounts Walker(*StmtCountMap, *this);
873 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
874 Walker.VisitFunctionDecl(FD);
875 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
876 Walker.VisitObjCMethodDecl(MD);
877 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
878 Walker.VisitBlockDecl(BD);
879 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
880 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
881 }
882
883 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)884 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
885 llvm::Function *Fn) {
886 if (!haveRegionCounts())
887 return;
888
889 uint64_t FunctionCount = getRegionCount(nullptr);
890 Fn->setEntryCount(FunctionCount);
891 }
892
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)893 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
894 llvm::Value *StepV) {
895 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
896 return;
897 if (!Builder.GetInsertBlock())
898 return;
899
900 unsigned Counter = (*RegionCounterMap)[S];
901 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
902
903 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
904 Builder.getInt64(FunctionHash),
905 Builder.getInt32(NumRegionCounters),
906 Builder.getInt32(Counter), StepV};
907 if (!StepV)
908 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
909 makeArrayRef(Args, 4));
910 else
911 Builder.CreateCall(
912 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
913 makeArrayRef(Args));
914 }
915
916 // This method either inserts a call to the profile run-time during
917 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)918 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
919 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
920
921 if (!EnableValueProfiling)
922 return;
923
924 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
925 return;
926
927 if (isa<llvm::Constant>(ValuePtr))
928 return;
929
930 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
931 if (InstrumentValueSites && RegionCounterMap) {
932 auto BuilderInsertPoint = Builder.saveIP();
933 Builder.SetInsertPoint(ValueSite);
934 llvm::Value *Args[5] = {
935 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
936 Builder.getInt64(FunctionHash),
937 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
938 Builder.getInt32(ValueKind),
939 Builder.getInt32(NumValueSites[ValueKind]++)
940 };
941 Builder.CreateCall(
942 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
943 Builder.restoreIP(BuilderInsertPoint);
944 return;
945 }
946
947 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
948 if (PGOReader && haveRegionCounts()) {
949 // We record the top most called three functions at each call site.
950 // Profile metadata contains "VP" string identifying this metadata
951 // as value profiling data, then a uint32_t value for the value profiling
952 // kind, a uint64_t value for the total number of times the call is
953 // executed, followed by the function hash and execution count (uint64_t)
954 // pairs for each function.
955 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
956 return;
957
958 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
959 (llvm::InstrProfValueKind)ValueKind,
960 NumValueSites[ValueKind]);
961
962 NumValueSites[ValueKind]++;
963 }
964 }
965
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)966 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
967 bool IsInMainFile) {
968 CGM.getPGOStats().addVisited(IsInMainFile);
969 RegionCounts.clear();
970 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
971 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
972 if (auto E = RecordExpected.takeError()) {
973 auto IPE = llvm::InstrProfError::take(std::move(E));
974 if (IPE == llvm::instrprof_error::unknown_function)
975 CGM.getPGOStats().addMissing(IsInMainFile);
976 else if (IPE == llvm::instrprof_error::hash_mismatch)
977 CGM.getPGOStats().addMismatched(IsInMainFile);
978 else if (IPE == llvm::instrprof_error::malformed)
979 // TODO: Consider a more specific warning for this case.
980 CGM.getPGOStats().addMismatched(IsInMainFile);
981 return;
982 }
983 ProfRecord =
984 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
985 RegionCounts = ProfRecord->Counts;
986 }
987
988 /// Calculate what to divide by to scale weights.
989 ///
990 /// Given the maximum weight, calculate a divisor that will scale all the
991 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)992 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
993 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
994 }
995
996 /// Scale an individual branch weight (and add 1).
997 ///
998 /// Scale a 64-bit weight down to 32-bits using \c Scale.
999 ///
1000 /// According to Laplace's Rule of Succession, it is better to compute the
1001 /// weight based on the count plus 1, so universally add 1 to the value.
1002 ///
1003 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1004 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1005 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1006 assert(Scale && "scale by 0?");
1007 uint64_t Scaled = Weight / Scale + 1;
1008 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1009 return Scaled;
1010 }
1011
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount)1012 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1013 uint64_t FalseCount) {
1014 // Check for empty weights.
1015 if (!TrueCount && !FalseCount)
1016 return nullptr;
1017
1018 // Calculate how to scale down to 32-bits.
1019 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1020
1021 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1022 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1023 scaleBranchWeight(FalseCount, Scale));
1024 }
1025
1026 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights)1027 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1028 // We need at least two elements to create meaningful weights.
1029 if (Weights.size() < 2)
1030 return nullptr;
1031
1032 // Check for empty weights.
1033 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1034 if (MaxWeight == 0)
1035 return nullptr;
1036
1037 // Calculate how to scale down to 32-bits.
1038 uint64_t Scale = calculateWeightScale(MaxWeight);
1039
1040 SmallVector<uint32_t, 16> ScaledWeights;
1041 ScaledWeights.reserve(Weights.size());
1042 for (uint64_t W : Weights)
1043 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1044
1045 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1046 return MDHelper.createBranchWeights(ScaledWeights);
1047 }
1048
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount)1049 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1050 uint64_t LoopCount) {
1051 if (!PGO.haveRegionCounts())
1052 return nullptr;
1053 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1054 assert(CondCount.hasValue() && "missing expected loop condition count");
1055 if (*CondCount == 0)
1056 return nullptr;
1057 return createProfileWeights(LoopCount,
1058 std::max(*CondCount, LoopCount) - LoopCount);
1059 }
1060