1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/PostDominators.h"
20 #include "llvm/Analysis/TargetLibraryInfo.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/Metadata.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/BranchProbability.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include <cassert>
43 #include <cstdint>
44 #include <iterator>
45 #include <utility>
46 
47 using namespace llvm;
48 
49 #define DEBUG_TYPE "branch-prob"
50 
51 static cl::opt<bool> PrintBranchProb(
52     "print-bpi", cl::init(false), cl::Hidden,
53     cl::desc("Print the branch probability info."));
54 
55 cl::opt<std::string> PrintBranchProbFuncName(
56     "print-bpi-func-name", cl::Hidden,
57     cl::desc("The option to specify the name of the function "
58              "whose branch probability info is printed."));
59 
60 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
61                       "Branch Probability Analysis", false, true)
62 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
63 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
64 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
65 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
66                     "Branch Probability Analysis", false, true)
67 
68 BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
69     : FunctionPass(ID) {
70   initializeBranchProbabilityInfoWrapperPassPass(
71       *PassRegistry::getPassRegistry());
72 }
73 
74 char BranchProbabilityInfoWrapperPass::ID = 0;
75 
76 // Weights are for internal use only. They are used by heuristics to help to
77 // estimate edges' probability. Example:
78 //
79 // Using "Loop Branch Heuristics" we predict weights of edges for the
80 // block BB2.
81 //         ...
82 //          |
83 //          V
84 //         BB1<-+
85 //          |   |
86 //          |   | (Weight = 124)
87 //          V   |
88 //         BB2--+
89 //          |
90 //          | (Weight = 4)
91 //          V
92 //         BB3
93 //
94 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
95 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
96 static const uint32_t LBH_TAKEN_WEIGHT = 124;
97 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
98 // Unlikely edges within a loop are half as likely as other edges
99 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
100 
101 /// Unreachable-terminating branch taken probability.
102 ///
103 /// This is the probability for a branch being taken to a block that terminates
104 /// (eventually) in unreachable. These are predicted as unlikely as possible.
105 /// All reachable probability will proportionally share the remaining part.
106 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
107 
108 /// Weight for a branch taken going into a cold block.
109 ///
110 /// This is the weight for a branch taken toward a block marked
111 /// cold.  A block is marked cold if it's postdominated by a
112 /// block containing a call to a cold function.  Cold functions
113 /// are those marked with attribute 'cold'.
114 static const uint32_t CC_TAKEN_WEIGHT = 4;
115 
116 /// Weight for a branch not-taken into a cold block.
117 ///
118 /// This is the weight for a branch not taken toward a block marked
119 /// cold.
120 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
121 
122 static const uint32_t PH_TAKEN_WEIGHT = 20;
123 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
124 
125 static const uint32_t ZH_TAKEN_WEIGHT = 20;
126 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
127 
128 static const uint32_t FPH_TAKEN_WEIGHT = 20;
129 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
130 
131 /// This is the probability for an ordered floating point comparison.
132 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
133 /// This is the probability for an unordered floating point comparison, it means
134 /// one or two of the operands are NaN. Usually it is used to test for an
135 /// exceptional case, so the result is unlikely.
136 static const uint32_t FPH_UNO_WEIGHT = 1;
137 
138 /// Invoke-terminating normal branch taken weight
139 ///
140 /// This is the weight for branching to the normal destination of an invoke
141 /// instruction. We expect this to happen most of the time. Set the weight to an
142 /// absurdly high value so that nested loops subsume it.
143 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
144 
145 /// Invoke-terminating normal branch not-taken weight.
146 ///
147 /// This is the weight for branching to the unwind destination of an invoke
148 /// instruction. This is essentially never taken.
149 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
150 
151 BranchProbabilityInfo::SccInfo::SccInfo(const Function &F) {
152   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
153   // FIXME: We could only calculate this if the CFG is known to be irreducible
154   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
155   int SccNum = 0;
156   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
157        ++It, ++SccNum) {
158     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
159     // catch them.
160     const std::vector<const BasicBlock *> &Scc = *It;
161     if (Scc.size() == 1)
162       continue;
163 
164     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
165     for (const auto *BB : Scc) {
166       LLVM_DEBUG(dbgs() << " " << BB->getName());
167       SccNums[BB] = SccNum;
168       calculateSccBlockType(BB, SccNum);
169     }
170     LLVM_DEBUG(dbgs() << "\n");
171   }
172 }
173 
174 int BranchProbabilityInfo::SccInfo::getSCCNum(const BasicBlock *BB) const {
175   auto SccIt = SccNums.find(BB);
176   if (SccIt == SccNums.end())
177     return -1;
178   return SccIt->second;
179 }
180 
181 void BranchProbabilityInfo::SccInfo::getSccEnterBlocks(
182     int SccNum, SmallVectorImpl<BasicBlock *> &Enters) const {
183 
184   for (auto MapIt : SccBlocks[SccNum]) {
185     const auto *BB = MapIt.first;
186     if (isSCCHeader(BB, SccNum))
187       for (const auto *Pred : predecessors(BB))
188         if (getSCCNum(Pred) != SccNum)
189           Enters.push_back(const_cast<BasicBlock *>(BB));
190   }
191 }
192 
193 void BranchProbabilityInfo::SccInfo::getSccExitBlocks(
194     int SccNum, SmallVectorImpl<BasicBlock *> &Exits) const {
195   for (auto MapIt : SccBlocks[SccNum]) {
196     const auto *BB = MapIt.first;
197     if (isSCCExitingBlock(BB, SccNum))
198       for (const auto *Succ : successors(BB))
199         if (getSCCNum(Succ) != SccNum)
200           Exits.push_back(const_cast<BasicBlock *>(BB));
201   }
202 }
203 
204 uint32_t BranchProbabilityInfo::SccInfo::getSccBlockType(const BasicBlock *BB,
205                                                          int SccNum) const {
206   assert(getSCCNum(BB) == SccNum);
207 
208   assert(SccBlocks.size() > static_cast<unsigned>(SccNum) && "Unknown SCC");
209   const auto &SccBlockTypes = SccBlocks[SccNum];
210 
211   auto It = SccBlockTypes.find(BB);
212   if (It != SccBlockTypes.end()) {
213     return It->second;
214   }
215   return Inner;
216 }
217 
218 void BranchProbabilityInfo::SccInfo::calculateSccBlockType(const BasicBlock *BB,
219                                                            int SccNum) {
220   assert(getSCCNum(BB) == SccNum);
221   uint32_t BlockType = Inner;
222 
223   if (llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
224                    [&](const BasicBlock *Pred) {
225         // Consider any block that is an entry point to the SCC as
226         // a header.
227         return getSCCNum(Pred) != SccNum;
228       }))
229     BlockType |= Header;
230 
231   if (llvm::any_of(
232           make_range(succ_begin(BB), succ_end(BB)),
233           [&](const BasicBlock *Succ) { return getSCCNum(Succ) != SccNum; }))
234     BlockType |= Exiting;
235 
236   // Lazily compute the set of headers for a given SCC and cache the results
237   // in the SccHeaderMap.
238   if (SccBlocks.size() <= static_cast<unsigned>(SccNum))
239     SccBlocks.resize(SccNum + 1);
240   auto &SccBlockTypes = SccBlocks[SccNum];
241 
242   if (BlockType != Inner) {
243     bool IsInserted;
244     std::tie(std::ignore, IsInserted) =
245         SccBlockTypes.insert(std::make_pair(BB, BlockType));
246     assert(IsInserted && "Duplicated block in SCC");
247   }
248 }
249 
250 BranchProbabilityInfo::LoopBlock::LoopBlock(const BasicBlock *BB,
251                                             const LoopInfo &LI,
252                                             const SccInfo &SccI)
253     : BB(BB) {
254   LD.first = LI.getLoopFor(BB);
255   if (!LD.first) {
256     LD.second = SccI.getSCCNum(BB);
257   }
258 }
259 
260 bool BranchProbabilityInfo::isLoopEnteringEdge(const LoopEdge &Edge) const {
261   const auto &SrcBlock = Edge.first;
262   const auto &DstBlock = Edge.second;
263   return (DstBlock.getLoop() &&
264           !DstBlock.getLoop()->contains(SrcBlock.getLoop())) ||
265          // Assume that SCCs can't be nested.
266          (DstBlock.getSccNum() != -1 &&
267           SrcBlock.getSccNum() != DstBlock.getSccNum());
268 }
269 
270 bool BranchProbabilityInfo::isLoopExitingEdge(const LoopEdge &Edge) const {
271   return isLoopEnteringEdge({Edge.second, Edge.first});
272 }
273 
274 bool BranchProbabilityInfo::isLoopEnteringExitingEdge(
275     const LoopEdge &Edge) const {
276   return isLoopEnteringEdge(Edge) || isLoopExitingEdge(Edge);
277 }
278 
279 bool BranchProbabilityInfo::isLoopBackEdge(const LoopEdge &Edge) const {
280   const auto &SrcBlock = Edge.first;
281   const auto &DstBlock = Edge.second;
282   return SrcBlock.belongsToSameLoop(DstBlock) &&
283          ((DstBlock.getLoop() &&
284            DstBlock.getLoop()->getHeader() == DstBlock.getBlock()) ||
285           (DstBlock.getSccNum() != -1 &&
286            SccI->isSCCHeader(DstBlock.getBlock(), DstBlock.getSccNum())));
287 }
288 
289 void BranchProbabilityInfo::getLoopEnterBlocks(
290     const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Enters) const {
291   if (LB.getLoop()) {
292     auto *Header = LB.getLoop()->getHeader();
293     Enters.append(pred_begin(Header), pred_end(Header));
294   } else {
295     assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
296     SccI->getSccEnterBlocks(LB.getSccNum(), Enters);
297   }
298 }
299 
300 void BranchProbabilityInfo::getLoopExitBlocks(
301     const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Exits) const {
302   if (LB.getLoop()) {
303     LB.getLoop()->getExitBlocks(Exits);
304   } else {
305     assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
306     SccI->getSccExitBlocks(LB.getSccNum(), Exits);
307   }
308 }
309 
310 static void UpdatePDTWorklist(const BasicBlock *BB, PostDominatorTree *PDT,
311                               SmallVectorImpl<const BasicBlock *> &WorkList,
312                               SmallPtrSetImpl<const BasicBlock *> &TargetSet) {
313   SmallVector<BasicBlock *, 8> Descendants;
314   SmallPtrSet<const BasicBlock *, 16> NewItems;
315 
316   PDT->getDescendants(const_cast<BasicBlock *>(BB), Descendants);
317   for (auto *BB : Descendants)
318     if (TargetSet.insert(BB).second)
319       for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
320         if (!TargetSet.count(*PI))
321           NewItems.insert(*PI);
322   WorkList.insert(WorkList.end(), NewItems.begin(), NewItems.end());
323 }
324 
325 /// Compute a set of basic blocks that are post-dominated by unreachables.
326 void BranchProbabilityInfo::computePostDominatedByUnreachable(
327     const Function &F, PostDominatorTree *PDT) {
328   SmallVector<const BasicBlock *, 8> WorkList;
329   for (auto &BB : F) {
330     const Instruction *TI = BB.getTerminator();
331     if (TI->getNumSuccessors() == 0) {
332       if (isa<UnreachableInst>(TI) ||
333           // If this block is terminated by a call to
334           // @llvm.experimental.deoptimize then treat it like an unreachable
335           // since the @llvm.experimental.deoptimize call is expected to
336           // practically never execute.
337           BB.getTerminatingDeoptimizeCall())
338         UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByUnreachable);
339     }
340   }
341 
342   while (!WorkList.empty()) {
343     const BasicBlock *BB = WorkList.pop_back_val();
344     if (PostDominatedByUnreachable.count(BB))
345       continue;
346     // If the terminator is an InvokeInst, check only the normal destination
347     // block as the unwind edge of InvokeInst is also very unlikely taken.
348     if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
349       if (PostDominatedByUnreachable.count(II->getNormalDest()))
350         UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
351     }
352     // If all the successors are unreachable, BB is unreachable as well.
353     else if (!successors(BB).empty() &&
354              llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
355                return PostDominatedByUnreachable.count(Succ);
356              }))
357       UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
358   }
359 }
360 
361 /// compute a set of basic blocks that are post-dominated by ColdCalls.
362 void BranchProbabilityInfo::computePostDominatedByColdCall(
363     const Function &F, PostDominatorTree *PDT) {
364   SmallVector<const BasicBlock *, 8> WorkList;
365   for (auto &BB : F)
366     for (auto &I : BB)
367       if (const CallInst *CI = dyn_cast<CallInst>(&I))
368         if (CI->hasFnAttr(Attribute::Cold))
369           UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByColdCall);
370 
371   while (!WorkList.empty()) {
372     const BasicBlock *BB = WorkList.pop_back_val();
373 
374     // If the terminator is an InvokeInst, check only the normal destination
375     // block as the unwind edge of InvokeInst is also very unlikely taken.
376     if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
377       if (PostDominatedByColdCall.count(II->getNormalDest()))
378         UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
379     }
380     // If all of successor are post dominated then BB is also done.
381     else if (!successors(BB).empty() &&
382              llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
383                return PostDominatedByColdCall.count(Succ);
384              }))
385       UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
386   }
387 }
388 
389 /// Calculate edge weights for successors lead to unreachable.
390 ///
391 /// Predict that a successor which leads necessarily to an
392 /// unreachable-terminated block as extremely unlikely.
393 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
394   const Instruction *TI = BB->getTerminator();
395   (void) TI;
396   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
397   assert(!isa<InvokeInst>(TI) &&
398          "Invokes should have already been handled by calcInvokeHeuristics");
399 
400   SmallVector<unsigned, 4> UnreachableEdges;
401   SmallVector<unsigned, 4> ReachableEdges;
402 
403   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
404     if (PostDominatedByUnreachable.count(*I))
405       UnreachableEdges.push_back(I.getSuccessorIndex());
406     else
407       ReachableEdges.push_back(I.getSuccessorIndex());
408 
409   // Skip probabilities if all were reachable.
410   if (UnreachableEdges.empty())
411     return false;
412 
413   SmallVector<BranchProbability, 4> EdgeProbabilities(
414       BB->getTerminator()->getNumSuccessors(), BranchProbability::getUnknown());
415   if (ReachableEdges.empty()) {
416     BranchProbability Prob(1, UnreachableEdges.size());
417     for (unsigned SuccIdx : UnreachableEdges)
418       EdgeProbabilities[SuccIdx] = Prob;
419     setEdgeProbability(BB, EdgeProbabilities);
420     return true;
421   }
422 
423   auto UnreachableProb = UR_TAKEN_PROB;
424   auto ReachableProb =
425       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
426       ReachableEdges.size();
427 
428   for (unsigned SuccIdx : UnreachableEdges)
429     EdgeProbabilities[SuccIdx] = UnreachableProb;
430   for (unsigned SuccIdx : ReachableEdges)
431     EdgeProbabilities[SuccIdx] = ReachableProb;
432 
433   setEdgeProbability(BB, EdgeProbabilities);
434   return true;
435 }
436 
437 // Propagate existing explicit probabilities from either profile data or
438 // 'expect' intrinsic processing. Examine metadata against unreachable
439 // heuristic. The probability of the edge coming to unreachable block is
440 // set to min of metadata and unreachable heuristic.
441 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
442   const Instruction *TI = BB->getTerminator();
443   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
444   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI) ||
445         isa<InvokeInst>(TI)))
446     return false;
447 
448   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
449   if (!WeightsNode)
450     return false;
451 
452   // Check that the number of successors is manageable.
453   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
454 
455   // Ensure there are weights for all of the successors. Note that the first
456   // operand to the metadata node is a name, not a weight.
457   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
458     return false;
459 
460   // Build up the final weights that will be used in a temporary buffer.
461   // Compute the sum of all weights to later decide whether they need to
462   // be scaled to fit in 32 bits.
463   uint64_t WeightSum = 0;
464   SmallVector<uint32_t, 2> Weights;
465   SmallVector<unsigned, 2> UnreachableIdxs;
466   SmallVector<unsigned, 2> ReachableIdxs;
467   Weights.reserve(TI->getNumSuccessors());
468   for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
469     ConstantInt *Weight =
470         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
471     if (!Weight)
472       return false;
473     assert(Weight->getValue().getActiveBits() <= 32 &&
474            "Too many bits for uint32_t");
475     Weights.push_back(Weight->getZExtValue());
476     WeightSum += Weights.back();
477     if (PostDominatedByUnreachable.count(TI->getSuccessor(I - 1)))
478       UnreachableIdxs.push_back(I - 1);
479     else
480       ReachableIdxs.push_back(I - 1);
481   }
482   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
483 
484   // If the sum of weights does not fit in 32 bits, scale every weight down
485   // accordingly.
486   uint64_t ScalingFactor =
487       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
488 
489   if (ScalingFactor > 1) {
490     WeightSum = 0;
491     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
492       Weights[I] /= ScalingFactor;
493       WeightSum += Weights[I];
494     }
495   }
496   assert(WeightSum <= UINT32_MAX &&
497          "Expected weights to scale down to 32 bits");
498 
499   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
500     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
501       Weights[I] = 1;
502     WeightSum = TI->getNumSuccessors();
503   }
504 
505   // Set the probability.
506   SmallVector<BranchProbability, 2> BP;
507   for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
508     BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
509 
510   // Examine the metadata against unreachable heuristic.
511   // If the unreachable heuristic is more strong then we use it for this edge.
512   if (UnreachableIdxs.size() == 0 || ReachableIdxs.size() == 0) {
513     setEdgeProbability(BB, BP);
514     return true;
515   }
516 
517   auto UnreachableProb = UR_TAKEN_PROB;
518   for (auto I : UnreachableIdxs)
519     if (UnreachableProb < BP[I]) {
520       BP[I] = UnreachableProb;
521     }
522 
523   // Sum of all edge probabilities must be 1.0. If we modified the probability
524   // of some edges then we must distribute the introduced difference over the
525   // reachable blocks.
526   //
527   // Proportional distribution: the relation between probabilities of the
528   // reachable edges is kept unchanged. That is for any reachable edges i and j:
529   //   newBP[i] / newBP[j] == oldBP[i] / oldBP[j] =>
530   //   newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == K
531   // Where K is independent of i,j.
532   //   newBP[i] == oldBP[i] * K
533   // We need to find K.
534   // Make sum of all reachables of the left and right parts:
535   //   sum_of_reachable(newBP) == K * sum_of_reachable(oldBP)
536   // Sum of newBP must be equal to 1.0:
537   //   sum_of_reachable(newBP) + sum_of_unreachable(newBP) == 1.0 =>
538   //   sum_of_reachable(newBP) = 1.0 - sum_of_unreachable(newBP)
539   // Where sum_of_unreachable(newBP) is what has been just changed.
540   // Finally:
541   //   K == sum_of_reachable(newBP) / sum_of_reachable(oldBP) =>
542   //   K == (1.0 - sum_of_unreachable(newBP)) / sum_of_reachable(oldBP)
543   BranchProbability NewUnreachableSum = BranchProbability::getZero();
544   for (auto I : UnreachableIdxs)
545     NewUnreachableSum += BP[I];
546 
547   BranchProbability NewReachableSum =
548       BranchProbability::getOne() - NewUnreachableSum;
549 
550   BranchProbability OldReachableSum = BranchProbability::getZero();
551   for (auto I : ReachableIdxs)
552     OldReachableSum += BP[I];
553 
554   if (OldReachableSum != NewReachableSum) { // Anything to dsitribute?
555     if (OldReachableSum.isZero()) {
556       // If all oldBP[i] are zeroes then the proportional distribution results
557       // in all zero probabilities and the error stays big. In this case we
558       // evenly spread NewReachableSum over the reachable edges.
559       BranchProbability PerEdge = NewReachableSum / ReachableIdxs.size();
560       for (auto I : ReachableIdxs)
561         BP[I] = PerEdge;
562     } else {
563       for (auto I : ReachableIdxs) {
564         // We use uint64_t to avoid double rounding error of the following
565         // calculation: BP[i] = BP[i] * NewReachableSum / OldReachableSum
566         // The formula is taken from the private constructor
567         // BranchProbability(uint32_t Numerator, uint32_t Denominator)
568         uint64_t Mul = static_cast<uint64_t>(NewReachableSum.getNumerator()) *
569                        BP[I].getNumerator();
570         uint32_t Div = static_cast<uint32_t>(
571             divideNearest(Mul, OldReachableSum.getNumerator()));
572         BP[I] = BranchProbability::getRaw(Div);
573       }
574     }
575   }
576 
577   setEdgeProbability(BB, BP);
578 
579   return true;
580 }
581 
582 /// Calculate edge weights for edges leading to cold blocks.
583 ///
584 /// A cold block is one post-dominated by  a block with a call to a
585 /// cold function.  Those edges are unlikely to be taken, so we give
586 /// them relatively low weight.
587 ///
588 /// Return true if we could compute the weights for cold edges.
589 /// Return false, otherwise.
590 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
591   const Instruction *TI = BB->getTerminator();
592   (void) TI;
593   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
594   assert(!isa<InvokeInst>(TI) &&
595          "Invokes should have already been handled by calcInvokeHeuristics");
596 
597   // Determine which successors are post-dominated by a cold block.
598   SmallVector<unsigned, 4> ColdEdges;
599   SmallVector<unsigned, 4> NormalEdges;
600   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
601     if (PostDominatedByColdCall.count(*I))
602       ColdEdges.push_back(I.getSuccessorIndex());
603     else
604       NormalEdges.push_back(I.getSuccessorIndex());
605 
606   // Skip probabilities if no cold edges.
607   if (ColdEdges.empty())
608     return false;
609 
610   SmallVector<BranchProbability, 4> EdgeProbabilities(
611       BB->getTerminator()->getNumSuccessors(), BranchProbability::getUnknown());
612   if (NormalEdges.empty()) {
613     BranchProbability Prob(1, ColdEdges.size());
614     for (unsigned SuccIdx : ColdEdges)
615       EdgeProbabilities[SuccIdx] = Prob;
616     setEdgeProbability(BB, EdgeProbabilities);
617     return true;
618   }
619 
620   auto ColdProb = BranchProbability::getBranchProbability(
621       CC_TAKEN_WEIGHT,
622       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
623   auto NormalProb = BranchProbability::getBranchProbability(
624       CC_NONTAKEN_WEIGHT,
625       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
626 
627   for (unsigned SuccIdx : ColdEdges)
628     EdgeProbabilities[SuccIdx] = ColdProb;
629   for (unsigned SuccIdx : NormalEdges)
630     EdgeProbabilities[SuccIdx] = NormalProb;
631 
632   setEdgeProbability(BB, EdgeProbabilities);
633   return true;
634 }
635 
636 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
637 // between two pointer or pointer and NULL will fail.
638 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
639   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
640   if (!BI || !BI->isConditional())
641     return false;
642 
643   Value *Cond = BI->getCondition();
644   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
645   if (!CI || !CI->isEquality())
646     return false;
647 
648   Value *LHS = CI->getOperand(0);
649 
650   if (!LHS->getType()->isPointerTy())
651     return false;
652 
653   assert(CI->getOperand(1)->getType()->isPointerTy());
654 
655   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
656                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
657   BranchProbability UntakenProb(PH_NONTAKEN_WEIGHT,
658                                 PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
659 
660   // p != 0   ->   isProb = true
661   // p == 0   ->   isProb = false
662   // p != q   ->   isProb = true
663   // p == q   ->   isProb = false;
664   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
665   if (!isProb)
666     std::swap(TakenProb, UntakenProb);
667 
668   setEdgeProbability(
669       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
670   return true;
671 }
672 
673 // Compute the unlikely successors to the block BB in the loop L, specifically
674 // those that are unlikely because this is a loop, and add them to the
675 // UnlikelyBlocks set.
676 static void
677 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
678                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
679   // Sometimes in a loop we have a branch whose condition is made false by
680   // taking it. This is typically something like
681   //  int n = 0;
682   //  while (...) {
683   //    if (++n >= MAX) {
684   //      n = 0;
685   //    }
686   //  }
687   // In this sort of situation taking the branch means that at the very least it
688   // won't be taken again in the next iteration of the loop, so we should
689   // consider it less likely than a typical branch.
690   //
691   // We detect this by looking back through the graph of PHI nodes that sets the
692   // value that the condition depends on, and seeing if we can reach a successor
693   // block which can be determined to make the condition false.
694   //
695   // FIXME: We currently consider unlikely blocks to be half as likely as other
696   // blocks, but if we consider the example above the likelyhood is actually
697   // 1/MAX. We could therefore be more precise in how unlikely we consider
698   // blocks to be, but it would require more careful examination of the form
699   // of the comparison expression.
700   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
701   if (!BI || !BI->isConditional())
702     return;
703 
704   // Check if the branch is based on an instruction compared with a constant
705   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
706   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
707       !isa<Constant>(CI->getOperand(1)))
708     return;
709 
710   // Either the instruction must be a PHI, or a chain of operations involving
711   // constants that ends in a PHI which we can then collapse into a single value
712   // if the PHI value is known.
713   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
714   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
715   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
716   // Collect the instructions until we hit a PHI
717   SmallVector<BinaryOperator *, 1> InstChain;
718   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
719          isa<Constant>(CmpLHS->getOperand(1))) {
720     // Stop if the chain extends outside of the loop
721     if (!L->contains(CmpLHS))
722       return;
723     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
724     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
725     if (CmpLHS)
726       CmpPHI = dyn_cast<PHINode>(CmpLHS);
727   }
728   if (!CmpPHI || !L->contains(CmpPHI))
729     return;
730 
731   // Trace the phi node to find all values that come from successors of BB
732   SmallPtrSet<PHINode*, 8> VisitedInsts;
733   SmallVector<PHINode*, 8> WorkList;
734   WorkList.push_back(CmpPHI);
735   VisitedInsts.insert(CmpPHI);
736   while (!WorkList.empty()) {
737     PHINode *P = WorkList.back();
738     WorkList.pop_back();
739     for (BasicBlock *B : P->blocks()) {
740       // Skip blocks that aren't part of the loop
741       if (!L->contains(B))
742         continue;
743       Value *V = P->getIncomingValueForBlock(B);
744       // If the source is a PHI add it to the work list if we haven't
745       // already visited it.
746       if (PHINode *PN = dyn_cast<PHINode>(V)) {
747         if (VisitedInsts.insert(PN).second)
748           WorkList.push_back(PN);
749         continue;
750       }
751       // If this incoming value is a constant and B is a successor of BB, then
752       // we can constant-evaluate the compare to see if it makes the branch be
753       // taken or not.
754       Constant *CmpLHSConst = dyn_cast<Constant>(V);
755       if (!CmpLHSConst || !llvm::is_contained(successors(BB), B))
756         continue;
757       // First collapse InstChain
758       for (Instruction *I : llvm::reverse(InstChain)) {
759         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
760                                         cast<Constant>(I->getOperand(1)), true);
761         if (!CmpLHSConst)
762           break;
763       }
764       if (!CmpLHSConst)
765         continue;
766       // Now constant-evaluate the compare
767       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
768                                                   CmpLHSConst, CmpConst, true);
769       // If the result means we don't branch to the block then that block is
770       // unlikely.
771       if (Result &&
772           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
773            (Result->isOneValue() && B == BI->getSuccessor(1))))
774         UnlikelyBlocks.insert(B);
775     }
776   }
777 }
778 
779 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
780 // as taken, exiting edges as not-taken.
781 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
782                                                      const LoopInfo &LI) {
783   LoopBlock LB(BB, LI, *SccI.get());
784   if (!LB.belongsToLoop())
785     return false;
786 
787   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
788   if (LB.getLoop())
789     computeUnlikelySuccessors(BB, LB.getLoop(), UnlikelyBlocks);
790 
791   SmallVector<unsigned, 8> BackEdges;
792   SmallVector<unsigned, 8> ExitingEdges;
793   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
794   SmallVector<unsigned, 8> UnlikelyEdges;
795 
796   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
797     LoopBlock SuccLB(*I, LI, *SccI.get());
798     LoopEdge Edge(LB, SuccLB);
799     bool IsUnlikelyEdge =
800         LB.getLoop() && (UnlikelyBlocks.find(*I) != UnlikelyBlocks.end());
801 
802     if (IsUnlikelyEdge)
803       UnlikelyEdges.push_back(I.getSuccessorIndex());
804     else if (isLoopExitingEdge(Edge))
805       ExitingEdges.push_back(I.getSuccessorIndex());
806     else if (isLoopBackEdge(Edge))
807       BackEdges.push_back(I.getSuccessorIndex());
808     else {
809       InEdges.push_back(I.getSuccessorIndex());
810     }
811   }
812 
813   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
814     return false;
815 
816   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
817   // normalize them so that they sum up to one.
818   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
819                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
820                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
821                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
822 
823   SmallVector<BranchProbability, 4> EdgeProbabilities(
824       BB->getTerminator()->getNumSuccessors(), BranchProbability::getUnknown());
825   if (uint32_t numBackEdges = BackEdges.size()) {
826     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
827     auto Prob = TakenProb / numBackEdges;
828     for (unsigned SuccIdx : BackEdges)
829       EdgeProbabilities[SuccIdx] = Prob;
830   }
831 
832   if (uint32_t numInEdges = InEdges.size()) {
833     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
834     auto Prob = TakenProb / numInEdges;
835     for (unsigned SuccIdx : InEdges)
836       EdgeProbabilities[SuccIdx] = Prob;
837   }
838 
839   if (uint32_t numExitingEdges = ExitingEdges.size()) {
840     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
841                                                        Denom);
842     auto Prob = NotTakenProb / numExitingEdges;
843     for (unsigned SuccIdx : ExitingEdges)
844       EdgeProbabilities[SuccIdx] = Prob;
845   }
846 
847   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
848     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
849                                                        Denom);
850     auto Prob = UnlikelyProb / numUnlikelyEdges;
851     for (unsigned SuccIdx : UnlikelyEdges)
852       EdgeProbabilities[SuccIdx] = Prob;
853   }
854 
855   setEdgeProbability(BB, EdgeProbabilities);
856   return true;
857 }
858 
859 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
860                                                const TargetLibraryInfo *TLI) {
861   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
862   if (!BI || !BI->isConditional())
863     return false;
864 
865   Value *Cond = BI->getCondition();
866   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
867   if (!CI)
868     return false;
869 
870   auto GetConstantInt = [](Value *V) {
871     if (auto *I = dyn_cast<BitCastInst>(V))
872       return dyn_cast<ConstantInt>(I->getOperand(0));
873     return dyn_cast<ConstantInt>(V);
874   };
875 
876   Value *RHS = CI->getOperand(1);
877   ConstantInt *CV = GetConstantInt(RHS);
878   if (!CV)
879     return false;
880 
881   // If the LHS is the result of AND'ing a value with a single bit bitmask,
882   // we don't have information about probabilities.
883   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
884     if (LHS->getOpcode() == Instruction::And)
885       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
886         if (AndRHS->getValue().isPowerOf2())
887           return false;
888 
889   // Check if the LHS is the return value of a library function
890   LibFunc Func = NumLibFuncs;
891   if (TLI)
892     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
893       if (Function *CalledFn = Call->getCalledFunction())
894         TLI->getLibFunc(*CalledFn, Func);
895 
896   bool isProb;
897   if (Func == LibFunc_strcasecmp ||
898       Func == LibFunc_strcmp ||
899       Func == LibFunc_strncasecmp ||
900       Func == LibFunc_strncmp ||
901       Func == LibFunc_memcmp ||
902       Func == LibFunc_bcmp) {
903     // strcmp and similar functions return zero, negative, or positive, if the
904     // first string is equal, less, or greater than the second. We consider it
905     // likely that the strings are not equal, so a comparison with zero is
906     // probably false, but also a comparison with any other number is also
907     // probably false given that what exactly is returned for nonzero values is
908     // not specified. Any kind of comparison other than equality we know
909     // nothing about.
910     switch (CI->getPredicate()) {
911     case CmpInst::ICMP_EQ:
912       isProb = false;
913       break;
914     case CmpInst::ICMP_NE:
915       isProb = true;
916       break;
917     default:
918       return false;
919     }
920   } else if (CV->isZero()) {
921     switch (CI->getPredicate()) {
922     case CmpInst::ICMP_EQ:
923       // X == 0   ->  Unlikely
924       isProb = false;
925       break;
926     case CmpInst::ICMP_NE:
927       // X != 0   ->  Likely
928       isProb = true;
929       break;
930     case CmpInst::ICMP_SLT:
931       // X < 0   ->  Unlikely
932       isProb = false;
933       break;
934     case CmpInst::ICMP_SGT:
935       // X > 0   ->  Likely
936       isProb = true;
937       break;
938     default:
939       return false;
940     }
941   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
942     // InstCombine canonicalizes X <= 0 into X < 1.
943     // X <= 0   ->  Unlikely
944     isProb = false;
945   } else if (CV->isMinusOne()) {
946     switch (CI->getPredicate()) {
947     case CmpInst::ICMP_EQ:
948       // X == -1  ->  Unlikely
949       isProb = false;
950       break;
951     case CmpInst::ICMP_NE:
952       // X != -1  ->  Likely
953       isProb = true;
954       break;
955     case CmpInst::ICMP_SGT:
956       // InstCombine canonicalizes X >= 0 into X > -1.
957       // X >= 0   ->  Likely
958       isProb = true;
959       break;
960     default:
961       return false;
962     }
963   } else {
964     return false;
965   }
966 
967   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
968                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
969   BranchProbability UntakenProb(ZH_NONTAKEN_WEIGHT,
970                                 ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
971   if (!isProb)
972     std::swap(TakenProb, UntakenProb);
973 
974   setEdgeProbability(
975       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
976   return true;
977 }
978 
979 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
980   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
981   if (!BI || !BI->isConditional())
982     return false;
983 
984   Value *Cond = BI->getCondition();
985   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
986   if (!FCmp)
987     return false;
988 
989   uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
990   uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
991   bool isProb;
992   if (FCmp->isEquality()) {
993     // f1 == f2 -> Unlikely
994     // f1 != f2 -> Likely
995     isProb = !FCmp->isTrueWhenEqual();
996   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
997     // !isnan -> Likely
998     isProb = true;
999     TakenWeight = FPH_ORD_WEIGHT;
1000     NontakenWeight = FPH_UNO_WEIGHT;
1001   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
1002     // isnan -> Unlikely
1003     isProb = false;
1004     TakenWeight = FPH_ORD_WEIGHT;
1005     NontakenWeight = FPH_UNO_WEIGHT;
1006   } else {
1007     return false;
1008   }
1009 
1010   BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
1011   BranchProbability UntakenProb(NontakenWeight, TakenWeight + NontakenWeight);
1012   if (!isProb)
1013     std::swap(TakenProb, UntakenProb);
1014 
1015   setEdgeProbability(
1016       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
1017   return true;
1018 }
1019 
1020 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
1021   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
1022   if (!II)
1023     return false;
1024 
1025   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
1026                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
1027   setEdgeProbability(
1028       BB, SmallVector<BranchProbability, 2>({TakenProb, TakenProb.getCompl()}));
1029   return true;
1030 }
1031 
1032 void BranchProbabilityInfo::releaseMemory() {
1033   Probs.clear();
1034   Handles.clear();
1035 }
1036 
1037 bool BranchProbabilityInfo::invalidate(Function &, const PreservedAnalyses &PA,
1038                                        FunctionAnalysisManager::Invalidator &) {
1039   // Check whether the analysis, all analyses on functions, or the function's
1040   // CFG have been preserved.
1041   auto PAC = PA.getChecker<BranchProbabilityAnalysis>();
1042   return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
1043            PAC.preservedSet<CFGAnalyses>());
1044 }
1045 
1046 void BranchProbabilityInfo::print(raw_ostream &OS) const {
1047   OS << "---- Branch Probabilities ----\n";
1048   // We print the probabilities from the last function the analysis ran over,
1049   // or the function it is currently running over.
1050   assert(LastF && "Cannot print prior to running over a function");
1051   for (const auto &BI : *LastF) {
1052     for (const_succ_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
1053          ++SI) {
1054       printEdgeProbability(OS << "  ", &BI, *SI);
1055     }
1056   }
1057 }
1058 
1059 bool BranchProbabilityInfo::
1060 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
1061   // Hot probability is at least 4/5 = 80%
1062   // FIXME: Compare against a static "hot" BranchProbability.
1063   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
1064 }
1065 
1066 const BasicBlock *
1067 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
1068   auto MaxProb = BranchProbability::getZero();
1069   const BasicBlock *MaxSucc = nullptr;
1070 
1071   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
1072     const BasicBlock *Succ = *I;
1073     auto Prob = getEdgeProbability(BB, Succ);
1074     if (Prob > MaxProb) {
1075       MaxProb = Prob;
1076       MaxSucc = Succ;
1077     }
1078   }
1079 
1080   // Hot probability is at least 4/5 = 80%
1081   if (MaxProb > BranchProbability(4, 5))
1082     return MaxSucc;
1083 
1084   return nullptr;
1085 }
1086 
1087 /// Get the raw edge probability for the edge. If can't find it, return a
1088 /// default probability 1/N where N is the number of successors. Here an edge is
1089 /// specified using PredBlock and an
1090 /// index to the successors.
1091 BranchProbability
1092 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
1093                                           unsigned IndexInSuccessors) const {
1094   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
1095 
1096   if (I != Probs.end())
1097     return I->second;
1098 
1099   return {1, static_cast<uint32_t>(succ_size(Src))};
1100 }
1101 
1102 BranchProbability
1103 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
1104                                           const_succ_iterator Dst) const {
1105   return getEdgeProbability(Src, Dst.getSuccessorIndex());
1106 }
1107 
1108 /// Get the raw edge probability calculated for the block pair. This returns the
1109 /// sum of all raw edge probabilities from Src to Dst.
1110 BranchProbability
1111 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
1112                                           const BasicBlock *Dst) const {
1113   auto Prob = BranchProbability::getZero();
1114   bool FoundProb = false;
1115   uint32_t EdgeCount = 0;
1116   for (const_succ_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
1117     if (*I == Dst) {
1118       ++EdgeCount;
1119       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
1120       if (MapI != Probs.end()) {
1121         FoundProb = true;
1122         Prob += MapI->second;
1123       }
1124     }
1125   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
1126   return FoundProb ? Prob : BranchProbability(EdgeCount, succ_num);
1127 }
1128 
1129 /// Set the edge probability for a given edge specified by PredBlock and an
1130 /// index to the successors.
1131 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
1132                                                unsigned IndexInSuccessors,
1133                                                BranchProbability Prob) {
1134   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
1135   Handles.insert(BasicBlockCallbackVH(Src, this));
1136   LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
1137                     << IndexInSuccessors << " successor probability to " << Prob
1138                     << "\n");
1139 }
1140 
1141 /// Set the edge probability for all edges at once.
1142 void BranchProbabilityInfo::setEdgeProbability(
1143     const BasicBlock *Src, const SmallVectorImpl<BranchProbability> &Probs) {
1144   assert(Src->getTerminator()->getNumSuccessors() == Probs.size());
1145   if (Probs.size() == 0)
1146     return; // Nothing to set.
1147 
1148   uint64_t TotalNumerator = 0;
1149   for (unsigned SuccIdx = 0; SuccIdx < Probs.size(); ++SuccIdx) {
1150     setEdgeProbability(Src, SuccIdx, Probs[SuccIdx]);
1151     TotalNumerator += Probs[SuccIdx].getNumerator();
1152   }
1153 
1154   // Because of rounding errors the total probability cannot be checked to be
1155   // 1.0 exactly. That is TotalNumerator == BranchProbability::getDenominator.
1156   // Instead, every single probability in Probs must be as accurate as possible.
1157   // This results in error 1/denominator at most, thus the total absolute error
1158   // should be within Probs.size / BranchProbability::getDenominator.
1159   assert(TotalNumerator <= BranchProbability::getDenominator() + Probs.size());
1160   assert(TotalNumerator >= BranchProbability::getDenominator() - Probs.size());
1161 }
1162 
1163 raw_ostream &
1164 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
1165                                             const BasicBlock *Src,
1166                                             const BasicBlock *Dst) const {
1167   const BranchProbability Prob = getEdgeProbability(Src, Dst);
1168   OS << "edge " << Src->getName() << " -> " << Dst->getName()
1169      << " probability is " << Prob
1170      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
1171 
1172   return OS;
1173 }
1174 
1175 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
1176   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
1177     auto MapI = Probs.find(std::make_pair(BB, I.getSuccessorIndex()));
1178     if (MapI != Probs.end())
1179       Probs.erase(MapI);
1180   }
1181 }
1182 
1183 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
1184                                       const TargetLibraryInfo *TLI,
1185                                       PostDominatorTree *PDT) {
1186   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
1187                     << " ----\n\n");
1188   LastF = &F; // Store the last function we ran on for printing.
1189   assert(PostDominatedByUnreachable.empty());
1190   assert(PostDominatedByColdCall.empty());
1191 
1192   SccI = std::make_unique<SccInfo>(F);
1193 
1194   std::unique_ptr<PostDominatorTree> PDTPtr;
1195 
1196   if (!PDT) {
1197     PDTPtr = std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
1198     PDT = PDTPtr.get();
1199   }
1200 
1201   computePostDominatedByUnreachable(F, PDT);
1202   computePostDominatedByColdCall(F, PDT);
1203 
1204   // Walk the basic blocks in post-order so that we can build up state about
1205   // the successors of a block iteratively.
1206   for (auto BB : post_order(&F.getEntryBlock())) {
1207     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
1208                       << "\n");
1209     // If there is no at least two successors, no sense to set probability.
1210     if (BB->getTerminator()->getNumSuccessors() < 2)
1211       continue;
1212     if (calcMetadataWeights(BB))
1213       continue;
1214     if (calcInvokeHeuristics(BB))
1215       continue;
1216     if (calcUnreachableHeuristics(BB))
1217       continue;
1218     if (calcColdCallHeuristics(BB))
1219       continue;
1220     if (calcLoopBranchHeuristics(BB, LI))
1221       continue;
1222     if (calcPointerHeuristics(BB))
1223       continue;
1224     if (calcZeroHeuristics(BB, TLI))
1225       continue;
1226     if (calcFloatingPointHeuristics(BB))
1227       continue;
1228   }
1229 
1230   PostDominatedByUnreachable.clear();
1231   PostDominatedByColdCall.clear();
1232   SccI.reset();
1233 
1234   if (PrintBranchProb &&
1235       (PrintBranchProbFuncName.empty() ||
1236        F.getName().equals(PrintBranchProbFuncName))) {
1237     print(dbgs());
1238   }
1239 }
1240 
1241 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1242     AnalysisUsage &AU) const {
1243   // We require DT so it's available when LI is available. The LI updating code
1244   // asserts that DT is also present so if we don't make sure that we have DT
1245   // here, that assert will trigger.
1246   AU.addRequired<DominatorTreeWrapperPass>();
1247   AU.addRequired<LoopInfoWrapperPass>();
1248   AU.addRequired<TargetLibraryInfoWrapperPass>();
1249   AU.addRequired<PostDominatorTreeWrapperPass>();
1250   AU.setPreservesAll();
1251 }
1252 
1253 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1254   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1255   const TargetLibraryInfo &TLI =
1256       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1257   PostDominatorTree &PDT =
1258       getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
1259   BPI.calculate(F, LI, &TLI, &PDT);
1260   return false;
1261 }
1262 
1263 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1264 
1265 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1266                                              const Module *) const {
1267   BPI.print(OS);
1268 }
1269 
1270 AnalysisKey BranchProbabilityAnalysis::Key;
1271 BranchProbabilityInfo
1272 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1273   BranchProbabilityInfo BPI;
1274   BPI.calculate(F, AM.getResult<LoopAnalysis>(F),
1275                 &AM.getResult<TargetLibraryAnalysis>(F),
1276                 &AM.getResult<PostDominatorTreeAnalysis>(F));
1277   return BPI;
1278 }
1279 
1280 PreservedAnalyses
1281 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1282   OS << "Printing analysis results of BPI for function "
1283      << "'" << F.getName() << "':"
1284      << "\n";
1285   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1286   return PreservedAnalyses::all();
1287 }
1288