1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/CFG.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/LLVMContext.h"
30 #include "llvm/IR/Metadata.h"
31 #include "llvm/IR/PassManager.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/BranchProbability.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/raw_ostream.h"
41 #include <cassert>
42 #include <cstdint>
43 #include <iterator>
44 #include <utility>
45 
46 using namespace llvm;
47 
48 #define DEBUG_TYPE "branch-prob"
49 
50 static cl::opt<bool> PrintBranchProb(
51     "print-bpi", cl::init(false), cl::Hidden,
52     cl::desc("Print the branch probability info."));
53 
54 cl::opt<std::string> PrintBranchProbFuncName(
55     "print-bpi-func-name", cl::Hidden,
56     cl::desc("The option to specify the name of the function "
57              "whose branch probability info is printed."));
58 
59 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
60                       "Branch Probability Analysis", false, true)
61 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
62 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
63 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
64                     "Branch Probability Analysis", false, true)
65 
66 BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
67     : FunctionPass(ID) {
68   initializeBranchProbabilityInfoWrapperPassPass(
69       *PassRegistry::getPassRegistry());
70 }
71 
72 char BranchProbabilityInfoWrapperPass::ID = 0;
73 
74 // Weights are for internal use only. They are used by heuristics to help to
75 // estimate edges' probability. Example:
76 //
77 // Using "Loop Branch Heuristics" we predict weights of edges for the
78 // block BB2.
79 //         ...
80 //          |
81 //          V
82 //         BB1<-+
83 //          |   |
84 //          |   | (Weight = 124)
85 //          V   |
86 //         BB2--+
87 //          |
88 //          | (Weight = 4)
89 //          V
90 //         BB3
91 //
92 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
93 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
94 static const uint32_t LBH_TAKEN_WEIGHT = 124;
95 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
96 // Unlikely edges within a loop are half as likely as other edges
97 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
98 
99 /// Unreachable-terminating branch taken probability.
100 ///
101 /// This is the probability for a branch being taken to a block that terminates
102 /// (eventually) in unreachable. These are predicted as unlikely as possible.
103 /// All reachable probability will equally share the remaining part.
104 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
105 
106 /// Weight for a branch taken going into a cold block.
107 ///
108 /// This is the weight for a branch taken toward a block marked
109 /// cold.  A block is marked cold if it's postdominated by a
110 /// block containing a call to a cold function.  Cold functions
111 /// are those marked with attribute 'cold'.
112 static const uint32_t CC_TAKEN_WEIGHT = 4;
113 
114 /// Weight for a branch not-taken into a cold block.
115 ///
116 /// This is the weight for a branch not taken toward a block marked
117 /// cold.
118 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
119 
120 static const uint32_t PH_TAKEN_WEIGHT = 20;
121 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
122 
123 static const uint32_t ZH_TAKEN_WEIGHT = 20;
124 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
125 
126 static const uint32_t FPH_TAKEN_WEIGHT = 20;
127 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
128 
129 /// This is the probability for an ordered floating point comparison.
130 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
131 /// This is the probability for an unordered floating point comparison, it means
132 /// one or two of the operands are NaN. Usually it is used to test for an
133 /// exceptional case, so the result is unlikely.
134 static const uint32_t FPH_UNO_WEIGHT = 1;
135 
136 /// Invoke-terminating normal branch taken weight
137 ///
138 /// This is the weight for branching to the normal destination of an invoke
139 /// instruction. We expect this to happen most of the time. Set the weight to an
140 /// absurdly high value so that nested loops subsume it.
141 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
142 
143 /// Invoke-terminating normal branch not-taken weight.
144 ///
145 /// This is the weight for branching to the unwind destination of an invoke
146 /// instruction. This is essentially never taken.
147 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
148 
149 /// Add \p BB to PostDominatedByUnreachable set if applicable.
150 void
151 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
152   const Instruction *TI = BB->getTerminator();
153   if (TI->getNumSuccessors() == 0) {
154     if (isa<UnreachableInst>(TI) ||
155         // If this block is terminated by a call to
156         // @llvm.experimental.deoptimize then treat it like an unreachable since
157         // the @llvm.experimental.deoptimize call is expected to practically
158         // never execute.
159         BB->getTerminatingDeoptimizeCall())
160       PostDominatedByUnreachable.insert(BB);
161     return;
162   }
163 
164   // If the terminator is an InvokeInst, check only the normal destination block
165   // as the unwind edge of InvokeInst is also very unlikely taken.
166   if (auto *II = dyn_cast<InvokeInst>(TI)) {
167     if (PostDominatedByUnreachable.count(II->getNormalDest()))
168       PostDominatedByUnreachable.insert(BB);
169     return;
170   }
171 
172   for (auto *I : successors(BB))
173     // If any of successor is not post dominated then BB is also not.
174     if (!PostDominatedByUnreachable.count(I))
175       return;
176 
177   PostDominatedByUnreachable.insert(BB);
178 }
179 
180 /// Add \p BB to PostDominatedByColdCall set if applicable.
181 void
182 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
183   assert(!PostDominatedByColdCall.count(BB));
184   const Instruction *TI = BB->getTerminator();
185   if (TI->getNumSuccessors() == 0)
186     return;
187 
188   // If all of successor are post dominated then BB is also done.
189   if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
190         return PostDominatedByColdCall.count(SuccBB);
191       })) {
192     PostDominatedByColdCall.insert(BB);
193     return;
194   }
195 
196   // If the terminator is an InvokeInst, check only the normal destination
197   // block as the unwind edge of InvokeInst is also very unlikely taken.
198   if (auto *II = dyn_cast<InvokeInst>(TI))
199     if (PostDominatedByColdCall.count(II->getNormalDest())) {
200       PostDominatedByColdCall.insert(BB);
201       return;
202     }
203 
204   // Otherwise, if the block itself contains a cold function, add it to the
205   // set of blocks post-dominated by a cold call.
206   for (auto &I : *BB)
207     if (const CallInst *CI = dyn_cast<CallInst>(&I))
208       if (CI->hasFnAttr(Attribute::Cold)) {
209         PostDominatedByColdCall.insert(BB);
210         return;
211       }
212 }
213 
214 /// Calculate edge weights for successors lead to unreachable.
215 ///
216 /// Predict that a successor which leads necessarily to an
217 /// unreachable-terminated block as extremely unlikely.
218 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
219   const Instruction *TI = BB->getTerminator();
220   (void) TI;
221   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
222   assert(!isa<InvokeInst>(TI) &&
223          "Invokes should have already been handled by calcInvokeHeuristics");
224 
225   SmallVector<unsigned, 4> UnreachableEdges;
226   SmallVector<unsigned, 4> ReachableEdges;
227 
228   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
229     if (PostDominatedByUnreachable.count(*I))
230       UnreachableEdges.push_back(I.getSuccessorIndex());
231     else
232       ReachableEdges.push_back(I.getSuccessorIndex());
233 
234   // Skip probabilities if all were reachable.
235   if (UnreachableEdges.empty())
236     return false;
237 
238   if (ReachableEdges.empty()) {
239     BranchProbability Prob(1, UnreachableEdges.size());
240     for (unsigned SuccIdx : UnreachableEdges)
241       setEdgeProbability(BB, SuccIdx, Prob);
242     return true;
243   }
244 
245   auto UnreachableProb = UR_TAKEN_PROB;
246   auto ReachableProb =
247       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
248       ReachableEdges.size();
249 
250   for (unsigned SuccIdx : UnreachableEdges)
251     setEdgeProbability(BB, SuccIdx, UnreachableProb);
252   for (unsigned SuccIdx : ReachableEdges)
253     setEdgeProbability(BB, SuccIdx, ReachableProb);
254 
255   return true;
256 }
257 
258 // Propagate existing explicit probabilities from either profile data or
259 // 'expect' intrinsic processing. Examine metadata against unreachable
260 // heuristic. The probability of the edge coming to unreachable block is
261 // set to min of metadata and unreachable heuristic.
262 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
263   const Instruction *TI = BB->getTerminator();
264   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
265   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
266     return false;
267 
268   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
269   if (!WeightsNode)
270     return false;
271 
272   // Check that the number of successors is manageable.
273   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
274 
275   // Ensure there are weights for all of the successors. Note that the first
276   // operand to the metadata node is a name, not a weight.
277   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
278     return false;
279 
280   // Build up the final weights that will be used in a temporary buffer.
281   // Compute the sum of all weights to later decide whether they need to
282   // be scaled to fit in 32 bits.
283   uint64_t WeightSum = 0;
284   SmallVector<uint32_t, 2> Weights;
285   SmallVector<unsigned, 2> UnreachableIdxs;
286   SmallVector<unsigned, 2> ReachableIdxs;
287   Weights.reserve(TI->getNumSuccessors());
288   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
289     ConstantInt *Weight =
290         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
291     if (!Weight)
292       return false;
293     assert(Weight->getValue().getActiveBits() <= 32 &&
294            "Too many bits for uint32_t");
295     Weights.push_back(Weight->getZExtValue());
296     WeightSum += Weights.back();
297     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
298       UnreachableIdxs.push_back(i - 1);
299     else
300       ReachableIdxs.push_back(i - 1);
301   }
302   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
303 
304   // If the sum of weights does not fit in 32 bits, scale every weight down
305   // accordingly.
306   uint64_t ScalingFactor =
307       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
308 
309   if (ScalingFactor > 1) {
310     WeightSum = 0;
311     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
312       Weights[i] /= ScalingFactor;
313       WeightSum += Weights[i];
314     }
315   }
316   assert(WeightSum <= UINT32_MAX &&
317          "Expected weights to scale down to 32 bits");
318 
319   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
320     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
321       Weights[i] = 1;
322     WeightSum = TI->getNumSuccessors();
323   }
324 
325   // Set the probability.
326   SmallVector<BranchProbability, 2> BP;
327   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
328     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
329 
330   // Examine the metadata against unreachable heuristic.
331   // If the unreachable heuristic is more strong then we use it for this edge.
332   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
333     auto ToDistribute = BranchProbability::getZero();
334     auto UnreachableProb = UR_TAKEN_PROB;
335     for (auto i : UnreachableIdxs)
336       if (UnreachableProb < BP[i]) {
337         ToDistribute += BP[i] - UnreachableProb;
338         BP[i] = UnreachableProb;
339       }
340 
341     // If we modified the probability of some edges then we must distribute
342     // the difference between reachable blocks.
343     if (ToDistribute > BranchProbability::getZero()) {
344       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
345       for (auto i : ReachableIdxs)
346         BP[i] += PerEdge;
347     }
348   }
349 
350   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
351     setEdgeProbability(BB, i, BP[i]);
352 
353   return true;
354 }
355 
356 /// Calculate edge weights for edges leading to cold blocks.
357 ///
358 /// A cold block is one post-dominated by  a block with a call to a
359 /// cold function.  Those edges are unlikely to be taken, so we give
360 /// them relatively low weight.
361 ///
362 /// Return true if we could compute the weights for cold edges.
363 /// Return false, otherwise.
364 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
365   const Instruction *TI = BB->getTerminator();
366   (void) TI;
367   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
368   assert(!isa<InvokeInst>(TI) &&
369          "Invokes should have already been handled by calcInvokeHeuristics");
370 
371   // Determine which successors are post-dominated by a cold block.
372   SmallVector<unsigned, 4> ColdEdges;
373   SmallVector<unsigned, 4> NormalEdges;
374   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
375     if (PostDominatedByColdCall.count(*I))
376       ColdEdges.push_back(I.getSuccessorIndex());
377     else
378       NormalEdges.push_back(I.getSuccessorIndex());
379 
380   // Skip probabilities if no cold edges.
381   if (ColdEdges.empty())
382     return false;
383 
384   if (NormalEdges.empty()) {
385     BranchProbability Prob(1, ColdEdges.size());
386     for (unsigned SuccIdx : ColdEdges)
387       setEdgeProbability(BB, SuccIdx, Prob);
388     return true;
389   }
390 
391   auto ColdProb = BranchProbability::getBranchProbability(
392       CC_TAKEN_WEIGHT,
393       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
394   auto NormalProb = BranchProbability::getBranchProbability(
395       CC_NONTAKEN_WEIGHT,
396       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
397 
398   for (unsigned SuccIdx : ColdEdges)
399     setEdgeProbability(BB, SuccIdx, ColdProb);
400   for (unsigned SuccIdx : NormalEdges)
401     setEdgeProbability(BB, SuccIdx, NormalProb);
402 
403   return true;
404 }
405 
406 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
407 // between two pointer or pointer and NULL will fail.
408 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
409   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
410   if (!BI || !BI->isConditional())
411     return false;
412 
413   Value *Cond = BI->getCondition();
414   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
415   if (!CI || !CI->isEquality())
416     return false;
417 
418   Value *LHS = CI->getOperand(0);
419 
420   if (!LHS->getType()->isPointerTy())
421     return false;
422 
423   assert(CI->getOperand(1)->getType()->isPointerTy());
424 
425   // p != 0   ->   isProb = true
426   // p == 0   ->   isProb = false
427   // p != q   ->   isProb = true
428   // p == q   ->   isProb = false;
429   unsigned TakenIdx = 0, NonTakenIdx = 1;
430   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
431   if (!isProb)
432     std::swap(TakenIdx, NonTakenIdx);
433 
434   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
435                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
436   setEdgeProbability(BB, TakenIdx, TakenProb);
437   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
438   return true;
439 }
440 
441 static int getSCCNum(const BasicBlock *BB,
442                      const BranchProbabilityInfo::SccInfo &SccI) {
443   auto SccIt = SccI.SccNums.find(BB);
444   if (SccIt == SccI.SccNums.end())
445     return -1;
446   return SccIt->second;
447 }
448 
449 // Consider any block that is an entry point to the SCC as a header.
450 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
451                         BranchProbabilityInfo::SccInfo &SccI) {
452   assert(getSCCNum(BB, SccI) == SccNum);
453 
454   // Lazily compute the set of headers for a given SCC and cache the results
455   // in the SccHeaderMap.
456   if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
457     SccI.SccHeaders.resize(SccNum + 1);
458   auto &HeaderMap = SccI.SccHeaders[SccNum];
459   bool Inserted;
460   BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
461   std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
462   if (Inserted) {
463     bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
464                                  [&](const BasicBlock *Pred) {
465                                    return getSCCNum(Pred, SccI) != SccNum;
466                                  });
467     HeaderMapIt->second = IsHeader;
468     return IsHeader;
469   } else
470     return HeaderMapIt->second;
471 }
472 
473 // Compute the unlikely successors to the block BB in the loop L, specifically
474 // those that are unlikely because this is a loop, and add them to the
475 // UnlikelyBlocks set.
476 static void
477 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
478                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
479   // Sometimes in a loop we have a branch whose condition is made false by
480   // taking it. This is typically something like
481   //  int n = 0;
482   //  while (...) {
483   //    if (++n >= MAX) {
484   //      n = 0;
485   //    }
486   //  }
487   // In this sort of situation taking the branch means that at the very least it
488   // won't be taken again in the next iteration of the loop, so we should
489   // consider it less likely than a typical branch.
490   //
491   // We detect this by looking back through the graph of PHI nodes that sets the
492   // value that the condition depends on, and seeing if we can reach a successor
493   // block which can be determined to make the condition false.
494   //
495   // FIXME: We currently consider unlikely blocks to be half as likely as other
496   // blocks, but if we consider the example above the likelyhood is actually
497   // 1/MAX. We could therefore be more precise in how unlikely we consider
498   // blocks to be, but it would require more careful examination of the form
499   // of the comparison expression.
500   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
501   if (!BI || !BI->isConditional())
502     return;
503 
504   // Check if the branch is based on an instruction compared with a constant
505   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
506   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
507       !isa<Constant>(CI->getOperand(1)))
508     return;
509 
510   // Either the instruction must be a PHI, or a chain of operations involving
511   // constants that ends in a PHI which we can then collapse into a single value
512   // if the PHI value is known.
513   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
514   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
515   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
516   // Collect the instructions until we hit a PHI
517   SmallVector<BinaryOperator *, 1> InstChain;
518   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
519          isa<Constant>(CmpLHS->getOperand(1))) {
520     // Stop if the chain extends outside of the loop
521     if (!L->contains(CmpLHS))
522       return;
523     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
524     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
525     if (CmpLHS)
526       CmpPHI = dyn_cast<PHINode>(CmpLHS);
527   }
528   if (!CmpPHI || !L->contains(CmpPHI))
529     return;
530 
531   // Trace the phi node to find all values that come from successors of BB
532   SmallPtrSet<PHINode*, 8> VisitedInsts;
533   SmallVector<PHINode*, 8> WorkList;
534   WorkList.push_back(CmpPHI);
535   VisitedInsts.insert(CmpPHI);
536   while (!WorkList.empty()) {
537     PHINode *P = WorkList.back();
538     WorkList.pop_back();
539     for (BasicBlock *B : P->blocks()) {
540       // Skip blocks that aren't part of the loop
541       if (!L->contains(B))
542         continue;
543       Value *V = P->getIncomingValueForBlock(B);
544       // If the source is a PHI add it to the work list if we haven't
545       // already visited it.
546       if (PHINode *PN = dyn_cast<PHINode>(V)) {
547         if (VisitedInsts.insert(PN).second)
548           WorkList.push_back(PN);
549         continue;
550       }
551       // If this incoming value is a constant and B is a successor of BB, then
552       // we can constant-evaluate the compare to see if it makes the branch be
553       // taken or not.
554       Constant *CmpLHSConst = dyn_cast<Constant>(V);
555       if (!CmpLHSConst ||
556           std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
557         continue;
558       // First collapse InstChain
559       for (Instruction *I : llvm::reverse(InstChain)) {
560         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
561                                         cast<Constant>(I->getOperand(1)), true);
562         if (!CmpLHSConst)
563           break;
564       }
565       if (!CmpLHSConst)
566         continue;
567       // Now constant-evaluate the compare
568       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
569                                                   CmpLHSConst, CmpConst, true);
570       // If the result means we don't branch to the block then that block is
571       // unlikely.
572       if (Result &&
573           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
574            (Result->isOneValue() && B == BI->getSuccessor(1))))
575         UnlikelyBlocks.insert(B);
576     }
577   }
578 }
579 
580 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
581 // as taken, exiting edges as not-taken.
582 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
583                                                      const LoopInfo &LI,
584                                                      SccInfo &SccI) {
585   int SccNum;
586   Loop *L = LI.getLoopFor(BB);
587   if (!L) {
588     SccNum = getSCCNum(BB, SccI);
589     if (SccNum < 0)
590       return false;
591   }
592 
593   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
594   if (L)
595     computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
596 
597   SmallVector<unsigned, 8> BackEdges;
598   SmallVector<unsigned, 8> ExitingEdges;
599   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
600   SmallVector<unsigned, 8> UnlikelyEdges;
601 
602   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
603     // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
604     // irreducible loops.
605     if (L) {
606       if (UnlikelyBlocks.count(*I) != 0)
607         UnlikelyEdges.push_back(I.getSuccessorIndex());
608       else if (!L->contains(*I))
609         ExitingEdges.push_back(I.getSuccessorIndex());
610       else if (L->getHeader() == *I)
611         BackEdges.push_back(I.getSuccessorIndex());
612       else
613         InEdges.push_back(I.getSuccessorIndex());
614     } else {
615       if (getSCCNum(*I, SccI) != SccNum)
616         ExitingEdges.push_back(I.getSuccessorIndex());
617       else if (isSCCHeader(*I, SccNum, SccI))
618         BackEdges.push_back(I.getSuccessorIndex());
619       else
620         InEdges.push_back(I.getSuccessorIndex());
621     }
622   }
623 
624   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
625     return false;
626 
627   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
628   // normalize them so that they sum up to one.
629   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
630                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
631                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
632                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
633 
634   if (uint32_t numBackEdges = BackEdges.size()) {
635     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
636     auto Prob = TakenProb / numBackEdges;
637     for (unsigned SuccIdx : BackEdges)
638       setEdgeProbability(BB, SuccIdx, Prob);
639   }
640 
641   if (uint32_t numInEdges = InEdges.size()) {
642     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
643     auto Prob = TakenProb / numInEdges;
644     for (unsigned SuccIdx : InEdges)
645       setEdgeProbability(BB, SuccIdx, Prob);
646   }
647 
648   if (uint32_t numExitingEdges = ExitingEdges.size()) {
649     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
650                                                        Denom);
651     auto Prob = NotTakenProb / numExitingEdges;
652     for (unsigned SuccIdx : ExitingEdges)
653       setEdgeProbability(BB, SuccIdx, Prob);
654   }
655 
656   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
657     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
658                                                        Denom);
659     auto Prob = UnlikelyProb / numUnlikelyEdges;
660     for (unsigned SuccIdx : UnlikelyEdges)
661       setEdgeProbability(BB, SuccIdx, Prob);
662   }
663 
664   return true;
665 }
666 
667 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
668                                                const TargetLibraryInfo *TLI) {
669   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
670   if (!BI || !BI->isConditional())
671     return false;
672 
673   Value *Cond = BI->getCondition();
674   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
675   if (!CI)
676     return false;
677 
678   auto GetConstantInt = [](Value *V) {
679     if (auto *I = dyn_cast<BitCastInst>(V))
680       return dyn_cast<ConstantInt>(I->getOperand(0));
681     return dyn_cast<ConstantInt>(V);
682   };
683 
684   Value *RHS = CI->getOperand(1);
685   ConstantInt *CV = GetConstantInt(RHS);
686   if (!CV)
687     return false;
688 
689   // If the LHS is the result of AND'ing a value with a single bit bitmask,
690   // we don't have information about probabilities.
691   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
692     if (LHS->getOpcode() == Instruction::And)
693       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
694         if (AndRHS->getValue().isPowerOf2())
695           return false;
696 
697   // Check if the LHS is the return value of a library function
698   LibFunc Func = NumLibFuncs;
699   if (TLI)
700     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
701       if (Function *CalledFn = Call->getCalledFunction())
702         TLI->getLibFunc(*CalledFn, Func);
703 
704   bool isProb;
705   if (Func == LibFunc_strcasecmp ||
706       Func == LibFunc_strcmp ||
707       Func == LibFunc_strncasecmp ||
708       Func == LibFunc_strncmp ||
709       Func == LibFunc_memcmp) {
710     // strcmp and similar functions return zero, negative, or positive, if the
711     // first string is equal, less, or greater than the second. We consider it
712     // likely that the strings are not equal, so a comparison with zero is
713     // probably false, but also a comparison with any other number is also
714     // probably false given that what exactly is returned for nonzero values is
715     // not specified. Any kind of comparison other than equality we know
716     // nothing about.
717     switch (CI->getPredicate()) {
718     case CmpInst::ICMP_EQ:
719       isProb = false;
720       break;
721     case CmpInst::ICMP_NE:
722       isProb = true;
723       break;
724     default:
725       return false;
726     }
727   } else if (CV->isZero()) {
728     switch (CI->getPredicate()) {
729     case CmpInst::ICMP_EQ:
730       // X == 0   ->  Unlikely
731       isProb = false;
732       break;
733     case CmpInst::ICMP_NE:
734       // X != 0   ->  Likely
735       isProb = true;
736       break;
737     case CmpInst::ICMP_SLT:
738       // X < 0   ->  Unlikely
739       isProb = false;
740       break;
741     case CmpInst::ICMP_SGT:
742       // X > 0   ->  Likely
743       isProb = true;
744       break;
745     default:
746       return false;
747     }
748   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
749     // InstCombine canonicalizes X <= 0 into X < 1.
750     // X <= 0   ->  Unlikely
751     isProb = false;
752   } else if (CV->isMinusOne()) {
753     switch (CI->getPredicate()) {
754     case CmpInst::ICMP_EQ:
755       // X == -1  ->  Unlikely
756       isProb = false;
757       break;
758     case CmpInst::ICMP_NE:
759       // X != -1  ->  Likely
760       isProb = true;
761       break;
762     case CmpInst::ICMP_SGT:
763       // InstCombine canonicalizes X >= 0 into X > -1.
764       // X >= 0   ->  Likely
765       isProb = true;
766       break;
767     default:
768       return false;
769     }
770   } else {
771     return false;
772   }
773 
774   unsigned TakenIdx = 0, NonTakenIdx = 1;
775 
776   if (!isProb)
777     std::swap(TakenIdx, NonTakenIdx);
778 
779   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
780                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
781   setEdgeProbability(BB, TakenIdx, TakenProb);
782   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
783   return true;
784 }
785 
786 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
787   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
788   if (!BI || !BI->isConditional())
789     return false;
790 
791   Value *Cond = BI->getCondition();
792   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
793   if (!FCmp)
794     return false;
795 
796   uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
797   uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
798   bool isProb;
799   if (FCmp->isEquality()) {
800     // f1 == f2 -> Unlikely
801     // f1 != f2 -> Likely
802     isProb = !FCmp->isTrueWhenEqual();
803   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
804     // !isnan -> Likely
805     isProb = true;
806     TakenWeight = FPH_ORD_WEIGHT;
807     NontakenWeight = FPH_UNO_WEIGHT;
808   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
809     // isnan -> Unlikely
810     isProb = false;
811     TakenWeight = FPH_ORD_WEIGHT;
812     NontakenWeight = FPH_UNO_WEIGHT;
813   } else {
814     return false;
815   }
816 
817   unsigned TakenIdx = 0, NonTakenIdx = 1;
818 
819   if (!isProb)
820     std::swap(TakenIdx, NonTakenIdx);
821 
822   BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
823   setEdgeProbability(BB, TakenIdx, TakenProb);
824   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
825   return true;
826 }
827 
828 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
829   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
830   if (!II)
831     return false;
832 
833   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
834                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
835   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
836   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
837   return true;
838 }
839 
840 void BranchProbabilityInfo::releaseMemory() {
841   Probs.clear();
842 }
843 
844 void BranchProbabilityInfo::print(raw_ostream &OS) const {
845   OS << "---- Branch Probabilities ----\n";
846   // We print the probabilities from the last function the analysis ran over,
847   // or the function it is currently running over.
848   assert(LastF && "Cannot print prior to running over a function");
849   for (const auto &BI : *LastF) {
850     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
851          ++SI) {
852       printEdgeProbability(OS << "  ", &BI, *SI);
853     }
854   }
855 }
856 
857 bool BranchProbabilityInfo::
858 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
859   // Hot probability is at least 4/5 = 80%
860   // FIXME: Compare against a static "hot" BranchProbability.
861   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
862 }
863 
864 const BasicBlock *
865 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
866   auto MaxProb = BranchProbability::getZero();
867   const BasicBlock *MaxSucc = nullptr;
868 
869   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
870     const BasicBlock *Succ = *I;
871     auto Prob = getEdgeProbability(BB, Succ);
872     if (Prob > MaxProb) {
873       MaxProb = Prob;
874       MaxSucc = Succ;
875     }
876   }
877 
878   // Hot probability is at least 4/5 = 80%
879   if (MaxProb > BranchProbability(4, 5))
880     return MaxSucc;
881 
882   return nullptr;
883 }
884 
885 /// Get the raw edge probability for the edge. If can't find it, return a
886 /// default probability 1/N where N is the number of successors. Here an edge is
887 /// specified using PredBlock and an
888 /// index to the successors.
889 BranchProbability
890 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
891                                           unsigned IndexInSuccessors) const {
892   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
893 
894   if (I != Probs.end())
895     return I->second;
896 
897   return {1, static_cast<uint32_t>(succ_size(Src))};
898 }
899 
900 BranchProbability
901 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
902                                           succ_const_iterator Dst) const {
903   return getEdgeProbability(Src, Dst.getSuccessorIndex());
904 }
905 
906 /// Get the raw edge probability calculated for the block pair. This returns the
907 /// sum of all raw edge probabilities from Src to Dst.
908 BranchProbability
909 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
910                                           const BasicBlock *Dst) const {
911   auto Prob = BranchProbability::getZero();
912   bool FoundProb = false;
913   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
914     if (*I == Dst) {
915       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
916       if (MapI != Probs.end()) {
917         FoundProb = true;
918         Prob += MapI->second;
919       }
920     }
921   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
922   return FoundProb ? Prob : BranchProbability(1, succ_num);
923 }
924 
925 /// Set the edge probability for a given edge specified by PredBlock and an
926 /// index to the successors.
927 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
928                                                unsigned IndexInSuccessors,
929                                                BranchProbability Prob) {
930   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
931   Handles.insert(BasicBlockCallbackVH(Src, this));
932   LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
933                     << IndexInSuccessors << " successor probability to " << Prob
934                     << "\n");
935 }
936 
937 raw_ostream &
938 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
939                                             const BasicBlock *Src,
940                                             const BasicBlock *Dst) const {
941   const BranchProbability Prob = getEdgeProbability(Src, Dst);
942   OS << "edge " << Src->getName() << " -> " << Dst->getName()
943      << " probability is " << Prob
944      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
945 
946   return OS;
947 }
948 
949 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
950   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
951     auto Key = I->first;
952     if (Key.first == BB)
953       Probs.erase(Key);
954   }
955 }
956 
957 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
958                                       const TargetLibraryInfo *TLI) {
959   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
960                     << " ----\n\n");
961   LastF = &F; // Store the last function we ran on for printing.
962   assert(PostDominatedByUnreachable.empty());
963   assert(PostDominatedByColdCall.empty());
964 
965   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
966   // FIXME: We could only calculate this if the CFG is known to be irreducible
967   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
968   int SccNum = 0;
969   SccInfo SccI;
970   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
971        ++It, ++SccNum) {
972     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
973     // catch them.
974     const std::vector<const BasicBlock *> &Scc = *It;
975     if (Scc.size() == 1)
976       continue;
977 
978     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
979     for (auto *BB : Scc) {
980       LLVM_DEBUG(dbgs() << " " << BB->getName());
981       SccI.SccNums[BB] = SccNum;
982     }
983     LLVM_DEBUG(dbgs() << "\n");
984   }
985 
986   // Walk the basic blocks in post-order so that we can build up state about
987   // the successors of a block iteratively.
988   for (auto BB : post_order(&F.getEntryBlock())) {
989     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
990                       << "\n");
991     updatePostDominatedByUnreachable(BB);
992     updatePostDominatedByColdCall(BB);
993     // If there is no at least two successors, no sense to set probability.
994     if (BB->getTerminator()->getNumSuccessors() < 2)
995       continue;
996     if (calcMetadataWeights(BB))
997       continue;
998     if (calcInvokeHeuristics(BB))
999       continue;
1000     if (calcUnreachableHeuristics(BB))
1001       continue;
1002     if (calcColdCallHeuristics(BB))
1003       continue;
1004     if (calcLoopBranchHeuristics(BB, LI, SccI))
1005       continue;
1006     if (calcPointerHeuristics(BB))
1007       continue;
1008     if (calcZeroHeuristics(BB, TLI))
1009       continue;
1010     if (calcFloatingPointHeuristics(BB))
1011       continue;
1012   }
1013 
1014   PostDominatedByUnreachable.clear();
1015   PostDominatedByColdCall.clear();
1016 
1017   if (PrintBranchProb &&
1018       (PrintBranchProbFuncName.empty() ||
1019        F.getName().equals(PrintBranchProbFuncName))) {
1020     print(dbgs());
1021   }
1022 }
1023 
1024 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1025     AnalysisUsage &AU) const {
1026   // We require DT so it's available when LI is available. The LI updating code
1027   // asserts that DT is also present so if we don't make sure that we have DT
1028   // here, that assert will trigger.
1029   AU.addRequired<DominatorTreeWrapperPass>();
1030   AU.addRequired<LoopInfoWrapperPass>();
1031   AU.addRequired<TargetLibraryInfoWrapperPass>();
1032   AU.setPreservesAll();
1033 }
1034 
1035 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1036   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1037   const TargetLibraryInfo &TLI =
1038       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1039   BPI.calculate(F, LI, &TLI);
1040   return false;
1041 }
1042 
1043 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1044 
1045 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1046                                              const Module *) const {
1047   BPI.print(OS);
1048 }
1049 
1050 AnalysisKey BranchProbabilityAnalysis::Key;
1051 BranchProbabilityInfo
1052 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1053   BranchProbabilityInfo BPI;
1054   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1055   return BPI;
1056 }
1057 
1058 PreservedAnalyses
1059 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1060   OS << "Printing analysis results of BPI for function "
1061      << "'" << F.getName() << "':"
1062      << "\n";
1063   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1064   return PreservedAnalyses::all();
1065 }
1066