1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Loops should be simplified before this analysis.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/BranchProbabilityInfo.h"
15 #include "llvm/ADT/PostOrderIterator.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/Analysis/TargetLibraryInfo.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/Metadata.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/BranchProbability.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <cstdint>
42 #include <iterator>
43 #include <utility>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "branch-prob"
48 
49 static cl::opt<bool> PrintBranchProb(
50     "print-bpi", cl::init(false), cl::Hidden,
51     cl::desc("Print the branch probability info."));
52 
53 cl::opt<std::string> PrintBranchProbFuncName(
54     "print-bpi-func-name", cl::Hidden,
55     cl::desc("The option to specify the name of the function "
56              "whose branch probability info is printed."));
57 
58 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
59                       "Branch Probability Analysis", false, true)
60 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
61 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
62 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
63                     "Branch Probability Analysis", false, true)
64 
65 char BranchProbabilityInfoWrapperPass::ID = 0;
66 
67 // Weights are for internal use only. They are used by heuristics to help to
68 // estimate edges' probability. Example:
69 //
70 // Using "Loop Branch Heuristics" we predict weights of edges for the
71 // block BB2.
72 //         ...
73 //          |
74 //          V
75 //         BB1<-+
76 //          |   |
77 //          |   | (Weight = 124)
78 //          V   |
79 //         BB2--+
80 //          |
81 //          | (Weight = 4)
82 //          V
83 //         BB3
84 //
85 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
86 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
87 static const uint32_t LBH_TAKEN_WEIGHT = 124;
88 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
89 // Unlikely edges within a loop are half as likely as other edges
90 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
91 
92 /// Unreachable-terminating branch taken probability.
93 ///
94 /// This is the probability for a branch being taken to a block that terminates
95 /// (eventually) in unreachable. These are predicted as unlikely as possible.
96 /// All reachable probability will equally share the remaining part.
97 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
98 
99 /// Weight for a branch taken going into a cold block.
100 ///
101 /// This is the weight for a branch taken toward a block marked
102 /// cold.  A block is marked cold if it's postdominated by a
103 /// block containing a call to a cold function.  Cold functions
104 /// are those marked with attribute 'cold'.
105 static const uint32_t CC_TAKEN_WEIGHT = 4;
106 
107 /// Weight for a branch not-taken into a cold block.
108 ///
109 /// This is the weight for a branch not taken toward a block marked
110 /// cold.
111 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
112 
113 static const uint32_t PH_TAKEN_WEIGHT = 20;
114 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
115 
116 static const uint32_t ZH_TAKEN_WEIGHT = 20;
117 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
118 
119 static const uint32_t FPH_TAKEN_WEIGHT = 20;
120 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
121 
122 /// Invoke-terminating normal branch taken weight
123 ///
124 /// This is the weight for branching to the normal destination of an invoke
125 /// instruction. We expect this to happen most of the time. Set the weight to an
126 /// absurdly high value so that nested loops subsume it.
127 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
128 
129 /// Invoke-terminating normal branch not-taken weight.
130 ///
131 /// This is the weight for branching to the unwind destination of an invoke
132 /// instruction. This is essentially never taken.
133 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
134 
135 /// Add \p BB to PostDominatedByUnreachable set if applicable.
136 void
137 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
138   const TerminatorInst *TI = BB->getTerminator();
139   if (TI->getNumSuccessors() == 0) {
140     if (isa<UnreachableInst>(TI) ||
141         // If this block is terminated by a call to
142         // @llvm.experimental.deoptimize then treat it like an unreachable since
143         // the @llvm.experimental.deoptimize call is expected to practically
144         // never execute.
145         BB->getTerminatingDeoptimizeCall())
146       PostDominatedByUnreachable.insert(BB);
147     return;
148   }
149 
150   // If the terminator is an InvokeInst, check only the normal destination block
151   // as the unwind edge of InvokeInst is also very unlikely taken.
152   if (auto *II = dyn_cast<InvokeInst>(TI)) {
153     if (PostDominatedByUnreachable.count(II->getNormalDest()))
154       PostDominatedByUnreachable.insert(BB);
155     return;
156   }
157 
158   for (auto *I : successors(BB))
159     // If any of successor is not post dominated then BB is also not.
160     if (!PostDominatedByUnreachable.count(I))
161       return;
162 
163   PostDominatedByUnreachable.insert(BB);
164 }
165 
166 /// Add \p BB to PostDominatedByColdCall set if applicable.
167 void
168 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
169   assert(!PostDominatedByColdCall.count(BB));
170   const TerminatorInst *TI = BB->getTerminator();
171   if (TI->getNumSuccessors() == 0)
172     return;
173 
174   // If all of successor are post dominated then BB is also done.
175   if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
176         return PostDominatedByColdCall.count(SuccBB);
177       })) {
178     PostDominatedByColdCall.insert(BB);
179     return;
180   }
181 
182   // If the terminator is an InvokeInst, check only the normal destination
183   // block as the unwind edge of InvokeInst is also very unlikely taken.
184   if (auto *II = dyn_cast<InvokeInst>(TI))
185     if (PostDominatedByColdCall.count(II->getNormalDest())) {
186       PostDominatedByColdCall.insert(BB);
187       return;
188     }
189 
190   // Otherwise, if the block itself contains a cold function, add it to the
191   // set of blocks post-dominated by a cold call.
192   for (auto &I : *BB)
193     if (const CallInst *CI = dyn_cast<CallInst>(&I))
194       if (CI->hasFnAttr(Attribute::Cold)) {
195         PostDominatedByColdCall.insert(BB);
196         return;
197       }
198 }
199 
200 /// Calculate edge weights for successors lead to unreachable.
201 ///
202 /// Predict that a successor which leads necessarily to an
203 /// unreachable-terminated block as extremely unlikely.
204 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
205   const TerminatorInst *TI = BB->getTerminator();
206   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
207 
208   // Return false here so that edge weights for InvokeInst could be decided
209   // in calcInvokeHeuristics().
210   if (isa<InvokeInst>(TI))
211     return false;
212 
213   SmallVector<unsigned, 4> UnreachableEdges;
214   SmallVector<unsigned, 4> ReachableEdges;
215 
216   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
217     if (PostDominatedByUnreachable.count(*I))
218       UnreachableEdges.push_back(I.getSuccessorIndex());
219     else
220       ReachableEdges.push_back(I.getSuccessorIndex());
221 
222   // Skip probabilities if all were reachable.
223   if (UnreachableEdges.empty())
224     return false;
225 
226   if (ReachableEdges.empty()) {
227     BranchProbability Prob(1, UnreachableEdges.size());
228     for (unsigned SuccIdx : UnreachableEdges)
229       setEdgeProbability(BB, SuccIdx, Prob);
230     return true;
231   }
232 
233   auto UnreachableProb = UR_TAKEN_PROB;
234   auto ReachableProb =
235       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
236       ReachableEdges.size();
237 
238   for (unsigned SuccIdx : UnreachableEdges)
239     setEdgeProbability(BB, SuccIdx, UnreachableProb);
240   for (unsigned SuccIdx : ReachableEdges)
241     setEdgeProbability(BB, SuccIdx, ReachableProb);
242 
243   return true;
244 }
245 
246 // Propagate existing explicit probabilities from either profile data or
247 // 'expect' intrinsic processing. Examine metadata against unreachable
248 // heuristic. The probability of the edge coming to unreachable block is
249 // set to min of metadata and unreachable heuristic.
250 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
251   const TerminatorInst *TI = BB->getTerminator();
252   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
253   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
254     return false;
255 
256   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
257   if (!WeightsNode)
258     return false;
259 
260   // Check that the number of successors is manageable.
261   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
262 
263   // Ensure there are weights for all of the successors. Note that the first
264   // operand to the metadata node is a name, not a weight.
265   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
266     return false;
267 
268   // Build up the final weights that will be used in a temporary buffer.
269   // Compute the sum of all weights to later decide whether they need to
270   // be scaled to fit in 32 bits.
271   uint64_t WeightSum = 0;
272   SmallVector<uint32_t, 2> Weights;
273   SmallVector<unsigned, 2> UnreachableIdxs;
274   SmallVector<unsigned, 2> ReachableIdxs;
275   Weights.reserve(TI->getNumSuccessors());
276   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
277     ConstantInt *Weight =
278         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
279     if (!Weight)
280       return false;
281     assert(Weight->getValue().getActiveBits() <= 32 &&
282            "Too many bits for uint32_t");
283     Weights.push_back(Weight->getZExtValue());
284     WeightSum += Weights.back();
285     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
286       UnreachableIdxs.push_back(i - 1);
287     else
288       ReachableIdxs.push_back(i - 1);
289   }
290   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
291 
292   // If the sum of weights does not fit in 32 bits, scale every weight down
293   // accordingly.
294   uint64_t ScalingFactor =
295       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
296 
297   if (ScalingFactor > 1) {
298     WeightSum = 0;
299     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
300       Weights[i] /= ScalingFactor;
301       WeightSum += Weights[i];
302     }
303   }
304   assert(WeightSum <= UINT32_MAX &&
305          "Expected weights to scale down to 32 bits");
306 
307   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
308     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
309       Weights[i] = 1;
310     WeightSum = TI->getNumSuccessors();
311   }
312 
313   // Set the probability.
314   SmallVector<BranchProbability, 2> BP;
315   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
316     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
317 
318   // Examine the metadata against unreachable heuristic.
319   // If the unreachable heuristic is more strong then we use it for this edge.
320   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
321     auto ToDistribute = BranchProbability::getZero();
322     auto UnreachableProb = UR_TAKEN_PROB;
323     for (auto i : UnreachableIdxs)
324       if (UnreachableProb < BP[i]) {
325         ToDistribute += BP[i] - UnreachableProb;
326         BP[i] = UnreachableProb;
327       }
328 
329     // If we modified the probability of some edges then we must distribute
330     // the difference between reachable blocks.
331     if (ToDistribute > BranchProbability::getZero()) {
332       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
333       for (auto i : ReachableIdxs)
334         BP[i] += PerEdge;
335     }
336   }
337 
338   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
339     setEdgeProbability(BB, i, BP[i]);
340 
341   return true;
342 }
343 
344 /// Calculate edge weights for edges leading to cold blocks.
345 ///
346 /// A cold block is one post-dominated by  a block with a call to a
347 /// cold function.  Those edges are unlikely to be taken, so we give
348 /// them relatively low weight.
349 ///
350 /// Return true if we could compute the weights for cold edges.
351 /// Return false, otherwise.
352 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
353   const TerminatorInst *TI = BB->getTerminator();
354   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
355 
356   // Return false here so that edge weights for InvokeInst could be decided
357   // in calcInvokeHeuristics().
358   if (isa<InvokeInst>(TI))
359     return false;
360 
361   // Determine which successors are post-dominated by a cold block.
362   SmallVector<unsigned, 4> ColdEdges;
363   SmallVector<unsigned, 4> NormalEdges;
364   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
365     if (PostDominatedByColdCall.count(*I))
366       ColdEdges.push_back(I.getSuccessorIndex());
367     else
368       NormalEdges.push_back(I.getSuccessorIndex());
369 
370   // Skip probabilities if no cold edges.
371   if (ColdEdges.empty())
372     return false;
373 
374   if (NormalEdges.empty()) {
375     BranchProbability Prob(1, ColdEdges.size());
376     for (unsigned SuccIdx : ColdEdges)
377       setEdgeProbability(BB, SuccIdx, Prob);
378     return true;
379   }
380 
381   auto ColdProb = BranchProbability::getBranchProbability(
382       CC_TAKEN_WEIGHT,
383       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
384   auto NormalProb = BranchProbability::getBranchProbability(
385       CC_NONTAKEN_WEIGHT,
386       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
387 
388   for (unsigned SuccIdx : ColdEdges)
389     setEdgeProbability(BB, SuccIdx, ColdProb);
390   for (unsigned SuccIdx : NormalEdges)
391     setEdgeProbability(BB, SuccIdx, NormalProb);
392 
393   return true;
394 }
395 
396 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
397 // between two pointer or pointer and NULL will fail.
398 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
399   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
400   if (!BI || !BI->isConditional())
401     return false;
402 
403   Value *Cond = BI->getCondition();
404   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
405   if (!CI || !CI->isEquality())
406     return false;
407 
408   Value *LHS = CI->getOperand(0);
409 
410   if (!LHS->getType()->isPointerTy())
411     return false;
412 
413   assert(CI->getOperand(1)->getType()->isPointerTy());
414 
415   // p != 0   ->   isProb = true
416   // p == 0   ->   isProb = false
417   // p != q   ->   isProb = true
418   // p == q   ->   isProb = false;
419   unsigned TakenIdx = 0, NonTakenIdx = 1;
420   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
421   if (!isProb)
422     std::swap(TakenIdx, NonTakenIdx);
423 
424   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
425                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
426   setEdgeProbability(BB, TakenIdx, TakenProb);
427   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
428   return true;
429 }
430 
431 static int getSCCNum(const BasicBlock *BB,
432                      const BranchProbabilityInfo::SccInfo &SccI) {
433   auto SccIt = SccI.SccNums.find(BB);
434   if (SccIt == SccI.SccNums.end())
435     return -1;
436   return SccIt->second;
437 }
438 
439 // Consider any block that is an entry point to the SCC as a header.
440 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
441                         BranchProbabilityInfo::SccInfo &SccI) {
442   assert(getSCCNum(BB, SccI) == SccNum);
443 
444   // Lazily compute the set of headers for a given SCC and cache the results
445   // in the SccHeaderMap.
446   if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
447     SccI.SccHeaders.resize(SccNum + 1);
448   auto &HeaderMap = SccI.SccHeaders[SccNum];
449   bool Inserted;
450   BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
451   std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
452   if (Inserted) {
453     bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
454                                  [&](const BasicBlock *Pred) {
455                                    return getSCCNum(Pred, SccI) != SccNum;
456                                  });
457     HeaderMapIt->second = IsHeader;
458     return IsHeader;
459   } else
460     return HeaderMapIt->second;
461 }
462 
463 // Compute the unlikely successors to the block BB in the loop L, specifically
464 // those that are unlikely because this is a loop, and add them to the
465 // UnlikelyBlocks set.
466 static void
467 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
468                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
469   // Sometimes in a loop we have a branch whose condition is made false by
470   // taking it. This is typically something like
471   //  int n = 0;
472   //  while (...) {
473   //    if (++n >= MAX) {
474   //      n = 0;
475   //    }
476   //  }
477   // In this sort of situation taking the branch means that at the very least it
478   // won't be taken again in the next iteration of the loop, so we should
479   // consider it less likely than a typical branch.
480   //
481   // We detect this by looking back through the graph of PHI nodes that sets the
482   // value that the condition depends on, and seeing if we can reach a successor
483   // block which can be determined to make the condition false.
484   //
485   // FIXME: We currently consider unlikely blocks to be half as likely as other
486   // blocks, but if we consider the example above the likelyhood is actually
487   // 1/MAX. We could therefore be more precise in how unlikely we consider
488   // blocks to be, but it would require more careful examination of the form
489   // of the comparison expression.
490   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
491   if (!BI || !BI->isConditional())
492     return;
493 
494   // Check if the branch is based on an instruction compared with a constant
495   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
496   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
497       !isa<Constant>(CI->getOperand(1)))
498     return;
499 
500   // Either the instruction must be a PHI, or a chain of operations involving
501   // constants that ends in a PHI which we can then collapse into a single value
502   // if the PHI value is known.
503   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
504   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
505   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
506   // Collect the instructions until we hit a PHI
507   std::list<BinaryOperator*> InstChain;
508   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
509          isa<Constant>(CmpLHS->getOperand(1))) {
510     // Stop if the chain extends outside of the loop
511     if (!L->contains(CmpLHS))
512       return;
513     InstChain.push_front(dyn_cast<BinaryOperator>(CmpLHS));
514     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
515     if (CmpLHS)
516       CmpPHI = dyn_cast<PHINode>(CmpLHS);
517   }
518   if (!CmpPHI || !L->contains(CmpPHI))
519     return;
520 
521   // Trace the phi node to find all values that come from successors of BB
522   SmallPtrSet<PHINode*, 8> VisitedInsts;
523   SmallVector<PHINode*, 8> WorkList;
524   WorkList.push_back(CmpPHI);
525   VisitedInsts.insert(CmpPHI);
526   while (!WorkList.empty()) {
527     PHINode *P = WorkList.back();
528     WorkList.pop_back();
529     for (BasicBlock *B : P->blocks()) {
530       // Skip blocks that aren't part of the loop
531       if (!L->contains(B))
532         continue;
533       Value *V = P->getIncomingValueForBlock(B);
534       // If the source is a PHI add it to the work list if we haven't
535       // already visited it.
536       if (PHINode *PN = dyn_cast<PHINode>(V)) {
537         if (VisitedInsts.insert(PN).second)
538           WorkList.push_back(PN);
539         continue;
540       }
541       // If this incoming value is a constant and B is a successor of BB, then
542       // we can constant-evaluate the compare to see if it makes the branch be
543       // taken or not.
544       Constant *CmpLHSConst = dyn_cast<Constant>(V);
545       if (!CmpLHSConst ||
546           std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
547         continue;
548       // First collapse InstChain
549       for (Instruction *I : InstChain) {
550         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
551                                         dyn_cast<Constant>(I->getOperand(1)),
552                                         true);
553         if (!CmpLHSConst)
554           break;
555       }
556       if (!CmpLHSConst)
557         continue;
558       // Now constant-evaluate the compare
559       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
560                                                   CmpLHSConst, CmpConst, true);
561       // If the result means we don't branch to the block then that block is
562       // unlikely.
563       if (Result &&
564           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
565            (Result->isOneValue() && B == BI->getSuccessor(1))))
566         UnlikelyBlocks.insert(B);
567     }
568   }
569 }
570 
571 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
572 // as taken, exiting edges as not-taken.
573 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
574                                                      const LoopInfo &LI,
575                                                      SccInfo &SccI) {
576   int SccNum;
577   Loop *L = LI.getLoopFor(BB);
578   if (!L) {
579     SccNum = getSCCNum(BB, SccI);
580     if (SccNum < 0)
581       return false;
582   }
583 
584   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
585   if (L)
586     computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
587 
588   SmallVector<unsigned, 8> BackEdges;
589   SmallVector<unsigned, 8> ExitingEdges;
590   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
591   SmallVector<unsigned, 8> UnlikelyEdges;
592 
593   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
594     // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
595     // irreducible loops.
596     if (L) {
597       if (UnlikelyBlocks.count(*I) != 0)
598         UnlikelyEdges.push_back(I.getSuccessorIndex());
599       else if (!L->contains(*I))
600         ExitingEdges.push_back(I.getSuccessorIndex());
601       else if (L->getHeader() == *I)
602         BackEdges.push_back(I.getSuccessorIndex());
603       else
604         InEdges.push_back(I.getSuccessorIndex());
605     } else {
606       if (getSCCNum(*I, SccI) != SccNum)
607         ExitingEdges.push_back(I.getSuccessorIndex());
608       else if (isSCCHeader(*I, SccNum, SccI))
609         BackEdges.push_back(I.getSuccessorIndex());
610       else
611         InEdges.push_back(I.getSuccessorIndex());
612     }
613   }
614 
615   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
616     return false;
617 
618   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
619   // normalize them so that they sum up to one.
620   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
621                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
622                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
623                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
624 
625   if (uint32_t numBackEdges = BackEdges.size()) {
626     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
627     auto Prob = TakenProb / numBackEdges;
628     for (unsigned SuccIdx : BackEdges)
629       setEdgeProbability(BB, SuccIdx, Prob);
630   }
631 
632   if (uint32_t numInEdges = InEdges.size()) {
633     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
634     auto Prob = TakenProb / numInEdges;
635     for (unsigned SuccIdx : InEdges)
636       setEdgeProbability(BB, SuccIdx, Prob);
637   }
638 
639   if (uint32_t numExitingEdges = ExitingEdges.size()) {
640     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
641                                                        Denom);
642     auto Prob = NotTakenProb / numExitingEdges;
643     for (unsigned SuccIdx : ExitingEdges)
644       setEdgeProbability(BB, SuccIdx, Prob);
645   }
646 
647   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
648     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
649                                                        Denom);
650     auto Prob = UnlikelyProb / numUnlikelyEdges;
651     for (unsigned SuccIdx : UnlikelyEdges)
652       setEdgeProbability(BB, SuccIdx, Prob);
653   }
654 
655   return true;
656 }
657 
658 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
659                                                const TargetLibraryInfo *TLI) {
660   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
661   if (!BI || !BI->isConditional())
662     return false;
663 
664   Value *Cond = BI->getCondition();
665   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
666   if (!CI)
667     return false;
668 
669   Value *RHS = CI->getOperand(1);
670   ConstantInt *CV = dyn_cast<ConstantInt>(RHS);
671   if (!CV)
672     return false;
673 
674   // If the LHS is the result of AND'ing a value with a single bit bitmask,
675   // we don't have information about probabilities.
676   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
677     if (LHS->getOpcode() == Instruction::And)
678       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
679         if (AndRHS->getValue().isPowerOf2())
680           return false;
681 
682   // Check if the LHS is the return value of a library function
683   LibFunc Func = NumLibFuncs;
684   if (TLI)
685     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
686       if (Function *CalledFn = Call->getCalledFunction())
687         TLI->getLibFunc(*CalledFn, Func);
688 
689   bool isProb;
690   if (Func == LibFunc_strcasecmp ||
691       Func == LibFunc_strcmp ||
692       Func == LibFunc_strncasecmp ||
693       Func == LibFunc_strncmp ||
694       Func == LibFunc_memcmp) {
695     // strcmp and similar functions return zero, negative, or positive, if the
696     // first string is equal, less, or greater than the second. We consider it
697     // likely that the strings are not equal, so a comparison with zero is
698     // probably false, but also a comparison with any other number is also
699     // probably false given that what exactly is returned for nonzero values is
700     // not specified. Any kind of comparison other than equality we know
701     // nothing about.
702     switch (CI->getPredicate()) {
703     case CmpInst::ICMP_EQ:
704       isProb = false;
705       break;
706     case CmpInst::ICMP_NE:
707       isProb = true;
708       break;
709     default:
710       return false;
711     }
712   } else if (CV->isZero()) {
713     switch (CI->getPredicate()) {
714     case CmpInst::ICMP_EQ:
715       // X == 0   ->  Unlikely
716       isProb = false;
717       break;
718     case CmpInst::ICMP_NE:
719       // X != 0   ->  Likely
720       isProb = true;
721       break;
722     case CmpInst::ICMP_SLT:
723       // X < 0   ->  Unlikely
724       isProb = false;
725       break;
726     case CmpInst::ICMP_SGT:
727       // X > 0   ->  Likely
728       isProb = true;
729       break;
730     default:
731       return false;
732     }
733   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
734     // InstCombine canonicalizes X <= 0 into X < 1.
735     // X <= 0   ->  Unlikely
736     isProb = false;
737   } else if (CV->isMinusOne()) {
738     switch (CI->getPredicate()) {
739     case CmpInst::ICMP_EQ:
740       // X == -1  ->  Unlikely
741       isProb = false;
742       break;
743     case CmpInst::ICMP_NE:
744       // X != -1  ->  Likely
745       isProb = true;
746       break;
747     case CmpInst::ICMP_SGT:
748       // InstCombine canonicalizes X >= 0 into X > -1.
749       // X >= 0   ->  Likely
750       isProb = true;
751       break;
752     default:
753       return false;
754     }
755   } else {
756     return false;
757   }
758 
759   unsigned TakenIdx = 0, NonTakenIdx = 1;
760 
761   if (!isProb)
762     std::swap(TakenIdx, NonTakenIdx);
763 
764   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
765                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
766   setEdgeProbability(BB, TakenIdx, TakenProb);
767   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
768   return true;
769 }
770 
771 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
772   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
773   if (!BI || !BI->isConditional())
774     return false;
775 
776   Value *Cond = BI->getCondition();
777   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
778   if (!FCmp)
779     return false;
780 
781   bool isProb;
782   if (FCmp->isEquality()) {
783     // f1 == f2 -> Unlikely
784     // f1 != f2 -> Likely
785     isProb = !FCmp->isTrueWhenEqual();
786   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
787     // !isnan -> Likely
788     isProb = true;
789   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
790     // isnan -> Unlikely
791     isProb = false;
792   } else {
793     return false;
794   }
795 
796   unsigned TakenIdx = 0, NonTakenIdx = 1;
797 
798   if (!isProb)
799     std::swap(TakenIdx, NonTakenIdx);
800 
801   BranchProbability TakenProb(FPH_TAKEN_WEIGHT,
802                               FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
803   setEdgeProbability(BB, TakenIdx, TakenProb);
804   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
805   return true;
806 }
807 
808 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
809   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
810   if (!II)
811     return false;
812 
813   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
814                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
815   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
816   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
817   return true;
818 }
819 
820 void BranchProbabilityInfo::releaseMemory() {
821   Probs.clear();
822 }
823 
824 void BranchProbabilityInfo::print(raw_ostream &OS) const {
825   OS << "---- Branch Probabilities ----\n";
826   // We print the probabilities from the last function the analysis ran over,
827   // or the function it is currently running over.
828   assert(LastF && "Cannot print prior to running over a function");
829   for (const auto &BI : *LastF) {
830     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
831          ++SI) {
832       printEdgeProbability(OS << "  ", &BI, *SI);
833     }
834   }
835 }
836 
837 bool BranchProbabilityInfo::
838 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
839   // Hot probability is at least 4/5 = 80%
840   // FIXME: Compare against a static "hot" BranchProbability.
841   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
842 }
843 
844 const BasicBlock *
845 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
846   auto MaxProb = BranchProbability::getZero();
847   const BasicBlock *MaxSucc = nullptr;
848 
849   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
850     const BasicBlock *Succ = *I;
851     auto Prob = getEdgeProbability(BB, Succ);
852     if (Prob > MaxProb) {
853       MaxProb = Prob;
854       MaxSucc = Succ;
855     }
856   }
857 
858   // Hot probability is at least 4/5 = 80%
859   if (MaxProb > BranchProbability(4, 5))
860     return MaxSucc;
861 
862   return nullptr;
863 }
864 
865 /// Get the raw edge probability for the edge. If can't find it, return a
866 /// default probability 1/N where N is the number of successors. Here an edge is
867 /// specified using PredBlock and an
868 /// index to the successors.
869 BranchProbability
870 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
871                                           unsigned IndexInSuccessors) const {
872   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
873 
874   if (I != Probs.end())
875     return I->second;
876 
877   return {1, static_cast<uint32_t>(succ_size(Src))};
878 }
879 
880 BranchProbability
881 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
882                                           succ_const_iterator Dst) const {
883   return getEdgeProbability(Src, Dst.getSuccessorIndex());
884 }
885 
886 /// Get the raw edge probability calculated for the block pair. This returns the
887 /// sum of all raw edge probabilities from Src to Dst.
888 BranchProbability
889 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
890                                           const BasicBlock *Dst) const {
891   auto Prob = BranchProbability::getZero();
892   bool FoundProb = false;
893   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
894     if (*I == Dst) {
895       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
896       if (MapI != Probs.end()) {
897         FoundProb = true;
898         Prob += MapI->second;
899       }
900     }
901   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
902   return FoundProb ? Prob : BranchProbability(1, succ_num);
903 }
904 
905 /// Set the edge probability for a given edge specified by PredBlock and an
906 /// index to the successors.
907 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
908                                                unsigned IndexInSuccessors,
909                                                BranchProbability Prob) {
910   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
911   Handles.insert(BasicBlockCallbackVH(Src, this));
912   LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
913                     << IndexInSuccessors << " successor probability to " << Prob
914                     << "\n");
915 }
916 
917 raw_ostream &
918 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
919                                             const BasicBlock *Src,
920                                             const BasicBlock *Dst) const {
921   const BranchProbability Prob = getEdgeProbability(Src, Dst);
922   OS << "edge " << Src->getName() << " -> " << Dst->getName()
923      << " probability is " << Prob
924      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
925 
926   return OS;
927 }
928 
929 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
930   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
931     auto Key = I->first;
932     if (Key.first == BB)
933       Probs.erase(Key);
934   }
935 }
936 
937 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
938                                       const TargetLibraryInfo *TLI) {
939   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
940                     << " ----\n\n");
941   LastF = &F; // Store the last function we ran on for printing.
942   assert(PostDominatedByUnreachable.empty());
943   assert(PostDominatedByColdCall.empty());
944 
945   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
946   // FIXME: We could only calculate this if the CFG is known to be irreducible
947   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
948   int SccNum = 0;
949   SccInfo SccI;
950   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
951        ++It, ++SccNum) {
952     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
953     // catch them.
954     const std::vector<const BasicBlock *> &Scc = *It;
955     if (Scc.size() == 1)
956       continue;
957 
958     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
959     for (auto *BB : Scc) {
960       LLVM_DEBUG(dbgs() << " " << BB->getName());
961       SccI.SccNums[BB] = SccNum;
962     }
963     LLVM_DEBUG(dbgs() << "\n");
964   }
965 
966   // Walk the basic blocks in post-order so that we can build up state about
967   // the successors of a block iteratively.
968   for (auto BB : post_order(&F.getEntryBlock())) {
969     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
970                       << "\n");
971     updatePostDominatedByUnreachable(BB);
972     updatePostDominatedByColdCall(BB);
973     // If there is no at least two successors, no sense to set probability.
974     if (BB->getTerminator()->getNumSuccessors() < 2)
975       continue;
976     if (calcMetadataWeights(BB))
977       continue;
978     if (calcUnreachableHeuristics(BB))
979       continue;
980     if (calcColdCallHeuristics(BB))
981       continue;
982     if (calcLoopBranchHeuristics(BB, LI, SccI))
983       continue;
984     if (calcPointerHeuristics(BB))
985       continue;
986     if (calcZeroHeuristics(BB, TLI))
987       continue;
988     if (calcFloatingPointHeuristics(BB))
989       continue;
990     calcInvokeHeuristics(BB);
991   }
992 
993   PostDominatedByUnreachable.clear();
994   PostDominatedByColdCall.clear();
995 
996   if (PrintBranchProb &&
997       (PrintBranchProbFuncName.empty() ||
998        F.getName().equals(PrintBranchProbFuncName))) {
999     print(dbgs());
1000   }
1001 }
1002 
1003 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1004     AnalysisUsage &AU) const {
1005   // We require DT so it's available when LI is available. The LI updating code
1006   // asserts that DT is also present so if we don't make sure that we have DT
1007   // here, that assert will trigger.
1008   AU.addRequired<DominatorTreeWrapperPass>();
1009   AU.addRequired<LoopInfoWrapperPass>();
1010   AU.addRequired<TargetLibraryInfoWrapperPass>();
1011   AU.setPreservesAll();
1012 }
1013 
1014 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1015   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1016   const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
1017   BPI.calculate(F, LI, &TLI);
1018   return false;
1019 }
1020 
1021 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1022 
1023 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1024                                              const Module *) const {
1025   BPI.print(OS);
1026 }
1027 
1028 AnalysisKey BranchProbabilityAnalysis::Key;
1029 BranchProbabilityInfo
1030 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1031   BranchProbabilityInfo BPI;
1032   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1033   return BPI;
1034 }
1035 
1036 PreservedAnalyses
1037 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1038   OS << "Printing analysis results of BPI for function "
1039      << "'" << F.getName() << "':"
1040      << "\n";
1041   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1042   return PreservedAnalyses::all();
1043 }
1044