1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25 
26 using namespace clang;
27 using namespace CodeGen;
28 
29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   StringRef RawFuncName = Name;
32 
33   // Function names may be prefixed with a binary '1' to indicate
34   // that the backend should not modify the symbols due to any platform
35   // naming convention. Do not include that '1' in the PGO profile name.
36   if (RawFuncName[0] == '\1')
37     RawFuncName = RawFuncName.substr(1);
38 
39   FuncName = RawFuncName;
40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
41     // For local symbols, prepend the main file name to distinguish them.
42     // Do not include the full path in the file name since there's no guarantee
43     // that it will stay the same, e.g., if the files are checked out from
44     // version control in different locations.
45     if (CGM.getCodeGenOpts().MainFileName.empty())
46       FuncName = FuncName.insert(0, "<unknown>:");
47     else
48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
49   }
50 
51   // If we're generating a profile, create a variable for the name.
52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
53     createFuncNameVar(Linkage);
54 }
55 
56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
57   setFuncName(Fn->getName(), Fn->getLinkage());
58 }
59 
60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
61   // Usually, we want to match the function's linkage, but
62   // available_externally and extern_weak both have the wrong semantics.
63   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
64     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
65   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
66     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
67 
68   auto *Value =
69       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
70   FuncNameVar =
71       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
72                                Value, "__llvm_profile_name_" + FuncName);
73 
74   // Hide the symbol so that we correctly get a copy for each executable.
75   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
76     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
77 }
78 
79 namespace {
80 /// \brief Stable hasher for PGO region counters.
81 ///
82 /// PGOHash produces a stable hash of a given function's control flow.
83 ///
84 /// Changing the output of this hash will invalidate all previously generated
85 /// profiles -- i.e., don't do it.
86 ///
87 /// \note  When this hash does eventually change (years?), we still need to
88 /// support old hashes.  We'll need to pull in the version number from the
89 /// profile data format and use the matching hash function.
90 class PGOHash {
91   uint64_t Working;
92   unsigned Count;
93   llvm::MD5 MD5;
94 
95   static const int NumBitsPerType = 6;
96   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
97   static const unsigned TooBig = 1u << NumBitsPerType;
98 
99 public:
100   /// \brief Hash values for AST nodes.
101   ///
102   /// Distinct values for AST nodes that have region counters attached.
103   ///
104   /// These values must be stable.  All new members must be added at the end,
105   /// and no members should be removed.  Changing the enumeration value for an
106   /// AST node will affect the hash of every function that contains that node.
107   enum HashType : unsigned char {
108     None = 0,
109     LabelStmt = 1,
110     WhileStmt,
111     DoStmt,
112     ForStmt,
113     CXXForRangeStmt,
114     ObjCForCollectionStmt,
115     SwitchStmt,
116     CaseStmt,
117     DefaultStmt,
118     IfStmt,
119     CXXTryStmt,
120     CXXCatchStmt,
121     ConditionalOperator,
122     BinaryOperatorLAnd,
123     BinaryOperatorLOr,
124     BinaryConditionalOperator,
125 
126     // Keep this last.  It's for the static assert that follows.
127     LastHashType
128   };
129   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130 
131   // TODO: When this format changes, take in a version number here, and use the
132   // old hash calculation for file formats that used the old hash.
133   PGOHash() : Working(0), Count(0) {}
134   void combine(HashType Type);
135   uint64_t finalize();
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140 
141   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
142   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
143     /// The next counter value to assign.
144     unsigned NextCounter;
145     /// The function hash.
146     PGOHash Hash;
147     /// The map of statements to counters.
148     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
149 
150     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
151         : NextCounter(0), CounterMap(CounterMap) {}
152 
153     // Blocks and lambdas are handled as separate functions, so we need not
154     // traverse them in the parent context.
155     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
156     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
157     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
158 
159     bool VisitDecl(const Decl *D) {
160       switch (D->getKind()) {
161       default:
162         break;
163       case Decl::Function:
164       case Decl::CXXMethod:
165       case Decl::CXXConstructor:
166       case Decl::CXXDestructor:
167       case Decl::CXXConversion:
168       case Decl::ObjCMethod:
169       case Decl::Block:
170       case Decl::Captured:
171         CounterMap[D->getBody()] = NextCounter++;
172         break;
173       }
174       return true;
175     }
176 
177     bool VisitStmt(const Stmt *S) {
178       auto Type = getHashType(S);
179       if (Type == PGOHash::None)
180         return true;
181 
182       CounterMap[S] = NextCounter++;
183       Hash.combine(Type);
184       return true;
185     }
186     PGOHash::HashType getHashType(const Stmt *S) {
187       switch (S->getStmtClass()) {
188       default:
189         break;
190       case Stmt::LabelStmtClass:
191         return PGOHash::LabelStmt;
192       case Stmt::WhileStmtClass:
193         return PGOHash::WhileStmt;
194       case Stmt::DoStmtClass:
195         return PGOHash::DoStmt;
196       case Stmt::ForStmtClass:
197         return PGOHash::ForStmt;
198       case Stmt::CXXForRangeStmtClass:
199         return PGOHash::CXXForRangeStmt;
200       case Stmt::ObjCForCollectionStmtClass:
201         return PGOHash::ObjCForCollectionStmt;
202       case Stmt::SwitchStmtClass:
203         return PGOHash::SwitchStmt;
204       case Stmt::CaseStmtClass:
205         return PGOHash::CaseStmt;
206       case Stmt::DefaultStmtClass:
207         return PGOHash::DefaultStmt;
208       case Stmt::IfStmtClass:
209         return PGOHash::IfStmt;
210       case Stmt::CXXTryStmtClass:
211         return PGOHash::CXXTryStmt;
212       case Stmt::CXXCatchStmtClass:
213         return PGOHash::CXXCatchStmt;
214       case Stmt::ConditionalOperatorClass:
215         return PGOHash::ConditionalOperator;
216       case Stmt::BinaryConditionalOperatorClass:
217         return PGOHash::BinaryConditionalOperator;
218       case Stmt::BinaryOperatorClass: {
219         const BinaryOperator *BO = cast<BinaryOperator>(S);
220         if (BO->getOpcode() == BO_LAnd)
221           return PGOHash::BinaryOperatorLAnd;
222         if (BO->getOpcode() == BO_LOr)
223           return PGOHash::BinaryOperatorLOr;
224         break;
225       }
226       }
227       return PGOHash::None;
228     }
229   };
230 
231   /// A StmtVisitor that propagates the raw counts through the AST and
232   /// records the count at statements where the value may change.
233   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
234     /// PGO state.
235     CodeGenPGO &PGO;
236 
237     /// A flag that is set when the current count should be recorded on the
238     /// next statement, such as at the exit of a loop.
239     bool RecordNextStmtCount;
240 
241     /// The map of statements to count values.
242     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
243 
244     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
245     struct BreakContinue {
246       uint64_t BreakCount;
247       uint64_t ContinueCount;
248       BreakContinue() : BreakCount(0), ContinueCount(0) {}
249     };
250     SmallVector<BreakContinue, 8> BreakContinueStack;
251 
252     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
253                         CodeGenPGO &PGO)
254         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
255 
256     void RecordStmtCount(const Stmt *S) {
257       if (RecordNextStmtCount) {
258         CountMap[S] = PGO.getCurrentRegionCount();
259         RecordNextStmtCount = false;
260       }
261     }
262 
263     void VisitStmt(const Stmt *S) {
264       RecordStmtCount(S);
265       for (Stmt::const_child_range I = S->children(); I; ++I) {
266         if (*I)
267          this->Visit(*I);
268       }
269     }
270 
271     void VisitFunctionDecl(const FunctionDecl *D) {
272       // Counter tracks entry to the function body.
273       RegionCounter Cnt(PGO, D->getBody());
274       Cnt.beginRegion();
275       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
276       Visit(D->getBody());
277     }
278 
279     // Skip lambda expressions. We visit these as FunctionDecls when we're
280     // generating them and aren't interested in the body when generating a
281     // parent context.
282     void VisitLambdaExpr(const LambdaExpr *LE) {}
283 
284     void VisitCapturedDecl(const CapturedDecl *D) {
285       // Counter tracks entry to the capture body.
286       RegionCounter Cnt(PGO, D->getBody());
287       Cnt.beginRegion();
288       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
289       Visit(D->getBody());
290     }
291 
292     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
293       // Counter tracks entry to the method body.
294       RegionCounter Cnt(PGO, D->getBody());
295       Cnt.beginRegion();
296       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
297       Visit(D->getBody());
298     }
299 
300     void VisitBlockDecl(const BlockDecl *D) {
301       // Counter tracks entry to the block body.
302       RegionCounter Cnt(PGO, D->getBody());
303       Cnt.beginRegion();
304       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
305       Visit(D->getBody());
306     }
307 
308     void VisitReturnStmt(const ReturnStmt *S) {
309       RecordStmtCount(S);
310       if (S->getRetValue())
311         Visit(S->getRetValue());
312       PGO.setCurrentRegionUnreachable();
313       RecordNextStmtCount = true;
314     }
315 
316     void VisitGotoStmt(const GotoStmt *S) {
317       RecordStmtCount(S);
318       PGO.setCurrentRegionUnreachable();
319       RecordNextStmtCount = true;
320     }
321 
322     void VisitLabelStmt(const LabelStmt *S) {
323       RecordNextStmtCount = false;
324       // Counter tracks the block following the label.
325       RegionCounter Cnt(PGO, S);
326       Cnt.beginRegion();
327       CountMap[S] = PGO.getCurrentRegionCount();
328       Visit(S->getSubStmt());
329     }
330 
331     void VisitBreakStmt(const BreakStmt *S) {
332       RecordStmtCount(S);
333       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
334       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
335       PGO.setCurrentRegionUnreachable();
336       RecordNextStmtCount = true;
337     }
338 
339     void VisitContinueStmt(const ContinueStmt *S) {
340       RecordStmtCount(S);
341       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
342       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
343       PGO.setCurrentRegionUnreachable();
344       RecordNextStmtCount = true;
345     }
346 
347     void VisitWhileStmt(const WhileStmt *S) {
348       RecordStmtCount(S);
349       // Counter tracks the body of the loop.
350       RegionCounter Cnt(PGO, S);
351       BreakContinueStack.push_back(BreakContinue());
352       // Visit the body region first so the break/continue adjustments can be
353       // included when visiting the condition.
354       Cnt.beginRegion();
355       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
356       Visit(S->getBody());
357       Cnt.adjustForControlFlow();
358 
359       // ...then go back and propagate counts through the condition. The count
360       // at the start of the condition is the sum of the incoming edges,
361       // the backedge from the end of the loop body, and the edges from
362       // continue statements.
363       BreakContinue BC = BreakContinueStack.pop_back_val();
364       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
365                                 Cnt.getAdjustedCount() + BC.ContinueCount);
366       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
367       Visit(S->getCond());
368       Cnt.adjustForControlFlow();
369       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
370       RecordNextStmtCount = true;
371     }
372 
373     void VisitDoStmt(const DoStmt *S) {
374       RecordStmtCount(S);
375       // Counter tracks the body of the loop.
376       RegionCounter Cnt(PGO, S);
377       BreakContinueStack.push_back(BreakContinue());
378       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
379       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
380       Visit(S->getBody());
381       Cnt.adjustForControlFlow();
382 
383       BreakContinue BC = BreakContinueStack.pop_back_val();
384       // The count at the start of the condition is equal to the count at the
385       // end of the body. The adjusted count does not include either the
386       // fall-through count coming into the loop or the continue count, so add
387       // both of those separately. This is coincidentally the same equation as
388       // with while loops but for different reasons.
389       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
390                                 Cnt.getAdjustedCount() + BC.ContinueCount);
391       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
392       Visit(S->getCond());
393       Cnt.adjustForControlFlow();
394       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
395       RecordNextStmtCount = true;
396     }
397 
398     void VisitForStmt(const ForStmt *S) {
399       RecordStmtCount(S);
400       if (S->getInit())
401         Visit(S->getInit());
402       // Counter tracks the body of the loop.
403       RegionCounter Cnt(PGO, S);
404       BreakContinueStack.push_back(BreakContinue());
405       // Visit the body region first. (This is basically the same as a while
406       // loop; see further comments in VisitWhileStmt.)
407       Cnt.beginRegion();
408       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
409       Visit(S->getBody());
410       Cnt.adjustForControlFlow();
411 
412       // The increment is essentially part of the body but it needs to include
413       // the count for all the continue statements.
414       if (S->getInc()) {
415         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
416                                   BreakContinueStack.back().ContinueCount);
417         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
418         Visit(S->getInc());
419         Cnt.adjustForControlFlow();
420       }
421 
422       BreakContinue BC = BreakContinueStack.pop_back_val();
423 
424       // ...then go back and propagate counts through the condition.
425       if (S->getCond()) {
426         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
427                                   Cnt.getAdjustedCount() +
428                                   BC.ContinueCount);
429         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
430         Visit(S->getCond());
431         Cnt.adjustForControlFlow();
432       }
433       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
434       RecordNextStmtCount = true;
435     }
436 
437     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
438       RecordStmtCount(S);
439       Visit(S->getRangeStmt());
440       Visit(S->getBeginEndStmt());
441       // Counter tracks the body of the loop.
442       RegionCounter Cnt(PGO, S);
443       BreakContinueStack.push_back(BreakContinue());
444       // Visit the body region first. (This is basically the same as a while
445       // loop; see further comments in VisitWhileStmt.)
446       Cnt.beginRegion();
447       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
448       Visit(S->getLoopVarStmt());
449       Visit(S->getBody());
450       Cnt.adjustForControlFlow();
451 
452       // The increment is essentially part of the body but it needs to include
453       // the count for all the continue statements.
454       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
455                                 BreakContinueStack.back().ContinueCount);
456       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
457       Visit(S->getInc());
458       Cnt.adjustForControlFlow();
459 
460       BreakContinue BC = BreakContinueStack.pop_back_val();
461 
462       // ...then go back and propagate counts through the condition.
463       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
464                                 Cnt.getAdjustedCount() +
465                                 BC.ContinueCount);
466       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
467       Visit(S->getCond());
468       Cnt.adjustForControlFlow();
469       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
470       RecordNextStmtCount = true;
471     }
472 
473     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
474       RecordStmtCount(S);
475       Visit(S->getElement());
476       // Counter tracks the body of the loop.
477       RegionCounter Cnt(PGO, S);
478       BreakContinueStack.push_back(BreakContinue());
479       Cnt.beginRegion();
480       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481       Visit(S->getBody());
482       BreakContinue BC = BreakContinueStack.pop_back_val();
483       Cnt.adjustForControlFlow();
484       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
485       RecordNextStmtCount = true;
486     }
487 
488     void VisitSwitchStmt(const SwitchStmt *S) {
489       RecordStmtCount(S);
490       Visit(S->getCond());
491       PGO.setCurrentRegionUnreachable();
492       BreakContinueStack.push_back(BreakContinue());
493       Visit(S->getBody());
494       // If the switch is inside a loop, add the continue counts.
495       BreakContinue BC = BreakContinueStack.pop_back_val();
496       if (!BreakContinueStack.empty())
497         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
498       // Counter tracks the exit block of the switch.
499       RegionCounter ExitCnt(PGO, S);
500       ExitCnt.beginRegion();
501       RecordNextStmtCount = true;
502     }
503 
504     void VisitCaseStmt(const CaseStmt *S) {
505       RecordNextStmtCount = false;
506       // Counter for this particular case. This counts only jumps from the
507       // switch header and does not include fallthrough from the case before
508       // this one.
509       RegionCounter Cnt(PGO, S);
510       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
511       CountMap[S] = Cnt.getCount();
512       RecordNextStmtCount = true;
513       Visit(S->getSubStmt());
514     }
515 
516     void VisitDefaultStmt(const DefaultStmt *S) {
517       RecordNextStmtCount = false;
518       // Counter for this default case. This does not include fallthrough from
519       // the previous case.
520       RegionCounter Cnt(PGO, S);
521       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
522       CountMap[S] = Cnt.getCount();
523       RecordNextStmtCount = true;
524       Visit(S->getSubStmt());
525     }
526 
527     void VisitIfStmt(const IfStmt *S) {
528       RecordStmtCount(S);
529       // Counter tracks the "then" part of an if statement. The count for
530       // the "else" part, if it exists, will be calculated from this counter.
531       RegionCounter Cnt(PGO, S);
532       Visit(S->getCond());
533 
534       Cnt.beginRegion();
535       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
536       Visit(S->getThen());
537       Cnt.adjustForControlFlow();
538 
539       if (S->getElse()) {
540         Cnt.beginElseRegion();
541         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
542         Visit(S->getElse());
543         Cnt.adjustForControlFlow();
544       }
545       Cnt.applyAdjustmentsToRegion(0);
546       RecordNextStmtCount = true;
547     }
548 
549     void VisitCXXTryStmt(const CXXTryStmt *S) {
550       RecordStmtCount(S);
551       Visit(S->getTryBlock());
552       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
553         Visit(S->getHandler(I));
554       // Counter tracks the continuation block of the try statement.
555       RegionCounter Cnt(PGO, S);
556       Cnt.beginRegion();
557       RecordNextStmtCount = true;
558     }
559 
560     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
561       RecordNextStmtCount = false;
562       // Counter tracks the catch statement's handler block.
563       RegionCounter Cnt(PGO, S);
564       Cnt.beginRegion();
565       CountMap[S] = PGO.getCurrentRegionCount();
566       Visit(S->getHandlerBlock());
567     }
568 
569     void VisitAbstractConditionalOperator(
570         const AbstractConditionalOperator *E) {
571       RecordStmtCount(E);
572       // Counter tracks the "true" part of a conditional operator. The
573       // count in the "false" part will be calculated from this counter.
574       RegionCounter Cnt(PGO, E);
575       Visit(E->getCond());
576 
577       Cnt.beginRegion();
578       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
579       Visit(E->getTrueExpr());
580       Cnt.adjustForControlFlow();
581 
582       Cnt.beginElseRegion();
583       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
584       Visit(E->getFalseExpr());
585       Cnt.adjustForControlFlow();
586 
587       Cnt.applyAdjustmentsToRegion(0);
588       RecordNextStmtCount = true;
589     }
590 
591     void VisitBinLAnd(const BinaryOperator *E) {
592       RecordStmtCount(E);
593       // Counter tracks the right hand side of a logical and operator.
594       RegionCounter Cnt(PGO, E);
595       Visit(E->getLHS());
596       Cnt.beginRegion();
597       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
598       Visit(E->getRHS());
599       Cnt.adjustForControlFlow();
600       Cnt.applyAdjustmentsToRegion(0);
601       RecordNextStmtCount = true;
602     }
603 
604     void VisitBinLOr(const BinaryOperator *E) {
605       RecordStmtCount(E);
606       // Counter tracks the right hand side of a logical or operator.
607       RegionCounter Cnt(PGO, E);
608       Visit(E->getLHS());
609       Cnt.beginRegion();
610       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
611       Visit(E->getRHS());
612       Cnt.adjustForControlFlow();
613       Cnt.applyAdjustmentsToRegion(0);
614       RecordNextStmtCount = true;
615     }
616   };
617 }
618 
619 void PGOHash::combine(HashType Type) {
620   // Check that we never combine 0 and only have six bits.
621   assert(Type && "Hash is invalid: unexpected type 0");
622   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
623 
624   // Pass through MD5 if enough work has built up.
625   if (Count && Count % NumTypesPerWord == 0) {
626     using namespace llvm::support;
627     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
628     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
629     Working = 0;
630   }
631 
632   // Accumulate the current type.
633   ++Count;
634   Working = Working << NumBitsPerType | Type;
635 }
636 
637 uint64_t PGOHash::finalize() {
638   // Use Working as the hash directly if we never used MD5.
639   if (Count <= NumTypesPerWord)
640     // No need to byte swap here, since none of the math was endian-dependent.
641     // This number will be byte-swapped as required on endianness transitions,
642     // so we will see the same value on the other side.
643     return Working;
644 
645   // Check for remaining work in Working.
646   if (Working)
647     MD5.update(Working);
648 
649   // Finalize the MD5 and return the hash.
650   llvm::MD5::MD5Result Result;
651   MD5.final(Result);
652   using namespace llvm::support;
653   return endian::read<uint64_t, little, unaligned>(Result);
654 }
655 
656 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
657   // Make sure we only emit coverage mapping for one constructor/destructor.
658   // Clang emits several functions for the constructor and the destructor of
659   // a class. Every function is instrumented, but we only want to provide
660   // coverage for one of them. Because of that we only emit the coverage mapping
661   // for the base constructor/destructor.
662   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
663        GD.getCtorType() != Ctor_Base) ||
664       (isa<CXXDestructorDecl>(GD.getDecl()) &&
665        GD.getDtorType() != Dtor_Base)) {
666     SkipCoverageMapping = true;
667   }
668 }
669 
670 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
671   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
672   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
673   if (!InstrumentRegions && !PGOReader)
674     return;
675   if (D->isImplicit())
676     return;
677   CGM.ClearUnusedCoverageMapping(D);
678   setFuncName(Fn);
679 
680   mapRegionCounters(D);
681   if (CGM.getCodeGenOpts().CoverageMapping)
682     emitCounterRegionMapping(D);
683   if (PGOReader) {
684     SourceManager &SM = CGM.getContext().getSourceManager();
685     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
686     computeRegionCounts(D);
687     applyFunctionAttributes(PGOReader, Fn);
688   }
689 }
690 
691 void CodeGenPGO::mapRegionCounters(const Decl *D) {
692   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
693   MapRegionCounters Walker(*RegionCounterMap);
694   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
695     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
696   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
697     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
698   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
699     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
700   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
701     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
702   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
703   NumRegionCounters = Walker.NextCounter;
704   FunctionHash = Walker.Hash.finalize();
705 }
706 
707 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
708   if (SkipCoverageMapping)
709     return;
710   // Don't map the functions inside the system headers
711   auto Loc = D->getBody()->getLocStart();
712   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
713     return;
714 
715   std::string CoverageMapping;
716   llvm::raw_string_ostream OS(CoverageMapping);
717   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
718                                 CGM.getContext().getSourceManager(),
719                                 CGM.getLangOpts(), RegionCounterMap.get());
720   MappingGen.emitCounterMapping(D, OS);
721   OS.flush();
722 
723   if (CoverageMapping.empty())
724     return;
725 
726   CGM.getCoverageMapping()->addFunctionMappingRecord(
727       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
728 }
729 
730 void
731 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
732                                     llvm::GlobalValue::LinkageTypes Linkage) {
733   if (SkipCoverageMapping)
734     return;
735   // Don't map the functions inside the system headers
736   auto Loc = D->getBody()->getLocStart();
737   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
738     return;
739 
740   std::string CoverageMapping;
741   llvm::raw_string_ostream OS(CoverageMapping);
742   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
743                                 CGM.getContext().getSourceManager(),
744                                 CGM.getLangOpts());
745   MappingGen.emitEmptyMapping(D, OS);
746   OS.flush();
747 
748   if (CoverageMapping.empty())
749     return;
750 
751   setFuncName(FuncName, Linkage);
752   CGM.getCoverageMapping()->addFunctionMappingRecord(
753       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
754 }
755 
756 void CodeGenPGO::computeRegionCounts(const Decl *D) {
757   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
758   ComputeRegionCounts Walker(*StmtCountMap, *this);
759   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
760     Walker.VisitFunctionDecl(FD);
761   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
762     Walker.VisitObjCMethodDecl(MD);
763   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
764     Walker.VisitBlockDecl(BD);
765   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
766     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
767 }
768 
769 void
770 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
771                                     llvm::Function *Fn) {
772   if (!haveRegionCounts())
773     return;
774 
775   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
776   uint64_t FunctionCount = getRegionCount(0);
777   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
778     // Turn on InlineHint attribute for hot functions.
779     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
780     Fn->addFnAttr(llvm::Attribute::InlineHint);
781   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
782     // Turn on Cold attribute for cold functions.
783     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
784     Fn->addFnAttr(llvm::Attribute::Cold);
785 }
786 
787 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
788   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
789     return;
790   if (!Builder.GetInsertPoint())
791     return;
792   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
793   Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
794                       llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
795                       Builder.getInt64(FunctionHash),
796                       Builder.getInt32(NumRegionCounters),
797                       Builder.getInt32(Counter));
798 }
799 
800 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
801                                   bool IsInMainFile) {
802   CGM.getPGOStats().addVisited(IsInMainFile);
803   RegionCounts.clear();
804   if (std::error_code EC =
805           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
806     if (EC == llvm::instrprof_error::unknown_function)
807       CGM.getPGOStats().addMissing(IsInMainFile);
808     else if (EC == llvm::instrprof_error::hash_mismatch)
809       CGM.getPGOStats().addMismatched(IsInMainFile);
810     else if (EC == llvm::instrprof_error::malformed)
811       // TODO: Consider a more specific warning for this case.
812       CGM.getPGOStats().addMismatched(IsInMainFile);
813     RegionCounts.clear();
814   }
815 }
816 
817 /// \brief Calculate what to divide by to scale weights.
818 ///
819 /// Given the maximum weight, calculate a divisor that will scale all the
820 /// weights to strictly less than UINT32_MAX.
821 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
822   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
823 }
824 
825 /// \brief Scale an individual branch weight (and add 1).
826 ///
827 /// Scale a 64-bit weight down to 32-bits using \c Scale.
828 ///
829 /// According to Laplace's Rule of Succession, it is better to compute the
830 /// weight based on the count plus 1, so universally add 1 to the value.
831 ///
832 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
833 /// greater than \c Weight.
834 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
835   assert(Scale && "scale by 0?");
836   uint64_t Scaled = Weight / Scale + 1;
837   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
838   return Scaled;
839 }
840 
841 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
842                                               uint64_t FalseCount) {
843   // Check for empty weights.
844   if (!TrueCount && !FalseCount)
845     return nullptr;
846 
847   // Calculate how to scale down to 32-bits.
848   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
849 
850   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
851   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
852                                       scaleBranchWeight(FalseCount, Scale));
853 }
854 
855 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
856   // We need at least two elements to create meaningful weights.
857   if (Weights.size() < 2)
858     return nullptr;
859 
860   // Check for empty weights.
861   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
862   if (MaxWeight == 0)
863     return nullptr;
864 
865   // Calculate how to scale down to 32-bits.
866   uint64_t Scale = calculateWeightScale(MaxWeight);
867 
868   SmallVector<uint32_t, 16> ScaledWeights;
869   ScaledWeights.reserve(Weights.size());
870   for (uint64_t W : Weights)
871     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
872 
873   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
874   return MDHelper.createBranchWeights(ScaledWeights);
875 }
876 
877 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
878                                             RegionCounter &Cnt) {
879   if (!haveRegionCounts())
880     return nullptr;
881   uint64_t LoopCount = Cnt.getCount();
882   uint64_t CondCount = 0;
883   bool Found = getStmtCount(Cond, CondCount);
884   assert(Found && "missing expected loop condition count");
885   (void)Found;
886   if (CondCount == 0)
887     return nullptr;
888   return createBranchWeights(LoopCount,
889                              std::max(CondCount, LoopCount) - LoopCount);
890 }
891