1 //===--- SelectOptimize.cpp - Convert select to branches if profitable ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts selects to conditional jumps when profitable.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/ADT/SmallVector.h"
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/Analysis/BlockFrequencyInfo.h"
16 #include "llvm/Analysis/BranchProbabilityInfo.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/CodeGen/TargetSchedule.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Target/TargetMachine.h"
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "select-optimize"
34 
35 STATISTIC(NumSelectsConverted, "Number of selects converted");
36 
37 namespace {
38 
39 class SelectOptimize : public FunctionPass {
40   const TargetMachine *TM = nullptr;
41   const TargetSubtargetInfo *TSI;
42   const TargetLowering *TLI = nullptr;
43   const LoopInfo *LI;
44   std::unique_ptr<BlockFrequencyInfo> BFI;
45   std::unique_ptr<BranchProbabilityInfo> BPI;
46 
47 public:
48   static char ID;
49   SelectOptimize() : FunctionPass(ID) {
50     initializeSelectOptimizePass(*PassRegistry::getPassRegistry());
51   }
52 
53   bool runOnFunction(Function &F) override;
54 
55   void getAnalysisUsage(AnalysisUsage &AU) const override {
56     AU.addRequired<TargetPassConfig>();
57     AU.addRequired<LoopInfoWrapperPass>();
58   }
59 
60 private:
61   // Select groups consist of consecutive select instructions with the same
62   // condition.
63   using SelectGroup = SmallVector<SelectInst *, 2>;
64   using SelectGroups = SmallVector<SelectGroup, 2>;
65 
66   bool optimizeSelects(Function &F);
67   void convertProfitableSIGroups(SelectGroups &ProfSIGroups);
68   void collectSelectGroups(BasicBlock &BB, SelectGroups &SIGroups);
69   bool isSelectKindSupported(SelectInst *SI);
70 };
71 } // namespace
72 
73 char SelectOptimize::ID = 0;
74 
75 INITIALIZE_PASS_BEGIN(SelectOptimize, DEBUG_TYPE, "Optimize selects", false,
76                       false)
77 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
78 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
79 INITIALIZE_PASS_END(SelectOptimize, DEBUG_TYPE, "Optimize selects", false,
80                     false)
81 
82 FunctionPass *llvm::createSelectOptimizePass() { return new SelectOptimize(); }
83 
84 bool SelectOptimize::runOnFunction(Function &F) {
85   TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
86   TSI = TM->getSubtargetImpl(F);
87   TLI = TSI->getTargetLowering();
88   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
89   BPI.reset(new BranchProbabilityInfo(F, *LI));
90   BFI.reset(new BlockFrequencyInfo(F, *BPI, *LI));
91 
92   return optimizeSelects(F);
93 }
94 
95 bool SelectOptimize::optimizeSelects(Function &F) {
96   // Collect all the select groups.
97   SelectGroups SIGroups;
98   for (BasicBlock &BB : F) {
99     collectSelectGroups(BB, SIGroups);
100   }
101 
102   // Determine for which select groups it is profitable converting to branches.
103   SelectGroups ProfSIGroups;
104   // For now assume that all select groups can be profitably converted to
105   // branches.
106   for (SelectGroup &ASI : SIGroups) {
107     ProfSIGroups.push_back(ASI);
108   }
109 
110   // Convert to branches the select groups that were deemed
111   // profitable-to-convert.
112   convertProfitableSIGroups(ProfSIGroups);
113 
114   // Code modified if at least one select group was converted.
115   return !ProfSIGroups.empty();
116 }
117 
118 /// If \p isTrue is true, return the true value of \p SI, otherwise return
119 /// false value of \p SI. If the true/false value of \p SI is defined by any
120 /// select instructions in \p Selects, look through the defining select
121 /// instruction until the true/false value is not defined in \p Selects.
122 static Value *
123 getTrueOrFalseValue(SelectInst *SI, bool isTrue,
124                     const SmallPtrSet<const Instruction *, 2> &Selects) {
125   Value *V = nullptr;
126   for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI);
127        DefSI = dyn_cast<SelectInst>(V)) {
128     assert(DefSI->getCondition() == SI->getCondition() &&
129            "The condition of DefSI does not match with SI");
130     V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
131   }
132   assert(V && "Failed to get select true/false value");
133   return V;
134 }
135 
136 void SelectOptimize::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
137   for (SelectGroup &ASI : ProfSIGroups) {
138     // TODO: eliminate the redundancy of logic transforming selects to branches
139     // by removing CodeGenPrepare::optimizeSelectInst and optimizing here
140     // selects for all cases (with and without profile information).
141 
142     // Transform a sequence like this:
143     //    start:
144     //       %cmp = cmp uge i32 %a, %b
145     //       %sel = select i1 %cmp, i32 %c, i32 %d
146     //
147     // Into:
148     //    start:
149     //       %cmp = cmp uge i32 %a, %b
150     //       %cmp.frozen = freeze %cmp
151     //       br i1 %cmp.frozen, label %select.end, label %select.false
152     //    select.false:
153     //       br label %select.end
154     //    select.end:
155     //       %sel = phi i32 [ %c, %start ], [ %d, %select.false ]
156     //
157     // %cmp should be frozen, otherwise it may introduce undefined behavior.
158 
159     // We split the block containing the select(s) into two blocks.
160     SelectInst *SI = ASI.front();
161     SelectInst *LastSI = ASI.back();
162     BasicBlock *StartBlock = SI->getParent();
163     BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI));
164     BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end");
165     BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency());
166     // Delete the unconditional branch that was just created by the split.
167     StartBlock->getTerminator()->eraseFromParent();
168 
169     // Move any debug/pseudo instructions that were in-between the select
170     // group to the newly-created end block.
171     SmallVector<Instruction *, 2> DebugPseudoINS;
172     auto DIt = SI->getIterator();
173     while (&*DIt != LastSI) {
174       if (DIt->isDebugOrPseudoInst())
175         DebugPseudoINS.push_back(&*DIt);
176       DIt++;
177     }
178     for (auto DI : DebugPseudoINS) {
179       DI->moveBefore(&*EndBlock->getFirstInsertionPt());
180     }
181 
182     // These are the new basic blocks for the conditional branch.
183     // For now, no instruction sinking to the true/false blocks.
184     // Thus both True and False blocks will be empty.
185     BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr;
186 
187     // Use the 'false' side for a new input value to the PHI.
188     FalseBlock = BasicBlock::Create(SI->getContext(), "select.false",
189                                     EndBlock->getParent(), EndBlock);
190     auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
191     FalseBranch->setDebugLoc(SI->getDebugLoc());
192 
193     // For the 'true' side the path originates from the start block from the
194     // point view of the new PHI.
195     TrueBlock = StartBlock;
196 
197     // Insert the real conditional branch based on the original condition.
198     BasicBlock *TT, *FT;
199     TT = EndBlock;
200     FT = FalseBlock;
201     IRBuilder<> IB(SI);
202     auto *CondFr =
203         IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen");
204     IB.CreateCondBr(CondFr, TT, FT, SI);
205 
206     SmallPtrSet<const Instruction *, 2> INS;
207     INS.insert(ASI.begin(), ASI.end());
208     // Use reverse iterator because later select may use the value of the
209     // earlier select, and we need to propagate value through earlier select
210     // to get the PHI operand.
211     for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
212       SelectInst *SI = *It;
213       // The select itself is replaced with a PHI Node.
214       PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
215       PN->takeName(SI);
216       PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock);
217       PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock);
218       PN->setDebugLoc(SI->getDebugLoc());
219 
220       SI->replaceAllUsesWith(PN);
221       SI->eraseFromParent();
222       INS.erase(SI);
223       ++NumSelectsConverted;
224     }
225   }
226 }
227 
228 void SelectOptimize::collectSelectGroups(BasicBlock &BB,
229                                          SelectGroups &SIGroups) {
230   BasicBlock::iterator BBIt = BB.begin();
231   while (BBIt != BB.end()) {
232     Instruction *I = &*BBIt++;
233     if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
234       SelectGroup SIGroup;
235       SIGroup.push_back(SI);
236       while (BBIt != BB.end()) {
237         Instruction *NI = &*BBIt;
238         SelectInst *NSI = dyn_cast<SelectInst>(NI);
239         if (NSI && SI->getCondition() == NSI->getCondition()) {
240           SIGroup.push_back(NSI);
241         } else if (!NI->isDebugOrPseudoInst()) {
242           // Debug/pseudo instructions should be skipped and not prevent the
243           // formation of a select group.
244           break;
245         }
246         ++BBIt;
247       }
248 
249       // If the select type is not supported, no point optimizing it.
250       // Instruction selection will take care of it.
251       if (!isSelectKindSupported(SI))
252         continue;
253 
254       SIGroups.push_back(SIGroup);
255     }
256   }
257 }
258 
259 bool SelectOptimize::isSelectKindSupported(SelectInst *SI) {
260   bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1);
261   if (VectorCond)
262     return false;
263   TargetLowering::SelectSupportKind SelectKind;
264   if (SI->getType()->isVectorTy())
265     SelectKind = TargetLowering::ScalarCondVectorVal;
266   else
267     SelectKind = TargetLowering::ScalarValSelect;
268   return TLI->isSelectSupported(SelectKind);
269 }
270