1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25 
26 using namespace clang;
27 using namespace CodeGen;
28 
29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   StringRef RawFuncName = Name;
32 
33   // Function names may be prefixed with a binary '1' to indicate
34   // that the backend should not modify the symbols due to any platform
35   // naming convention. Do not include that '1' in the PGO profile name.
36   if (RawFuncName[0] == '\1')
37     RawFuncName = RawFuncName.substr(1);
38 
39   FuncName = RawFuncName;
40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41     // For local symbols, prepend the main file name to distinguish them.
42     // Do not include the full path in the file name since there's no guarantee
43     // that it will stay the same, e.g., if the files are checked out from
44     // version control in different locations.
45     if (CGM.getCodeGenOpts().MainFileName.empty())
46       FuncName = FuncName.insert(0, "<unknown>:");
47     else
48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49   }
50 
51   // If we're generating a profile, create a variable for the name.
52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53     createFuncNameVar(Linkage);
54 }
55 
56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57   setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59 
60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61   // Usually, we want to match the function's linkage, but
62   // available_externally and extern_weak both have the wrong semantics.
63   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
64     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
65   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
66     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
67 
68   auto *Value =
69       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
70   FuncNameVar =
71       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
72                                Value, "__llvm_profile_name_" + FuncName);
73 
74   // Hide the symbol so that we correctly get a copy for each executable.
75   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
76     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
77 }
78 
79 namespace {
80 /// \brief Stable hasher for PGO region counters.
81 ///
82 /// PGOHash produces a stable hash of a given function's control flow.
83 ///
84 /// Changing the output of this hash will invalidate all previously generated
85 /// profiles -- i.e., don't do it.
86 ///
87 /// \note  When this hash does eventually change (years?), we still need to
88 /// support old hashes.  We'll need to pull in the version number from the
89 /// profile data format and use the matching hash function.
90 class PGOHash {
91   uint64_t Working;
92   unsigned Count;
93   llvm::MD5 MD5;
94 
95   static const int NumBitsPerType = 6;
96   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
97   static const unsigned TooBig = 1u << NumBitsPerType;
98 
99 public:
100   /// \brief Hash values for AST nodes.
101   ///
102   /// Distinct values for AST nodes that have region counters attached.
103   ///
104   /// These values must be stable.  All new members must be added at the end,
105   /// and no members should be removed.  Changing the enumeration value for an
106   /// AST node will affect the hash of every function that contains that node.
107   enum HashType : unsigned char {
108     None = 0,
109     LabelStmt = 1,
110     WhileStmt,
111     DoStmt,
112     ForStmt,
113     CXXForRangeStmt,
114     ObjCForCollectionStmt,
115     SwitchStmt,
116     CaseStmt,
117     DefaultStmt,
118     IfStmt,
119     CXXTryStmt,
120     CXXCatchStmt,
121     ConditionalOperator,
122     BinaryOperatorLAnd,
123     BinaryOperatorLOr,
124     BinaryConditionalOperator,
125 
126     // Keep this last.  It's for the static assert that follows.
127     LastHashType
128   };
129   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130 
131   // TODO: When this format changes, take in a version number here, and use the
132   // old hash calculation for file formats that used the old hash.
133   PGOHash() : Working(0), Count(0) {}
134   void combine(HashType Type);
135   uint64_t finalize();
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140 
141 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
142 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
143   /// The next counter value to assign.
144   unsigned NextCounter;
145   /// The function hash.
146   PGOHash Hash;
147   /// The map of statements to counters.
148   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
149 
150   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
151       : NextCounter(0), CounterMap(CounterMap) {}
152 
153   // Blocks and lambdas are handled as separate functions, so we need not
154   // traverse them in the parent context.
155   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
156   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
157   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
158 
159   bool VisitDecl(const Decl *D) {
160     switch (D->getKind()) {
161     default:
162       break;
163     case Decl::Function:
164     case Decl::CXXMethod:
165     case Decl::CXXConstructor:
166     case Decl::CXXDestructor:
167     case Decl::CXXConversion:
168     case Decl::ObjCMethod:
169     case Decl::Block:
170     case Decl::Captured:
171       CounterMap[D->getBody()] = NextCounter++;
172       break;
173     }
174     return true;
175   }
176 
177   bool VisitStmt(const Stmt *S) {
178     auto Type = getHashType(S);
179     if (Type == PGOHash::None)
180       return true;
181 
182     CounterMap[S] = NextCounter++;
183     Hash.combine(Type);
184     return true;
185   }
186   PGOHash::HashType getHashType(const Stmt *S) {
187     switch (S->getStmtClass()) {
188     default:
189       break;
190     case Stmt::LabelStmtClass:
191       return PGOHash::LabelStmt;
192     case Stmt::WhileStmtClass:
193       return PGOHash::WhileStmt;
194     case Stmt::DoStmtClass:
195       return PGOHash::DoStmt;
196     case Stmt::ForStmtClass:
197       return PGOHash::ForStmt;
198     case Stmt::CXXForRangeStmtClass:
199       return PGOHash::CXXForRangeStmt;
200     case Stmt::ObjCForCollectionStmtClass:
201       return PGOHash::ObjCForCollectionStmt;
202     case Stmt::SwitchStmtClass:
203       return PGOHash::SwitchStmt;
204     case Stmt::CaseStmtClass:
205       return PGOHash::CaseStmt;
206     case Stmt::DefaultStmtClass:
207       return PGOHash::DefaultStmt;
208     case Stmt::IfStmtClass:
209       return PGOHash::IfStmt;
210     case Stmt::CXXTryStmtClass:
211       return PGOHash::CXXTryStmt;
212     case Stmt::CXXCatchStmtClass:
213       return PGOHash::CXXCatchStmt;
214     case Stmt::ConditionalOperatorClass:
215       return PGOHash::ConditionalOperator;
216     case Stmt::BinaryConditionalOperatorClass:
217       return PGOHash::BinaryConditionalOperator;
218     case Stmt::BinaryOperatorClass: {
219       const BinaryOperator *BO = cast<BinaryOperator>(S);
220       if (BO->getOpcode() == BO_LAnd)
221         return PGOHash::BinaryOperatorLAnd;
222       if (BO->getOpcode() == BO_LOr)
223         return PGOHash::BinaryOperatorLOr;
224       break;
225     }
226     }
227     return PGOHash::None;
228   }
229 };
230 
231 /// A StmtVisitor that propagates the raw counts through the AST and
232 /// records the count at statements where the value may change.
233 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
234   /// PGO state.
235   CodeGenPGO &PGO;
236 
237   /// A flag that is set when the current count should be recorded on the
238   /// next statement, such as at the exit of a loop.
239   bool RecordNextStmtCount;
240 
241   /// The map of statements to count values.
242   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
243 
244   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
245   struct BreakContinue {
246     uint64_t BreakCount;
247     uint64_t ContinueCount;
248     BreakContinue() : BreakCount(0), ContinueCount(0) {}
249   };
250   SmallVector<BreakContinue, 8> BreakContinueStack;
251 
252   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
253                       CodeGenPGO &PGO)
254       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
255 
256   void RecordStmtCount(const Stmt *S) {
257     if (RecordNextStmtCount) {
258       CountMap[S] = PGO.getCurrentRegionCount();
259       RecordNextStmtCount = false;
260     }
261   }
262 
263   void VisitStmt(const Stmt *S) {
264     RecordStmtCount(S);
265     for (Stmt::const_child_range I = S->children(); I; ++I) {
266       if (*I)
267         this->Visit(*I);
268     }
269   }
270 
271   void VisitFunctionDecl(const FunctionDecl *D) {
272     // Counter tracks entry to the function body.
273     RegionCounter Cnt(PGO, D->getBody());
274     Cnt.beginRegion();
275     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
276     Visit(D->getBody());
277   }
278 
279   // Skip lambda expressions. We visit these as FunctionDecls when we're
280   // generating them and aren't interested in the body when generating a
281   // parent context.
282   void VisitLambdaExpr(const LambdaExpr *LE) {}
283 
284   void VisitCapturedDecl(const CapturedDecl *D) {
285     // Counter tracks entry to the capture body.
286     RegionCounter Cnt(PGO, D->getBody());
287     Cnt.beginRegion();
288     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
289     Visit(D->getBody());
290   }
291 
292   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
293     // Counter tracks entry to the method body.
294     RegionCounter Cnt(PGO, D->getBody());
295     Cnt.beginRegion();
296     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
297     Visit(D->getBody());
298   }
299 
300   void VisitBlockDecl(const BlockDecl *D) {
301     // Counter tracks entry to the block body.
302     RegionCounter Cnt(PGO, D->getBody());
303     Cnt.beginRegion();
304     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
305     Visit(D->getBody());
306   }
307 
308   void VisitReturnStmt(const ReturnStmt *S) {
309     RecordStmtCount(S);
310     if (S->getRetValue())
311       Visit(S->getRetValue());
312     PGO.setCurrentRegionUnreachable();
313     RecordNextStmtCount = true;
314   }
315 
316   void VisitGotoStmt(const GotoStmt *S) {
317     RecordStmtCount(S);
318     PGO.setCurrentRegionUnreachable();
319     RecordNextStmtCount = true;
320   }
321 
322   void VisitLabelStmt(const LabelStmt *S) {
323     RecordNextStmtCount = false;
324     // Counter tracks the block following the label.
325     RegionCounter Cnt(PGO, S);
326     Cnt.beginRegion();
327     CountMap[S] = PGO.getCurrentRegionCount();
328     Visit(S->getSubStmt());
329   }
330 
331   void VisitBreakStmt(const BreakStmt *S) {
332     RecordStmtCount(S);
333     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
334     BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
335     PGO.setCurrentRegionUnreachable();
336     RecordNextStmtCount = true;
337   }
338 
339   void VisitContinueStmt(const ContinueStmt *S) {
340     RecordStmtCount(S);
341     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
342     BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
343     PGO.setCurrentRegionUnreachable();
344     RecordNextStmtCount = true;
345   }
346 
347   void VisitWhileStmt(const WhileStmt *S) {
348     RecordStmtCount(S);
349     // Counter tracks the body of the loop.
350     RegionCounter Cnt(PGO, S);
351     BreakContinueStack.push_back(BreakContinue());
352     // Visit the body region first so the break/continue adjustments can be
353     // included when visiting the condition.
354     Cnt.beginRegion();
355     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
356     Visit(S->getBody());
357     Cnt.adjustForControlFlow();
358 
359     // ...then go back and propagate counts through the condition. The count
360     // at the start of the condition is the sum of the incoming edges,
361     // the backedge from the end of the loop body, and the edges from
362     // continue statements.
363     BreakContinue BC = BreakContinueStack.pop_back_val();
364     Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
365                               BC.ContinueCount);
366     CountMap[S->getCond()] = PGO.getCurrentRegionCount();
367     Visit(S->getCond());
368     Cnt.adjustForControlFlow();
369     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
370     RecordNextStmtCount = true;
371   }
372 
373   void VisitDoStmt(const DoStmt *S) {
374     RecordStmtCount(S);
375     // Counter tracks the body of the loop.
376     RegionCounter Cnt(PGO, S);
377     BreakContinueStack.push_back(BreakContinue());
378     Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
379     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
380     Visit(S->getBody());
381     Cnt.adjustForControlFlow();
382 
383     BreakContinue BC = BreakContinueStack.pop_back_val();
384     // The count at the start of the condition is equal to the count at the
385     // end of the body. The adjusted count does not include either the
386     // fall-through count coming into the loop or the continue count, so add
387     // both of those separately. This is coincidentally the same equation as
388     // with while loops but for different reasons.
389     Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
390                               BC.ContinueCount);
391     CountMap[S->getCond()] = PGO.getCurrentRegionCount();
392     Visit(S->getCond());
393     Cnt.adjustForControlFlow();
394     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
395     RecordNextStmtCount = true;
396   }
397 
398   void VisitForStmt(const ForStmt *S) {
399     RecordStmtCount(S);
400     if (S->getInit())
401       Visit(S->getInit());
402     // Counter tracks the body of the loop.
403     RegionCounter Cnt(PGO, S);
404     BreakContinueStack.push_back(BreakContinue());
405     // Visit the body region first. (This is basically the same as a while
406     // loop; see further comments in VisitWhileStmt.)
407     Cnt.beginRegion();
408     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
409     Visit(S->getBody());
410     Cnt.adjustForControlFlow();
411 
412     // The increment is essentially part of the body but it needs to include
413     // the count for all the continue statements.
414     if (S->getInc()) {
415       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
416                                 BreakContinueStack.back().ContinueCount);
417       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
418       Visit(S->getInc());
419       Cnt.adjustForControlFlow();
420     }
421 
422     BreakContinue BC = BreakContinueStack.pop_back_val();
423 
424     // ...then go back and propagate counts through the condition.
425     if (S->getCond()) {
426       Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
427                                 BC.ContinueCount);
428       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
429       Visit(S->getCond());
430       Cnt.adjustForControlFlow();
431     }
432     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
433     RecordNextStmtCount = true;
434   }
435 
436   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
437     RecordStmtCount(S);
438     Visit(S->getRangeStmt());
439     Visit(S->getBeginEndStmt());
440     // Counter tracks the body of the loop.
441     RegionCounter Cnt(PGO, S);
442     BreakContinueStack.push_back(BreakContinue());
443     // Visit the body region first. (This is basically the same as a while
444     // loop; see further comments in VisitWhileStmt.)
445     Cnt.beginRegion();
446     CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
447     Visit(S->getLoopVarStmt());
448     Visit(S->getBody());
449     Cnt.adjustForControlFlow();
450 
451     // The increment is essentially part of the body but it needs to include
452     // the count for all the continue statements.
453     Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
454                               BreakContinueStack.back().ContinueCount);
455     CountMap[S->getInc()] = PGO.getCurrentRegionCount();
456     Visit(S->getInc());
457     Cnt.adjustForControlFlow();
458 
459     BreakContinue BC = BreakContinueStack.pop_back_val();
460 
461     // ...then go back and propagate counts through the condition.
462     Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
463                               BC.ContinueCount);
464     CountMap[S->getCond()] = PGO.getCurrentRegionCount();
465     Visit(S->getCond());
466     Cnt.adjustForControlFlow();
467     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
468     RecordNextStmtCount = true;
469   }
470 
471   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
472     RecordStmtCount(S);
473     Visit(S->getElement());
474     // Counter tracks the body of the loop.
475     RegionCounter Cnt(PGO, S);
476     BreakContinueStack.push_back(BreakContinue());
477     Cnt.beginRegion();
478     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
479     Visit(S->getBody());
480     BreakContinue BC = BreakContinueStack.pop_back_val();
481     Cnt.adjustForControlFlow();
482     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
483     RecordNextStmtCount = true;
484   }
485 
486   void VisitSwitchStmt(const SwitchStmt *S) {
487     RecordStmtCount(S);
488     Visit(S->getCond());
489     PGO.setCurrentRegionUnreachable();
490     BreakContinueStack.push_back(BreakContinue());
491     Visit(S->getBody());
492     // If the switch is inside a loop, add the continue counts.
493     BreakContinue BC = BreakContinueStack.pop_back_val();
494     if (!BreakContinueStack.empty())
495       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
496     // Counter tracks the exit block of the switch.
497     RegionCounter ExitCnt(PGO, S);
498     ExitCnt.beginRegion();
499     RecordNextStmtCount = true;
500   }
501 
502   void VisitCaseStmt(const CaseStmt *S) {
503     RecordNextStmtCount = false;
504     // Counter for this particular case. This counts only jumps from the
505     // switch header and does not include fallthrough from the case before
506     // this one.
507     RegionCounter Cnt(PGO, S);
508     Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
509     CountMap[S] = Cnt.getCount();
510     RecordNextStmtCount = true;
511     Visit(S->getSubStmt());
512   }
513 
514   void VisitDefaultStmt(const DefaultStmt *S) {
515     RecordNextStmtCount = false;
516     // Counter for this default case. This does not include fallthrough from
517     // the previous case.
518     RegionCounter Cnt(PGO, S);
519     Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
520     CountMap[S] = Cnt.getCount();
521     RecordNextStmtCount = true;
522     Visit(S->getSubStmt());
523   }
524 
525   void VisitIfStmt(const IfStmt *S) {
526     RecordStmtCount(S);
527     // Counter tracks the "then" part of an if statement. The count for
528     // the "else" part, if it exists, will be calculated from this counter.
529     RegionCounter Cnt(PGO, S);
530     Visit(S->getCond());
531 
532     Cnt.beginRegion();
533     CountMap[S->getThen()] = PGO.getCurrentRegionCount();
534     Visit(S->getThen());
535     Cnt.adjustForControlFlow();
536 
537     if (S->getElse()) {
538       Cnt.beginElseRegion();
539       CountMap[S->getElse()] = PGO.getCurrentRegionCount();
540       Visit(S->getElse());
541       Cnt.adjustForControlFlow();
542     }
543     Cnt.applyAdjustmentsToRegion(0);
544     RecordNextStmtCount = true;
545   }
546 
547   void VisitCXXTryStmt(const CXXTryStmt *S) {
548     RecordStmtCount(S);
549     Visit(S->getTryBlock());
550     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
551       Visit(S->getHandler(I));
552     // Counter tracks the continuation block of the try statement.
553     RegionCounter Cnt(PGO, S);
554     Cnt.beginRegion();
555     RecordNextStmtCount = true;
556   }
557 
558   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
559     RecordNextStmtCount = false;
560     // Counter tracks the catch statement's handler block.
561     RegionCounter Cnt(PGO, S);
562     Cnt.beginRegion();
563     CountMap[S] = PGO.getCurrentRegionCount();
564     Visit(S->getHandlerBlock());
565   }
566 
567   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
568     RecordStmtCount(E);
569     // Counter tracks the "true" part of a conditional operator. The
570     // count in the "false" part will be calculated from this counter.
571     RegionCounter Cnt(PGO, E);
572     Visit(E->getCond());
573 
574     Cnt.beginRegion();
575     CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
576     Visit(E->getTrueExpr());
577     Cnt.adjustForControlFlow();
578 
579     Cnt.beginElseRegion();
580     CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
581     Visit(E->getFalseExpr());
582     Cnt.adjustForControlFlow();
583 
584     Cnt.applyAdjustmentsToRegion(0);
585     RecordNextStmtCount = true;
586   }
587 
588   void VisitBinLAnd(const BinaryOperator *E) {
589     RecordStmtCount(E);
590     // Counter tracks the right hand side of a logical and operator.
591     RegionCounter Cnt(PGO, E);
592     Visit(E->getLHS());
593     Cnt.beginRegion();
594     CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
595     Visit(E->getRHS());
596     Cnt.adjustForControlFlow();
597     Cnt.applyAdjustmentsToRegion(0);
598     RecordNextStmtCount = true;
599   }
600 
601   void VisitBinLOr(const BinaryOperator *E) {
602     RecordStmtCount(E);
603     // Counter tracks the right hand side of a logical or operator.
604     RegionCounter Cnt(PGO, E);
605     Visit(E->getLHS());
606     Cnt.beginRegion();
607     CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
608     Visit(E->getRHS());
609     Cnt.adjustForControlFlow();
610     Cnt.applyAdjustmentsToRegion(0);
611     RecordNextStmtCount = true;
612   }
613 };
614 }
615 
616 void PGOHash::combine(HashType Type) {
617   // Check that we never combine 0 and only have six bits.
618   assert(Type && "Hash is invalid: unexpected type 0");
619   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
620 
621   // Pass through MD5 if enough work has built up.
622   if (Count && Count % NumTypesPerWord == 0) {
623     using namespace llvm::support;
624     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
625     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
626     Working = 0;
627   }
628 
629   // Accumulate the current type.
630   ++Count;
631   Working = Working << NumBitsPerType | Type;
632 }
633 
634 uint64_t PGOHash::finalize() {
635   // Use Working as the hash directly if we never used MD5.
636   if (Count <= NumTypesPerWord)
637     // No need to byte swap here, since none of the math was endian-dependent.
638     // This number will be byte-swapped as required on endianness transitions,
639     // so we will see the same value on the other side.
640     return Working;
641 
642   // Check for remaining work in Working.
643   if (Working)
644     MD5.update(Working);
645 
646   // Finalize the MD5 and return the hash.
647   llvm::MD5::MD5Result Result;
648   MD5.final(Result);
649   using namespace llvm::support;
650   return endian::read<uint64_t, little, unaligned>(Result);
651 }
652 
653 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
654   // Make sure we only emit coverage mapping for one constructor/destructor.
655   // Clang emits several functions for the constructor and the destructor of
656   // a class. Every function is instrumented, but we only want to provide
657   // coverage for one of them. Because of that we only emit the coverage mapping
658   // for the base constructor/destructor.
659   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
660        GD.getCtorType() != Ctor_Base) ||
661       (isa<CXXDestructorDecl>(GD.getDecl()) &&
662        GD.getDtorType() != Dtor_Base)) {
663     SkipCoverageMapping = true;
664   }
665 }
666 
667 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
668   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
669   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
670   if (!InstrumentRegions && !PGOReader)
671     return;
672   if (D->isImplicit())
673     return;
674   CGM.ClearUnusedCoverageMapping(D);
675   setFuncName(Fn);
676 
677   mapRegionCounters(D);
678   if (CGM.getCodeGenOpts().CoverageMapping)
679     emitCounterRegionMapping(D);
680   if (PGOReader) {
681     SourceManager &SM = CGM.getContext().getSourceManager();
682     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
683     computeRegionCounts(D);
684     applyFunctionAttributes(PGOReader, Fn);
685   }
686 }
687 
688 void CodeGenPGO::mapRegionCounters(const Decl *D) {
689   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
690   MapRegionCounters Walker(*RegionCounterMap);
691   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
692     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
693   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
694     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
695   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
696     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
697   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
698     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
699   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
700   NumRegionCounters = Walker.NextCounter;
701   FunctionHash = Walker.Hash.finalize();
702 }
703 
704 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
705   if (SkipCoverageMapping)
706     return;
707   // Don't map the functions inside the system headers
708   auto Loc = D->getBody()->getLocStart();
709   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
710     return;
711 
712   std::string CoverageMapping;
713   llvm::raw_string_ostream OS(CoverageMapping);
714   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
715                                 CGM.getContext().getSourceManager(),
716                                 CGM.getLangOpts(), RegionCounterMap.get());
717   MappingGen.emitCounterMapping(D, OS);
718   OS.flush();
719 
720   if (CoverageMapping.empty())
721     return;
722 
723   CGM.getCoverageMapping()->addFunctionMappingRecord(
724       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
725 }
726 
727 void
728 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
729                                     llvm::GlobalValue::LinkageTypes Linkage) {
730   if (SkipCoverageMapping)
731     return;
732   // Don't map the functions inside the system headers
733   auto Loc = D->getBody()->getLocStart();
734   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
735     return;
736 
737   std::string CoverageMapping;
738   llvm::raw_string_ostream OS(CoverageMapping);
739   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
740                                 CGM.getContext().getSourceManager(),
741                                 CGM.getLangOpts());
742   MappingGen.emitEmptyMapping(D, OS);
743   OS.flush();
744 
745   if (CoverageMapping.empty())
746     return;
747 
748   setFuncName(FuncName, Linkage);
749   CGM.getCoverageMapping()->addFunctionMappingRecord(
750       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
751 }
752 
753 void CodeGenPGO::computeRegionCounts(const Decl *D) {
754   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
755   ComputeRegionCounts Walker(*StmtCountMap, *this);
756   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
757     Walker.VisitFunctionDecl(FD);
758   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
759     Walker.VisitObjCMethodDecl(MD);
760   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
761     Walker.VisitBlockDecl(BD);
762   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
763     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
764 }
765 
766 void
767 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
768                                     llvm::Function *Fn) {
769   if (!haveRegionCounts())
770     return;
771 
772   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
773   uint64_t FunctionCount = getRegionCount(0);
774   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
775     // Turn on InlineHint attribute for hot functions.
776     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
777     Fn->addFnAttr(llvm::Attribute::InlineHint);
778   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
779     // Turn on Cold attribute for cold functions.
780     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
781     Fn->addFnAttr(llvm::Attribute::Cold);
782 }
783 
784 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
785   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
786     return;
787   if (!Builder.GetInsertPoint())
788     return;
789   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
790   Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
791                       llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
792                       Builder.getInt64(FunctionHash),
793                       Builder.getInt32(NumRegionCounters),
794                       Builder.getInt32(Counter));
795 }
796 
797 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
798                                   bool IsInMainFile) {
799   CGM.getPGOStats().addVisited(IsInMainFile);
800   RegionCounts.clear();
801   if (std::error_code EC =
802           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
803     if (EC == llvm::instrprof_error::unknown_function)
804       CGM.getPGOStats().addMissing(IsInMainFile);
805     else if (EC == llvm::instrprof_error::hash_mismatch)
806       CGM.getPGOStats().addMismatched(IsInMainFile);
807     else if (EC == llvm::instrprof_error::malformed)
808       // TODO: Consider a more specific warning for this case.
809       CGM.getPGOStats().addMismatched(IsInMainFile);
810     RegionCounts.clear();
811   }
812 }
813 
814 /// \brief Calculate what to divide by to scale weights.
815 ///
816 /// Given the maximum weight, calculate a divisor that will scale all the
817 /// weights to strictly less than UINT32_MAX.
818 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
819   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
820 }
821 
822 /// \brief Scale an individual branch weight (and add 1).
823 ///
824 /// Scale a 64-bit weight down to 32-bits using \c Scale.
825 ///
826 /// According to Laplace's Rule of Succession, it is better to compute the
827 /// weight based on the count plus 1, so universally add 1 to the value.
828 ///
829 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
830 /// greater than \c Weight.
831 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
832   assert(Scale && "scale by 0?");
833   uint64_t Scaled = Weight / Scale + 1;
834   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
835   return Scaled;
836 }
837 
838 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
839                                               uint64_t FalseCount) {
840   // Check for empty weights.
841   if (!TrueCount && !FalseCount)
842     return nullptr;
843 
844   // Calculate how to scale down to 32-bits.
845   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
846 
847   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
848   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
849                                       scaleBranchWeight(FalseCount, Scale));
850 }
851 
852 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
853   // We need at least two elements to create meaningful weights.
854   if (Weights.size() < 2)
855     return nullptr;
856 
857   // Check for empty weights.
858   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
859   if (MaxWeight == 0)
860     return nullptr;
861 
862   // Calculate how to scale down to 32-bits.
863   uint64_t Scale = calculateWeightScale(MaxWeight);
864 
865   SmallVector<uint32_t, 16> ScaledWeights;
866   ScaledWeights.reserve(Weights.size());
867   for (uint64_t W : Weights)
868     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
869 
870   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
871   return MDHelper.createBranchWeights(ScaledWeights);
872 }
873 
874 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
875                                             RegionCounter &Cnt) {
876   if (!haveRegionCounts())
877     return nullptr;
878   uint64_t LoopCount = Cnt.getCount();
879   uint64_t CondCount = 0;
880   bool Found = getStmtCount(Cond, CondCount);
881   assert(Found && "missing expected loop condition count");
882   (void)Found;
883   if (CondCount == 0)
884     return nullptr;
885   return createBranchWeights(LoopCount,
886                              std::max(CondCount, LoopCount) - LoopCount);
887 }
888