1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55 
56   // Keep this set to the latest hash version.
57   PGO_HASH_LATEST = PGO_HASH_V2
58 };
59 
60 namespace {
61 /// Stable hasher for PGO region counters.
62 ///
63 /// PGOHash produces a stable hash of a given function's control flow.
64 ///
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
67 ///
68 /// \note  When this hash does eventually change (years?), we still need to
69 /// support old hashes.  We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
71 class PGOHash {
72   uint64_t Working;
73   unsigned Count;
74   PGOHashVersion HashVersion;
75   llvm::MD5 MD5;
76 
77   static const int NumBitsPerType = 6;
78   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79   static const unsigned TooBig = 1u << NumBitsPerType;
80 
81 public:
82   /// Hash values for AST nodes.
83   ///
84   /// Distinct values for AST nodes that have region counters attached.
85   ///
86   /// These values must be stable.  All new members must be added at the end,
87   /// and no members should be removed.  Changing the enumeration value for an
88   /// AST node will affect the hash of every function that contains that node.
89   enum HashType : unsigned char {
90     None = 0,
91     LabelStmt = 1,
92     WhileStmt,
93     DoStmt,
94     ForStmt,
95     CXXForRangeStmt,
96     ObjCForCollectionStmt,
97     SwitchStmt,
98     CaseStmt,
99     DefaultStmt,
100     IfStmt,
101     CXXTryStmt,
102     CXXCatchStmt,
103     ConditionalOperator,
104     BinaryOperatorLAnd,
105     BinaryOperatorLOr,
106     BinaryConditionalOperator,
107     // The preceding values are available with PGO_HASH_V1.
108 
109     EndOfScope,
110     IfThenBranch,
111     IfElseBranch,
112     GotoStmt,
113     IndirectGotoStmt,
114     BreakStmt,
115     ContinueStmt,
116     ReturnStmt,
117     ThrowExpr,
118     UnaryOperatorLNot,
119     BinaryOperatorLT,
120     BinaryOperatorGT,
121     BinaryOperatorLE,
122     BinaryOperatorGE,
123     BinaryOperatorEQ,
124     BinaryOperatorNE,
125     // The preceding values are available with PGO_HASH_V2.
126 
127     // Keep this last.  It's for the static assert that follows.
128     LastHashType
129   };
130   static_assert(LastHashType <= TooBig, "Too many types in HashType");
131 
132   PGOHash(PGOHashVersion HashVersion)
133       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134   void combine(HashType Type);
135   uint64_t finalize();
136   PGOHashVersion getHashVersion() const { return HashVersion; }
137 };
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
141 
142 /// Get the PGO hash version used in the given indexed profile.
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144                                         CodeGenModule &CGM) {
145   if (PGOReader->getVersion() <= 4)
146     return PGO_HASH_V1;
147   return PGO_HASH_V2;
148 }
149 
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152   using Base = RecursiveASTVisitor<MapRegionCounters>;
153 
154   /// The next counter value to assign.
155   unsigned NextCounter;
156   /// The function hash.
157   PGOHash Hash;
158   /// The map of statements to counters.
159   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160 
161   MapRegionCounters(PGOHashVersion HashVersion,
162                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164 
165   // Blocks and lambdas are handled as separate functions, so we need not
166   // traverse them in the parent context.
167   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168   bool TraverseLambdaExpr(LambdaExpr *LE) {
169     // Traverse the captures, but not the body.
170     for (auto C : zip(LE->captures(), LE->capture_inits()))
171       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
172     return true;
173   }
174   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
175 
176   bool VisitDecl(const Decl *D) {
177     switch (D->getKind()) {
178     default:
179       break;
180     case Decl::Function:
181     case Decl::CXXMethod:
182     case Decl::CXXConstructor:
183     case Decl::CXXDestructor:
184     case Decl::CXXConversion:
185     case Decl::ObjCMethod:
186     case Decl::Block:
187     case Decl::Captured:
188       CounterMap[D->getBody()] = NextCounter++;
189       break;
190     }
191     return true;
192   }
193 
194   /// If \p S gets a fresh counter, update the counter mappings. Return the
195   /// V1 hash of \p S.
196   PGOHash::HashType updateCounterMappings(Stmt *S) {
197     auto Type = getHashType(PGO_HASH_V1, S);
198     if (Type != PGOHash::None)
199       CounterMap[S] = NextCounter++;
200     return Type;
201   }
202 
203   /// Include \p S in the function hash.
204   bool VisitStmt(Stmt *S) {
205     auto Type = updateCounterMappings(S);
206     if (Hash.getHashVersion() != PGO_HASH_V1)
207       Type = getHashType(Hash.getHashVersion(), S);
208     if (Type != PGOHash::None)
209       Hash.combine(Type);
210     return true;
211   }
212 
213   bool TraverseIfStmt(IfStmt *If) {
214     // If we used the V1 hash, use the default traversal.
215     if (Hash.getHashVersion() == PGO_HASH_V1)
216       return Base::TraverseIfStmt(If);
217 
218     // Otherwise, keep track of which branch we're in while traversing.
219     VisitStmt(If);
220     for (Stmt *CS : If->children()) {
221       if (!CS)
222         continue;
223       if (CS == If->getThen())
224         Hash.combine(PGOHash::IfThenBranch);
225       else if (CS == If->getElse())
226         Hash.combine(PGOHash::IfElseBranch);
227       TraverseStmt(CS);
228     }
229     Hash.combine(PGOHash::EndOfScope);
230     return true;
231   }
232 
233 // If the statement type \p N is nestable, and its nesting impacts profile
234 // stability, define a custom traversal which tracks the end of the statement
235 // in the hash (provided we're not using the V1 hash).
236 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
237   bool Traverse##N(N *S) {                                                     \
238     Base::Traverse##N(S);                                                      \
239     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240       Hash.combine(PGOHash::EndOfScope);                                       \
241     return true;                                                               \
242   }
243 
244   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
245   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
246   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
247   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
248   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
249   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
250   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
251 
252   /// Get version \p HashVersion of the PGO hash for \p S.
253   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
254     switch (S->getStmtClass()) {
255     default:
256       break;
257     case Stmt::LabelStmtClass:
258       return PGOHash::LabelStmt;
259     case Stmt::WhileStmtClass:
260       return PGOHash::WhileStmt;
261     case Stmt::DoStmtClass:
262       return PGOHash::DoStmt;
263     case Stmt::ForStmtClass:
264       return PGOHash::ForStmt;
265     case Stmt::CXXForRangeStmtClass:
266       return PGOHash::CXXForRangeStmt;
267     case Stmt::ObjCForCollectionStmtClass:
268       return PGOHash::ObjCForCollectionStmt;
269     case Stmt::SwitchStmtClass:
270       return PGOHash::SwitchStmt;
271     case Stmt::CaseStmtClass:
272       return PGOHash::CaseStmt;
273     case Stmt::DefaultStmtClass:
274       return PGOHash::DefaultStmt;
275     case Stmt::IfStmtClass:
276       return PGOHash::IfStmt;
277     case Stmt::CXXTryStmtClass:
278       return PGOHash::CXXTryStmt;
279     case Stmt::CXXCatchStmtClass:
280       return PGOHash::CXXCatchStmt;
281     case Stmt::ConditionalOperatorClass:
282       return PGOHash::ConditionalOperator;
283     case Stmt::BinaryConditionalOperatorClass:
284       return PGOHash::BinaryConditionalOperator;
285     case Stmt::BinaryOperatorClass: {
286       const BinaryOperator *BO = cast<BinaryOperator>(S);
287       if (BO->getOpcode() == BO_LAnd)
288         return PGOHash::BinaryOperatorLAnd;
289       if (BO->getOpcode() == BO_LOr)
290         return PGOHash::BinaryOperatorLOr;
291       if (HashVersion == PGO_HASH_V2) {
292         switch (BO->getOpcode()) {
293         default:
294           break;
295         case BO_LT:
296           return PGOHash::BinaryOperatorLT;
297         case BO_GT:
298           return PGOHash::BinaryOperatorGT;
299         case BO_LE:
300           return PGOHash::BinaryOperatorLE;
301         case BO_GE:
302           return PGOHash::BinaryOperatorGE;
303         case BO_EQ:
304           return PGOHash::BinaryOperatorEQ;
305         case BO_NE:
306           return PGOHash::BinaryOperatorNE;
307         }
308       }
309       break;
310     }
311     }
312 
313     if (HashVersion == PGO_HASH_V2) {
314       switch (S->getStmtClass()) {
315       default:
316         break;
317       case Stmt::GotoStmtClass:
318         return PGOHash::GotoStmt;
319       case Stmt::IndirectGotoStmtClass:
320         return PGOHash::IndirectGotoStmt;
321       case Stmt::BreakStmtClass:
322         return PGOHash::BreakStmt;
323       case Stmt::ContinueStmtClass:
324         return PGOHash::ContinueStmt;
325       case Stmt::ReturnStmtClass:
326         return PGOHash::ReturnStmt;
327       case Stmt::CXXThrowExprClass:
328         return PGOHash::ThrowExpr;
329       case Stmt::UnaryOperatorClass: {
330         const UnaryOperator *UO = cast<UnaryOperator>(S);
331         if (UO->getOpcode() == UO_LNot)
332           return PGOHash::UnaryOperatorLNot;
333         break;
334       }
335       }
336     }
337 
338     return PGOHash::None;
339   }
340 };
341 
342 /// A StmtVisitor that propagates the raw counts through the AST and
343 /// records the count at statements where the value may change.
344 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
345   /// PGO state.
346   CodeGenPGO &PGO;
347 
348   /// A flag that is set when the current count should be recorded on the
349   /// next statement, such as at the exit of a loop.
350   bool RecordNextStmtCount;
351 
352   /// The count at the current location in the traversal.
353   uint64_t CurrentCount;
354 
355   /// The map of statements to count values.
356   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
357 
358   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
359   struct BreakContinue {
360     uint64_t BreakCount;
361     uint64_t ContinueCount;
362     BreakContinue() : BreakCount(0), ContinueCount(0) {}
363   };
364   SmallVector<BreakContinue, 8> BreakContinueStack;
365 
366   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
367                       CodeGenPGO &PGO)
368       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
369 
370   void RecordStmtCount(const Stmt *S) {
371     if (RecordNextStmtCount) {
372       CountMap[S] = CurrentCount;
373       RecordNextStmtCount = false;
374     }
375   }
376 
377   /// Set and return the current count.
378   uint64_t setCount(uint64_t Count) {
379     CurrentCount = Count;
380     return Count;
381   }
382 
383   void VisitStmt(const Stmt *S) {
384     RecordStmtCount(S);
385     for (const Stmt *Child : S->children())
386       if (Child)
387         this->Visit(Child);
388   }
389 
390   void VisitFunctionDecl(const FunctionDecl *D) {
391     // Counter tracks entry to the function body.
392     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
393     CountMap[D->getBody()] = BodyCount;
394     Visit(D->getBody());
395   }
396 
397   // Skip lambda expressions. We visit these as FunctionDecls when we're
398   // generating them and aren't interested in the body when generating a
399   // parent context.
400   void VisitLambdaExpr(const LambdaExpr *LE) {}
401 
402   void VisitCapturedDecl(const CapturedDecl *D) {
403     // Counter tracks entry to the capture body.
404     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
405     CountMap[D->getBody()] = BodyCount;
406     Visit(D->getBody());
407   }
408 
409   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410     // Counter tracks entry to the method body.
411     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412     CountMap[D->getBody()] = BodyCount;
413     Visit(D->getBody());
414   }
415 
416   void VisitBlockDecl(const BlockDecl *D) {
417     // Counter tracks entry to the block body.
418     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
419     CountMap[D->getBody()] = BodyCount;
420     Visit(D->getBody());
421   }
422 
423   void VisitReturnStmt(const ReturnStmt *S) {
424     RecordStmtCount(S);
425     if (S->getRetValue())
426       Visit(S->getRetValue());
427     CurrentCount = 0;
428     RecordNextStmtCount = true;
429   }
430 
431   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
432     RecordStmtCount(E);
433     if (E->getSubExpr())
434       Visit(E->getSubExpr());
435     CurrentCount = 0;
436     RecordNextStmtCount = true;
437   }
438 
439   void VisitGotoStmt(const GotoStmt *S) {
440     RecordStmtCount(S);
441     CurrentCount = 0;
442     RecordNextStmtCount = true;
443   }
444 
445   void VisitLabelStmt(const LabelStmt *S) {
446     RecordNextStmtCount = false;
447     // Counter tracks the block following the label.
448     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
449     CountMap[S] = BlockCount;
450     Visit(S->getSubStmt());
451   }
452 
453   void VisitBreakStmt(const BreakStmt *S) {
454     RecordStmtCount(S);
455     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
456     BreakContinueStack.back().BreakCount += CurrentCount;
457     CurrentCount = 0;
458     RecordNextStmtCount = true;
459   }
460 
461   void VisitContinueStmt(const ContinueStmt *S) {
462     RecordStmtCount(S);
463     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
464     BreakContinueStack.back().ContinueCount += CurrentCount;
465     CurrentCount = 0;
466     RecordNextStmtCount = true;
467   }
468 
469   void VisitWhileStmt(const WhileStmt *S) {
470     RecordStmtCount(S);
471     uint64_t ParentCount = CurrentCount;
472 
473     BreakContinueStack.push_back(BreakContinue());
474     // Visit the body region first so the break/continue adjustments can be
475     // included when visiting the condition.
476     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
477     CountMap[S->getBody()] = CurrentCount;
478     Visit(S->getBody());
479     uint64_t BackedgeCount = CurrentCount;
480 
481     // ...then go back and propagate counts through the condition. The count
482     // at the start of the condition is the sum of the incoming edges,
483     // the backedge from the end of the loop body, and the edges from
484     // continue statements.
485     BreakContinue BC = BreakContinueStack.pop_back_val();
486     uint64_t CondCount =
487         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
488     CountMap[S->getCond()] = CondCount;
489     Visit(S->getCond());
490     setCount(BC.BreakCount + CondCount - BodyCount);
491     RecordNextStmtCount = true;
492   }
493 
494   void VisitDoStmt(const DoStmt *S) {
495     RecordStmtCount(S);
496     uint64_t LoopCount = PGO.getRegionCount(S);
497 
498     BreakContinueStack.push_back(BreakContinue());
499     // The count doesn't include the fallthrough from the parent scope. Add it.
500     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
501     CountMap[S->getBody()] = BodyCount;
502     Visit(S->getBody());
503     uint64_t BackedgeCount = CurrentCount;
504 
505     BreakContinue BC = BreakContinueStack.pop_back_val();
506     // The count at the start of the condition is equal to the count at the
507     // end of the body, plus any continues.
508     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
509     CountMap[S->getCond()] = CondCount;
510     Visit(S->getCond());
511     setCount(BC.BreakCount + CondCount - LoopCount);
512     RecordNextStmtCount = true;
513   }
514 
515   void VisitForStmt(const ForStmt *S) {
516     RecordStmtCount(S);
517     if (S->getInit())
518       Visit(S->getInit());
519 
520     uint64_t ParentCount = CurrentCount;
521 
522     BreakContinueStack.push_back(BreakContinue());
523     // Visit the body region first. (This is basically the same as a while
524     // loop; see further comments in VisitWhileStmt.)
525     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
526     CountMap[S->getBody()] = BodyCount;
527     Visit(S->getBody());
528     uint64_t BackedgeCount = CurrentCount;
529     BreakContinue BC = BreakContinueStack.pop_back_val();
530 
531     // The increment is essentially part of the body but it needs to include
532     // the count for all the continue statements.
533     if (S->getInc()) {
534       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
535       CountMap[S->getInc()] = IncCount;
536       Visit(S->getInc());
537     }
538 
539     // ...then go back and propagate counts through the condition.
540     uint64_t CondCount =
541         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
542     if (S->getCond()) {
543       CountMap[S->getCond()] = CondCount;
544       Visit(S->getCond());
545     }
546     setCount(BC.BreakCount + CondCount - BodyCount);
547     RecordNextStmtCount = true;
548   }
549 
550   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
551     RecordStmtCount(S);
552     if (S->getInit())
553       Visit(S->getInit());
554     Visit(S->getLoopVarStmt());
555     Visit(S->getRangeStmt());
556     Visit(S->getBeginStmt());
557     Visit(S->getEndStmt());
558 
559     uint64_t ParentCount = CurrentCount;
560     BreakContinueStack.push_back(BreakContinue());
561     // Visit the body region first. (This is basically the same as a while
562     // loop; see further comments in VisitWhileStmt.)
563     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
564     CountMap[S->getBody()] = BodyCount;
565     Visit(S->getBody());
566     uint64_t BackedgeCount = CurrentCount;
567     BreakContinue BC = BreakContinueStack.pop_back_val();
568 
569     // The increment is essentially part of the body but it needs to include
570     // the count for all the continue statements.
571     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
572     CountMap[S->getInc()] = IncCount;
573     Visit(S->getInc());
574 
575     // ...then go back and propagate counts through the condition.
576     uint64_t CondCount =
577         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
578     CountMap[S->getCond()] = CondCount;
579     Visit(S->getCond());
580     setCount(BC.BreakCount + CondCount - BodyCount);
581     RecordNextStmtCount = true;
582   }
583 
584   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
585     RecordStmtCount(S);
586     Visit(S->getElement());
587     uint64_t ParentCount = CurrentCount;
588     BreakContinueStack.push_back(BreakContinue());
589     // Counter tracks the body of the loop.
590     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
591     CountMap[S->getBody()] = BodyCount;
592     Visit(S->getBody());
593     uint64_t BackedgeCount = CurrentCount;
594     BreakContinue BC = BreakContinueStack.pop_back_val();
595 
596     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
597              BodyCount);
598     RecordNextStmtCount = true;
599   }
600 
601   void VisitSwitchStmt(const SwitchStmt *S) {
602     RecordStmtCount(S);
603     if (S->getInit())
604       Visit(S->getInit());
605     Visit(S->getCond());
606     CurrentCount = 0;
607     BreakContinueStack.push_back(BreakContinue());
608     Visit(S->getBody());
609     // If the switch is inside a loop, add the continue counts.
610     BreakContinue BC = BreakContinueStack.pop_back_val();
611     if (!BreakContinueStack.empty())
612       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
613     // Counter tracks the exit block of the switch.
614     setCount(PGO.getRegionCount(S));
615     RecordNextStmtCount = true;
616   }
617 
618   void VisitSwitchCase(const SwitchCase *S) {
619     RecordNextStmtCount = false;
620     // Counter for this particular case. This counts only jumps from the
621     // switch header and does not include fallthrough from the case before
622     // this one.
623     uint64_t CaseCount = PGO.getRegionCount(S);
624     setCount(CurrentCount + CaseCount);
625     // We need the count without fallthrough in the mapping, so it's more useful
626     // for branch probabilities.
627     CountMap[S] = CaseCount;
628     RecordNextStmtCount = true;
629     Visit(S->getSubStmt());
630   }
631 
632   void VisitIfStmt(const IfStmt *S) {
633     RecordStmtCount(S);
634     uint64_t ParentCount = CurrentCount;
635     if (S->getInit())
636       Visit(S->getInit());
637     Visit(S->getCond());
638 
639     // Counter tracks the "then" part of an if statement. The count for
640     // the "else" part, if it exists, will be calculated from this counter.
641     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
642     CountMap[S->getThen()] = ThenCount;
643     Visit(S->getThen());
644     uint64_t OutCount = CurrentCount;
645 
646     uint64_t ElseCount = ParentCount - ThenCount;
647     if (S->getElse()) {
648       setCount(ElseCount);
649       CountMap[S->getElse()] = ElseCount;
650       Visit(S->getElse());
651       OutCount += CurrentCount;
652     } else
653       OutCount += ElseCount;
654     setCount(OutCount);
655     RecordNextStmtCount = true;
656   }
657 
658   void VisitCXXTryStmt(const CXXTryStmt *S) {
659     RecordStmtCount(S);
660     Visit(S->getTryBlock());
661     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
662       Visit(S->getHandler(I));
663     // Counter tracks the continuation block of the try statement.
664     setCount(PGO.getRegionCount(S));
665     RecordNextStmtCount = true;
666   }
667 
668   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
669     RecordNextStmtCount = false;
670     // Counter tracks the catch statement's handler block.
671     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
672     CountMap[S] = CatchCount;
673     Visit(S->getHandlerBlock());
674   }
675 
676   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
677     RecordStmtCount(E);
678     uint64_t ParentCount = CurrentCount;
679     Visit(E->getCond());
680 
681     // Counter tracks the "true" part of a conditional operator. The
682     // count in the "false" part will be calculated from this counter.
683     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
684     CountMap[E->getTrueExpr()] = TrueCount;
685     Visit(E->getTrueExpr());
686     uint64_t OutCount = CurrentCount;
687 
688     uint64_t FalseCount = setCount(ParentCount - TrueCount);
689     CountMap[E->getFalseExpr()] = FalseCount;
690     Visit(E->getFalseExpr());
691     OutCount += CurrentCount;
692 
693     setCount(OutCount);
694     RecordNextStmtCount = true;
695   }
696 
697   void VisitBinLAnd(const BinaryOperator *E) {
698     RecordStmtCount(E);
699     uint64_t ParentCount = CurrentCount;
700     Visit(E->getLHS());
701     // Counter tracks the right hand side of a logical and operator.
702     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
703     CountMap[E->getRHS()] = RHSCount;
704     Visit(E->getRHS());
705     setCount(ParentCount + RHSCount - CurrentCount);
706     RecordNextStmtCount = true;
707   }
708 
709   void VisitBinLOr(const BinaryOperator *E) {
710     RecordStmtCount(E);
711     uint64_t ParentCount = CurrentCount;
712     Visit(E->getLHS());
713     // Counter tracks the right hand side of a logical or operator.
714     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
715     CountMap[E->getRHS()] = RHSCount;
716     Visit(E->getRHS());
717     setCount(ParentCount + RHSCount - CurrentCount);
718     RecordNextStmtCount = true;
719   }
720 };
721 } // end anonymous namespace
722 
723 void PGOHash::combine(HashType Type) {
724   // Check that we never combine 0 and only have six bits.
725   assert(Type && "Hash is invalid: unexpected type 0");
726   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
727 
728   // Pass through MD5 if enough work has built up.
729   if (Count && Count % NumTypesPerWord == 0) {
730     using namespace llvm::support;
731     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
732     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
733     Working = 0;
734   }
735 
736   // Accumulate the current type.
737   ++Count;
738   Working = Working << NumBitsPerType | Type;
739 }
740 
741 uint64_t PGOHash::finalize() {
742   // Use Working as the hash directly if we never used MD5.
743   if (Count <= NumTypesPerWord)
744     // No need to byte swap here, since none of the math was endian-dependent.
745     // This number will be byte-swapped as required on endianness transitions,
746     // so we will see the same value on the other side.
747     return Working;
748 
749   // Check for remaining work in Working.
750   if (Working) {
751     using namespace llvm::support;
752     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
753     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
754   }
755 
756   // Finalize the MD5 and return the hash.
757   llvm::MD5::MD5Result Result;
758   MD5.final(Result);
759   return Result.low();
760 }
761 
762 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
763   const Decl *D = GD.getDecl();
764   if (!D->hasBody())
765     return;
766 
767   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
768   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
769   if (!InstrumentRegions && !PGOReader)
770     return;
771   if (D->isImplicit())
772     return;
773   // Constructors and destructors may be represented by several functions in IR.
774   // If so, instrument only base variant, others are implemented by delegation
775   // to the base one, it would be counted twice otherwise.
776   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
777     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
778       if (GD.getCtorType() != Ctor_Base &&
779           CodeGenFunction::IsConstructorDelegationValid(CCD))
780         return;
781   }
782   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
783     return;
784 
785   CGM.ClearUnusedCoverageMapping(D);
786   setFuncName(Fn);
787 
788   mapRegionCounters(D);
789   if (CGM.getCodeGenOpts().CoverageMapping)
790     emitCounterRegionMapping(D);
791   if (PGOReader) {
792     SourceManager &SM = CGM.getContext().getSourceManager();
793     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
794     computeRegionCounts(D);
795     applyFunctionAttributes(PGOReader, Fn);
796   }
797 }
798 
799 void CodeGenPGO::mapRegionCounters(const Decl *D) {
800   // Use the latest hash version when inserting instrumentation, but use the
801   // version in the indexed profile if we're reading PGO data.
802   PGOHashVersion HashVersion = PGO_HASH_LATEST;
803   if (auto *PGOReader = CGM.getPGOReader())
804     HashVersion = getPGOHashVersion(PGOReader, CGM);
805 
806   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
807   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
808   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
809     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
810   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
811     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
812   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
813     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
814   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
815     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
816   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
817   NumRegionCounters = Walker.NextCounter;
818   FunctionHash = Walker.Hash.finalize();
819 }
820 
821 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
822   if (!D->getBody())
823     return true;
824 
825   // Don't map the functions in system headers.
826   const auto &SM = CGM.getContext().getSourceManager();
827   auto Loc = D->getBody()->getBeginLoc();
828   return SM.isInSystemHeader(Loc);
829 }
830 
831 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
832   if (skipRegionMappingForDecl(D))
833     return;
834 
835   std::string CoverageMapping;
836   llvm::raw_string_ostream OS(CoverageMapping);
837   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
838                                 CGM.getContext().getSourceManager(),
839                                 CGM.getLangOpts(), RegionCounterMap.get());
840   MappingGen.emitCounterMapping(D, OS);
841   OS.flush();
842 
843   if (CoverageMapping.empty())
844     return;
845 
846   CGM.getCoverageMapping()->addFunctionMappingRecord(
847       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
848 }
849 
850 void
851 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
852                                     llvm::GlobalValue::LinkageTypes Linkage) {
853   if (skipRegionMappingForDecl(D))
854     return;
855 
856   std::string CoverageMapping;
857   llvm::raw_string_ostream OS(CoverageMapping);
858   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
859                                 CGM.getContext().getSourceManager(),
860                                 CGM.getLangOpts());
861   MappingGen.emitEmptyMapping(D, OS);
862   OS.flush();
863 
864   if (CoverageMapping.empty())
865     return;
866 
867   setFuncName(Name, Linkage);
868   CGM.getCoverageMapping()->addFunctionMappingRecord(
869       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
870 }
871 
872 void CodeGenPGO::computeRegionCounts(const Decl *D) {
873   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
874   ComputeRegionCounts Walker(*StmtCountMap, *this);
875   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
876     Walker.VisitFunctionDecl(FD);
877   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
878     Walker.VisitObjCMethodDecl(MD);
879   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
880     Walker.VisitBlockDecl(BD);
881   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
882     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
883 }
884 
885 void
886 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
887                                     llvm::Function *Fn) {
888   if (!haveRegionCounts())
889     return;
890 
891   uint64_t FunctionCount = getRegionCount(nullptr);
892   Fn->setEntryCount(FunctionCount);
893 }
894 
895 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
896                                       llvm::Value *StepV) {
897   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
898     return;
899   if (!Builder.GetInsertBlock())
900     return;
901 
902   unsigned Counter = (*RegionCounterMap)[S];
903   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
904 
905   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
906                          Builder.getInt64(FunctionHash),
907                          Builder.getInt32(NumRegionCounters),
908                          Builder.getInt32(Counter), StepV};
909   if (!StepV)
910     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
911                        makeArrayRef(Args, 4));
912   else
913     Builder.CreateCall(
914         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
915         makeArrayRef(Args));
916 }
917 
918 // This method either inserts a call to the profile run-time during
919 // instrumentation or puts profile data into metadata for PGO use.
920 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
921     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
922 
923   if (!EnableValueProfiling)
924     return;
925 
926   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
927     return;
928 
929   if (isa<llvm::Constant>(ValuePtr))
930     return;
931 
932   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
933   if (InstrumentValueSites && RegionCounterMap) {
934     auto BuilderInsertPoint = Builder.saveIP();
935     Builder.SetInsertPoint(ValueSite);
936     llvm::Value *Args[5] = {
937         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
938         Builder.getInt64(FunctionHash),
939         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
940         Builder.getInt32(ValueKind),
941         Builder.getInt32(NumValueSites[ValueKind]++)
942     };
943     Builder.CreateCall(
944         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
945     Builder.restoreIP(BuilderInsertPoint);
946     return;
947   }
948 
949   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
950   if (PGOReader && haveRegionCounts()) {
951     // We record the top most called three functions at each call site.
952     // Profile metadata contains "VP" string identifying this metadata
953     // as value profiling data, then a uint32_t value for the value profiling
954     // kind, a uint64_t value for the total number of times the call is
955     // executed, followed by the function hash and execution count (uint64_t)
956     // pairs for each function.
957     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
958       return;
959 
960     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
961                             (llvm::InstrProfValueKind)ValueKind,
962                             NumValueSites[ValueKind]);
963 
964     NumValueSites[ValueKind]++;
965   }
966 }
967 
968 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
969                                   bool IsInMainFile) {
970   CGM.getPGOStats().addVisited(IsInMainFile);
971   RegionCounts.clear();
972   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
973       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
974   if (auto E = RecordExpected.takeError()) {
975     auto IPE = llvm::InstrProfError::take(std::move(E));
976     if (IPE == llvm::instrprof_error::unknown_function)
977       CGM.getPGOStats().addMissing(IsInMainFile);
978     else if (IPE == llvm::instrprof_error::hash_mismatch)
979       CGM.getPGOStats().addMismatched(IsInMainFile);
980     else if (IPE == llvm::instrprof_error::malformed)
981       // TODO: Consider a more specific warning for this case.
982       CGM.getPGOStats().addMismatched(IsInMainFile);
983     return;
984   }
985   ProfRecord =
986       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
987   RegionCounts = ProfRecord->Counts;
988 }
989 
990 /// Calculate what to divide by to scale weights.
991 ///
992 /// Given the maximum weight, calculate a divisor that will scale all the
993 /// weights to strictly less than UINT32_MAX.
994 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
995   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
996 }
997 
998 /// Scale an individual branch weight (and add 1).
999 ///
1000 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1001 ///
1002 /// According to Laplace's Rule of Succession, it is better to compute the
1003 /// weight based on the count plus 1, so universally add 1 to the value.
1004 ///
1005 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1006 /// greater than \c Weight.
1007 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1008   assert(Scale && "scale by 0?");
1009   uint64_t Scaled = Weight / Scale + 1;
1010   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1011   return Scaled;
1012 }
1013 
1014 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1015                                                     uint64_t FalseCount) {
1016   // Check for empty weights.
1017   if (!TrueCount && !FalseCount)
1018     return nullptr;
1019 
1020   // Calculate how to scale down to 32-bits.
1021   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1022 
1023   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1024   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1025                                       scaleBranchWeight(FalseCount, Scale));
1026 }
1027 
1028 llvm::MDNode *
1029 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1030   // We need at least two elements to create meaningful weights.
1031   if (Weights.size() < 2)
1032     return nullptr;
1033 
1034   // Check for empty weights.
1035   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1036   if (MaxWeight == 0)
1037     return nullptr;
1038 
1039   // Calculate how to scale down to 32-bits.
1040   uint64_t Scale = calculateWeightScale(MaxWeight);
1041 
1042   SmallVector<uint32_t, 16> ScaledWeights;
1043   ScaledWeights.reserve(Weights.size());
1044   for (uint64_t W : Weights)
1045     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1046 
1047   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1048   return MDHelper.createBranchWeights(ScaledWeights);
1049 }
1050 
1051 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1052                                                            uint64_t LoopCount) {
1053   if (!PGO.haveRegionCounts())
1054     return nullptr;
1055   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1056   if (!CondCount || *CondCount == 0)
1057     return nullptr;
1058   return createProfileWeights(LoopCount,
1059                               std::max(*CondCount, LoopCount) - LoopCount);
1060 }
1061