1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25 
26 using namespace clang;
27 using namespace CodeGen;
28 
29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   FuncName = llvm::getPGOFuncName(Name, Linkage, CGM.getCodeGenOpts().MainFileName);
32 
33   // If we're generating a profile, create a variable for the name.
34   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
35     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
36 }
37 
38 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
39   setFuncName(Fn->getName(), Fn->getLinkage());
40 }
41 
42 namespace {
43 /// \brief Stable hasher for PGO region counters.
44 ///
45 /// PGOHash produces a stable hash of a given function's control flow.
46 ///
47 /// Changing the output of this hash will invalidate all previously generated
48 /// profiles -- i.e., don't do it.
49 ///
50 /// \note  When this hash does eventually change (years?), we still need to
51 /// support old hashes.  We'll need to pull in the version number from the
52 /// profile data format and use the matching hash function.
53 class PGOHash {
54   uint64_t Working;
55   unsigned Count;
56   llvm::MD5 MD5;
57 
58   static const int NumBitsPerType = 6;
59   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
60   static const unsigned TooBig = 1u << NumBitsPerType;
61 
62 public:
63   /// \brief Hash values for AST nodes.
64   ///
65   /// Distinct values for AST nodes that have region counters attached.
66   ///
67   /// These values must be stable.  All new members must be added at the end,
68   /// and no members should be removed.  Changing the enumeration value for an
69   /// AST node will affect the hash of every function that contains that node.
70   enum HashType : unsigned char {
71     None = 0,
72     LabelStmt = 1,
73     WhileStmt,
74     DoStmt,
75     ForStmt,
76     CXXForRangeStmt,
77     ObjCForCollectionStmt,
78     SwitchStmt,
79     CaseStmt,
80     DefaultStmt,
81     IfStmt,
82     CXXTryStmt,
83     CXXCatchStmt,
84     ConditionalOperator,
85     BinaryOperatorLAnd,
86     BinaryOperatorLOr,
87     BinaryConditionalOperator,
88 
89     // Keep this last.  It's for the static assert that follows.
90     LastHashType
91   };
92   static_assert(LastHashType <= TooBig, "Too many types in HashType");
93 
94   // TODO: When this format changes, take in a version number here, and use the
95   // old hash calculation for file formats that used the old hash.
96   PGOHash() : Working(0), Count(0) {}
97   void combine(HashType Type);
98   uint64_t finalize();
99 };
100 const int PGOHash::NumBitsPerType;
101 const unsigned PGOHash::NumTypesPerWord;
102 const unsigned PGOHash::TooBig;
103 
104 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
105 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
106   /// The next counter value to assign.
107   unsigned NextCounter;
108   /// The function hash.
109   PGOHash Hash;
110   /// The map of statements to counters.
111   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
112 
113   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
114       : NextCounter(0), CounterMap(CounterMap) {}
115 
116   // Blocks and lambdas are handled as separate functions, so we need not
117   // traverse them in the parent context.
118   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
119   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
120   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
121 
122   bool VisitDecl(const Decl *D) {
123     switch (D->getKind()) {
124     default:
125       break;
126     case Decl::Function:
127     case Decl::CXXMethod:
128     case Decl::CXXConstructor:
129     case Decl::CXXDestructor:
130     case Decl::CXXConversion:
131     case Decl::ObjCMethod:
132     case Decl::Block:
133     case Decl::Captured:
134       CounterMap[D->getBody()] = NextCounter++;
135       break;
136     }
137     return true;
138   }
139 
140   bool VisitStmt(const Stmt *S) {
141     auto Type = getHashType(S);
142     if (Type == PGOHash::None)
143       return true;
144 
145     CounterMap[S] = NextCounter++;
146     Hash.combine(Type);
147     return true;
148   }
149   PGOHash::HashType getHashType(const Stmt *S) {
150     switch (S->getStmtClass()) {
151     default:
152       break;
153     case Stmt::LabelStmtClass:
154       return PGOHash::LabelStmt;
155     case Stmt::WhileStmtClass:
156       return PGOHash::WhileStmt;
157     case Stmt::DoStmtClass:
158       return PGOHash::DoStmt;
159     case Stmt::ForStmtClass:
160       return PGOHash::ForStmt;
161     case Stmt::CXXForRangeStmtClass:
162       return PGOHash::CXXForRangeStmt;
163     case Stmt::ObjCForCollectionStmtClass:
164       return PGOHash::ObjCForCollectionStmt;
165     case Stmt::SwitchStmtClass:
166       return PGOHash::SwitchStmt;
167     case Stmt::CaseStmtClass:
168       return PGOHash::CaseStmt;
169     case Stmt::DefaultStmtClass:
170       return PGOHash::DefaultStmt;
171     case Stmt::IfStmtClass:
172       return PGOHash::IfStmt;
173     case Stmt::CXXTryStmtClass:
174       return PGOHash::CXXTryStmt;
175     case Stmt::CXXCatchStmtClass:
176       return PGOHash::CXXCatchStmt;
177     case Stmt::ConditionalOperatorClass:
178       return PGOHash::ConditionalOperator;
179     case Stmt::BinaryConditionalOperatorClass:
180       return PGOHash::BinaryConditionalOperator;
181     case Stmt::BinaryOperatorClass: {
182       const BinaryOperator *BO = cast<BinaryOperator>(S);
183       if (BO->getOpcode() == BO_LAnd)
184         return PGOHash::BinaryOperatorLAnd;
185       if (BO->getOpcode() == BO_LOr)
186         return PGOHash::BinaryOperatorLOr;
187       break;
188     }
189     }
190     return PGOHash::None;
191   }
192 };
193 
194 /// A StmtVisitor that propagates the raw counts through the AST and
195 /// records the count at statements where the value may change.
196 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
197   /// PGO state.
198   CodeGenPGO &PGO;
199 
200   /// A flag that is set when the current count should be recorded on the
201   /// next statement, such as at the exit of a loop.
202   bool RecordNextStmtCount;
203 
204   /// The count at the current location in the traversal.
205   uint64_t CurrentCount;
206 
207   /// The map of statements to count values.
208   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
209 
210   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
211   struct BreakContinue {
212     uint64_t BreakCount;
213     uint64_t ContinueCount;
214     BreakContinue() : BreakCount(0), ContinueCount(0) {}
215   };
216   SmallVector<BreakContinue, 8> BreakContinueStack;
217 
218   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
219                       CodeGenPGO &PGO)
220       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
221 
222   void RecordStmtCount(const Stmt *S) {
223     if (RecordNextStmtCount) {
224       CountMap[S] = CurrentCount;
225       RecordNextStmtCount = false;
226     }
227   }
228 
229   /// Set and return the current count.
230   uint64_t setCount(uint64_t Count) {
231     CurrentCount = Count;
232     return Count;
233   }
234 
235   void VisitStmt(const Stmt *S) {
236     RecordStmtCount(S);
237     for (const Stmt *Child : S->children())
238       if (Child)
239         this->Visit(Child);
240   }
241 
242   void VisitFunctionDecl(const FunctionDecl *D) {
243     // Counter tracks entry to the function body.
244     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
245     CountMap[D->getBody()] = BodyCount;
246     Visit(D->getBody());
247   }
248 
249   // Skip lambda expressions. We visit these as FunctionDecls when we're
250   // generating them and aren't interested in the body when generating a
251   // parent context.
252   void VisitLambdaExpr(const LambdaExpr *LE) {}
253 
254   void VisitCapturedDecl(const CapturedDecl *D) {
255     // Counter tracks entry to the capture body.
256     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
257     CountMap[D->getBody()] = BodyCount;
258     Visit(D->getBody());
259   }
260 
261   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
262     // Counter tracks entry to the method body.
263     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
264     CountMap[D->getBody()] = BodyCount;
265     Visit(D->getBody());
266   }
267 
268   void VisitBlockDecl(const BlockDecl *D) {
269     // Counter tracks entry to the block body.
270     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
271     CountMap[D->getBody()] = BodyCount;
272     Visit(D->getBody());
273   }
274 
275   void VisitReturnStmt(const ReturnStmt *S) {
276     RecordStmtCount(S);
277     if (S->getRetValue())
278       Visit(S->getRetValue());
279     CurrentCount = 0;
280     RecordNextStmtCount = true;
281   }
282 
283   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
284     RecordStmtCount(E);
285     if (E->getSubExpr())
286       Visit(E->getSubExpr());
287     CurrentCount = 0;
288     RecordNextStmtCount = true;
289   }
290 
291   void VisitGotoStmt(const GotoStmt *S) {
292     RecordStmtCount(S);
293     CurrentCount = 0;
294     RecordNextStmtCount = true;
295   }
296 
297   void VisitLabelStmt(const LabelStmt *S) {
298     RecordNextStmtCount = false;
299     // Counter tracks the block following the label.
300     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
301     CountMap[S] = BlockCount;
302     Visit(S->getSubStmt());
303   }
304 
305   void VisitBreakStmt(const BreakStmt *S) {
306     RecordStmtCount(S);
307     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
308     BreakContinueStack.back().BreakCount += CurrentCount;
309     CurrentCount = 0;
310     RecordNextStmtCount = true;
311   }
312 
313   void VisitContinueStmt(const ContinueStmt *S) {
314     RecordStmtCount(S);
315     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
316     BreakContinueStack.back().ContinueCount += CurrentCount;
317     CurrentCount = 0;
318     RecordNextStmtCount = true;
319   }
320 
321   void VisitWhileStmt(const WhileStmt *S) {
322     RecordStmtCount(S);
323     uint64_t ParentCount = CurrentCount;
324 
325     BreakContinueStack.push_back(BreakContinue());
326     // Visit the body region first so the break/continue adjustments can be
327     // included when visiting the condition.
328     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
329     CountMap[S->getBody()] = CurrentCount;
330     Visit(S->getBody());
331     uint64_t BackedgeCount = CurrentCount;
332 
333     // ...then go back and propagate counts through the condition. The count
334     // at the start of the condition is the sum of the incoming edges,
335     // the backedge from the end of the loop body, and the edges from
336     // continue statements.
337     BreakContinue BC = BreakContinueStack.pop_back_val();
338     uint64_t CondCount =
339         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
340     CountMap[S->getCond()] = CondCount;
341     Visit(S->getCond());
342     setCount(BC.BreakCount + CondCount - BodyCount);
343     RecordNextStmtCount = true;
344   }
345 
346   void VisitDoStmt(const DoStmt *S) {
347     RecordStmtCount(S);
348     uint64_t LoopCount = PGO.getRegionCount(S);
349 
350     BreakContinueStack.push_back(BreakContinue());
351     // The count doesn't include the fallthrough from the parent scope. Add it.
352     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
353     CountMap[S->getBody()] = BodyCount;
354     Visit(S->getBody());
355     uint64_t BackedgeCount = CurrentCount;
356 
357     BreakContinue BC = BreakContinueStack.pop_back_val();
358     // The count at the start of the condition is equal to the count at the
359     // end of the body, plus any continues.
360     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
361     CountMap[S->getCond()] = CondCount;
362     Visit(S->getCond());
363     setCount(BC.BreakCount + CondCount - LoopCount);
364     RecordNextStmtCount = true;
365   }
366 
367   void VisitForStmt(const ForStmt *S) {
368     RecordStmtCount(S);
369     if (S->getInit())
370       Visit(S->getInit());
371 
372     uint64_t ParentCount = CurrentCount;
373 
374     BreakContinueStack.push_back(BreakContinue());
375     // Visit the body region first. (This is basically the same as a while
376     // loop; see further comments in VisitWhileStmt.)
377     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
378     CountMap[S->getBody()] = BodyCount;
379     Visit(S->getBody());
380     uint64_t BackedgeCount = CurrentCount;
381     BreakContinue BC = BreakContinueStack.pop_back_val();
382 
383     // The increment is essentially part of the body but it needs to include
384     // the count for all the continue statements.
385     if (S->getInc()) {
386       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
387       CountMap[S->getInc()] = IncCount;
388       Visit(S->getInc());
389     }
390 
391     // ...then go back and propagate counts through the condition.
392     uint64_t CondCount =
393         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
394     if (S->getCond()) {
395       CountMap[S->getCond()] = CondCount;
396       Visit(S->getCond());
397     }
398     setCount(BC.BreakCount + CondCount - BodyCount);
399     RecordNextStmtCount = true;
400   }
401 
402   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
403     RecordStmtCount(S);
404     Visit(S->getLoopVarStmt());
405     Visit(S->getRangeStmt());
406     Visit(S->getBeginEndStmt());
407 
408     uint64_t ParentCount = CurrentCount;
409     BreakContinueStack.push_back(BreakContinue());
410     // Visit the body region first. (This is basically the same as a while
411     // loop; see further comments in VisitWhileStmt.)
412     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
413     CountMap[S->getBody()] = BodyCount;
414     Visit(S->getBody());
415     uint64_t BackedgeCount = CurrentCount;
416     BreakContinue BC = BreakContinueStack.pop_back_val();
417 
418     // The increment is essentially part of the body but it needs to include
419     // the count for all the continue statements.
420     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
421     CountMap[S->getInc()] = IncCount;
422     Visit(S->getInc());
423 
424     // ...then go back and propagate counts through the condition.
425     uint64_t CondCount =
426         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
427     CountMap[S->getCond()] = CondCount;
428     Visit(S->getCond());
429     setCount(BC.BreakCount + CondCount - BodyCount);
430     RecordNextStmtCount = true;
431   }
432 
433   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
434     RecordStmtCount(S);
435     Visit(S->getElement());
436     uint64_t ParentCount = CurrentCount;
437     BreakContinueStack.push_back(BreakContinue());
438     // Counter tracks the body of the loop.
439     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
440     CountMap[S->getBody()] = BodyCount;
441     Visit(S->getBody());
442     uint64_t BackedgeCount = CurrentCount;
443     BreakContinue BC = BreakContinueStack.pop_back_val();
444 
445     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
446              BodyCount);
447     RecordNextStmtCount = true;
448   }
449 
450   void VisitSwitchStmt(const SwitchStmt *S) {
451     RecordStmtCount(S);
452     Visit(S->getCond());
453     CurrentCount = 0;
454     BreakContinueStack.push_back(BreakContinue());
455     Visit(S->getBody());
456     // If the switch is inside a loop, add the continue counts.
457     BreakContinue BC = BreakContinueStack.pop_back_val();
458     if (!BreakContinueStack.empty())
459       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
460     // Counter tracks the exit block of the switch.
461     setCount(PGO.getRegionCount(S));
462     RecordNextStmtCount = true;
463   }
464 
465   void VisitSwitchCase(const SwitchCase *S) {
466     RecordNextStmtCount = false;
467     // Counter for this particular case. This counts only jumps from the
468     // switch header and does not include fallthrough from the case before
469     // this one.
470     uint64_t CaseCount = PGO.getRegionCount(S);
471     setCount(CurrentCount + CaseCount);
472     // We need the count without fallthrough in the mapping, so it's more useful
473     // for branch probabilities.
474     CountMap[S] = CaseCount;
475     RecordNextStmtCount = true;
476     Visit(S->getSubStmt());
477   }
478 
479   void VisitIfStmt(const IfStmt *S) {
480     RecordStmtCount(S);
481     uint64_t ParentCount = CurrentCount;
482     Visit(S->getCond());
483 
484     // Counter tracks the "then" part of an if statement. The count for
485     // the "else" part, if it exists, will be calculated from this counter.
486     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
487     CountMap[S->getThen()] = ThenCount;
488     Visit(S->getThen());
489     uint64_t OutCount = CurrentCount;
490 
491     uint64_t ElseCount = ParentCount - ThenCount;
492     if (S->getElse()) {
493       setCount(ElseCount);
494       CountMap[S->getElse()] = ElseCount;
495       Visit(S->getElse());
496       OutCount += CurrentCount;
497     } else
498       OutCount += ElseCount;
499     setCount(OutCount);
500     RecordNextStmtCount = true;
501   }
502 
503   void VisitCXXTryStmt(const CXXTryStmt *S) {
504     RecordStmtCount(S);
505     Visit(S->getTryBlock());
506     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
507       Visit(S->getHandler(I));
508     // Counter tracks the continuation block of the try statement.
509     setCount(PGO.getRegionCount(S));
510     RecordNextStmtCount = true;
511   }
512 
513   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
514     RecordNextStmtCount = false;
515     // Counter tracks the catch statement's handler block.
516     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
517     CountMap[S] = CatchCount;
518     Visit(S->getHandlerBlock());
519   }
520 
521   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
522     RecordStmtCount(E);
523     uint64_t ParentCount = CurrentCount;
524     Visit(E->getCond());
525 
526     // Counter tracks the "true" part of a conditional operator. The
527     // count in the "false" part will be calculated from this counter.
528     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
529     CountMap[E->getTrueExpr()] = TrueCount;
530     Visit(E->getTrueExpr());
531     uint64_t OutCount = CurrentCount;
532 
533     uint64_t FalseCount = setCount(ParentCount - TrueCount);
534     CountMap[E->getFalseExpr()] = FalseCount;
535     Visit(E->getFalseExpr());
536     OutCount += CurrentCount;
537 
538     setCount(OutCount);
539     RecordNextStmtCount = true;
540   }
541 
542   void VisitBinLAnd(const BinaryOperator *E) {
543     RecordStmtCount(E);
544     uint64_t ParentCount = CurrentCount;
545     Visit(E->getLHS());
546     // Counter tracks the right hand side of a logical and operator.
547     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
548     CountMap[E->getRHS()] = RHSCount;
549     Visit(E->getRHS());
550     setCount(ParentCount + RHSCount - CurrentCount);
551     RecordNextStmtCount = true;
552   }
553 
554   void VisitBinLOr(const BinaryOperator *E) {
555     RecordStmtCount(E);
556     uint64_t ParentCount = CurrentCount;
557     Visit(E->getLHS());
558     // Counter tracks the right hand side of a logical or operator.
559     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
560     CountMap[E->getRHS()] = RHSCount;
561     Visit(E->getRHS());
562     setCount(ParentCount + RHSCount - CurrentCount);
563     RecordNextStmtCount = true;
564   }
565 };
566 } // end anonymous namespace
567 
568 void PGOHash::combine(HashType Type) {
569   // Check that we never combine 0 and only have six bits.
570   assert(Type && "Hash is invalid: unexpected type 0");
571   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
572 
573   // Pass through MD5 if enough work has built up.
574   if (Count && Count % NumTypesPerWord == 0) {
575     using namespace llvm::support;
576     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
577     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
578     Working = 0;
579   }
580 
581   // Accumulate the current type.
582   ++Count;
583   Working = Working << NumBitsPerType | Type;
584 }
585 
586 uint64_t PGOHash::finalize() {
587   // Use Working as the hash directly if we never used MD5.
588   if (Count <= NumTypesPerWord)
589     // No need to byte swap here, since none of the math was endian-dependent.
590     // This number will be byte-swapped as required on endianness transitions,
591     // so we will see the same value on the other side.
592     return Working;
593 
594   // Check for remaining work in Working.
595   if (Working)
596     MD5.update(Working);
597 
598   // Finalize the MD5 and return the hash.
599   llvm::MD5::MD5Result Result;
600   MD5.final(Result);
601   using namespace llvm::support;
602   return endian::read<uint64_t, little, unaligned>(Result);
603 }
604 
605 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
606   // Make sure we only emit coverage mapping for one constructor/destructor.
607   // Clang emits several functions for the constructor and the destructor of
608   // a class. Every function is instrumented, but we only want to provide
609   // coverage for one of them. Because of that we only emit the coverage mapping
610   // for the base constructor/destructor.
611   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
612        GD.getCtorType() != Ctor_Base) ||
613       (isa<CXXDestructorDecl>(GD.getDecl()) &&
614        GD.getDtorType() != Dtor_Base)) {
615     SkipCoverageMapping = true;
616   }
617 }
618 
619 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
620   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
621   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
622   if (!InstrumentRegions && !PGOReader)
623     return;
624   if (D->isImplicit())
625     return;
626   CGM.ClearUnusedCoverageMapping(D);
627   setFuncName(Fn);
628 
629   mapRegionCounters(D);
630   if (CGM.getCodeGenOpts().CoverageMapping)
631     emitCounterRegionMapping(D);
632   if (PGOReader) {
633     SourceManager &SM = CGM.getContext().getSourceManager();
634     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
635     computeRegionCounts(D);
636     applyFunctionAttributes(PGOReader, Fn);
637   }
638 }
639 
640 void CodeGenPGO::mapRegionCounters(const Decl *D) {
641   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
642   MapRegionCounters Walker(*RegionCounterMap);
643   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
644     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
645   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
646     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
647   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
648     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
649   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
650     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
651   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
652   NumRegionCounters = Walker.NextCounter;
653   FunctionHash = Walker.Hash.finalize();
654 }
655 
656 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
657   if (SkipCoverageMapping)
658     return;
659   // Don't map the functions inside the system headers
660   auto Loc = D->getBody()->getLocStart();
661   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
662     return;
663 
664   std::string CoverageMapping;
665   llvm::raw_string_ostream OS(CoverageMapping);
666   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
667                                 CGM.getContext().getSourceManager(),
668                                 CGM.getLangOpts(), RegionCounterMap.get());
669   MappingGen.emitCounterMapping(D, OS);
670   OS.flush();
671 
672   if (CoverageMapping.empty())
673     return;
674 
675   CGM.getCoverageMapping()->addFunctionMappingRecord(
676       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
677 }
678 
679 void
680 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
681                                     llvm::GlobalValue::LinkageTypes Linkage) {
682   if (SkipCoverageMapping)
683     return;
684   // Don't map the functions inside the system headers
685   auto Loc = D->getBody()->getLocStart();
686   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
687     return;
688 
689   std::string CoverageMapping;
690   llvm::raw_string_ostream OS(CoverageMapping);
691   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
692                                 CGM.getContext().getSourceManager(),
693                                 CGM.getLangOpts());
694   MappingGen.emitEmptyMapping(D, OS);
695   OS.flush();
696 
697   if (CoverageMapping.empty())
698     return;
699 
700   setFuncName(Name, Linkage);
701   CGM.getCoverageMapping()->addFunctionMappingRecord(
702       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
703 }
704 
705 void CodeGenPGO::computeRegionCounts(const Decl *D) {
706   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
707   ComputeRegionCounts Walker(*StmtCountMap, *this);
708   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
709     Walker.VisitFunctionDecl(FD);
710   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
711     Walker.VisitObjCMethodDecl(MD);
712   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
713     Walker.VisitBlockDecl(BD);
714   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
715     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
716 }
717 
718 void
719 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
720                                     llvm::Function *Fn) {
721   if (!haveRegionCounts())
722     return;
723 
724   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
725   uint64_t FunctionCount = getRegionCount(nullptr);
726   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
727     // Turn on InlineHint attribute for hot functions.
728     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
729     Fn->addFnAttr(llvm::Attribute::InlineHint);
730   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
731     // Turn on Cold attribute for cold functions.
732     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
733     Fn->addFnAttr(llvm::Attribute::Cold);
734 
735   Fn->setEntryCount(FunctionCount);
736 }
737 
738 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
739   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
740     return;
741   if (!Builder.GetInsertBlock())
742     return;
743 
744   unsigned Counter = (*RegionCounterMap)[S];
745   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
746   Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
747                      {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
748                       Builder.getInt64(FunctionHash),
749                       Builder.getInt32(NumRegionCounters),
750                       Builder.getInt32(Counter)});
751 }
752 
753 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
754                                   bool IsInMainFile) {
755   CGM.getPGOStats().addVisited(IsInMainFile);
756   RegionCounts.clear();
757   if (std::error_code EC =
758           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
759     if (EC == llvm::instrprof_error::unknown_function)
760       CGM.getPGOStats().addMissing(IsInMainFile);
761     else if (EC == llvm::instrprof_error::hash_mismatch)
762       CGM.getPGOStats().addMismatched(IsInMainFile);
763     else if (EC == llvm::instrprof_error::malformed)
764       // TODO: Consider a more specific warning for this case.
765       CGM.getPGOStats().addMismatched(IsInMainFile);
766     RegionCounts.clear();
767   }
768 }
769 
770 /// \brief Calculate what to divide by to scale weights.
771 ///
772 /// Given the maximum weight, calculate a divisor that will scale all the
773 /// weights to strictly less than UINT32_MAX.
774 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
775   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
776 }
777 
778 /// \brief Scale an individual branch weight (and add 1).
779 ///
780 /// Scale a 64-bit weight down to 32-bits using \c Scale.
781 ///
782 /// According to Laplace's Rule of Succession, it is better to compute the
783 /// weight based on the count plus 1, so universally add 1 to the value.
784 ///
785 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
786 /// greater than \c Weight.
787 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
788   assert(Scale && "scale by 0?");
789   uint64_t Scaled = Weight / Scale + 1;
790   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
791   return Scaled;
792 }
793 
794 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
795                                                     uint64_t FalseCount) {
796   // Check for empty weights.
797   if (!TrueCount && !FalseCount)
798     return nullptr;
799 
800   // Calculate how to scale down to 32-bits.
801   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
802 
803   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
804   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
805                                       scaleBranchWeight(FalseCount, Scale));
806 }
807 
808 llvm::MDNode *
809 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
810   // We need at least two elements to create meaningful weights.
811   if (Weights.size() < 2)
812     return nullptr;
813 
814   // Check for empty weights.
815   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
816   if (MaxWeight == 0)
817     return nullptr;
818 
819   // Calculate how to scale down to 32-bits.
820   uint64_t Scale = calculateWeightScale(MaxWeight);
821 
822   SmallVector<uint32_t, 16> ScaledWeights;
823   ScaledWeights.reserve(Weights.size());
824   for (uint64_t W : Weights)
825     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
826 
827   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
828   return MDHelper.createBranchWeights(ScaledWeights);
829 }
830 
831 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
832                                                            uint64_t LoopCount) {
833   if (!PGO.haveRegionCounts())
834     return nullptr;
835   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
836   assert(CondCount.hasValue() && "missing expected loop condition count");
837   if (*CondCount == 0)
838     return nullptr;
839   return createProfileWeights(LoopCount,
840                               std::max(*CondCount, LoopCount) - LoopCount);
841 }
842