1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool> EnableValueProfiling(
26   "enable-value-profiling", llvm::cl::ZeroOrMore,
27   llvm::cl::desc("Enable value profiling"), llvm::cl::init(false));
28 
29 using namespace clang;
30 using namespace CodeGen;
31 
32 void CodeGenPGO::setFuncName(StringRef Name,
33                              llvm::GlobalValue::LinkageTypes Linkage) {
34   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35   FuncName = llvm::getPGOFuncName(
36       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
38 
39   // If we're generating a profile, create a variable for the name.
40   if (CGM.getCodeGenOpts().hasProfileClangInstr())
41     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
42 }
43 
44 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45   setFuncName(Fn->getName(), Fn->getLinkage());
46 }
47 
48 namespace {
49 /// \brief Stable hasher for PGO region counters.
50 ///
51 /// PGOHash produces a stable hash of a given function's control flow.
52 ///
53 /// Changing the output of this hash will invalidate all previously generated
54 /// profiles -- i.e., don't do it.
55 ///
56 /// \note  When this hash does eventually change (years?), we still need to
57 /// support old hashes.  We'll need to pull in the version number from the
58 /// profile data format and use the matching hash function.
59 class PGOHash {
60   uint64_t Working;
61   unsigned Count;
62   llvm::MD5 MD5;
63 
64   static const int NumBitsPerType = 6;
65   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
66   static const unsigned TooBig = 1u << NumBitsPerType;
67 
68 public:
69   /// \brief Hash values for AST nodes.
70   ///
71   /// Distinct values for AST nodes that have region counters attached.
72   ///
73   /// These values must be stable.  All new members must be added at the end,
74   /// and no members should be removed.  Changing the enumeration value for an
75   /// AST node will affect the hash of every function that contains that node.
76   enum HashType : unsigned char {
77     None = 0,
78     LabelStmt = 1,
79     WhileStmt,
80     DoStmt,
81     ForStmt,
82     CXXForRangeStmt,
83     ObjCForCollectionStmt,
84     SwitchStmt,
85     CaseStmt,
86     DefaultStmt,
87     IfStmt,
88     CXXTryStmt,
89     CXXCatchStmt,
90     ConditionalOperator,
91     BinaryOperatorLAnd,
92     BinaryOperatorLOr,
93     BinaryConditionalOperator,
94 
95     // Keep this last.  It's for the static assert that follows.
96     LastHashType
97   };
98   static_assert(LastHashType <= TooBig, "Too many types in HashType");
99 
100   // TODO: When this format changes, take in a version number here, and use the
101   // old hash calculation for file formats that used the old hash.
102   PGOHash() : Working(0), Count(0) {}
103   void combine(HashType Type);
104   uint64_t finalize();
105 };
106 const int PGOHash::NumBitsPerType;
107 const unsigned PGOHash::NumTypesPerWord;
108 const unsigned PGOHash::TooBig;
109 
110 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
111 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
112   /// The next counter value to assign.
113   unsigned NextCounter;
114   /// The function hash.
115   PGOHash Hash;
116   /// The map of statements to counters.
117   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
118 
119   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
120       : NextCounter(0), CounterMap(CounterMap) {}
121 
122   // Blocks and lambdas are handled as separate functions, so we need not
123   // traverse them in the parent context.
124   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
125   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
126   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
127 
128   bool VisitDecl(const Decl *D) {
129     switch (D->getKind()) {
130     default:
131       break;
132     case Decl::Function:
133     case Decl::CXXMethod:
134     case Decl::CXXConstructor:
135     case Decl::CXXDestructor:
136     case Decl::CXXConversion:
137     case Decl::ObjCMethod:
138     case Decl::Block:
139     case Decl::Captured:
140       CounterMap[D->getBody()] = NextCounter++;
141       break;
142     }
143     return true;
144   }
145 
146   bool VisitStmt(const Stmt *S) {
147     auto Type = getHashType(S);
148     if (Type == PGOHash::None)
149       return true;
150 
151     CounterMap[S] = NextCounter++;
152     Hash.combine(Type);
153     return true;
154   }
155   PGOHash::HashType getHashType(const Stmt *S) {
156     switch (S->getStmtClass()) {
157     default:
158       break;
159     case Stmt::LabelStmtClass:
160       return PGOHash::LabelStmt;
161     case Stmt::WhileStmtClass:
162       return PGOHash::WhileStmt;
163     case Stmt::DoStmtClass:
164       return PGOHash::DoStmt;
165     case Stmt::ForStmtClass:
166       return PGOHash::ForStmt;
167     case Stmt::CXXForRangeStmtClass:
168       return PGOHash::CXXForRangeStmt;
169     case Stmt::ObjCForCollectionStmtClass:
170       return PGOHash::ObjCForCollectionStmt;
171     case Stmt::SwitchStmtClass:
172       return PGOHash::SwitchStmt;
173     case Stmt::CaseStmtClass:
174       return PGOHash::CaseStmt;
175     case Stmt::DefaultStmtClass:
176       return PGOHash::DefaultStmt;
177     case Stmt::IfStmtClass:
178       return PGOHash::IfStmt;
179     case Stmt::CXXTryStmtClass:
180       return PGOHash::CXXTryStmt;
181     case Stmt::CXXCatchStmtClass:
182       return PGOHash::CXXCatchStmt;
183     case Stmt::ConditionalOperatorClass:
184       return PGOHash::ConditionalOperator;
185     case Stmt::BinaryConditionalOperatorClass:
186       return PGOHash::BinaryConditionalOperator;
187     case Stmt::BinaryOperatorClass: {
188       const BinaryOperator *BO = cast<BinaryOperator>(S);
189       if (BO->getOpcode() == BO_LAnd)
190         return PGOHash::BinaryOperatorLAnd;
191       if (BO->getOpcode() == BO_LOr)
192         return PGOHash::BinaryOperatorLOr;
193       break;
194     }
195     }
196     return PGOHash::None;
197   }
198 };
199 
200 /// A StmtVisitor that propagates the raw counts through the AST and
201 /// records the count at statements where the value may change.
202 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
203   /// PGO state.
204   CodeGenPGO &PGO;
205 
206   /// A flag that is set when the current count should be recorded on the
207   /// next statement, such as at the exit of a loop.
208   bool RecordNextStmtCount;
209 
210   /// The count at the current location in the traversal.
211   uint64_t CurrentCount;
212 
213   /// The map of statements to count values.
214   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
215 
216   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
217   struct BreakContinue {
218     uint64_t BreakCount;
219     uint64_t ContinueCount;
220     BreakContinue() : BreakCount(0), ContinueCount(0) {}
221   };
222   SmallVector<BreakContinue, 8> BreakContinueStack;
223 
224   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
225                       CodeGenPGO &PGO)
226       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
227 
228   void RecordStmtCount(const Stmt *S) {
229     if (RecordNextStmtCount) {
230       CountMap[S] = CurrentCount;
231       RecordNextStmtCount = false;
232     }
233   }
234 
235   /// Set and return the current count.
236   uint64_t setCount(uint64_t Count) {
237     CurrentCount = Count;
238     return Count;
239   }
240 
241   void VisitStmt(const Stmt *S) {
242     RecordStmtCount(S);
243     for (const Stmt *Child : S->children())
244       if (Child)
245         this->Visit(Child);
246   }
247 
248   void VisitFunctionDecl(const FunctionDecl *D) {
249     // Counter tracks entry to the function body.
250     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
251     CountMap[D->getBody()] = BodyCount;
252     Visit(D->getBody());
253   }
254 
255   // Skip lambda expressions. We visit these as FunctionDecls when we're
256   // generating them and aren't interested in the body when generating a
257   // parent context.
258   void VisitLambdaExpr(const LambdaExpr *LE) {}
259 
260   void VisitCapturedDecl(const CapturedDecl *D) {
261     // Counter tracks entry to the capture body.
262     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
263     CountMap[D->getBody()] = BodyCount;
264     Visit(D->getBody());
265   }
266 
267   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
268     // Counter tracks entry to the method body.
269     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
270     CountMap[D->getBody()] = BodyCount;
271     Visit(D->getBody());
272   }
273 
274   void VisitBlockDecl(const BlockDecl *D) {
275     // Counter tracks entry to the block body.
276     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
277     CountMap[D->getBody()] = BodyCount;
278     Visit(D->getBody());
279   }
280 
281   void VisitReturnStmt(const ReturnStmt *S) {
282     RecordStmtCount(S);
283     if (S->getRetValue())
284       Visit(S->getRetValue());
285     CurrentCount = 0;
286     RecordNextStmtCount = true;
287   }
288 
289   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
290     RecordStmtCount(E);
291     if (E->getSubExpr())
292       Visit(E->getSubExpr());
293     CurrentCount = 0;
294     RecordNextStmtCount = true;
295   }
296 
297   void VisitGotoStmt(const GotoStmt *S) {
298     RecordStmtCount(S);
299     CurrentCount = 0;
300     RecordNextStmtCount = true;
301   }
302 
303   void VisitLabelStmt(const LabelStmt *S) {
304     RecordNextStmtCount = false;
305     // Counter tracks the block following the label.
306     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
307     CountMap[S] = BlockCount;
308     Visit(S->getSubStmt());
309   }
310 
311   void VisitBreakStmt(const BreakStmt *S) {
312     RecordStmtCount(S);
313     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
314     BreakContinueStack.back().BreakCount += CurrentCount;
315     CurrentCount = 0;
316     RecordNextStmtCount = true;
317   }
318 
319   void VisitContinueStmt(const ContinueStmt *S) {
320     RecordStmtCount(S);
321     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
322     BreakContinueStack.back().ContinueCount += CurrentCount;
323     CurrentCount = 0;
324     RecordNextStmtCount = true;
325   }
326 
327   void VisitWhileStmt(const WhileStmt *S) {
328     RecordStmtCount(S);
329     uint64_t ParentCount = CurrentCount;
330 
331     BreakContinueStack.push_back(BreakContinue());
332     // Visit the body region first so the break/continue adjustments can be
333     // included when visiting the condition.
334     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
335     CountMap[S->getBody()] = CurrentCount;
336     Visit(S->getBody());
337     uint64_t BackedgeCount = CurrentCount;
338 
339     // ...then go back and propagate counts through the condition. The count
340     // at the start of the condition is the sum of the incoming edges,
341     // the backedge from the end of the loop body, and the edges from
342     // continue statements.
343     BreakContinue BC = BreakContinueStack.pop_back_val();
344     uint64_t CondCount =
345         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
346     CountMap[S->getCond()] = CondCount;
347     Visit(S->getCond());
348     setCount(BC.BreakCount + CondCount - BodyCount);
349     RecordNextStmtCount = true;
350   }
351 
352   void VisitDoStmt(const DoStmt *S) {
353     RecordStmtCount(S);
354     uint64_t LoopCount = PGO.getRegionCount(S);
355 
356     BreakContinueStack.push_back(BreakContinue());
357     // The count doesn't include the fallthrough from the parent scope. Add it.
358     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
359     CountMap[S->getBody()] = BodyCount;
360     Visit(S->getBody());
361     uint64_t BackedgeCount = CurrentCount;
362 
363     BreakContinue BC = BreakContinueStack.pop_back_val();
364     // The count at the start of the condition is equal to the count at the
365     // end of the body, plus any continues.
366     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
367     CountMap[S->getCond()] = CondCount;
368     Visit(S->getCond());
369     setCount(BC.BreakCount + CondCount - LoopCount);
370     RecordNextStmtCount = true;
371   }
372 
373   void VisitForStmt(const ForStmt *S) {
374     RecordStmtCount(S);
375     if (S->getInit())
376       Visit(S->getInit());
377 
378     uint64_t ParentCount = CurrentCount;
379 
380     BreakContinueStack.push_back(BreakContinue());
381     // Visit the body region first. (This is basically the same as a while
382     // loop; see further comments in VisitWhileStmt.)
383     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
384     CountMap[S->getBody()] = BodyCount;
385     Visit(S->getBody());
386     uint64_t BackedgeCount = CurrentCount;
387     BreakContinue BC = BreakContinueStack.pop_back_val();
388 
389     // The increment is essentially part of the body but it needs to include
390     // the count for all the continue statements.
391     if (S->getInc()) {
392       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
393       CountMap[S->getInc()] = IncCount;
394       Visit(S->getInc());
395     }
396 
397     // ...then go back and propagate counts through the condition.
398     uint64_t CondCount =
399         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
400     if (S->getCond()) {
401       CountMap[S->getCond()] = CondCount;
402       Visit(S->getCond());
403     }
404     setCount(BC.BreakCount + CondCount - BodyCount);
405     RecordNextStmtCount = true;
406   }
407 
408   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
409     RecordStmtCount(S);
410     Visit(S->getLoopVarStmt());
411     Visit(S->getRangeStmt());
412     Visit(S->getBeginStmt());
413     Visit(S->getEndStmt());
414 
415     uint64_t ParentCount = CurrentCount;
416     BreakContinueStack.push_back(BreakContinue());
417     // Visit the body region first. (This is basically the same as a while
418     // loop; see further comments in VisitWhileStmt.)
419     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
420     CountMap[S->getBody()] = BodyCount;
421     Visit(S->getBody());
422     uint64_t BackedgeCount = CurrentCount;
423     BreakContinue BC = BreakContinueStack.pop_back_val();
424 
425     // The increment is essentially part of the body but it needs to include
426     // the count for all the continue statements.
427     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
428     CountMap[S->getInc()] = IncCount;
429     Visit(S->getInc());
430 
431     // ...then go back and propagate counts through the condition.
432     uint64_t CondCount =
433         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
434     CountMap[S->getCond()] = CondCount;
435     Visit(S->getCond());
436     setCount(BC.BreakCount + CondCount - BodyCount);
437     RecordNextStmtCount = true;
438   }
439 
440   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
441     RecordStmtCount(S);
442     Visit(S->getElement());
443     uint64_t ParentCount = CurrentCount;
444     BreakContinueStack.push_back(BreakContinue());
445     // Counter tracks the body of the loop.
446     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
447     CountMap[S->getBody()] = BodyCount;
448     Visit(S->getBody());
449     uint64_t BackedgeCount = CurrentCount;
450     BreakContinue BC = BreakContinueStack.pop_back_val();
451 
452     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
453              BodyCount);
454     RecordNextStmtCount = true;
455   }
456 
457   void VisitSwitchStmt(const SwitchStmt *S) {
458     RecordStmtCount(S);
459     Visit(S->getCond());
460     CurrentCount = 0;
461     BreakContinueStack.push_back(BreakContinue());
462     Visit(S->getBody());
463     // If the switch is inside a loop, add the continue counts.
464     BreakContinue BC = BreakContinueStack.pop_back_val();
465     if (!BreakContinueStack.empty())
466       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
467     // Counter tracks the exit block of the switch.
468     setCount(PGO.getRegionCount(S));
469     RecordNextStmtCount = true;
470   }
471 
472   void VisitSwitchCase(const SwitchCase *S) {
473     RecordNextStmtCount = false;
474     // Counter for this particular case. This counts only jumps from the
475     // switch header and does not include fallthrough from the case before
476     // this one.
477     uint64_t CaseCount = PGO.getRegionCount(S);
478     setCount(CurrentCount + CaseCount);
479     // We need the count without fallthrough in the mapping, so it's more useful
480     // for branch probabilities.
481     CountMap[S] = CaseCount;
482     RecordNextStmtCount = true;
483     Visit(S->getSubStmt());
484   }
485 
486   void VisitIfStmt(const IfStmt *S) {
487     RecordStmtCount(S);
488     uint64_t ParentCount = CurrentCount;
489     Visit(S->getCond());
490 
491     // Counter tracks the "then" part of an if statement. The count for
492     // the "else" part, if it exists, will be calculated from this counter.
493     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
494     CountMap[S->getThen()] = ThenCount;
495     Visit(S->getThen());
496     uint64_t OutCount = CurrentCount;
497 
498     uint64_t ElseCount = ParentCount - ThenCount;
499     if (S->getElse()) {
500       setCount(ElseCount);
501       CountMap[S->getElse()] = ElseCount;
502       Visit(S->getElse());
503       OutCount += CurrentCount;
504     } else
505       OutCount += ElseCount;
506     setCount(OutCount);
507     RecordNextStmtCount = true;
508   }
509 
510   void VisitCXXTryStmt(const CXXTryStmt *S) {
511     RecordStmtCount(S);
512     Visit(S->getTryBlock());
513     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
514       Visit(S->getHandler(I));
515     // Counter tracks the continuation block of the try statement.
516     setCount(PGO.getRegionCount(S));
517     RecordNextStmtCount = true;
518   }
519 
520   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
521     RecordNextStmtCount = false;
522     // Counter tracks the catch statement's handler block.
523     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
524     CountMap[S] = CatchCount;
525     Visit(S->getHandlerBlock());
526   }
527 
528   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
529     RecordStmtCount(E);
530     uint64_t ParentCount = CurrentCount;
531     Visit(E->getCond());
532 
533     // Counter tracks the "true" part of a conditional operator. The
534     // count in the "false" part will be calculated from this counter.
535     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
536     CountMap[E->getTrueExpr()] = TrueCount;
537     Visit(E->getTrueExpr());
538     uint64_t OutCount = CurrentCount;
539 
540     uint64_t FalseCount = setCount(ParentCount - TrueCount);
541     CountMap[E->getFalseExpr()] = FalseCount;
542     Visit(E->getFalseExpr());
543     OutCount += CurrentCount;
544 
545     setCount(OutCount);
546     RecordNextStmtCount = true;
547   }
548 
549   void VisitBinLAnd(const BinaryOperator *E) {
550     RecordStmtCount(E);
551     uint64_t ParentCount = CurrentCount;
552     Visit(E->getLHS());
553     // Counter tracks the right hand side of a logical and operator.
554     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
555     CountMap[E->getRHS()] = RHSCount;
556     Visit(E->getRHS());
557     setCount(ParentCount + RHSCount - CurrentCount);
558     RecordNextStmtCount = true;
559   }
560 
561   void VisitBinLOr(const BinaryOperator *E) {
562     RecordStmtCount(E);
563     uint64_t ParentCount = CurrentCount;
564     Visit(E->getLHS());
565     // Counter tracks the right hand side of a logical or operator.
566     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
567     CountMap[E->getRHS()] = RHSCount;
568     Visit(E->getRHS());
569     setCount(ParentCount + RHSCount - CurrentCount);
570     RecordNextStmtCount = true;
571   }
572 };
573 } // end anonymous namespace
574 
575 void PGOHash::combine(HashType Type) {
576   // Check that we never combine 0 and only have six bits.
577   assert(Type && "Hash is invalid: unexpected type 0");
578   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
579 
580   // Pass through MD5 if enough work has built up.
581   if (Count && Count % NumTypesPerWord == 0) {
582     using namespace llvm::support;
583     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
584     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
585     Working = 0;
586   }
587 
588   // Accumulate the current type.
589   ++Count;
590   Working = Working << NumBitsPerType | Type;
591 }
592 
593 uint64_t PGOHash::finalize() {
594   // Use Working as the hash directly if we never used MD5.
595   if (Count <= NumTypesPerWord)
596     // No need to byte swap here, since none of the math was endian-dependent.
597     // This number will be byte-swapped as required on endianness transitions,
598     // so we will see the same value on the other side.
599     return Working;
600 
601   // Check for remaining work in Working.
602   if (Working)
603     MD5.update(Working);
604 
605   // Finalize the MD5 and return the hash.
606   llvm::MD5::MD5Result Result;
607   MD5.final(Result);
608   using namespace llvm::support;
609   return endian::read<uint64_t, little, unaligned>(Result);
610 }
611 
612 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
613   const Decl *D = GD.getDecl();
614   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
615   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
616   if (!InstrumentRegions && !PGOReader)
617     return;
618   if (D->isImplicit())
619     return;
620   // Constructors and destructors may be represented by several functions in IR.
621   // If so, instrument only base variant, others are implemented by delegation
622   // to the base one, it would be counted twice otherwise.
623   if (CGM.getTarget().getCXXABI().hasConstructorVariants() &&
624       ((isa<CXXConstructorDecl>(GD.getDecl()) &&
625         GD.getCtorType() != Ctor_Base) ||
626        (isa<CXXDestructorDecl>(GD.getDecl()) &&
627         GD.getDtorType() != Dtor_Base))) {
628       return;
629   }
630   CGM.ClearUnusedCoverageMapping(D);
631   setFuncName(Fn);
632 
633   mapRegionCounters(D);
634   if (CGM.getCodeGenOpts().CoverageMapping)
635     emitCounterRegionMapping(D);
636   if (PGOReader) {
637     SourceManager &SM = CGM.getContext().getSourceManager();
638     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
639     computeRegionCounts(D);
640     applyFunctionAttributes(PGOReader, Fn);
641   }
642 }
643 
644 void CodeGenPGO::mapRegionCounters(const Decl *D) {
645   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
646   MapRegionCounters Walker(*RegionCounterMap);
647   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
648     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
649   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
650     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
651   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
652     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
653   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
654     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
655   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
656   NumRegionCounters = Walker.NextCounter;
657   FunctionHash = Walker.Hash.finalize();
658 }
659 
660 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
661   if (SkipCoverageMapping)
662     return;
663   // Don't map the functions inside the system headers
664   auto Loc = D->getBody()->getLocStart();
665   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
666     return;
667 
668   std::string CoverageMapping;
669   llvm::raw_string_ostream OS(CoverageMapping);
670   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
671                                 CGM.getContext().getSourceManager(),
672                                 CGM.getLangOpts(), RegionCounterMap.get());
673   MappingGen.emitCounterMapping(D, OS);
674   OS.flush();
675 
676   if (CoverageMapping.empty())
677     return;
678 
679   CGM.getCoverageMapping()->addFunctionMappingRecord(
680       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
681 }
682 
683 void
684 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
685                                     llvm::GlobalValue::LinkageTypes Linkage) {
686   if (SkipCoverageMapping)
687     return;
688   // Don't map the functions inside the system headers
689   auto Loc = D->getBody()->getLocStart();
690   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
691     return;
692 
693   std::string CoverageMapping;
694   llvm::raw_string_ostream OS(CoverageMapping);
695   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
696                                 CGM.getContext().getSourceManager(),
697                                 CGM.getLangOpts());
698   MappingGen.emitEmptyMapping(D, OS);
699   OS.flush();
700 
701   if (CoverageMapping.empty())
702     return;
703 
704   setFuncName(Name, Linkage);
705   CGM.getCoverageMapping()->addFunctionMappingRecord(
706       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
707 }
708 
709 void CodeGenPGO::computeRegionCounts(const Decl *D) {
710   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
711   ComputeRegionCounts Walker(*StmtCountMap, *this);
712   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
713     Walker.VisitFunctionDecl(FD);
714   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
715     Walker.VisitObjCMethodDecl(MD);
716   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
717     Walker.VisitBlockDecl(BD);
718   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
719     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
720 }
721 
722 void
723 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
724                                     llvm::Function *Fn) {
725   if (!haveRegionCounts())
726     return;
727 
728   uint64_t FunctionCount = getRegionCount(nullptr);
729   Fn->setEntryCount(FunctionCount);
730 }
731 
732 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
733   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
734     return;
735   if (!Builder.GetInsertBlock())
736     return;
737 
738   unsigned Counter = (*RegionCounterMap)[S];
739   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
740   Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
741                      {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
742                       Builder.getInt64(FunctionHash),
743                       Builder.getInt32(NumRegionCounters),
744                       Builder.getInt32(Counter)});
745 }
746 
747 // This method either inserts a call to the profile run-time during
748 // instrumentation or puts profile data into metadata for PGO use.
749 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
750     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
751 
752   if (!EnableValueProfiling)
753     return;
754 
755   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
756     return;
757 
758   if (isa<llvm::Constant>(ValuePtr))
759     return;
760 
761   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
762   if (InstrumentValueSites && RegionCounterMap) {
763     auto BuilderInsertPoint = Builder.saveIP();
764     Builder.SetInsertPoint(ValueSite);
765     llvm::Value *Args[5] = {
766         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
767         Builder.getInt64(FunctionHash),
768         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
769         Builder.getInt32(ValueKind),
770         Builder.getInt32(NumValueSites[ValueKind]++)
771     };
772     Builder.CreateCall(
773         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
774     Builder.restoreIP(BuilderInsertPoint);
775     return;
776   }
777 
778   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
779   if (PGOReader && haveRegionCounts()) {
780     // We record the top most called three functions at each call site.
781     // Profile metadata contains "VP" string identifying this metadata
782     // as value profiling data, then a uint32_t value for the value profiling
783     // kind, a uint64_t value for the total number of times the call is
784     // executed, followed by the function hash and execution count (uint64_t)
785     // pairs for each function.
786     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
787       return;
788 
789     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
790                             (llvm::InstrProfValueKind)ValueKind,
791                             NumValueSites[ValueKind]);
792 
793     NumValueSites[ValueKind]++;
794   }
795 }
796 
797 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
798                                   bool IsInMainFile) {
799   CGM.getPGOStats().addVisited(IsInMainFile);
800   RegionCounts.clear();
801   llvm::ErrorOr<llvm::InstrProfRecord> RecordErrorOr =
802       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
803   if (std::error_code EC = RecordErrorOr.getError()) {
804     if (EC == llvm::instrprof_error::unknown_function)
805       CGM.getPGOStats().addMissing(IsInMainFile);
806     else if (EC == llvm::instrprof_error::hash_mismatch)
807       CGM.getPGOStats().addMismatched(IsInMainFile);
808     else if (EC == llvm::instrprof_error::malformed)
809       // TODO: Consider a more specific warning for this case.
810       CGM.getPGOStats().addMismatched(IsInMainFile);
811     return;
812   }
813   ProfRecord =
814       llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordErrorOr.get()));
815   RegionCounts = ProfRecord->Counts;
816 }
817 
818 /// \brief Calculate what to divide by to scale weights.
819 ///
820 /// Given the maximum weight, calculate a divisor that will scale all the
821 /// weights to strictly less than UINT32_MAX.
822 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
823   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
824 }
825 
826 /// \brief Scale an individual branch weight (and add 1).
827 ///
828 /// Scale a 64-bit weight down to 32-bits using \c Scale.
829 ///
830 /// According to Laplace's Rule of Succession, it is better to compute the
831 /// weight based on the count plus 1, so universally add 1 to the value.
832 ///
833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
834 /// greater than \c Weight.
835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
836   assert(Scale && "scale by 0?");
837   uint64_t Scaled = Weight / Scale + 1;
838   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
839   return Scaled;
840 }
841 
842 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
843                                                     uint64_t FalseCount) {
844   // Check for empty weights.
845   if (!TrueCount && !FalseCount)
846     return nullptr;
847 
848   // Calculate how to scale down to 32-bits.
849   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
850 
851   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
852   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
853                                       scaleBranchWeight(FalseCount, Scale));
854 }
855 
856 llvm::MDNode *
857 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
858   // We need at least two elements to create meaningful weights.
859   if (Weights.size() < 2)
860     return nullptr;
861 
862   // Check for empty weights.
863   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
864   if (MaxWeight == 0)
865     return nullptr;
866 
867   // Calculate how to scale down to 32-bits.
868   uint64_t Scale = calculateWeightScale(MaxWeight);
869 
870   SmallVector<uint32_t, 16> ScaledWeights;
871   ScaledWeights.reserve(Weights.size());
872   for (uint64_t W : Weights)
873     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
874 
875   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
876   return MDHelper.createBranchWeights(ScaledWeights);
877 }
878 
879 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
880                                                            uint64_t LoopCount) {
881   if (!PGO.haveRegionCounts())
882     return nullptr;
883   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
884   assert(CondCount.hasValue() && "missing expected loop condition count");
885   if (*CondCount == 0)
886     return nullptr;
887   return createProfileWeights(LoopCount,
888                               std::max(*CondCount, LoopCount) - LoopCount);
889 }
890