1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool> EnableValueProfiling(
26   "enable-value-profiling", llvm::cl::ZeroOrMore,
27   llvm::cl::desc("Enable value profiling"), llvm::cl::init(false));
28 
29 using namespace clang;
30 using namespace CodeGen;
31 
32 void CodeGenPGO::setFuncName(StringRef Name,
33                              llvm::GlobalValue::LinkageTypes Linkage) {
34   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35   FuncName = llvm::getPGOFuncName(
36       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
38 
39   // If we're generating a profile, create a variable for the name.
40   if (CGM.getCodeGenOpts().hasProfileClangInstr())
41     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
42 }
43 
44 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45   setFuncName(Fn->getName(), Fn->getLinkage());
46   // Create PGOFuncName meta data.
47   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
48 }
49 
50 /// The version of the PGO hash algorithm.
51 enum PGOHashVersion : unsigned {
52   PGO_HASH_V1,
53   PGO_HASH_V2,
54 
55   // Keep this set to the latest hash version.
56   PGO_HASH_LATEST = PGO_HASH_V2
57 };
58 
59 namespace {
60 /// \brief Stable hasher for PGO region counters.
61 ///
62 /// PGOHash produces a stable hash of a given function's control flow.
63 ///
64 /// Changing the output of this hash will invalidate all previously generated
65 /// profiles -- i.e., don't do it.
66 ///
67 /// \note  When this hash does eventually change (years?), we still need to
68 /// support old hashes.  We'll need to pull in the version number from the
69 /// profile data format and use the matching hash function.
70 class PGOHash {
71   uint64_t Working;
72   unsigned Count;
73   PGOHashVersion HashVersion;
74   llvm::MD5 MD5;
75 
76   static const int NumBitsPerType = 6;
77   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
78   static const unsigned TooBig = 1u << NumBitsPerType;
79 
80 public:
81   /// \brief Hash values for AST nodes.
82   ///
83   /// Distinct values for AST nodes that have region counters attached.
84   ///
85   /// These values must be stable.  All new members must be added at the end,
86   /// and no members should be removed.  Changing the enumeration value for an
87   /// AST node will affect the hash of every function that contains that node.
88   enum HashType : unsigned char {
89     None = 0,
90     LabelStmt = 1,
91     WhileStmt,
92     DoStmt,
93     ForStmt,
94     CXXForRangeStmt,
95     ObjCForCollectionStmt,
96     SwitchStmt,
97     CaseStmt,
98     DefaultStmt,
99     IfStmt,
100     CXXTryStmt,
101     CXXCatchStmt,
102     ConditionalOperator,
103     BinaryOperatorLAnd,
104     BinaryOperatorLOr,
105     BinaryConditionalOperator,
106     // The preceding values are available with PGO_HASH_V1.
107 
108     EndOfScope,
109     IfThenBranch,
110     IfElseBranch,
111     GotoStmt,
112     IndirectGotoStmt,
113     BreakStmt,
114     ContinueStmt,
115     ReturnStmt,
116     ThrowExpr,
117     UnaryOperatorLNot,
118     BinaryOperatorLT,
119     BinaryOperatorGT,
120     BinaryOperatorLE,
121     BinaryOperatorGE,
122     BinaryOperatorEQ,
123     BinaryOperatorNE,
124     // The preceding values are available with PGO_HASH_V2.
125 
126     // Keep this last.  It's for the static assert that follows.
127     LastHashType
128   };
129   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130 
131   PGOHash(PGOHashVersion HashVersion)
132       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
133   void combine(HashType Type);
134   uint64_t finalize();
135   PGOHashVersion getHashVersion() const { return HashVersion; }
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140 
141 /// Get the PGO hash version used in the given indexed profile.
142 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
143                                         CodeGenModule &CGM) {
144   if (PGOReader->getVersion() <= 4)
145     return PGO_HASH_V1;
146   return PGO_HASH_V2;
147 }
148 
149 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
150 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
151   using Base = RecursiveASTVisitor<MapRegionCounters>;
152 
153   /// The next counter value to assign.
154   unsigned NextCounter;
155   /// The function hash.
156   PGOHash Hash;
157   /// The map of statements to counters.
158   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
159 
160   MapRegionCounters(PGOHashVersion HashVersion,
161                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
162       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
163 
164   // Blocks and lambdas are handled as separate functions, so we need not
165   // traverse them in the parent context.
166   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
167   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
168   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
169 
170   bool VisitDecl(const Decl *D) {
171     switch (D->getKind()) {
172     default:
173       break;
174     case Decl::Function:
175     case Decl::CXXMethod:
176     case Decl::CXXConstructor:
177     case Decl::CXXDestructor:
178     case Decl::CXXConversion:
179     case Decl::ObjCMethod:
180     case Decl::Block:
181     case Decl::Captured:
182       CounterMap[D->getBody()] = NextCounter++;
183       break;
184     }
185     return true;
186   }
187 
188   /// If \p S gets a fresh counter, update the counter mappings. Return the
189   /// V1 hash of \p S.
190   PGOHash::HashType updateCounterMappings(Stmt *S) {
191     auto Type = getHashType(PGO_HASH_V1, S);
192     if (Type != PGOHash::None)
193       CounterMap[S] = NextCounter++;
194     return Type;
195   }
196 
197   /// Include \p S in the function hash.
198   bool VisitStmt(Stmt *S) {
199     auto Type = updateCounterMappings(S);
200     if (Hash.getHashVersion() != PGO_HASH_V1)
201       Type = getHashType(Hash.getHashVersion(), S);
202     if (Type != PGOHash::None)
203       Hash.combine(Type);
204     return true;
205   }
206 
207   bool TraverseIfStmt(IfStmt *If) {
208     // If we used the V1 hash, use the default traversal.
209     if (Hash.getHashVersion() == PGO_HASH_V1)
210       return Base::TraverseIfStmt(If);
211 
212     // Otherwise, keep track of which branch we're in while traversing.
213     VisitStmt(If);
214     for (Stmt *CS : If->children()) {
215       if (!CS)
216         continue;
217       if (CS == If->getThen())
218         Hash.combine(PGOHash::IfThenBranch);
219       else if (CS == If->getElse())
220         Hash.combine(PGOHash::IfElseBranch);
221       TraverseStmt(CS);
222     }
223     Hash.combine(PGOHash::EndOfScope);
224     return true;
225   }
226 
227 // If the statement type \p N is nestable, and its nesting impacts profile
228 // stability, define a custom traversal which tracks the end of the statement
229 // in the hash (provided we're not using the V1 hash).
230 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
231   bool Traverse##N(N *S) {                                                     \
232     Base::Traverse##N(S);                                                      \
233     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
234       Hash.combine(PGOHash::EndOfScope);                                       \
235     return true;                                                               \
236   }
237 
238   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
239   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
240   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
241   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
242   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
243   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
244   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
245 
246   /// Get version \p HashVersion of the PGO hash for \p S.
247   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
248     switch (S->getStmtClass()) {
249     default:
250       break;
251     case Stmt::LabelStmtClass:
252       return PGOHash::LabelStmt;
253     case Stmt::WhileStmtClass:
254       return PGOHash::WhileStmt;
255     case Stmt::DoStmtClass:
256       return PGOHash::DoStmt;
257     case Stmt::ForStmtClass:
258       return PGOHash::ForStmt;
259     case Stmt::CXXForRangeStmtClass:
260       return PGOHash::CXXForRangeStmt;
261     case Stmt::ObjCForCollectionStmtClass:
262       return PGOHash::ObjCForCollectionStmt;
263     case Stmt::SwitchStmtClass:
264       return PGOHash::SwitchStmt;
265     case Stmt::CaseStmtClass:
266       return PGOHash::CaseStmt;
267     case Stmt::DefaultStmtClass:
268       return PGOHash::DefaultStmt;
269     case Stmt::IfStmtClass:
270       return PGOHash::IfStmt;
271     case Stmt::CXXTryStmtClass:
272       return PGOHash::CXXTryStmt;
273     case Stmt::CXXCatchStmtClass:
274       return PGOHash::CXXCatchStmt;
275     case Stmt::ConditionalOperatorClass:
276       return PGOHash::ConditionalOperator;
277     case Stmt::BinaryConditionalOperatorClass:
278       return PGOHash::BinaryConditionalOperator;
279     case Stmt::BinaryOperatorClass: {
280       const BinaryOperator *BO = cast<BinaryOperator>(S);
281       if (BO->getOpcode() == BO_LAnd)
282         return PGOHash::BinaryOperatorLAnd;
283       if (BO->getOpcode() == BO_LOr)
284         return PGOHash::BinaryOperatorLOr;
285       if (HashVersion == PGO_HASH_V2) {
286         switch (BO->getOpcode()) {
287         default:
288           break;
289         case BO_LT:
290           return PGOHash::BinaryOperatorLT;
291         case BO_GT:
292           return PGOHash::BinaryOperatorGT;
293         case BO_LE:
294           return PGOHash::BinaryOperatorLE;
295         case BO_GE:
296           return PGOHash::BinaryOperatorGE;
297         case BO_EQ:
298           return PGOHash::BinaryOperatorEQ;
299         case BO_NE:
300           return PGOHash::BinaryOperatorNE;
301         }
302       }
303       break;
304     }
305     }
306 
307     if (HashVersion == PGO_HASH_V2) {
308       switch (S->getStmtClass()) {
309       default:
310         break;
311       case Stmt::GotoStmtClass:
312         return PGOHash::GotoStmt;
313       case Stmt::IndirectGotoStmtClass:
314         return PGOHash::IndirectGotoStmt;
315       case Stmt::BreakStmtClass:
316         return PGOHash::BreakStmt;
317       case Stmt::ContinueStmtClass:
318         return PGOHash::ContinueStmt;
319       case Stmt::ReturnStmtClass:
320         return PGOHash::ReturnStmt;
321       case Stmt::CXXThrowExprClass:
322         return PGOHash::ThrowExpr;
323       case Stmt::UnaryOperatorClass: {
324         const UnaryOperator *UO = cast<UnaryOperator>(S);
325         if (UO->getOpcode() == UO_LNot)
326           return PGOHash::UnaryOperatorLNot;
327         break;
328       }
329       }
330     }
331 
332     return PGOHash::None;
333   }
334 };
335 
336 /// A StmtVisitor that propagates the raw counts through the AST and
337 /// records the count at statements where the value may change.
338 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
339   /// PGO state.
340   CodeGenPGO &PGO;
341 
342   /// A flag that is set when the current count should be recorded on the
343   /// next statement, such as at the exit of a loop.
344   bool RecordNextStmtCount;
345 
346   /// The count at the current location in the traversal.
347   uint64_t CurrentCount;
348 
349   /// The map of statements to count values.
350   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
351 
352   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
353   struct BreakContinue {
354     uint64_t BreakCount;
355     uint64_t ContinueCount;
356     BreakContinue() : BreakCount(0), ContinueCount(0) {}
357   };
358   SmallVector<BreakContinue, 8> BreakContinueStack;
359 
360   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
361                       CodeGenPGO &PGO)
362       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
363 
364   void RecordStmtCount(const Stmt *S) {
365     if (RecordNextStmtCount) {
366       CountMap[S] = CurrentCount;
367       RecordNextStmtCount = false;
368     }
369   }
370 
371   /// Set and return the current count.
372   uint64_t setCount(uint64_t Count) {
373     CurrentCount = Count;
374     return Count;
375   }
376 
377   void VisitStmt(const Stmt *S) {
378     RecordStmtCount(S);
379     for (const Stmt *Child : S->children())
380       if (Child)
381         this->Visit(Child);
382   }
383 
384   void VisitFunctionDecl(const FunctionDecl *D) {
385     // Counter tracks entry to the function body.
386     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
387     CountMap[D->getBody()] = BodyCount;
388     Visit(D->getBody());
389   }
390 
391   // Skip lambda expressions. We visit these as FunctionDecls when we're
392   // generating them and aren't interested in the body when generating a
393   // parent context.
394   void VisitLambdaExpr(const LambdaExpr *LE) {}
395 
396   void VisitCapturedDecl(const CapturedDecl *D) {
397     // Counter tracks entry to the capture body.
398     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
399     CountMap[D->getBody()] = BodyCount;
400     Visit(D->getBody());
401   }
402 
403   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
404     // Counter tracks entry to the method body.
405     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
406     CountMap[D->getBody()] = BodyCount;
407     Visit(D->getBody());
408   }
409 
410   void VisitBlockDecl(const BlockDecl *D) {
411     // Counter tracks entry to the block body.
412     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
413     CountMap[D->getBody()] = BodyCount;
414     Visit(D->getBody());
415   }
416 
417   void VisitReturnStmt(const ReturnStmt *S) {
418     RecordStmtCount(S);
419     if (S->getRetValue())
420       Visit(S->getRetValue());
421     CurrentCount = 0;
422     RecordNextStmtCount = true;
423   }
424 
425   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
426     RecordStmtCount(E);
427     if (E->getSubExpr())
428       Visit(E->getSubExpr());
429     CurrentCount = 0;
430     RecordNextStmtCount = true;
431   }
432 
433   void VisitGotoStmt(const GotoStmt *S) {
434     RecordStmtCount(S);
435     CurrentCount = 0;
436     RecordNextStmtCount = true;
437   }
438 
439   void VisitLabelStmt(const LabelStmt *S) {
440     RecordNextStmtCount = false;
441     // Counter tracks the block following the label.
442     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
443     CountMap[S] = BlockCount;
444     Visit(S->getSubStmt());
445   }
446 
447   void VisitBreakStmt(const BreakStmt *S) {
448     RecordStmtCount(S);
449     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
450     BreakContinueStack.back().BreakCount += CurrentCount;
451     CurrentCount = 0;
452     RecordNextStmtCount = true;
453   }
454 
455   void VisitContinueStmt(const ContinueStmt *S) {
456     RecordStmtCount(S);
457     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
458     BreakContinueStack.back().ContinueCount += CurrentCount;
459     CurrentCount = 0;
460     RecordNextStmtCount = true;
461   }
462 
463   void VisitWhileStmt(const WhileStmt *S) {
464     RecordStmtCount(S);
465     uint64_t ParentCount = CurrentCount;
466 
467     BreakContinueStack.push_back(BreakContinue());
468     // Visit the body region first so the break/continue adjustments can be
469     // included when visiting the condition.
470     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
471     CountMap[S->getBody()] = CurrentCount;
472     Visit(S->getBody());
473     uint64_t BackedgeCount = CurrentCount;
474 
475     // ...then go back and propagate counts through the condition. The count
476     // at the start of the condition is the sum of the incoming edges,
477     // the backedge from the end of the loop body, and the edges from
478     // continue statements.
479     BreakContinue BC = BreakContinueStack.pop_back_val();
480     uint64_t CondCount =
481         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
482     CountMap[S->getCond()] = CondCount;
483     Visit(S->getCond());
484     setCount(BC.BreakCount + CondCount - BodyCount);
485     RecordNextStmtCount = true;
486   }
487 
488   void VisitDoStmt(const DoStmt *S) {
489     RecordStmtCount(S);
490     uint64_t LoopCount = PGO.getRegionCount(S);
491 
492     BreakContinueStack.push_back(BreakContinue());
493     // The count doesn't include the fallthrough from the parent scope. Add it.
494     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
495     CountMap[S->getBody()] = BodyCount;
496     Visit(S->getBody());
497     uint64_t BackedgeCount = CurrentCount;
498 
499     BreakContinue BC = BreakContinueStack.pop_back_val();
500     // The count at the start of the condition is equal to the count at the
501     // end of the body, plus any continues.
502     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
503     CountMap[S->getCond()] = CondCount;
504     Visit(S->getCond());
505     setCount(BC.BreakCount + CondCount - LoopCount);
506     RecordNextStmtCount = true;
507   }
508 
509   void VisitForStmt(const ForStmt *S) {
510     RecordStmtCount(S);
511     if (S->getInit())
512       Visit(S->getInit());
513 
514     uint64_t ParentCount = CurrentCount;
515 
516     BreakContinueStack.push_back(BreakContinue());
517     // Visit the body region first. (This is basically the same as a while
518     // loop; see further comments in VisitWhileStmt.)
519     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
520     CountMap[S->getBody()] = BodyCount;
521     Visit(S->getBody());
522     uint64_t BackedgeCount = CurrentCount;
523     BreakContinue BC = BreakContinueStack.pop_back_val();
524 
525     // The increment is essentially part of the body but it needs to include
526     // the count for all the continue statements.
527     if (S->getInc()) {
528       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
529       CountMap[S->getInc()] = IncCount;
530       Visit(S->getInc());
531     }
532 
533     // ...then go back and propagate counts through the condition.
534     uint64_t CondCount =
535         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
536     if (S->getCond()) {
537       CountMap[S->getCond()] = CondCount;
538       Visit(S->getCond());
539     }
540     setCount(BC.BreakCount + CondCount - BodyCount);
541     RecordNextStmtCount = true;
542   }
543 
544   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
545     RecordStmtCount(S);
546     Visit(S->getLoopVarStmt());
547     Visit(S->getRangeStmt());
548     Visit(S->getBeginStmt());
549     Visit(S->getEndStmt());
550 
551     uint64_t ParentCount = CurrentCount;
552     BreakContinueStack.push_back(BreakContinue());
553     // Visit the body region first. (This is basically the same as a while
554     // loop; see further comments in VisitWhileStmt.)
555     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
556     CountMap[S->getBody()] = BodyCount;
557     Visit(S->getBody());
558     uint64_t BackedgeCount = CurrentCount;
559     BreakContinue BC = BreakContinueStack.pop_back_val();
560 
561     // The increment is essentially part of the body but it needs to include
562     // the count for all the continue statements.
563     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
564     CountMap[S->getInc()] = IncCount;
565     Visit(S->getInc());
566 
567     // ...then go back and propagate counts through the condition.
568     uint64_t CondCount =
569         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
570     CountMap[S->getCond()] = CondCount;
571     Visit(S->getCond());
572     setCount(BC.BreakCount + CondCount - BodyCount);
573     RecordNextStmtCount = true;
574   }
575 
576   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
577     RecordStmtCount(S);
578     Visit(S->getElement());
579     uint64_t ParentCount = CurrentCount;
580     BreakContinueStack.push_back(BreakContinue());
581     // Counter tracks the body of the loop.
582     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583     CountMap[S->getBody()] = BodyCount;
584     Visit(S->getBody());
585     uint64_t BackedgeCount = CurrentCount;
586     BreakContinue BC = BreakContinueStack.pop_back_val();
587 
588     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
589              BodyCount);
590     RecordNextStmtCount = true;
591   }
592 
593   void VisitSwitchStmt(const SwitchStmt *S) {
594     RecordStmtCount(S);
595     if (S->getInit())
596       Visit(S->getInit());
597     Visit(S->getCond());
598     CurrentCount = 0;
599     BreakContinueStack.push_back(BreakContinue());
600     Visit(S->getBody());
601     // If the switch is inside a loop, add the continue counts.
602     BreakContinue BC = BreakContinueStack.pop_back_val();
603     if (!BreakContinueStack.empty())
604       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
605     // Counter tracks the exit block of the switch.
606     setCount(PGO.getRegionCount(S));
607     RecordNextStmtCount = true;
608   }
609 
610   void VisitSwitchCase(const SwitchCase *S) {
611     RecordNextStmtCount = false;
612     // Counter for this particular case. This counts only jumps from the
613     // switch header and does not include fallthrough from the case before
614     // this one.
615     uint64_t CaseCount = PGO.getRegionCount(S);
616     setCount(CurrentCount + CaseCount);
617     // We need the count without fallthrough in the mapping, so it's more useful
618     // for branch probabilities.
619     CountMap[S] = CaseCount;
620     RecordNextStmtCount = true;
621     Visit(S->getSubStmt());
622   }
623 
624   void VisitIfStmt(const IfStmt *S) {
625     RecordStmtCount(S);
626     uint64_t ParentCount = CurrentCount;
627     if (S->getInit())
628       Visit(S->getInit());
629     Visit(S->getCond());
630 
631     // Counter tracks the "then" part of an if statement. The count for
632     // the "else" part, if it exists, will be calculated from this counter.
633     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
634     CountMap[S->getThen()] = ThenCount;
635     Visit(S->getThen());
636     uint64_t OutCount = CurrentCount;
637 
638     uint64_t ElseCount = ParentCount - ThenCount;
639     if (S->getElse()) {
640       setCount(ElseCount);
641       CountMap[S->getElse()] = ElseCount;
642       Visit(S->getElse());
643       OutCount += CurrentCount;
644     } else
645       OutCount += ElseCount;
646     setCount(OutCount);
647     RecordNextStmtCount = true;
648   }
649 
650   void VisitCXXTryStmt(const CXXTryStmt *S) {
651     RecordStmtCount(S);
652     Visit(S->getTryBlock());
653     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
654       Visit(S->getHandler(I));
655     // Counter tracks the continuation block of the try statement.
656     setCount(PGO.getRegionCount(S));
657     RecordNextStmtCount = true;
658   }
659 
660   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
661     RecordNextStmtCount = false;
662     // Counter tracks the catch statement's handler block.
663     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
664     CountMap[S] = CatchCount;
665     Visit(S->getHandlerBlock());
666   }
667 
668   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
669     RecordStmtCount(E);
670     uint64_t ParentCount = CurrentCount;
671     Visit(E->getCond());
672 
673     // Counter tracks the "true" part of a conditional operator. The
674     // count in the "false" part will be calculated from this counter.
675     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
676     CountMap[E->getTrueExpr()] = TrueCount;
677     Visit(E->getTrueExpr());
678     uint64_t OutCount = CurrentCount;
679 
680     uint64_t FalseCount = setCount(ParentCount - TrueCount);
681     CountMap[E->getFalseExpr()] = FalseCount;
682     Visit(E->getFalseExpr());
683     OutCount += CurrentCount;
684 
685     setCount(OutCount);
686     RecordNextStmtCount = true;
687   }
688 
689   void VisitBinLAnd(const BinaryOperator *E) {
690     RecordStmtCount(E);
691     uint64_t ParentCount = CurrentCount;
692     Visit(E->getLHS());
693     // Counter tracks the right hand side of a logical and operator.
694     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
695     CountMap[E->getRHS()] = RHSCount;
696     Visit(E->getRHS());
697     setCount(ParentCount + RHSCount - CurrentCount);
698     RecordNextStmtCount = true;
699   }
700 
701   void VisitBinLOr(const BinaryOperator *E) {
702     RecordStmtCount(E);
703     uint64_t ParentCount = CurrentCount;
704     Visit(E->getLHS());
705     // Counter tracks the right hand side of a logical or operator.
706     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
707     CountMap[E->getRHS()] = RHSCount;
708     Visit(E->getRHS());
709     setCount(ParentCount + RHSCount - CurrentCount);
710     RecordNextStmtCount = true;
711   }
712 };
713 } // end anonymous namespace
714 
715 void PGOHash::combine(HashType Type) {
716   // Check that we never combine 0 and only have six bits.
717   assert(Type && "Hash is invalid: unexpected type 0");
718   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
719 
720   // Pass through MD5 if enough work has built up.
721   if (Count && Count % NumTypesPerWord == 0) {
722     using namespace llvm::support;
723     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
724     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
725     Working = 0;
726   }
727 
728   // Accumulate the current type.
729   ++Count;
730   Working = Working << NumBitsPerType | Type;
731 }
732 
733 uint64_t PGOHash::finalize() {
734   // Use Working as the hash directly if we never used MD5.
735   if (Count <= NumTypesPerWord)
736     // No need to byte swap here, since none of the math was endian-dependent.
737     // This number will be byte-swapped as required on endianness transitions,
738     // so we will see the same value on the other side.
739     return Working;
740 
741   // Check for remaining work in Working.
742   if (Working)
743     MD5.update(Working);
744 
745   // Finalize the MD5 and return the hash.
746   llvm::MD5::MD5Result Result;
747   MD5.final(Result);
748   using namespace llvm::support;
749   return Result.low();
750 }
751 
752 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
753   const Decl *D = GD.getDecl();
754   if (!D->hasBody())
755     return;
756 
757   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
758   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
759   if (!InstrumentRegions && !PGOReader)
760     return;
761   if (D->isImplicit())
762     return;
763   // Constructors and destructors may be represented by several functions in IR.
764   // If so, instrument only base variant, others are implemented by delegation
765   // to the base one, it would be counted twice otherwise.
766   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
767     if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
768       return;
769 
770     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
771       if (GD.getCtorType() != Ctor_Base &&
772           CodeGenFunction::IsConstructorDelegationValid(CCD))
773         return;
774   }
775   CGM.ClearUnusedCoverageMapping(D);
776   setFuncName(Fn);
777 
778   mapRegionCounters(D);
779   if (CGM.getCodeGenOpts().CoverageMapping)
780     emitCounterRegionMapping(D);
781   if (PGOReader) {
782     SourceManager &SM = CGM.getContext().getSourceManager();
783     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
784     computeRegionCounts(D);
785     applyFunctionAttributes(PGOReader, Fn);
786   }
787 }
788 
789 void CodeGenPGO::mapRegionCounters(const Decl *D) {
790   // Use the latest hash version when inserting instrumentation, but use the
791   // version in the indexed profile if we're reading PGO data.
792   PGOHashVersion HashVersion = PGO_HASH_LATEST;
793   if (auto *PGOReader = CGM.getPGOReader())
794     HashVersion = getPGOHashVersion(PGOReader, CGM);
795 
796   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
797   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
798   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
799     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
800   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
801     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
802   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
803     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
804   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
805     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
806   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
807   NumRegionCounters = Walker.NextCounter;
808   FunctionHash = Walker.Hash.finalize();
809 }
810 
811 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
812   if (!D->getBody())
813     return true;
814 
815   // Don't map the functions in system headers.
816   const auto &SM = CGM.getContext().getSourceManager();
817   auto Loc = D->getBody()->getLocStart();
818   return SM.isInSystemHeader(Loc);
819 }
820 
821 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
822   if (skipRegionMappingForDecl(D))
823     return;
824 
825   std::string CoverageMapping;
826   llvm::raw_string_ostream OS(CoverageMapping);
827   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
828                                 CGM.getContext().getSourceManager(),
829                                 CGM.getLangOpts(), RegionCounterMap.get());
830   MappingGen.emitCounterMapping(D, OS);
831   OS.flush();
832 
833   if (CoverageMapping.empty())
834     return;
835 
836   CGM.getCoverageMapping()->addFunctionMappingRecord(
837       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
838 }
839 
840 void
841 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
842                                     llvm::GlobalValue::LinkageTypes Linkage) {
843   if (skipRegionMappingForDecl(D))
844     return;
845 
846   std::string CoverageMapping;
847   llvm::raw_string_ostream OS(CoverageMapping);
848   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
849                                 CGM.getContext().getSourceManager(),
850                                 CGM.getLangOpts());
851   MappingGen.emitEmptyMapping(D, OS);
852   OS.flush();
853 
854   if (CoverageMapping.empty())
855     return;
856 
857   setFuncName(Name, Linkage);
858   CGM.getCoverageMapping()->addFunctionMappingRecord(
859       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
860 }
861 
862 void CodeGenPGO::computeRegionCounts(const Decl *D) {
863   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
864   ComputeRegionCounts Walker(*StmtCountMap, *this);
865   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
866     Walker.VisitFunctionDecl(FD);
867   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
868     Walker.VisitObjCMethodDecl(MD);
869   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
870     Walker.VisitBlockDecl(BD);
871   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
872     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
873 }
874 
875 void
876 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
877                                     llvm::Function *Fn) {
878   if (!haveRegionCounts())
879     return;
880 
881   uint64_t FunctionCount = getRegionCount(nullptr);
882   Fn->setEntryCount(FunctionCount);
883 }
884 
885 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
886                                       llvm::Value *StepV) {
887   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
888     return;
889   if (!Builder.GetInsertBlock())
890     return;
891 
892   unsigned Counter = (*RegionCounterMap)[S];
893   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
894 
895   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
896                          Builder.getInt64(FunctionHash),
897                          Builder.getInt32(NumRegionCounters),
898                          Builder.getInt32(Counter), StepV};
899   if (!StepV)
900     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
901                        makeArrayRef(Args, 4));
902   else
903     Builder.CreateCall(
904         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
905         makeArrayRef(Args));
906 }
907 
908 // This method either inserts a call to the profile run-time during
909 // instrumentation or puts profile data into metadata for PGO use.
910 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
911     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
912 
913   if (!EnableValueProfiling)
914     return;
915 
916   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
917     return;
918 
919   if (isa<llvm::Constant>(ValuePtr))
920     return;
921 
922   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
923   if (InstrumentValueSites && RegionCounterMap) {
924     auto BuilderInsertPoint = Builder.saveIP();
925     Builder.SetInsertPoint(ValueSite);
926     llvm::Value *Args[5] = {
927         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
928         Builder.getInt64(FunctionHash),
929         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
930         Builder.getInt32(ValueKind),
931         Builder.getInt32(NumValueSites[ValueKind]++)
932     };
933     Builder.CreateCall(
934         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
935     Builder.restoreIP(BuilderInsertPoint);
936     return;
937   }
938 
939   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
940   if (PGOReader && haveRegionCounts()) {
941     // We record the top most called three functions at each call site.
942     // Profile metadata contains "VP" string identifying this metadata
943     // as value profiling data, then a uint32_t value for the value profiling
944     // kind, a uint64_t value for the total number of times the call is
945     // executed, followed by the function hash and execution count (uint64_t)
946     // pairs for each function.
947     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
948       return;
949 
950     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
951                             (llvm::InstrProfValueKind)ValueKind,
952                             NumValueSites[ValueKind]);
953 
954     NumValueSites[ValueKind]++;
955   }
956 }
957 
958 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
959                                   bool IsInMainFile) {
960   CGM.getPGOStats().addVisited(IsInMainFile);
961   RegionCounts.clear();
962   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
963       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
964   if (auto E = RecordExpected.takeError()) {
965     auto IPE = llvm::InstrProfError::take(std::move(E));
966     if (IPE == llvm::instrprof_error::unknown_function)
967       CGM.getPGOStats().addMissing(IsInMainFile);
968     else if (IPE == llvm::instrprof_error::hash_mismatch)
969       CGM.getPGOStats().addMismatched(IsInMainFile);
970     else if (IPE == llvm::instrprof_error::malformed)
971       // TODO: Consider a more specific warning for this case.
972       CGM.getPGOStats().addMismatched(IsInMainFile);
973     return;
974   }
975   ProfRecord =
976       llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
977   RegionCounts = ProfRecord->Counts;
978 }
979 
980 /// \brief Calculate what to divide by to scale weights.
981 ///
982 /// Given the maximum weight, calculate a divisor that will scale all the
983 /// weights to strictly less than UINT32_MAX.
984 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
985   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
986 }
987 
988 /// \brief Scale an individual branch weight (and add 1).
989 ///
990 /// Scale a 64-bit weight down to 32-bits using \c Scale.
991 ///
992 /// According to Laplace's Rule of Succession, it is better to compute the
993 /// weight based on the count plus 1, so universally add 1 to the value.
994 ///
995 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
996 /// greater than \c Weight.
997 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
998   assert(Scale && "scale by 0?");
999   uint64_t Scaled = Weight / Scale + 1;
1000   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1001   return Scaled;
1002 }
1003 
1004 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1005                                                     uint64_t FalseCount) {
1006   // Check for empty weights.
1007   if (!TrueCount && !FalseCount)
1008     return nullptr;
1009 
1010   // Calculate how to scale down to 32-bits.
1011   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1012 
1013   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1014   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1015                                       scaleBranchWeight(FalseCount, Scale));
1016 }
1017 
1018 llvm::MDNode *
1019 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1020   // We need at least two elements to create meaningful weights.
1021   if (Weights.size() < 2)
1022     return nullptr;
1023 
1024   // Check for empty weights.
1025   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1026   if (MaxWeight == 0)
1027     return nullptr;
1028 
1029   // Calculate how to scale down to 32-bits.
1030   uint64_t Scale = calculateWeightScale(MaxWeight);
1031 
1032   SmallVector<uint32_t, 16> ScaledWeights;
1033   ScaledWeights.reserve(Weights.size());
1034   for (uint64_t W : Weights)
1035     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1036 
1037   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1038   return MDHelper.createBranchWeights(ScaledWeights);
1039 }
1040 
1041 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1042                                                            uint64_t LoopCount) {
1043   if (!PGO.haveRegionCounts())
1044     return nullptr;
1045   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1046   assert(CondCount.hasValue() && "missing expected loop condition count");
1047   if (*CondCount == 0)
1048     return nullptr;
1049   return createProfileWeights(LoopCount,
1050                               std::max(*CondCount, LoopCount) - LoopCount);
1051 }
1052