1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Loops should be simplified before this analysis.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/BranchProbabilityInfo.h"
15 #include "llvm/ADT/PostOrderIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/CFG.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/LLVMContext.h"
29 #include "llvm/IR/Metadata.h"
30 #include "llvm/IR/PassManager.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/BranchProbability.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cassert>
39 #include <cstdint>
40 #include <iterator>
41 #include <utility>
42 
43 using namespace llvm;
44 
45 #define DEBUG_TYPE "branch-prob"
46 
47 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
48                       "Branch Probability Analysis", false, true)
49 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
50 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
51 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
52                     "Branch Probability Analysis", false, true)
53 
54 char BranchProbabilityInfoWrapperPass::ID = 0;
55 
56 // Weights are for internal use only. They are used by heuristics to help to
57 // estimate edges' probability. Example:
58 //
59 // Using "Loop Branch Heuristics" we predict weights of edges for the
60 // block BB2.
61 //         ...
62 //          |
63 //          V
64 //         BB1<-+
65 //          |   |
66 //          |   | (Weight = 124)
67 //          V   |
68 //         BB2--+
69 //          |
70 //          | (Weight = 4)
71 //          V
72 //         BB3
73 //
74 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
75 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
76 static const uint32_t LBH_TAKEN_WEIGHT = 124;
77 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
78 
79 /// \brief Unreachable-terminating branch taken probability.
80 ///
81 /// This is the probability for a branch being taken to a block that terminates
82 /// (eventually) in unreachable. These are predicted as unlikely as possible.
83 /// All reachable probability will equally share the remaining part.
84 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
85 
86 /// \brief Weight for a branch taken going into a cold block.
87 ///
88 /// This is the weight for a branch taken toward a block marked
89 /// cold.  A block is marked cold if it's postdominated by a
90 /// block containing a call to a cold function.  Cold functions
91 /// are those marked with attribute 'cold'.
92 static const uint32_t CC_TAKEN_WEIGHT = 4;
93 
94 /// \brief Weight for a branch not-taken into a cold block.
95 ///
96 /// This is the weight for a branch not taken toward a block marked
97 /// cold.
98 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
99 
100 static const uint32_t PH_TAKEN_WEIGHT = 20;
101 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
102 
103 static const uint32_t ZH_TAKEN_WEIGHT = 20;
104 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
105 
106 static const uint32_t FPH_TAKEN_WEIGHT = 20;
107 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
108 
109 /// \brief Invoke-terminating normal branch taken weight
110 ///
111 /// This is the weight for branching to the normal destination of an invoke
112 /// instruction. We expect this to happen most of the time. Set the weight to an
113 /// absurdly high value so that nested loops subsume it.
114 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
115 
116 /// \brief Invoke-terminating normal branch not-taken weight.
117 ///
118 /// This is the weight for branching to the unwind destination of an invoke
119 /// instruction. This is essentially never taken.
120 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
121 
122 /// \brief Add \p BB to PostDominatedByUnreachable set if applicable.
123 void
124 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
125   const TerminatorInst *TI = BB->getTerminator();
126   if (TI->getNumSuccessors() == 0) {
127     if (isa<UnreachableInst>(TI) ||
128         // If this block is terminated by a call to
129         // @llvm.experimental.deoptimize then treat it like an unreachable since
130         // the @llvm.experimental.deoptimize call is expected to practically
131         // never execute.
132         BB->getTerminatingDeoptimizeCall())
133       PostDominatedByUnreachable.insert(BB);
134     return;
135   }
136 
137   // If the terminator is an InvokeInst, check only the normal destination block
138   // as the unwind edge of InvokeInst is also very unlikely taken.
139   if (auto *II = dyn_cast<InvokeInst>(TI)) {
140     if (PostDominatedByUnreachable.count(II->getNormalDest()))
141       PostDominatedByUnreachable.insert(BB);
142     return;
143   }
144 
145   for (auto *I : successors(BB))
146     // If any of successor is not post dominated then BB is also not.
147     if (!PostDominatedByUnreachable.count(I))
148       return;
149 
150   PostDominatedByUnreachable.insert(BB);
151 }
152 
153 /// \brief Add \p BB to PostDominatedByColdCall set if applicable.
154 void
155 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
156   assert(!PostDominatedByColdCall.count(BB));
157   const TerminatorInst *TI = BB->getTerminator();
158   if (TI->getNumSuccessors() == 0)
159     return;
160 
161   // If all of successor are post dominated then BB is also done.
162   if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
163         return PostDominatedByColdCall.count(SuccBB);
164       })) {
165     PostDominatedByColdCall.insert(BB);
166     return;
167   }
168 
169   // If the terminator is an InvokeInst, check only the normal destination
170   // block as the unwind edge of InvokeInst is also very unlikely taken.
171   if (auto *II = dyn_cast<InvokeInst>(TI))
172     if (PostDominatedByColdCall.count(II->getNormalDest())) {
173       PostDominatedByColdCall.insert(BB);
174       return;
175     }
176 
177   // Otherwise, if the block itself contains a cold function, add it to the
178   // set of blocks post-dominated by a cold call.
179   for (auto &I : *BB)
180     if (const CallInst *CI = dyn_cast<CallInst>(&I))
181       if (CI->hasFnAttr(Attribute::Cold)) {
182         PostDominatedByColdCall.insert(BB);
183         return;
184       }
185 }
186 
187 /// \brief Calculate edge weights for successors lead to unreachable.
188 ///
189 /// Predict that a successor which leads necessarily to an
190 /// unreachable-terminated block as extremely unlikely.
191 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
192   const TerminatorInst *TI = BB->getTerminator();
193   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
194 
195   // Return false here so that edge weights for InvokeInst could be decided
196   // in calcInvokeHeuristics().
197   if (isa<InvokeInst>(TI))
198     return false;
199 
200   SmallVector<unsigned, 4> UnreachableEdges;
201   SmallVector<unsigned, 4> ReachableEdges;
202 
203   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
204     if (PostDominatedByUnreachable.count(*I))
205       UnreachableEdges.push_back(I.getSuccessorIndex());
206     else
207       ReachableEdges.push_back(I.getSuccessorIndex());
208 
209   // Skip probabilities if all were reachable.
210   if (UnreachableEdges.empty())
211     return false;
212 
213   if (ReachableEdges.empty()) {
214     BranchProbability Prob(1, UnreachableEdges.size());
215     for (unsigned SuccIdx : UnreachableEdges)
216       setEdgeProbability(BB, SuccIdx, Prob);
217     return true;
218   }
219 
220   auto UnreachableProb = UR_TAKEN_PROB;
221   auto ReachableProb =
222       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
223       ReachableEdges.size();
224 
225   for (unsigned SuccIdx : UnreachableEdges)
226     setEdgeProbability(BB, SuccIdx, UnreachableProb);
227   for (unsigned SuccIdx : ReachableEdges)
228     setEdgeProbability(BB, SuccIdx, ReachableProb);
229 
230   return true;
231 }
232 
233 // Propagate existing explicit probabilities from either profile data or
234 // 'expect' intrinsic processing. Examine metadata against unreachable
235 // heuristic. The probability of the edge coming to unreachable block is
236 // set to min of metadata and unreachable heuristic.
237 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
238   const TerminatorInst *TI = BB->getTerminator();
239   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
240   if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI))
241     return false;
242 
243   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
244   if (!WeightsNode)
245     return false;
246 
247   // Check that the number of successors is manageable.
248   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
249 
250   // Ensure there are weights for all of the successors. Note that the first
251   // operand to the metadata node is a name, not a weight.
252   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
253     return false;
254 
255   // Build up the final weights that will be used in a temporary buffer.
256   // Compute the sum of all weights to later decide whether they need to
257   // be scaled to fit in 32 bits.
258   uint64_t WeightSum = 0;
259   SmallVector<uint32_t, 2> Weights;
260   SmallVector<unsigned, 2> UnreachableIdxs;
261   SmallVector<unsigned, 2> ReachableIdxs;
262   Weights.reserve(TI->getNumSuccessors());
263   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
264     ConstantInt *Weight =
265         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
266     if (!Weight)
267       return false;
268     assert(Weight->getValue().getActiveBits() <= 32 &&
269            "Too many bits for uint32_t");
270     Weights.push_back(Weight->getZExtValue());
271     WeightSum += Weights.back();
272     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
273       UnreachableIdxs.push_back(i - 1);
274     else
275       ReachableIdxs.push_back(i - 1);
276   }
277   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
278 
279   // If the sum of weights does not fit in 32 bits, scale every weight down
280   // accordingly.
281   uint64_t ScalingFactor =
282       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
283 
284   if (ScalingFactor > 1) {
285     WeightSum = 0;
286     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
287       Weights[i] /= ScalingFactor;
288       WeightSum += Weights[i];
289     }
290   }
291   assert(WeightSum <= UINT32_MAX &&
292          "Expected weights to scale down to 32 bits");
293 
294   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
295     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
296       Weights[i] = 1;
297     WeightSum = TI->getNumSuccessors();
298   }
299 
300   // Set the probability.
301   SmallVector<BranchProbability, 2> BP;
302   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
303     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
304 
305   // Examine the metadata against unreachable heuristic.
306   // If the unreachable heuristic is more strong then we use it for this edge.
307   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
308     auto ToDistribute = BranchProbability::getZero();
309     auto UnreachableProb = UR_TAKEN_PROB;
310     for (auto i : UnreachableIdxs)
311       if (UnreachableProb < BP[i]) {
312         ToDistribute += BP[i] - UnreachableProb;
313         BP[i] = UnreachableProb;
314       }
315 
316     // If we modified the probability of some edges then we must distribute
317     // the difference between reachable blocks.
318     if (ToDistribute > BranchProbability::getZero()) {
319       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
320       for (auto i : ReachableIdxs)
321         BP[i] += PerEdge;
322     }
323   }
324 
325   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
326     setEdgeProbability(BB, i, BP[i]);
327 
328   return true;
329 }
330 
331 /// \brief Calculate edge weights for edges leading to cold blocks.
332 ///
333 /// A cold block is one post-dominated by  a block with a call to a
334 /// cold function.  Those edges are unlikely to be taken, so we give
335 /// them relatively low weight.
336 ///
337 /// Return true if we could compute the weights for cold edges.
338 /// Return false, otherwise.
339 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
340   const TerminatorInst *TI = BB->getTerminator();
341   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
342 
343   // Return false here so that edge weights for InvokeInst could be decided
344   // in calcInvokeHeuristics().
345   if (isa<InvokeInst>(TI))
346     return false;
347 
348   // Determine which successors are post-dominated by a cold block.
349   SmallVector<unsigned, 4> ColdEdges;
350   SmallVector<unsigned, 4> NormalEdges;
351   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
352     if (PostDominatedByColdCall.count(*I))
353       ColdEdges.push_back(I.getSuccessorIndex());
354     else
355       NormalEdges.push_back(I.getSuccessorIndex());
356 
357   // Skip probabilities if no cold edges.
358   if (ColdEdges.empty())
359     return false;
360 
361   if (NormalEdges.empty()) {
362     BranchProbability Prob(1, ColdEdges.size());
363     for (unsigned SuccIdx : ColdEdges)
364       setEdgeProbability(BB, SuccIdx, Prob);
365     return true;
366   }
367 
368   auto ColdProb = BranchProbability::getBranchProbability(
369       CC_TAKEN_WEIGHT,
370       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
371   auto NormalProb = BranchProbability::getBranchProbability(
372       CC_NONTAKEN_WEIGHT,
373       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
374 
375   for (unsigned SuccIdx : ColdEdges)
376     setEdgeProbability(BB, SuccIdx, ColdProb);
377   for (unsigned SuccIdx : NormalEdges)
378     setEdgeProbability(BB, SuccIdx, NormalProb);
379 
380   return true;
381 }
382 
383 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparsion
384 // between two pointer or pointer and NULL will fail.
385 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
386   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
387   if (!BI || !BI->isConditional())
388     return false;
389 
390   Value *Cond = BI->getCondition();
391   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
392   if (!CI || !CI->isEquality())
393     return false;
394 
395   Value *LHS = CI->getOperand(0);
396 
397   if (!LHS->getType()->isPointerTy())
398     return false;
399 
400   assert(CI->getOperand(1)->getType()->isPointerTy());
401 
402   // p != 0   ->   isProb = true
403   // p == 0   ->   isProb = false
404   // p != q   ->   isProb = true
405   // p == q   ->   isProb = false;
406   unsigned TakenIdx = 0, NonTakenIdx = 1;
407   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
408   if (!isProb)
409     std::swap(TakenIdx, NonTakenIdx);
410 
411   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
412                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
413   setEdgeProbability(BB, TakenIdx, TakenProb);
414   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
415   return true;
416 }
417 
418 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
419 // as taken, exiting edges as not-taken.
420 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
421                                                      const LoopInfo &LI) {
422   Loop *L = LI.getLoopFor(BB);
423   if (!L)
424     return false;
425 
426   SmallVector<unsigned, 8> BackEdges;
427   SmallVector<unsigned, 8> ExitingEdges;
428   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
429 
430   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
431     if (!L->contains(*I))
432       ExitingEdges.push_back(I.getSuccessorIndex());
433     else if (L->getHeader() == *I)
434       BackEdges.push_back(I.getSuccessorIndex());
435     else
436       InEdges.push_back(I.getSuccessorIndex());
437   }
438 
439   if (BackEdges.empty() && ExitingEdges.empty())
440     return false;
441 
442   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
443   // normalize them so that they sum up to one.
444   BranchProbability Probs[] = {BranchProbability::getZero(),
445                                BranchProbability::getZero(),
446                                BranchProbability::getZero()};
447   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
448                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
449                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
450   if (!BackEdges.empty())
451     Probs[0] = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
452   if (!InEdges.empty())
453     Probs[1] = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
454   if (!ExitingEdges.empty())
455     Probs[2] = BranchProbability(LBH_NONTAKEN_WEIGHT, Denom);
456 
457   if (uint32_t numBackEdges = BackEdges.size()) {
458     auto Prob = Probs[0] / numBackEdges;
459     for (unsigned SuccIdx : BackEdges)
460       setEdgeProbability(BB, SuccIdx, Prob);
461   }
462 
463   if (uint32_t numInEdges = InEdges.size()) {
464     auto Prob = Probs[1] / numInEdges;
465     for (unsigned SuccIdx : InEdges)
466       setEdgeProbability(BB, SuccIdx, Prob);
467   }
468 
469   if (uint32_t numExitingEdges = ExitingEdges.size()) {
470     auto Prob = Probs[2] / numExitingEdges;
471     for (unsigned SuccIdx : ExitingEdges)
472       setEdgeProbability(BB, SuccIdx, Prob);
473   }
474 
475   return true;
476 }
477 
478 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
479                                                const TargetLibraryInfo *TLI) {
480   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
481   if (!BI || !BI->isConditional())
482     return false;
483 
484   Value *Cond = BI->getCondition();
485   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
486   if (!CI)
487     return false;
488 
489   Value *RHS = CI->getOperand(1);
490   ConstantInt *CV = dyn_cast<ConstantInt>(RHS);
491   if (!CV)
492     return false;
493 
494   // If the LHS is the result of AND'ing a value with a single bit bitmask,
495   // we don't have information about probabilities.
496   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
497     if (LHS->getOpcode() == Instruction::And)
498       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
499         if (AndRHS->getUniqueInteger().isPowerOf2())
500           return false;
501 
502   // Check if the LHS is the return value of a library function
503   LibFunc Func = NumLibFuncs;
504   if (TLI)
505     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
506       if (Function *CalledFn = Call->getCalledFunction())
507         TLI->getLibFunc(*CalledFn, Func);
508 
509   bool isProb;
510   if (Func == LibFunc_strcasecmp ||
511       Func == LibFunc_strcmp ||
512       Func == LibFunc_strncasecmp ||
513       Func == LibFunc_strncmp ||
514       Func == LibFunc_memcmp) {
515     // strcmp and similar functions return zero, negative, or positive, if the
516     // first string is equal, less, or greater than the second. We consider it
517     // likely that the strings are not equal, so a comparison with zero is
518     // probably false, but also a comparison with any other number is also
519     // probably false given that what exactly is returned for nonzero values is
520     // not specified. Any kind of comparison other than equality we know
521     // nothing about.
522     switch (CI->getPredicate()) {
523     case CmpInst::ICMP_EQ:
524       isProb = false;
525       break;
526     case CmpInst::ICMP_NE:
527       isProb = true;
528       break;
529     default:
530       return false;
531     }
532   } else if (CV->isZero()) {
533     switch (CI->getPredicate()) {
534     case CmpInst::ICMP_EQ:
535       // X == 0   ->  Unlikely
536       isProb = false;
537       break;
538     case CmpInst::ICMP_NE:
539       // X != 0   ->  Likely
540       isProb = true;
541       break;
542     case CmpInst::ICMP_SLT:
543       // X < 0   ->  Unlikely
544       isProb = false;
545       break;
546     case CmpInst::ICMP_SGT:
547       // X > 0   ->  Likely
548       isProb = true;
549       break;
550     default:
551       return false;
552     }
553   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
554     // InstCombine canonicalizes X <= 0 into X < 1.
555     // X <= 0   ->  Unlikely
556     isProb = false;
557   } else if (CV->isMinusOne()) {
558     switch (CI->getPredicate()) {
559     case CmpInst::ICMP_EQ:
560       // X == -1  ->  Unlikely
561       isProb = false;
562       break;
563     case CmpInst::ICMP_NE:
564       // X != -1  ->  Likely
565       isProb = true;
566       break;
567     case CmpInst::ICMP_SGT:
568       // InstCombine canonicalizes X >= 0 into X > -1.
569       // X >= 0   ->  Likely
570       isProb = true;
571       break;
572     default:
573       return false;
574     }
575   } else {
576     return false;
577   }
578 
579   unsigned TakenIdx = 0, NonTakenIdx = 1;
580 
581   if (!isProb)
582     std::swap(TakenIdx, NonTakenIdx);
583 
584   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
585                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
586   setEdgeProbability(BB, TakenIdx, TakenProb);
587   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
588   return true;
589 }
590 
591 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
592   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
593   if (!BI || !BI->isConditional())
594     return false;
595 
596   Value *Cond = BI->getCondition();
597   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
598   if (!FCmp)
599     return false;
600 
601   bool isProb;
602   if (FCmp->isEquality()) {
603     // f1 == f2 -> Unlikely
604     // f1 != f2 -> Likely
605     isProb = !FCmp->isTrueWhenEqual();
606   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
607     // !isnan -> Likely
608     isProb = true;
609   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
610     // isnan -> Unlikely
611     isProb = false;
612   } else {
613     return false;
614   }
615 
616   unsigned TakenIdx = 0, NonTakenIdx = 1;
617 
618   if (!isProb)
619     std::swap(TakenIdx, NonTakenIdx);
620 
621   BranchProbability TakenProb(FPH_TAKEN_WEIGHT,
622                               FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
623   setEdgeProbability(BB, TakenIdx, TakenProb);
624   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
625   return true;
626 }
627 
628 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
629   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
630   if (!II)
631     return false;
632 
633   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
634                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
635   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
636   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
637   return true;
638 }
639 
640 void BranchProbabilityInfo::releaseMemory() {
641   Probs.clear();
642 }
643 
644 void BranchProbabilityInfo::print(raw_ostream &OS) const {
645   OS << "---- Branch Probabilities ----\n";
646   // We print the probabilities from the last function the analysis ran over,
647   // or the function it is currently running over.
648   assert(LastF && "Cannot print prior to running over a function");
649   for (const auto &BI : *LastF) {
650     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
651          ++SI) {
652       printEdgeProbability(OS << "  ", &BI, *SI);
653     }
654   }
655 }
656 
657 bool BranchProbabilityInfo::
658 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
659   // Hot probability is at least 4/5 = 80%
660   // FIXME: Compare against a static "hot" BranchProbability.
661   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
662 }
663 
664 const BasicBlock *
665 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
666   auto MaxProb = BranchProbability::getZero();
667   const BasicBlock *MaxSucc = nullptr;
668 
669   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
670     const BasicBlock *Succ = *I;
671     auto Prob = getEdgeProbability(BB, Succ);
672     if (Prob > MaxProb) {
673       MaxProb = Prob;
674       MaxSucc = Succ;
675     }
676   }
677 
678   // Hot probability is at least 4/5 = 80%
679   if (MaxProb > BranchProbability(4, 5))
680     return MaxSucc;
681 
682   return nullptr;
683 }
684 
685 /// Get the raw edge probability for the edge. If can't find it, return a
686 /// default probability 1/N where N is the number of successors. Here an edge is
687 /// specified using PredBlock and an
688 /// index to the successors.
689 BranchProbability
690 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
691                                           unsigned IndexInSuccessors) const {
692   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
693 
694   if (I != Probs.end())
695     return I->second;
696 
697   return {1,
698           static_cast<uint32_t>(std::distance(succ_begin(Src), succ_end(Src)))};
699 }
700 
701 BranchProbability
702 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
703                                           succ_const_iterator Dst) const {
704   return getEdgeProbability(Src, Dst.getSuccessorIndex());
705 }
706 
707 /// Get the raw edge probability calculated for the block pair. This returns the
708 /// sum of all raw edge probabilities from Src to Dst.
709 BranchProbability
710 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
711                                           const BasicBlock *Dst) const {
712   auto Prob = BranchProbability::getZero();
713   bool FoundProb = false;
714   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
715     if (*I == Dst) {
716       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
717       if (MapI != Probs.end()) {
718         FoundProb = true;
719         Prob += MapI->second;
720       }
721     }
722   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
723   return FoundProb ? Prob : BranchProbability(1, succ_num);
724 }
725 
726 /// Set the edge probability for a given edge specified by PredBlock and an
727 /// index to the successors.
728 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
729                                                unsigned IndexInSuccessors,
730                                                BranchProbability Prob) {
731   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
732   Handles.insert(BasicBlockCallbackVH(Src, this));
733   DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << IndexInSuccessors
734                << " successor probability to " << Prob << "\n");
735 }
736 
737 raw_ostream &
738 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
739                                             const BasicBlock *Src,
740                                             const BasicBlock *Dst) const {
741   const BranchProbability Prob = getEdgeProbability(Src, Dst);
742   OS << "edge " << Src->getName() << " -> " << Dst->getName()
743      << " probability is " << Prob
744      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
745 
746   return OS;
747 }
748 
749 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
750   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
751     auto Key = I->first;
752     if (Key.first == BB)
753       Probs.erase(Key);
754   }
755 }
756 
757 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
758                                       const TargetLibraryInfo *TLI) {
759   DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
760                << " ----\n\n");
761   LastF = &F; // Store the last function we ran on for printing.
762   assert(PostDominatedByUnreachable.empty());
763   assert(PostDominatedByColdCall.empty());
764 
765   // Walk the basic blocks in post-order so that we can build up state about
766   // the successors of a block iteratively.
767   for (auto BB : post_order(&F.getEntryBlock())) {
768     DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n");
769     updatePostDominatedByUnreachable(BB);
770     updatePostDominatedByColdCall(BB);
771     // If there is no at least two successors, no sense to set probability.
772     if (BB->getTerminator()->getNumSuccessors() < 2)
773       continue;
774     if (calcMetadataWeights(BB))
775       continue;
776     if (calcUnreachableHeuristics(BB))
777       continue;
778     if (calcColdCallHeuristics(BB))
779       continue;
780     if (calcLoopBranchHeuristics(BB, LI))
781       continue;
782     if (calcPointerHeuristics(BB))
783       continue;
784     if (calcZeroHeuristics(BB, TLI))
785       continue;
786     if (calcFloatingPointHeuristics(BB))
787       continue;
788     calcInvokeHeuristics(BB);
789   }
790 
791   PostDominatedByUnreachable.clear();
792   PostDominatedByColdCall.clear();
793 }
794 
795 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
796     AnalysisUsage &AU) const {
797   AU.addRequired<LoopInfoWrapperPass>();
798   AU.addRequired<TargetLibraryInfoWrapperPass>();
799   AU.setPreservesAll();
800 }
801 
802 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
803   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
804   const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
805   BPI.calculate(F, LI, &TLI);
806   return false;
807 }
808 
809 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
810 
811 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
812                                              const Module *) const {
813   BPI.print(OS);
814 }
815 
816 AnalysisKey BranchProbabilityAnalysis::Key;
817 BranchProbabilityInfo
818 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
819   BranchProbabilityInfo BPI;
820   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
821   return BPI;
822 }
823 
824 PreservedAnalyses
825 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
826   OS << "Printing analysis results of BPI for function "
827      << "'" << F.getName() << "':"
828      << "\n";
829   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
830   return PreservedAnalyses::all();
831 }
832