1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55 
56   // Keep this set to the latest hash version.
57   PGO_HASH_LATEST = PGO_HASH_V2
58 };
59 
60 namespace {
61 /// Stable hasher for PGO region counters.
62 ///
63 /// PGOHash produces a stable hash of a given function's control flow.
64 ///
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
67 ///
68 /// \note  When this hash does eventually change (years?), we still need to
69 /// support old hashes.  We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
71 class PGOHash {
72   uint64_t Working;
73   unsigned Count;
74   PGOHashVersion HashVersion;
75   llvm::MD5 MD5;
76 
77   static const int NumBitsPerType = 6;
78   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79   static const unsigned TooBig = 1u << NumBitsPerType;
80 
81 public:
82   /// Hash values for AST nodes.
83   ///
84   /// Distinct values for AST nodes that have region counters attached.
85   ///
86   /// These values must be stable.  All new members must be added at the end,
87   /// and no members should be removed.  Changing the enumeration value for an
88   /// AST node will affect the hash of every function that contains that node.
89   enum HashType : unsigned char {
90     None = 0,
91     LabelStmt = 1,
92     WhileStmt,
93     DoStmt,
94     ForStmt,
95     CXXForRangeStmt,
96     ObjCForCollectionStmt,
97     SwitchStmt,
98     CaseStmt,
99     DefaultStmt,
100     IfStmt,
101     CXXTryStmt,
102     CXXCatchStmt,
103     ConditionalOperator,
104     BinaryOperatorLAnd,
105     BinaryOperatorLOr,
106     BinaryConditionalOperator,
107     // The preceding values are available with PGO_HASH_V1.
108 
109     EndOfScope,
110     IfThenBranch,
111     IfElseBranch,
112     GotoStmt,
113     IndirectGotoStmt,
114     BreakStmt,
115     ContinueStmt,
116     ReturnStmt,
117     ThrowExpr,
118     UnaryOperatorLNot,
119     BinaryOperatorLT,
120     BinaryOperatorGT,
121     BinaryOperatorLE,
122     BinaryOperatorGE,
123     BinaryOperatorEQ,
124     BinaryOperatorNE,
125     // The preceding values are available with PGO_HASH_V2.
126 
127     // Keep this last.  It's for the static assert that follows.
128     LastHashType
129   };
130   static_assert(LastHashType <= TooBig, "Too many types in HashType");
131 
132   PGOHash(PGOHashVersion HashVersion)
133       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134   void combine(HashType Type);
135   uint64_t finalize();
136   PGOHashVersion getHashVersion() const { return HashVersion; }
137 };
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
141 
142 /// Get the PGO hash version used in the given indexed profile.
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144                                         CodeGenModule &CGM) {
145   if (PGOReader->getVersion() <= 4)
146     return PGO_HASH_V1;
147   return PGO_HASH_V2;
148 }
149 
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152   using Base = RecursiveASTVisitor<MapRegionCounters>;
153 
154   /// The next counter value to assign.
155   unsigned NextCounter;
156   /// The function hash.
157   PGOHash Hash;
158   /// The map of statements to counters.
159   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160 
161   MapRegionCounters(PGOHashVersion HashVersion,
162                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164 
165   // Blocks and lambdas are handled as separate functions, so we need not
166   // traverse them in the parent context.
167   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
169   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
170 
171   bool VisitDecl(const Decl *D) {
172     switch (D->getKind()) {
173     default:
174       break;
175     case Decl::Function:
176     case Decl::CXXMethod:
177     case Decl::CXXConstructor:
178     case Decl::CXXDestructor:
179     case Decl::CXXConversion:
180     case Decl::ObjCMethod:
181     case Decl::Block:
182     case Decl::Captured:
183       CounterMap[D->getBody()] = NextCounter++;
184       break;
185     }
186     return true;
187   }
188 
189   /// If \p S gets a fresh counter, update the counter mappings. Return the
190   /// V1 hash of \p S.
191   PGOHash::HashType updateCounterMappings(Stmt *S) {
192     auto Type = getHashType(PGO_HASH_V1, S);
193     if (Type != PGOHash::None)
194       CounterMap[S] = NextCounter++;
195     return Type;
196   }
197 
198   /// Include \p S in the function hash.
199   bool VisitStmt(Stmt *S) {
200     auto Type = updateCounterMappings(S);
201     if (Hash.getHashVersion() != PGO_HASH_V1)
202       Type = getHashType(Hash.getHashVersion(), S);
203     if (Type != PGOHash::None)
204       Hash.combine(Type);
205     return true;
206   }
207 
208   bool TraverseIfStmt(IfStmt *If) {
209     // If we used the V1 hash, use the default traversal.
210     if (Hash.getHashVersion() == PGO_HASH_V1)
211       return Base::TraverseIfStmt(If);
212 
213     // Otherwise, keep track of which branch we're in while traversing.
214     VisitStmt(If);
215     for (Stmt *CS : If->children()) {
216       if (!CS)
217         continue;
218       if (CS == If->getThen())
219         Hash.combine(PGOHash::IfThenBranch);
220       else if (CS == If->getElse())
221         Hash.combine(PGOHash::IfElseBranch);
222       TraverseStmt(CS);
223     }
224     Hash.combine(PGOHash::EndOfScope);
225     return true;
226   }
227 
228 // If the statement type \p N is nestable, and its nesting impacts profile
229 // stability, define a custom traversal which tracks the end of the statement
230 // in the hash (provided we're not using the V1 hash).
231 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
232   bool Traverse##N(N *S) {                                                     \
233     Base::Traverse##N(S);                                                      \
234     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
235       Hash.combine(PGOHash::EndOfScope);                                       \
236     return true;                                                               \
237   }
238 
239   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
240   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
241   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
242   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
243   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
244   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
245   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
246 
247   /// Get version \p HashVersion of the PGO hash for \p S.
248   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
249     switch (S->getStmtClass()) {
250     default:
251       break;
252     case Stmt::LabelStmtClass:
253       return PGOHash::LabelStmt;
254     case Stmt::WhileStmtClass:
255       return PGOHash::WhileStmt;
256     case Stmt::DoStmtClass:
257       return PGOHash::DoStmt;
258     case Stmt::ForStmtClass:
259       return PGOHash::ForStmt;
260     case Stmt::CXXForRangeStmtClass:
261       return PGOHash::CXXForRangeStmt;
262     case Stmt::ObjCForCollectionStmtClass:
263       return PGOHash::ObjCForCollectionStmt;
264     case Stmt::SwitchStmtClass:
265       return PGOHash::SwitchStmt;
266     case Stmt::CaseStmtClass:
267       return PGOHash::CaseStmt;
268     case Stmt::DefaultStmtClass:
269       return PGOHash::DefaultStmt;
270     case Stmt::IfStmtClass:
271       return PGOHash::IfStmt;
272     case Stmt::CXXTryStmtClass:
273       return PGOHash::CXXTryStmt;
274     case Stmt::CXXCatchStmtClass:
275       return PGOHash::CXXCatchStmt;
276     case Stmt::ConditionalOperatorClass:
277       return PGOHash::ConditionalOperator;
278     case Stmt::BinaryConditionalOperatorClass:
279       return PGOHash::BinaryConditionalOperator;
280     case Stmt::BinaryOperatorClass: {
281       const BinaryOperator *BO = cast<BinaryOperator>(S);
282       if (BO->getOpcode() == BO_LAnd)
283         return PGOHash::BinaryOperatorLAnd;
284       if (BO->getOpcode() == BO_LOr)
285         return PGOHash::BinaryOperatorLOr;
286       if (HashVersion == PGO_HASH_V2) {
287         switch (BO->getOpcode()) {
288         default:
289           break;
290         case BO_LT:
291           return PGOHash::BinaryOperatorLT;
292         case BO_GT:
293           return PGOHash::BinaryOperatorGT;
294         case BO_LE:
295           return PGOHash::BinaryOperatorLE;
296         case BO_GE:
297           return PGOHash::BinaryOperatorGE;
298         case BO_EQ:
299           return PGOHash::BinaryOperatorEQ;
300         case BO_NE:
301           return PGOHash::BinaryOperatorNE;
302         }
303       }
304       break;
305     }
306     }
307 
308     if (HashVersion == PGO_HASH_V2) {
309       switch (S->getStmtClass()) {
310       default:
311         break;
312       case Stmt::GotoStmtClass:
313         return PGOHash::GotoStmt;
314       case Stmt::IndirectGotoStmtClass:
315         return PGOHash::IndirectGotoStmt;
316       case Stmt::BreakStmtClass:
317         return PGOHash::BreakStmt;
318       case Stmt::ContinueStmtClass:
319         return PGOHash::ContinueStmt;
320       case Stmt::ReturnStmtClass:
321         return PGOHash::ReturnStmt;
322       case Stmt::CXXThrowExprClass:
323         return PGOHash::ThrowExpr;
324       case Stmt::UnaryOperatorClass: {
325         const UnaryOperator *UO = cast<UnaryOperator>(S);
326         if (UO->getOpcode() == UO_LNot)
327           return PGOHash::UnaryOperatorLNot;
328         break;
329       }
330       }
331     }
332 
333     return PGOHash::None;
334   }
335 };
336 
337 /// A StmtVisitor that propagates the raw counts through the AST and
338 /// records the count at statements where the value may change.
339 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
340   /// PGO state.
341   CodeGenPGO &PGO;
342 
343   /// A flag that is set when the current count should be recorded on the
344   /// next statement, such as at the exit of a loop.
345   bool RecordNextStmtCount;
346 
347   /// The count at the current location in the traversal.
348   uint64_t CurrentCount;
349 
350   /// The map of statements to count values.
351   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
352 
353   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
354   struct BreakContinue {
355     uint64_t BreakCount;
356     uint64_t ContinueCount;
357     BreakContinue() : BreakCount(0), ContinueCount(0) {}
358   };
359   SmallVector<BreakContinue, 8> BreakContinueStack;
360 
361   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
362                       CodeGenPGO &PGO)
363       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
364 
365   void RecordStmtCount(const Stmt *S) {
366     if (RecordNextStmtCount) {
367       CountMap[S] = CurrentCount;
368       RecordNextStmtCount = false;
369     }
370   }
371 
372   /// Set and return the current count.
373   uint64_t setCount(uint64_t Count) {
374     CurrentCount = Count;
375     return Count;
376   }
377 
378   void VisitStmt(const Stmt *S) {
379     RecordStmtCount(S);
380     for (const Stmt *Child : S->children())
381       if (Child)
382         this->Visit(Child);
383   }
384 
385   void VisitFunctionDecl(const FunctionDecl *D) {
386     // Counter tracks entry to the function body.
387     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
388     CountMap[D->getBody()] = BodyCount;
389     Visit(D->getBody());
390   }
391 
392   // Skip lambda expressions. We visit these as FunctionDecls when we're
393   // generating them and aren't interested in the body when generating a
394   // parent context.
395   void VisitLambdaExpr(const LambdaExpr *LE) {}
396 
397   void VisitCapturedDecl(const CapturedDecl *D) {
398     // Counter tracks entry to the capture body.
399     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
400     CountMap[D->getBody()] = BodyCount;
401     Visit(D->getBody());
402   }
403 
404   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
405     // Counter tracks entry to the method body.
406     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
407     CountMap[D->getBody()] = BodyCount;
408     Visit(D->getBody());
409   }
410 
411   void VisitBlockDecl(const BlockDecl *D) {
412     // Counter tracks entry to the block body.
413     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
414     CountMap[D->getBody()] = BodyCount;
415     Visit(D->getBody());
416   }
417 
418   void VisitReturnStmt(const ReturnStmt *S) {
419     RecordStmtCount(S);
420     if (S->getRetValue())
421       Visit(S->getRetValue());
422     CurrentCount = 0;
423     RecordNextStmtCount = true;
424   }
425 
426   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
427     RecordStmtCount(E);
428     if (E->getSubExpr())
429       Visit(E->getSubExpr());
430     CurrentCount = 0;
431     RecordNextStmtCount = true;
432   }
433 
434   void VisitGotoStmt(const GotoStmt *S) {
435     RecordStmtCount(S);
436     CurrentCount = 0;
437     RecordNextStmtCount = true;
438   }
439 
440   void VisitLabelStmt(const LabelStmt *S) {
441     RecordNextStmtCount = false;
442     // Counter tracks the block following the label.
443     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
444     CountMap[S] = BlockCount;
445     Visit(S->getSubStmt());
446   }
447 
448   void VisitBreakStmt(const BreakStmt *S) {
449     RecordStmtCount(S);
450     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
451     BreakContinueStack.back().BreakCount += CurrentCount;
452     CurrentCount = 0;
453     RecordNextStmtCount = true;
454   }
455 
456   void VisitContinueStmt(const ContinueStmt *S) {
457     RecordStmtCount(S);
458     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
459     BreakContinueStack.back().ContinueCount += CurrentCount;
460     CurrentCount = 0;
461     RecordNextStmtCount = true;
462   }
463 
464   void VisitWhileStmt(const WhileStmt *S) {
465     RecordStmtCount(S);
466     uint64_t ParentCount = CurrentCount;
467 
468     BreakContinueStack.push_back(BreakContinue());
469     // Visit the body region first so the break/continue adjustments can be
470     // included when visiting the condition.
471     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
472     CountMap[S->getBody()] = CurrentCount;
473     Visit(S->getBody());
474     uint64_t BackedgeCount = CurrentCount;
475 
476     // ...then go back and propagate counts through the condition. The count
477     // at the start of the condition is the sum of the incoming edges,
478     // the backedge from the end of the loop body, and the edges from
479     // continue statements.
480     BreakContinue BC = BreakContinueStack.pop_back_val();
481     uint64_t CondCount =
482         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
483     CountMap[S->getCond()] = CondCount;
484     Visit(S->getCond());
485     setCount(BC.BreakCount + CondCount - BodyCount);
486     RecordNextStmtCount = true;
487   }
488 
489   void VisitDoStmt(const DoStmt *S) {
490     RecordStmtCount(S);
491     uint64_t LoopCount = PGO.getRegionCount(S);
492 
493     BreakContinueStack.push_back(BreakContinue());
494     // The count doesn't include the fallthrough from the parent scope. Add it.
495     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
496     CountMap[S->getBody()] = BodyCount;
497     Visit(S->getBody());
498     uint64_t BackedgeCount = CurrentCount;
499 
500     BreakContinue BC = BreakContinueStack.pop_back_val();
501     // The count at the start of the condition is equal to the count at the
502     // end of the body, plus any continues.
503     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
504     CountMap[S->getCond()] = CondCount;
505     Visit(S->getCond());
506     setCount(BC.BreakCount + CondCount - LoopCount);
507     RecordNextStmtCount = true;
508   }
509 
510   void VisitForStmt(const ForStmt *S) {
511     RecordStmtCount(S);
512     if (S->getInit())
513       Visit(S->getInit());
514 
515     uint64_t ParentCount = CurrentCount;
516 
517     BreakContinueStack.push_back(BreakContinue());
518     // Visit the body region first. (This is basically the same as a while
519     // loop; see further comments in VisitWhileStmt.)
520     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
521     CountMap[S->getBody()] = BodyCount;
522     Visit(S->getBody());
523     uint64_t BackedgeCount = CurrentCount;
524     BreakContinue BC = BreakContinueStack.pop_back_val();
525 
526     // The increment is essentially part of the body but it needs to include
527     // the count for all the continue statements.
528     if (S->getInc()) {
529       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
530       CountMap[S->getInc()] = IncCount;
531       Visit(S->getInc());
532     }
533 
534     // ...then go back and propagate counts through the condition.
535     uint64_t CondCount =
536         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
537     if (S->getCond()) {
538       CountMap[S->getCond()] = CondCount;
539       Visit(S->getCond());
540     }
541     setCount(BC.BreakCount + CondCount - BodyCount);
542     RecordNextStmtCount = true;
543   }
544 
545   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
546     RecordStmtCount(S);
547     if (S->getInit())
548       Visit(S->getInit());
549     Visit(S->getLoopVarStmt());
550     Visit(S->getRangeStmt());
551     Visit(S->getBeginStmt());
552     Visit(S->getEndStmt());
553 
554     uint64_t ParentCount = CurrentCount;
555     BreakContinueStack.push_back(BreakContinue());
556     // Visit the body region first. (This is basically the same as a while
557     // loop; see further comments in VisitWhileStmt.)
558     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
559     CountMap[S->getBody()] = BodyCount;
560     Visit(S->getBody());
561     uint64_t BackedgeCount = CurrentCount;
562     BreakContinue BC = BreakContinueStack.pop_back_val();
563 
564     // The increment is essentially part of the body but it needs to include
565     // the count for all the continue statements.
566     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
567     CountMap[S->getInc()] = IncCount;
568     Visit(S->getInc());
569 
570     // ...then go back and propagate counts through the condition.
571     uint64_t CondCount =
572         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
573     CountMap[S->getCond()] = CondCount;
574     Visit(S->getCond());
575     setCount(BC.BreakCount + CondCount - BodyCount);
576     RecordNextStmtCount = true;
577   }
578 
579   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
580     RecordStmtCount(S);
581     Visit(S->getElement());
582     uint64_t ParentCount = CurrentCount;
583     BreakContinueStack.push_back(BreakContinue());
584     // Counter tracks the body of the loop.
585     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
586     CountMap[S->getBody()] = BodyCount;
587     Visit(S->getBody());
588     uint64_t BackedgeCount = CurrentCount;
589     BreakContinue BC = BreakContinueStack.pop_back_val();
590 
591     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
592              BodyCount);
593     RecordNextStmtCount = true;
594   }
595 
596   void VisitSwitchStmt(const SwitchStmt *S) {
597     RecordStmtCount(S);
598     if (S->getInit())
599       Visit(S->getInit());
600     Visit(S->getCond());
601     CurrentCount = 0;
602     BreakContinueStack.push_back(BreakContinue());
603     Visit(S->getBody());
604     // If the switch is inside a loop, add the continue counts.
605     BreakContinue BC = BreakContinueStack.pop_back_val();
606     if (!BreakContinueStack.empty())
607       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
608     // Counter tracks the exit block of the switch.
609     setCount(PGO.getRegionCount(S));
610     RecordNextStmtCount = true;
611   }
612 
613   void VisitSwitchCase(const SwitchCase *S) {
614     RecordNextStmtCount = false;
615     // Counter for this particular case. This counts only jumps from the
616     // switch header and does not include fallthrough from the case before
617     // this one.
618     uint64_t CaseCount = PGO.getRegionCount(S);
619     setCount(CurrentCount + CaseCount);
620     // We need the count without fallthrough in the mapping, so it's more useful
621     // for branch probabilities.
622     CountMap[S] = CaseCount;
623     RecordNextStmtCount = true;
624     Visit(S->getSubStmt());
625   }
626 
627   void VisitIfStmt(const IfStmt *S) {
628     RecordStmtCount(S);
629     uint64_t ParentCount = CurrentCount;
630     if (S->getInit())
631       Visit(S->getInit());
632     Visit(S->getCond());
633 
634     // Counter tracks the "then" part of an if statement. The count for
635     // the "else" part, if it exists, will be calculated from this counter.
636     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
637     CountMap[S->getThen()] = ThenCount;
638     Visit(S->getThen());
639     uint64_t OutCount = CurrentCount;
640 
641     uint64_t ElseCount = ParentCount - ThenCount;
642     if (S->getElse()) {
643       setCount(ElseCount);
644       CountMap[S->getElse()] = ElseCount;
645       Visit(S->getElse());
646       OutCount += CurrentCount;
647     } else
648       OutCount += ElseCount;
649     setCount(OutCount);
650     RecordNextStmtCount = true;
651   }
652 
653   void VisitCXXTryStmt(const CXXTryStmt *S) {
654     RecordStmtCount(S);
655     Visit(S->getTryBlock());
656     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
657       Visit(S->getHandler(I));
658     // Counter tracks the continuation block of the try statement.
659     setCount(PGO.getRegionCount(S));
660     RecordNextStmtCount = true;
661   }
662 
663   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
664     RecordNextStmtCount = false;
665     // Counter tracks the catch statement's handler block.
666     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
667     CountMap[S] = CatchCount;
668     Visit(S->getHandlerBlock());
669   }
670 
671   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
672     RecordStmtCount(E);
673     uint64_t ParentCount = CurrentCount;
674     Visit(E->getCond());
675 
676     // Counter tracks the "true" part of a conditional operator. The
677     // count in the "false" part will be calculated from this counter.
678     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
679     CountMap[E->getTrueExpr()] = TrueCount;
680     Visit(E->getTrueExpr());
681     uint64_t OutCount = CurrentCount;
682 
683     uint64_t FalseCount = setCount(ParentCount - TrueCount);
684     CountMap[E->getFalseExpr()] = FalseCount;
685     Visit(E->getFalseExpr());
686     OutCount += CurrentCount;
687 
688     setCount(OutCount);
689     RecordNextStmtCount = true;
690   }
691 
692   void VisitBinLAnd(const BinaryOperator *E) {
693     RecordStmtCount(E);
694     uint64_t ParentCount = CurrentCount;
695     Visit(E->getLHS());
696     // Counter tracks the right hand side of a logical and operator.
697     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
698     CountMap[E->getRHS()] = RHSCount;
699     Visit(E->getRHS());
700     setCount(ParentCount + RHSCount - CurrentCount);
701     RecordNextStmtCount = true;
702   }
703 
704   void VisitBinLOr(const BinaryOperator *E) {
705     RecordStmtCount(E);
706     uint64_t ParentCount = CurrentCount;
707     Visit(E->getLHS());
708     // Counter tracks the right hand side of a logical or operator.
709     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
710     CountMap[E->getRHS()] = RHSCount;
711     Visit(E->getRHS());
712     setCount(ParentCount + RHSCount - CurrentCount);
713     RecordNextStmtCount = true;
714   }
715 };
716 } // end anonymous namespace
717 
718 void PGOHash::combine(HashType Type) {
719   // Check that we never combine 0 and only have six bits.
720   assert(Type && "Hash is invalid: unexpected type 0");
721   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
722 
723   // Pass through MD5 if enough work has built up.
724   if (Count && Count % NumTypesPerWord == 0) {
725     using namespace llvm::support;
726     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
727     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
728     Working = 0;
729   }
730 
731   // Accumulate the current type.
732   ++Count;
733   Working = Working << NumBitsPerType | Type;
734 }
735 
736 uint64_t PGOHash::finalize() {
737   // Use Working as the hash directly if we never used MD5.
738   if (Count <= NumTypesPerWord)
739     // No need to byte swap here, since none of the math was endian-dependent.
740     // This number will be byte-swapped as required on endianness transitions,
741     // so we will see the same value on the other side.
742     return Working;
743 
744   // Check for remaining work in Working.
745   if (Working)
746     MD5.update(Working);
747 
748   // Finalize the MD5 and return the hash.
749   llvm::MD5::MD5Result Result;
750   MD5.final(Result);
751   using namespace llvm::support;
752   return Result.low();
753 }
754 
755 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
756   const Decl *D = GD.getDecl();
757   if (!D->hasBody())
758     return;
759 
760   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
761   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
762   if (!InstrumentRegions && !PGOReader)
763     return;
764   if (D->isImplicit())
765     return;
766   // Constructors and destructors may be represented by several functions in IR.
767   // If so, instrument only base variant, others are implemented by delegation
768   // to the base one, it would be counted twice otherwise.
769   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
770     if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
771       return;
772 
773     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
774       if (GD.getCtorType() != Ctor_Base &&
775           CodeGenFunction::IsConstructorDelegationValid(CCD))
776         return;
777   }
778   CGM.ClearUnusedCoverageMapping(D);
779   setFuncName(Fn);
780 
781   mapRegionCounters(D);
782   if (CGM.getCodeGenOpts().CoverageMapping)
783     emitCounterRegionMapping(D);
784   if (PGOReader) {
785     SourceManager &SM = CGM.getContext().getSourceManager();
786     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
787     computeRegionCounts(D);
788     applyFunctionAttributes(PGOReader, Fn);
789   }
790 }
791 
792 void CodeGenPGO::mapRegionCounters(const Decl *D) {
793   // Use the latest hash version when inserting instrumentation, but use the
794   // version in the indexed profile if we're reading PGO data.
795   PGOHashVersion HashVersion = PGO_HASH_LATEST;
796   if (auto *PGOReader = CGM.getPGOReader())
797     HashVersion = getPGOHashVersion(PGOReader, CGM);
798 
799   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
800   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
801   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
802     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
803   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
804     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
805   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
806     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
807   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
808     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
809   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
810   NumRegionCounters = Walker.NextCounter;
811   FunctionHash = Walker.Hash.finalize();
812 }
813 
814 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
815   if (!D->getBody())
816     return true;
817 
818   // Don't map the functions in system headers.
819   const auto &SM = CGM.getContext().getSourceManager();
820   auto Loc = D->getBody()->getBeginLoc();
821   return SM.isInSystemHeader(Loc);
822 }
823 
824 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
825   if (skipRegionMappingForDecl(D))
826     return;
827 
828   std::string CoverageMapping;
829   llvm::raw_string_ostream OS(CoverageMapping);
830   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
831                                 CGM.getContext().getSourceManager(),
832                                 CGM.getLangOpts(), RegionCounterMap.get());
833   MappingGen.emitCounterMapping(D, OS);
834   OS.flush();
835 
836   if (CoverageMapping.empty())
837     return;
838 
839   CGM.getCoverageMapping()->addFunctionMappingRecord(
840       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
841 }
842 
843 void
844 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
845                                     llvm::GlobalValue::LinkageTypes Linkage) {
846   if (skipRegionMappingForDecl(D))
847     return;
848 
849   std::string CoverageMapping;
850   llvm::raw_string_ostream OS(CoverageMapping);
851   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
852                                 CGM.getContext().getSourceManager(),
853                                 CGM.getLangOpts());
854   MappingGen.emitEmptyMapping(D, OS);
855   OS.flush();
856 
857   if (CoverageMapping.empty())
858     return;
859 
860   setFuncName(Name, Linkage);
861   CGM.getCoverageMapping()->addFunctionMappingRecord(
862       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
863 }
864 
865 void CodeGenPGO::computeRegionCounts(const Decl *D) {
866   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
867   ComputeRegionCounts Walker(*StmtCountMap, *this);
868   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
869     Walker.VisitFunctionDecl(FD);
870   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
871     Walker.VisitObjCMethodDecl(MD);
872   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
873     Walker.VisitBlockDecl(BD);
874   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
875     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
876 }
877 
878 void
879 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
880                                     llvm::Function *Fn) {
881   if (!haveRegionCounts())
882     return;
883 
884   uint64_t FunctionCount = getRegionCount(nullptr);
885   Fn->setEntryCount(FunctionCount);
886 }
887 
888 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
889                                       llvm::Value *StepV) {
890   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
891     return;
892   if (!Builder.GetInsertBlock())
893     return;
894 
895   unsigned Counter = (*RegionCounterMap)[S];
896   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
897 
898   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
899                          Builder.getInt64(FunctionHash),
900                          Builder.getInt32(NumRegionCounters),
901                          Builder.getInt32(Counter), StepV};
902   if (!StepV)
903     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
904                        makeArrayRef(Args, 4));
905   else
906     Builder.CreateCall(
907         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
908         makeArrayRef(Args));
909 }
910 
911 // This method either inserts a call to the profile run-time during
912 // instrumentation or puts profile data into metadata for PGO use.
913 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
914     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
915 
916   if (!EnableValueProfiling)
917     return;
918 
919   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
920     return;
921 
922   if (isa<llvm::Constant>(ValuePtr))
923     return;
924 
925   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
926   if (InstrumentValueSites && RegionCounterMap) {
927     auto BuilderInsertPoint = Builder.saveIP();
928     Builder.SetInsertPoint(ValueSite);
929     llvm::Value *Args[5] = {
930         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
931         Builder.getInt64(FunctionHash),
932         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
933         Builder.getInt32(ValueKind),
934         Builder.getInt32(NumValueSites[ValueKind]++)
935     };
936     Builder.CreateCall(
937         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
938     Builder.restoreIP(BuilderInsertPoint);
939     return;
940   }
941 
942   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
943   if (PGOReader && haveRegionCounts()) {
944     // We record the top most called three functions at each call site.
945     // Profile metadata contains "VP" string identifying this metadata
946     // as value profiling data, then a uint32_t value for the value profiling
947     // kind, a uint64_t value for the total number of times the call is
948     // executed, followed by the function hash and execution count (uint64_t)
949     // pairs for each function.
950     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
951       return;
952 
953     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
954                             (llvm::InstrProfValueKind)ValueKind,
955                             NumValueSites[ValueKind]);
956 
957     NumValueSites[ValueKind]++;
958   }
959 }
960 
961 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
962                                   bool IsInMainFile) {
963   CGM.getPGOStats().addVisited(IsInMainFile);
964   RegionCounts.clear();
965   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
966       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
967   if (auto E = RecordExpected.takeError()) {
968     auto IPE = llvm::InstrProfError::take(std::move(E));
969     if (IPE == llvm::instrprof_error::unknown_function)
970       CGM.getPGOStats().addMissing(IsInMainFile);
971     else if (IPE == llvm::instrprof_error::hash_mismatch)
972       CGM.getPGOStats().addMismatched(IsInMainFile);
973     else if (IPE == llvm::instrprof_error::malformed)
974       // TODO: Consider a more specific warning for this case.
975       CGM.getPGOStats().addMismatched(IsInMainFile);
976     return;
977   }
978   ProfRecord =
979       llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
980   RegionCounts = ProfRecord->Counts;
981 }
982 
983 /// Calculate what to divide by to scale weights.
984 ///
985 /// Given the maximum weight, calculate a divisor that will scale all the
986 /// weights to strictly less than UINT32_MAX.
987 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
988   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
989 }
990 
991 /// Scale an individual branch weight (and add 1).
992 ///
993 /// Scale a 64-bit weight down to 32-bits using \c Scale.
994 ///
995 /// According to Laplace's Rule of Succession, it is better to compute the
996 /// weight based on the count plus 1, so universally add 1 to the value.
997 ///
998 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
999 /// greater than \c Weight.
1000 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1001   assert(Scale && "scale by 0?");
1002   uint64_t Scaled = Weight / Scale + 1;
1003   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1004   return Scaled;
1005 }
1006 
1007 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1008                                                     uint64_t FalseCount) {
1009   // Check for empty weights.
1010   if (!TrueCount && !FalseCount)
1011     return nullptr;
1012 
1013   // Calculate how to scale down to 32-bits.
1014   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1015 
1016   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1017   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1018                                       scaleBranchWeight(FalseCount, Scale));
1019 }
1020 
1021 llvm::MDNode *
1022 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1023   // We need at least two elements to create meaningful weights.
1024   if (Weights.size() < 2)
1025     return nullptr;
1026 
1027   // Check for empty weights.
1028   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1029   if (MaxWeight == 0)
1030     return nullptr;
1031 
1032   // Calculate how to scale down to 32-bits.
1033   uint64_t Scale = calculateWeightScale(MaxWeight);
1034 
1035   SmallVector<uint32_t, 16> ScaledWeights;
1036   ScaledWeights.reserve(Weights.size());
1037   for (uint64_t W : Weights)
1038     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1039 
1040   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1041   return MDHelper.createBranchWeights(ScaledWeights);
1042 }
1043 
1044 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1045                                                            uint64_t LoopCount) {
1046   if (!PGO.haveRegionCounts())
1047     return nullptr;
1048   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1049   assert(CondCount.hasValue() && "missing expected loop condition count");
1050   if (*CondCount == 0)
1051     return nullptr;
1052   return createProfileWeights(LoopCount,
1053                               std::max(*CondCount, LoopCount) - LoopCount);
1054 }
1055