1 //===--- SelectOptimize.cpp - Convert select to branches if profitable ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts selects to conditional jumps when profitable.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/ADT/SmallVector.h"
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/Analysis/BlockFrequencyInfo.h"
16 #include "llvm/Analysis/BranchProbabilityInfo.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
19 #include "llvm/Analysis/ProfileSummaryInfo.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/CodeGen/Passes.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 #include "llvm/CodeGen/TargetPassConfig.h"
24 #include "llvm/CodeGen/TargetSchedule.h"
25 #include "llvm/CodeGen/TargetSubtargetInfo.h"
26 #include "llvm/IR/BasicBlock.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/InitializePasses.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Target/TargetMachine.h"
34 #include "llvm/Transforms/Utils/SizeOpts.h"
35 #include <algorithm>
36 #include <memory>
37 #include <queue>
38 #include <stack>
39 #include <string>
40 
41 using namespace llvm;
42 
43 #define DEBUG_TYPE "select-optimize"
44 
45 STATISTIC(NumSelectOptAnalyzed,
46           "Number of select groups considered for conversion to branch");
47 STATISTIC(NumSelectConvertedExpColdOperand,
48           "Number of select groups converted due to expensive cold operand");
49 STATISTIC(NumSelectConvertedHighPred,
50           "Number of select groups converted due to high-predictability");
51 STATISTIC(NumSelectUnPred,
52           "Number of select groups not converted due to unpredictability");
53 STATISTIC(NumSelectColdBB,
54           "Number of select groups not converted due to cold basic block");
55 STATISTIC(NumSelectsConverted, "Number of selects converted");
56 
57 static cl::opt<unsigned> ColdOperandThreshold(
58     "cold-operand-threshold",
59     cl::desc("Maximum frequency of path for an operand to be considered cold."),
60     cl::init(20), cl::Hidden);
61 
62 static cl::opt<unsigned> ColdOperandMaxCostMultiplier(
63     "cold-operand-max-cost-multiplier",
64     cl::desc("Maximum cost multiplier of TCC_expensive for the dependence "
65              "slice of a cold operand to be considered inexpensive."),
66     cl::init(1), cl::Hidden);
67 
68 namespace {
69 
70 class SelectOptimize : public FunctionPass {
71   const TargetMachine *TM = nullptr;
72   const TargetSubtargetInfo *TSI;
73   const TargetLowering *TLI = nullptr;
74   const TargetTransformInfo *TTI = nullptr;
75   const LoopInfo *LI;
76   DominatorTree *DT;
77   std::unique_ptr<BlockFrequencyInfo> BFI;
78   std::unique_ptr<BranchProbabilityInfo> BPI;
79   ProfileSummaryInfo *PSI;
80   OptimizationRemarkEmitter *ORE;
81 
82 public:
83   static char ID;
84 
85   SelectOptimize() : FunctionPass(ID) {
86     initializeSelectOptimizePass(*PassRegistry::getPassRegistry());
87   }
88 
89   bool runOnFunction(Function &F) override;
90 
91   void getAnalysisUsage(AnalysisUsage &AU) const override {
92     AU.addRequired<ProfileSummaryInfoWrapperPass>();
93     AU.addRequired<TargetPassConfig>();
94     AU.addRequired<TargetTransformInfoWrapperPass>();
95     AU.addRequired<DominatorTreeWrapperPass>();
96     AU.addRequired<LoopInfoWrapperPass>();
97     AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
98   }
99 
100 private:
101   // Select groups consist of consecutive select instructions with the same
102   // condition.
103   using SelectGroup = SmallVector<SelectInst *, 2>;
104   using SelectGroups = SmallVector<SelectGroup, 2>;
105 
106   // Converts select instructions of a function to conditional jumps when deemed
107   // profitable. Returns true if at least one select was converted.
108   bool optimizeSelects(Function &F);
109 
110   // Heuristics for determining which select instructions can be profitably
111   // conveted to branches. Separate heuristics for selects in inner-most loops
112   // and the rest of code regions (base heuristics for non-inner-most loop
113   // regions).
114   void optimizeSelectsBase(Function &F, SelectGroups &ProfSIGroups);
115   void optimizeSelectsInnerLoops(Function &F, SelectGroups &ProfSIGroups);
116 
117   // Converts to branches the select groups that were deemed
118   // profitable-to-convert.
119   void convertProfitableSIGroups(SelectGroups &ProfSIGroups);
120 
121   // Splits selects of a given basic block into select groups.
122   void collectSelectGroups(BasicBlock &BB, SelectGroups &SIGroups);
123 
124   // Determines for which select groups it is profitable converting to branches
125   // (base heuristics).
126   void findProfitableSIGroupsBase(SelectGroups &SIGroups,
127                                   SelectGroups &ProfSIGroups);
128   // Determines if a select group should be converted to a branch (base
129   // heuristics).
130   bool isConvertToBranchProfitableBase(const SmallVector<SelectInst *, 2> &ASI);
131 
132   // Returns true if there are expensive instructions in the cold value
133   // operand's (if any) dependence slice of any of the selects of the given
134   // group.
135   bool hasExpensiveColdOperand(const SmallVector<SelectInst *, 2> &ASI);
136 
137   // For a given source instruction, collect its backwards dependence slice
138   // consisting of instructions exclusively computed for producing the operands
139   // of the source instruction.
140   void getExclBackwardsSlice(Instruction *I,
141                              SmallVector<Instruction *, 2> &Slice);
142 
143   // Returns true if the condition of the select is highly predictable.
144   bool isSelectHighlyPredictable(const SelectInst *SI);
145 
146   // Returns true if the target architecture supports lowering a given select.
147   bool isSelectKindSupported(SelectInst *SI);
148 };
149 } // namespace
150 
151 char SelectOptimize::ID = 0;
152 
153 INITIALIZE_PASS_BEGIN(SelectOptimize, DEBUG_TYPE, "Optimize selects", false,
154                       false)
155 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
156 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
157 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
158 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
159 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
160 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
161 INITIALIZE_PASS_END(SelectOptimize, DEBUG_TYPE, "Optimize selects", false,
162                     false)
163 
164 FunctionPass *llvm::createSelectOptimizePass() { return new SelectOptimize(); }
165 
166 bool SelectOptimize::runOnFunction(Function &F) {
167   TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
168   TSI = TM->getSubtargetImpl(F);
169   TLI = TSI->getTargetLowering();
170 
171   // If none of the select types is supported then skip this pass.
172   // This is an optimization pass. Legality issues will be handled by
173   // instruction selection.
174   if (!TLI->isSelectSupported(TargetLowering::ScalarValSelect) &&
175       !TLI->isSelectSupported(TargetLowering::ScalarCondVectorVal) &&
176       !TLI->isSelectSupported(TargetLowering::VectorMaskSelect))
177     return false;
178 
179   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
180   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
181   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
182   BPI.reset(new BranchProbabilityInfo(F, *LI));
183   BFI.reset(new BlockFrequencyInfo(F, *BPI, *LI));
184   PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
185   ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
186 
187   // When optimizing for size, selects are preferable over branches.
188   if (F.hasOptSize() || llvm::shouldOptimizeForSize(&F, PSI, BFI.get()))
189     return false;
190 
191   return optimizeSelects(F);
192 }
193 
194 bool SelectOptimize::optimizeSelects(Function &F) {
195   // Determine for which select groups it is profitable converting to branches.
196   SelectGroups ProfSIGroups;
197   // Base heuristics apply only to non-loops and outer loops.
198   optimizeSelectsBase(F, ProfSIGroups);
199   // Separate heuristics for inner-most loops.
200   optimizeSelectsInnerLoops(F, ProfSIGroups);
201 
202   // Convert to branches the select groups that were deemed
203   // profitable-to-convert.
204   convertProfitableSIGroups(ProfSIGroups);
205 
206   // Code modified if at least one select group was converted.
207   return !ProfSIGroups.empty();
208 }
209 
210 void SelectOptimize::optimizeSelectsBase(Function &F,
211                                          SelectGroups &ProfSIGroups) {
212   // Collect all the select groups.
213   SelectGroups SIGroups;
214   for (BasicBlock &BB : F) {
215     // Base heuristics apply only to non-loops and outer loops.
216     Loop *L = LI->getLoopFor(&BB);
217     if (L && L->isInnermost())
218       continue;
219     collectSelectGroups(BB, SIGroups);
220   }
221 
222   // Determine for which select groups it is profitable converting to branches.
223   findProfitableSIGroupsBase(SIGroups, ProfSIGroups);
224 }
225 
226 void SelectOptimize::optimizeSelectsInnerLoops(Function &F,
227                                                SelectGroups &ProfSIGroups) {}
228 
229 /// If \p isTrue is true, return the true value of \p SI, otherwise return
230 /// false value of \p SI. If the true/false value of \p SI is defined by any
231 /// select instructions in \p Selects, look through the defining select
232 /// instruction until the true/false value is not defined in \p Selects.
233 static Value *
234 getTrueOrFalseValue(SelectInst *SI, bool isTrue,
235                     const SmallPtrSet<const Instruction *, 2> &Selects) {
236   Value *V = nullptr;
237   for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI);
238        DefSI = dyn_cast<SelectInst>(V)) {
239     assert(DefSI->getCondition() == SI->getCondition() &&
240            "The condition of DefSI does not match with SI");
241     V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
242   }
243   assert(V && "Failed to get select true/false value");
244   return V;
245 }
246 
247 void SelectOptimize::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
248   for (SelectGroup &ASI : ProfSIGroups) {
249     // TODO: eliminate the redundancy of logic transforming selects to branches
250     // by removing CodeGenPrepare::optimizeSelectInst and optimizing here
251     // selects for all cases (with and without profile information).
252 
253     // Transform a sequence like this:
254     //    start:
255     //       %cmp = cmp uge i32 %a, %b
256     //       %sel = select i1 %cmp, i32 %c, i32 %d
257     //
258     // Into:
259     //    start:
260     //       %cmp = cmp uge i32 %a, %b
261     //       %cmp.frozen = freeze %cmp
262     //       br i1 %cmp.frozen, label %select.end, label %select.false
263     //    select.false:
264     //       br label %select.end
265     //    select.end:
266     //       %sel = phi i32 [ %c, %start ], [ %d, %select.false ]
267     //
268     // %cmp should be frozen, otherwise it may introduce undefined behavior.
269 
270     // We split the block containing the select(s) into two blocks.
271     SelectInst *SI = ASI.front();
272     SelectInst *LastSI = ASI.back();
273     BasicBlock *StartBlock = SI->getParent();
274     BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI));
275     BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end");
276     BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency());
277     // Delete the unconditional branch that was just created by the split.
278     StartBlock->getTerminator()->eraseFromParent();
279 
280     // Move any debug/pseudo instructions that were in-between the select
281     // group to the newly-created end block.
282     SmallVector<Instruction *, 2> DebugPseudoINS;
283     auto DIt = SI->getIterator();
284     while (&*DIt != LastSI) {
285       if (DIt->isDebugOrPseudoInst())
286         DebugPseudoINS.push_back(&*DIt);
287       DIt++;
288     }
289     for (auto DI : DebugPseudoINS) {
290       DI->moveBefore(&*EndBlock->getFirstInsertionPt());
291     }
292 
293     // These are the new basic blocks for the conditional branch.
294     // For now, no instruction sinking to the true/false blocks.
295     // Thus both True and False blocks will be empty.
296     BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr;
297 
298     // Use the 'false' side for a new input value to the PHI.
299     FalseBlock = BasicBlock::Create(SI->getContext(), "select.false",
300                                     EndBlock->getParent(), EndBlock);
301     auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
302     FalseBranch->setDebugLoc(SI->getDebugLoc());
303 
304     // For the 'true' side the path originates from the start block from the
305     // point view of the new PHI.
306     TrueBlock = StartBlock;
307 
308     // Insert the real conditional branch based on the original condition.
309     BasicBlock *TT, *FT;
310     TT = EndBlock;
311     FT = FalseBlock;
312     IRBuilder<> IB(SI);
313     auto *CondFr =
314         IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen");
315     IB.CreateCondBr(CondFr, TT, FT, SI);
316 
317     SmallPtrSet<const Instruction *, 2> INS;
318     INS.insert(ASI.begin(), ASI.end());
319     // Use reverse iterator because later select may use the value of the
320     // earlier select, and we need to propagate value through earlier select
321     // to get the PHI operand.
322     for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
323       SelectInst *SI = *It;
324       // The select itself is replaced with a PHI Node.
325       PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
326       PN->takeName(SI);
327       PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock);
328       PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock);
329       PN->setDebugLoc(SI->getDebugLoc());
330 
331       SI->replaceAllUsesWith(PN);
332       SI->eraseFromParent();
333       INS.erase(SI);
334       ++NumSelectsConverted;
335     }
336   }
337 }
338 
339 void SelectOptimize::collectSelectGroups(BasicBlock &BB,
340                                          SelectGroups &SIGroups) {
341   BasicBlock::iterator BBIt = BB.begin();
342   while (BBIt != BB.end()) {
343     Instruction *I = &*BBIt++;
344     if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
345       SelectGroup SIGroup;
346       SIGroup.push_back(SI);
347       while (BBIt != BB.end()) {
348         Instruction *NI = &*BBIt;
349         SelectInst *NSI = dyn_cast<SelectInst>(NI);
350         if (NSI && SI->getCondition() == NSI->getCondition()) {
351           SIGroup.push_back(NSI);
352         } else if (!NI->isDebugOrPseudoInst()) {
353           // Debug/pseudo instructions should be skipped and not prevent the
354           // formation of a select group.
355           break;
356         }
357         ++BBIt;
358       }
359 
360       // If the select type is not supported, no point optimizing it.
361       // Instruction selection will take care of it.
362       if (!isSelectKindSupported(SI))
363         continue;
364 
365       SIGroups.push_back(SIGroup);
366     }
367   }
368 }
369 
370 void SelectOptimize::findProfitableSIGroupsBase(SelectGroups &SIGroups,
371                                                 SelectGroups &ProfSIGroups) {
372   for (SelectGroup &ASI : SIGroups) {
373     ++NumSelectOptAnalyzed;
374     if (isConvertToBranchProfitableBase(ASI))
375       ProfSIGroups.push_back(ASI);
376   }
377 }
378 
379 bool SelectOptimize::isConvertToBranchProfitableBase(
380     const SmallVector<SelectInst *, 2> &ASI) {
381   SelectInst *SI = ASI.front();
382   OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", SI);
383   OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", SI);
384 
385   // Skip cold basic blocks. Better to optimize for size for cold blocks.
386   if (PSI->isColdBlock(SI->getParent(), BFI.get())) {
387     ++NumSelectColdBB;
388     ORmiss << "Not converted to branch because of cold basic block. ";
389     ORE->emit(ORmiss);
390     return false;
391   }
392 
393   // If unpredictable, branch form is less profitable.
394   if (SI->getMetadata(LLVMContext::MD_unpredictable)) {
395     ++NumSelectUnPred;
396     ORmiss << "Not converted to branch because of unpredictable branch. ";
397     ORE->emit(ORmiss);
398     return false;
399   }
400 
401   // If highly predictable, branch form is more profitable, unless a
402   // predictable select is inexpensive in the target architecture.
403   if (isSelectHighlyPredictable(SI) && TLI->isPredictableSelectExpensive()) {
404     ++NumSelectConvertedHighPred;
405     OR << "Converted to branch because of highly predictable branch. ";
406     ORE->emit(OR);
407     return true;
408   }
409 
410   // Look for expensive instructions in the cold operand's (if any) dependence
411   // slice of any of the selects in the group.
412   if (hasExpensiveColdOperand(ASI)) {
413     ++NumSelectConvertedExpColdOperand;
414     OR << "Converted to branch because of expensive cold operand.";
415     ORE->emit(OR);
416     return true;
417   }
418 
419   ORmiss << "Not profitable to convert to branch (base heuristic).";
420   ORE->emit(ORmiss);
421   return false;
422 }
423 
424 static InstructionCost divideNearest(InstructionCost Numerator,
425                                      uint64_t Denominator) {
426   return (Numerator + (Denominator / 2)) / Denominator;
427 }
428 
429 bool SelectOptimize::hasExpensiveColdOperand(
430     const SmallVector<SelectInst *, 2> &ASI) {
431   bool ColdOperand = false;
432   uint64_t TrueWeight, FalseWeight, TotalWeight;
433   if (ASI.front()->extractProfMetadata(TrueWeight, FalseWeight)) {
434     uint64_t MinWeight = std::min(TrueWeight, FalseWeight);
435     TotalWeight = TrueWeight + FalseWeight;
436     // Is there a path with frequency <ColdOperandThreshold% (default:20%) ?
437     ColdOperand = TotalWeight * ColdOperandThreshold > 100 * MinWeight;
438   } else if (PSI->hasProfileSummary()) {
439     OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", ASI.front());
440     ORmiss << "Profile data available but missing branch-weights metadata for "
441               "select instruction. ";
442     ORE->emit(ORmiss);
443   }
444   if (!ColdOperand)
445     return false;
446   // Check if the cold path's dependence slice is expensive for any of the
447   // selects of the group.
448   for (SelectInst *SI : ASI) {
449     Instruction *ColdI = nullptr;
450     uint64_t HotWeight;
451     if (TrueWeight < FalseWeight) {
452       ColdI = dyn_cast<Instruction>(SI->getTrueValue());
453       HotWeight = FalseWeight;
454     } else {
455       ColdI = dyn_cast<Instruction>(SI->getFalseValue());
456       HotWeight = TrueWeight;
457     }
458     if (ColdI) {
459       SmallVector<Instruction *, 2> ColdSlice;
460       getExclBackwardsSlice(ColdI, ColdSlice);
461       InstructionCost SliceCost = 0;
462       for (auto *ColdII : ColdSlice) {
463         SliceCost +=
464             TTI->getInstructionCost(ColdII, TargetTransformInfo::TCK_Latency);
465       }
466       // The colder the cold value operand of the select is the more expensive
467       // the cmov becomes for computing the cold value operand every time. Thus,
468       // the colder the cold operand is the more its cost counts.
469       // Get nearest integer cost adjusted for coldness.
470       InstructionCost AdjSliceCost =
471           divideNearest(SliceCost * HotWeight, TotalWeight);
472       if (AdjSliceCost >=
473           ColdOperandMaxCostMultiplier * TargetTransformInfo::TCC_Expensive)
474         return true;
475     }
476   }
477   return false;
478 }
479 
480 // For a given source instruction, collect its backwards dependence slice
481 // consisting of instructions exclusively computed for the purpose of producing
482 // the operands of the source instruction. As an approximation
483 // (sufficiently-accurate in practice), we populate this set with the
484 // instructions of the backwards dependence slice that only have one-use and
485 // form an one-use chain that leads to the source instruction.
486 void SelectOptimize::getExclBackwardsSlice(
487     Instruction *I, SmallVector<Instruction *, 2> &Slice) {
488   SmallPtrSet<Instruction *, 2> Visited;
489   std::queue<Instruction *> Worklist;
490   Worklist.push(I);
491   while (!Worklist.empty()) {
492     Instruction *II = Worklist.front();
493     Worklist.pop();
494 
495     // Avoid cycles.
496     if (Visited.count(II))
497       continue;
498     Visited.insert(II);
499 
500     if (!II->hasOneUse())
501       continue;
502 
503     // Avoid considering instructions with less frequency than the source
504     // instruction (i.e., avoid colder code regions of the dependence slice).
505     if (BFI->getBlockFreq(II->getParent()) < BFI->getBlockFreq(I->getParent()))
506       continue;
507 
508     // Eligible one-use instruction added to the dependence slice.
509     Slice.push_back(II);
510 
511     // Explore all the operands of the current instruction to expand the slice.
512     for (unsigned k = 0; k < II->getNumOperands(); ++k)
513       if (auto *OpI = dyn_cast<Instruction>(II->getOperand(k)))
514         Worklist.push(OpI);
515   }
516 }
517 
518 bool SelectOptimize::isSelectHighlyPredictable(const SelectInst *SI) {
519   uint64_t TrueWeight, FalseWeight;
520   if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
521     uint64_t Max = std::max(TrueWeight, FalseWeight);
522     uint64_t Sum = TrueWeight + FalseWeight;
523     if (Sum != 0) {
524       auto Probability = BranchProbability::getBranchProbability(Max, Sum);
525       if (Probability > TTI->getPredictableBranchThreshold())
526         return true;
527     }
528   }
529   return false;
530 }
531 
532 bool SelectOptimize::isSelectKindSupported(SelectInst *SI) {
533   bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1);
534   if (VectorCond)
535     return false;
536   TargetLowering::SelectSupportKind SelectKind;
537   if (SI->getType()->isVectorTy())
538     SelectKind = TargetLowering::ScalarCondVectorVal;
539   else
540     SelectKind = TargetLowering::ScalarValSelect;
541   return TLI->isSelectSupported(SelectKind);
542 }
543