1ca7c307dSSotiris Apostolakis //===--- SelectOptimize.cpp - Convert select to branches if profitable ---===//
2ca7c307dSSotiris Apostolakis //
3ca7c307dSSotiris Apostolakis // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ca7c307dSSotiris Apostolakis // See https://llvm.org/LICENSE.txt for license information.
5ca7c307dSSotiris Apostolakis // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ca7c307dSSotiris Apostolakis //
7ca7c307dSSotiris Apostolakis //===----------------------------------------------------------------------===//
8ca7c307dSSotiris Apostolakis //
9ca7c307dSSotiris Apostolakis // This pass converts selects to conditional jumps when profitable.
10ca7c307dSSotiris Apostolakis //
11ca7c307dSSotiris Apostolakis //===----------------------------------------------------------------------===//
12ca7c307dSSotiris Apostolakis 
13*97c3ef5cSSotiris Apostolakis #include "llvm/ADT/SmallVector.h"
14*97c3ef5cSSotiris Apostolakis #include "llvm/ADT/Statistic.h"
15*97c3ef5cSSotiris Apostolakis #include "llvm/Analysis/BlockFrequencyInfo.h"
16*97c3ef5cSSotiris Apostolakis #include "llvm/Analysis/BranchProbabilityInfo.h"
17*97c3ef5cSSotiris Apostolakis #include "llvm/Analysis/LoopInfo.h"
18ca7c307dSSotiris Apostolakis #include "llvm/CodeGen/Passes.h"
19*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetLowering.h"
20*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetPassConfig.h"
21*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetSchedule.h"
22*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetSubtargetInfo.h"
23*97c3ef5cSSotiris Apostolakis #include "llvm/IR/BasicBlock.h"
24ca7c307dSSotiris Apostolakis #include "llvm/IR/Function.h"
25*97c3ef5cSSotiris Apostolakis #include "llvm/IR/IRBuilder.h"
26*97c3ef5cSSotiris Apostolakis #include "llvm/IR/Instruction.h"
27ca7c307dSSotiris Apostolakis #include "llvm/InitializePasses.h"
28ca7c307dSSotiris Apostolakis #include "llvm/Pass.h"
29*97c3ef5cSSotiris Apostolakis #include "llvm/Target/TargetMachine.h"
30ca7c307dSSotiris Apostolakis 
31ca7c307dSSotiris Apostolakis using namespace llvm;
32ca7c307dSSotiris Apostolakis 
33*97c3ef5cSSotiris Apostolakis #define DEBUG_TYPE "select-optimize"
34*97c3ef5cSSotiris Apostolakis 
35*97c3ef5cSSotiris Apostolakis STATISTIC(NumSelectsConverted, "Number of selects converted");
36*97c3ef5cSSotiris Apostolakis 
37ca7c307dSSotiris Apostolakis namespace {
38ca7c307dSSotiris Apostolakis 
39ca7c307dSSotiris Apostolakis class SelectOptimize : public FunctionPass {
40*97c3ef5cSSotiris Apostolakis   const TargetMachine *TM = nullptr;
41*97c3ef5cSSotiris Apostolakis   const TargetSubtargetInfo *TSI;
42*97c3ef5cSSotiris Apostolakis   const TargetLowering *TLI = nullptr;
43*97c3ef5cSSotiris Apostolakis   const LoopInfo *LI;
44*97c3ef5cSSotiris Apostolakis   std::unique_ptr<BlockFrequencyInfo> BFI;
45*97c3ef5cSSotiris Apostolakis   std::unique_ptr<BranchProbabilityInfo> BPI;
46*97c3ef5cSSotiris Apostolakis 
47ca7c307dSSotiris Apostolakis public:
48ca7c307dSSotiris Apostolakis   static char ID;
49ca7c307dSSotiris Apostolakis   SelectOptimize() : FunctionPass(ID) {
50ca7c307dSSotiris Apostolakis     initializeSelectOptimizePass(*PassRegistry::getPassRegistry());
51ca7c307dSSotiris Apostolakis   }
52ca7c307dSSotiris Apostolakis 
53ca7c307dSSotiris Apostolakis   bool runOnFunction(Function &F) override;
54ca7c307dSSotiris Apostolakis 
55*97c3ef5cSSotiris Apostolakis   void getAnalysisUsage(AnalysisUsage &AU) const override {
56*97c3ef5cSSotiris Apostolakis     AU.addRequired<TargetPassConfig>();
57*97c3ef5cSSotiris Apostolakis     AU.addRequired<LoopInfoWrapperPass>();
58*97c3ef5cSSotiris Apostolakis   }
59*97c3ef5cSSotiris Apostolakis 
60*97c3ef5cSSotiris Apostolakis private:
61*97c3ef5cSSotiris Apostolakis   // Select groups consist of consecutive select instructions with the same
62*97c3ef5cSSotiris Apostolakis   // condition.
63*97c3ef5cSSotiris Apostolakis   using SelectGroup = SmallVector<SelectInst *, 2>;
64*97c3ef5cSSotiris Apostolakis   using SelectGroups = SmallVector<SelectGroup, 2>;
65*97c3ef5cSSotiris Apostolakis 
66*97c3ef5cSSotiris Apostolakis   bool optimizeSelects(Function &F);
67*97c3ef5cSSotiris Apostolakis   void convertProfitableSIGroups(SelectGroups &ProfSIGroups);
68*97c3ef5cSSotiris Apostolakis   void collectSelectGroups(BasicBlock &BB, SelectGroups &SIGroups);
69*97c3ef5cSSotiris Apostolakis   bool isSelectKindSupported(SelectInst *SI);
70ca7c307dSSotiris Apostolakis };
71ca7c307dSSotiris Apostolakis } // namespace
72ca7c307dSSotiris Apostolakis 
73ca7c307dSSotiris Apostolakis char SelectOptimize::ID = 0;
74*97c3ef5cSSotiris Apostolakis 
75*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_BEGIN(SelectOptimize, DEBUG_TYPE, "Optimize selects", false,
76*97c3ef5cSSotiris Apostolakis                       false)
77*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
78*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
79*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_END(SelectOptimize, DEBUG_TYPE, "Optimize selects", false,
80ca7c307dSSotiris Apostolakis                     false)
81ca7c307dSSotiris Apostolakis 
82ca7c307dSSotiris Apostolakis FunctionPass *llvm::createSelectOptimizePass() { return new SelectOptimize(); }
83ca7c307dSSotiris Apostolakis 
84ca7c307dSSotiris Apostolakis bool SelectOptimize::runOnFunction(Function &F) {
85*97c3ef5cSSotiris Apostolakis   TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
86*97c3ef5cSSotiris Apostolakis   TSI = TM->getSubtargetImpl(F);
87*97c3ef5cSSotiris Apostolakis   TLI = TSI->getTargetLowering();
88*97c3ef5cSSotiris Apostolakis   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
89*97c3ef5cSSotiris Apostolakis   BPI.reset(new BranchProbabilityInfo(F, *LI));
90*97c3ef5cSSotiris Apostolakis   BFI.reset(new BlockFrequencyInfo(F, *BPI, *LI));
91*97c3ef5cSSotiris Apostolakis 
92*97c3ef5cSSotiris Apostolakis   return optimizeSelects(F);
93*97c3ef5cSSotiris Apostolakis }
94*97c3ef5cSSotiris Apostolakis 
95*97c3ef5cSSotiris Apostolakis bool SelectOptimize::optimizeSelects(Function &F) {
96*97c3ef5cSSotiris Apostolakis   // Collect all the select groups.
97*97c3ef5cSSotiris Apostolakis   SelectGroups SIGroups;
98*97c3ef5cSSotiris Apostolakis   for (BasicBlock &BB : F) {
99*97c3ef5cSSotiris Apostolakis     collectSelectGroups(BB, SIGroups);
100*97c3ef5cSSotiris Apostolakis   }
101*97c3ef5cSSotiris Apostolakis 
102*97c3ef5cSSotiris Apostolakis   // Determine for which select groups it is profitable converting to branches.
103*97c3ef5cSSotiris Apostolakis   SelectGroups ProfSIGroups;
104*97c3ef5cSSotiris Apostolakis   // For now assume that all select groups can be profitably converted to
105*97c3ef5cSSotiris Apostolakis   // branches.
106*97c3ef5cSSotiris Apostolakis   for (SelectGroup &ASI : SIGroups) {
107*97c3ef5cSSotiris Apostolakis     ProfSIGroups.push_back(ASI);
108*97c3ef5cSSotiris Apostolakis   }
109*97c3ef5cSSotiris Apostolakis 
110*97c3ef5cSSotiris Apostolakis   // Convert to branches the select groups that were deemed
111*97c3ef5cSSotiris Apostolakis   // profitable-to-convert.
112*97c3ef5cSSotiris Apostolakis   convertProfitableSIGroups(ProfSIGroups);
113*97c3ef5cSSotiris Apostolakis 
114*97c3ef5cSSotiris Apostolakis   // Code modified if at least one select group was converted.
115*97c3ef5cSSotiris Apostolakis   return !ProfSIGroups.empty();
116*97c3ef5cSSotiris Apostolakis }
117*97c3ef5cSSotiris Apostolakis 
118*97c3ef5cSSotiris Apostolakis /// If \p isTrue is true, return the true value of \p SI, otherwise return
119*97c3ef5cSSotiris Apostolakis /// false value of \p SI. If the true/false value of \p SI is defined by any
120*97c3ef5cSSotiris Apostolakis /// select instructions in \p Selects, look through the defining select
121*97c3ef5cSSotiris Apostolakis /// instruction until the true/false value is not defined in \p Selects.
122*97c3ef5cSSotiris Apostolakis static Value *
123*97c3ef5cSSotiris Apostolakis getTrueOrFalseValue(SelectInst *SI, bool isTrue,
124*97c3ef5cSSotiris Apostolakis                     const SmallPtrSet<const Instruction *, 2> &Selects) {
125*97c3ef5cSSotiris Apostolakis   Value *V = nullptr;
126*97c3ef5cSSotiris Apostolakis   for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI);
127*97c3ef5cSSotiris Apostolakis        DefSI = dyn_cast<SelectInst>(V)) {
128*97c3ef5cSSotiris Apostolakis     assert(DefSI->getCondition() == SI->getCondition() &&
129*97c3ef5cSSotiris Apostolakis            "The condition of DefSI does not match with SI");
130*97c3ef5cSSotiris Apostolakis     V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
131*97c3ef5cSSotiris Apostolakis   }
132*97c3ef5cSSotiris Apostolakis   assert(V && "Failed to get select true/false value");
133*97c3ef5cSSotiris Apostolakis   return V;
134*97c3ef5cSSotiris Apostolakis }
135*97c3ef5cSSotiris Apostolakis 
136*97c3ef5cSSotiris Apostolakis void SelectOptimize::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
137*97c3ef5cSSotiris Apostolakis   for (SelectGroup &ASI : ProfSIGroups) {
138*97c3ef5cSSotiris Apostolakis     // TODO: eliminate the redundancy of logic transforming selects to branches
139*97c3ef5cSSotiris Apostolakis     // by removing CodeGenPrepare::optimizeSelectInst and optimizing here
140*97c3ef5cSSotiris Apostolakis     // selects for all cases (with and without profile information).
141*97c3ef5cSSotiris Apostolakis 
142*97c3ef5cSSotiris Apostolakis     // Transform a sequence like this:
143*97c3ef5cSSotiris Apostolakis     //    start:
144*97c3ef5cSSotiris Apostolakis     //       %cmp = cmp uge i32 %a, %b
145*97c3ef5cSSotiris Apostolakis     //       %sel = select i1 %cmp, i32 %c, i32 %d
146*97c3ef5cSSotiris Apostolakis     //
147*97c3ef5cSSotiris Apostolakis     // Into:
148*97c3ef5cSSotiris Apostolakis     //    start:
149*97c3ef5cSSotiris Apostolakis     //       %cmp = cmp uge i32 %a, %b
150*97c3ef5cSSotiris Apostolakis     //       %cmp.frozen = freeze %cmp
151*97c3ef5cSSotiris Apostolakis     //       br i1 %cmp.frozen, label %select.end, label %select.false
152*97c3ef5cSSotiris Apostolakis     //    select.false:
153*97c3ef5cSSotiris Apostolakis     //       br label %select.end
154*97c3ef5cSSotiris Apostolakis     //    select.end:
155*97c3ef5cSSotiris Apostolakis     //       %sel = phi i32 [ %c, %start ], [ %d, %select.false ]
156*97c3ef5cSSotiris Apostolakis     //
157*97c3ef5cSSotiris Apostolakis     // %cmp should be frozen, otherwise it may introduce undefined behavior.
158*97c3ef5cSSotiris Apostolakis 
159*97c3ef5cSSotiris Apostolakis     // We split the block containing the select(s) into two blocks.
160*97c3ef5cSSotiris Apostolakis     SelectInst *SI = ASI.front();
161*97c3ef5cSSotiris Apostolakis     SelectInst *LastSI = ASI.back();
162*97c3ef5cSSotiris Apostolakis     BasicBlock *StartBlock = SI->getParent();
163*97c3ef5cSSotiris Apostolakis     BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI));
164*97c3ef5cSSotiris Apostolakis     BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end");
165*97c3ef5cSSotiris Apostolakis     BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency());
166*97c3ef5cSSotiris Apostolakis     // Delete the unconditional branch that was just created by the split.
167*97c3ef5cSSotiris Apostolakis     StartBlock->getTerminator()->eraseFromParent();
168*97c3ef5cSSotiris Apostolakis 
169*97c3ef5cSSotiris Apostolakis     // Move any debug/pseudo instructions that were in-between the select
170*97c3ef5cSSotiris Apostolakis     // group to the newly-created end block.
171*97c3ef5cSSotiris Apostolakis     SmallVector<Instruction *, 2> DebugPseudoINS;
172*97c3ef5cSSotiris Apostolakis     auto DIt = SI->getIterator();
173*97c3ef5cSSotiris Apostolakis     while (&*DIt != LastSI) {
174*97c3ef5cSSotiris Apostolakis       if (DIt->isDebugOrPseudoInst())
175*97c3ef5cSSotiris Apostolakis         DebugPseudoINS.push_back(&*DIt);
176*97c3ef5cSSotiris Apostolakis       DIt++;
177*97c3ef5cSSotiris Apostolakis     }
178*97c3ef5cSSotiris Apostolakis     for (auto DI : DebugPseudoINS) {
179*97c3ef5cSSotiris Apostolakis       DI->moveBefore(&*EndBlock->getFirstInsertionPt());
180*97c3ef5cSSotiris Apostolakis     }
181*97c3ef5cSSotiris Apostolakis 
182*97c3ef5cSSotiris Apostolakis     // These are the new basic blocks for the conditional branch.
183*97c3ef5cSSotiris Apostolakis     // For now, no instruction sinking to the true/false blocks.
184*97c3ef5cSSotiris Apostolakis     // Thus both True and False blocks will be empty.
185*97c3ef5cSSotiris Apostolakis     BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr;
186*97c3ef5cSSotiris Apostolakis 
187*97c3ef5cSSotiris Apostolakis     // Use the 'false' side for a new input value to the PHI.
188*97c3ef5cSSotiris Apostolakis     FalseBlock = BasicBlock::Create(SI->getContext(), "select.false",
189*97c3ef5cSSotiris Apostolakis                                     EndBlock->getParent(), EndBlock);
190*97c3ef5cSSotiris Apostolakis     auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
191*97c3ef5cSSotiris Apostolakis     FalseBranch->setDebugLoc(SI->getDebugLoc());
192*97c3ef5cSSotiris Apostolakis 
193*97c3ef5cSSotiris Apostolakis     // For the 'true' side the path originates from the start block from the
194*97c3ef5cSSotiris Apostolakis     // point view of the new PHI.
195*97c3ef5cSSotiris Apostolakis     TrueBlock = StartBlock;
196*97c3ef5cSSotiris Apostolakis 
197*97c3ef5cSSotiris Apostolakis     // Insert the real conditional branch based on the original condition.
198*97c3ef5cSSotiris Apostolakis     BasicBlock *TT, *FT;
199*97c3ef5cSSotiris Apostolakis     TT = EndBlock;
200*97c3ef5cSSotiris Apostolakis     FT = FalseBlock;
201*97c3ef5cSSotiris Apostolakis     IRBuilder<> IB(SI);
202*97c3ef5cSSotiris Apostolakis     auto *CondFr =
203*97c3ef5cSSotiris Apostolakis         IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen");
204*97c3ef5cSSotiris Apostolakis     IB.CreateCondBr(CondFr, TT, FT, SI);
205*97c3ef5cSSotiris Apostolakis 
206*97c3ef5cSSotiris Apostolakis     SmallPtrSet<const Instruction *, 2> INS;
207*97c3ef5cSSotiris Apostolakis     INS.insert(ASI.begin(), ASI.end());
208*97c3ef5cSSotiris Apostolakis     // Use reverse iterator because later select may use the value of the
209*97c3ef5cSSotiris Apostolakis     // earlier select, and we need to propagate value through earlier select
210*97c3ef5cSSotiris Apostolakis     // to get the PHI operand.
211*97c3ef5cSSotiris Apostolakis     for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
212*97c3ef5cSSotiris Apostolakis       SelectInst *SI = *It;
213*97c3ef5cSSotiris Apostolakis       // The select itself is replaced with a PHI Node.
214*97c3ef5cSSotiris Apostolakis       PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
215*97c3ef5cSSotiris Apostolakis       PN->takeName(SI);
216*97c3ef5cSSotiris Apostolakis       PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock);
217*97c3ef5cSSotiris Apostolakis       PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock);
218*97c3ef5cSSotiris Apostolakis       PN->setDebugLoc(SI->getDebugLoc());
219*97c3ef5cSSotiris Apostolakis 
220*97c3ef5cSSotiris Apostolakis       SI->replaceAllUsesWith(PN);
221*97c3ef5cSSotiris Apostolakis       SI->eraseFromParent();
222*97c3ef5cSSotiris Apostolakis       INS.erase(SI);
223*97c3ef5cSSotiris Apostolakis       ++NumSelectsConverted;
224*97c3ef5cSSotiris Apostolakis     }
225*97c3ef5cSSotiris Apostolakis   }
226*97c3ef5cSSotiris Apostolakis }
227*97c3ef5cSSotiris Apostolakis 
228*97c3ef5cSSotiris Apostolakis void SelectOptimize::collectSelectGroups(BasicBlock &BB,
229*97c3ef5cSSotiris Apostolakis                                          SelectGroups &SIGroups) {
230*97c3ef5cSSotiris Apostolakis   BasicBlock::iterator BBIt = BB.begin();
231*97c3ef5cSSotiris Apostolakis   while (BBIt != BB.end()) {
232*97c3ef5cSSotiris Apostolakis     Instruction *I = &*BBIt++;
233*97c3ef5cSSotiris Apostolakis     if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
234*97c3ef5cSSotiris Apostolakis       SelectGroup SIGroup;
235*97c3ef5cSSotiris Apostolakis       SIGroup.push_back(SI);
236*97c3ef5cSSotiris Apostolakis       while (BBIt != BB.end()) {
237*97c3ef5cSSotiris Apostolakis         Instruction *NI = &*BBIt;
238*97c3ef5cSSotiris Apostolakis         SelectInst *NSI = dyn_cast<SelectInst>(NI);
239*97c3ef5cSSotiris Apostolakis         if (NSI && SI->getCondition() == NSI->getCondition()) {
240*97c3ef5cSSotiris Apostolakis           SIGroup.push_back(NSI);
241*97c3ef5cSSotiris Apostolakis         } else if (!NI->isDebugOrPseudoInst()) {
242*97c3ef5cSSotiris Apostolakis           // Debug/pseudo instructions should be skipped and not prevent the
243*97c3ef5cSSotiris Apostolakis           // formation of a select group.
244*97c3ef5cSSotiris Apostolakis           break;
245*97c3ef5cSSotiris Apostolakis         }
246*97c3ef5cSSotiris Apostolakis         ++BBIt;
247*97c3ef5cSSotiris Apostolakis       }
248*97c3ef5cSSotiris Apostolakis 
249*97c3ef5cSSotiris Apostolakis       // If the select type is not supported, no point optimizing it.
250*97c3ef5cSSotiris Apostolakis       // Instruction selection will take care of it.
251*97c3ef5cSSotiris Apostolakis       if (!isSelectKindSupported(SI))
252*97c3ef5cSSotiris Apostolakis         continue;
253*97c3ef5cSSotiris Apostolakis 
254*97c3ef5cSSotiris Apostolakis       SIGroups.push_back(SIGroup);
255*97c3ef5cSSotiris Apostolakis     }
256*97c3ef5cSSotiris Apostolakis   }
257*97c3ef5cSSotiris Apostolakis }
258*97c3ef5cSSotiris Apostolakis 
259*97c3ef5cSSotiris Apostolakis bool SelectOptimize::isSelectKindSupported(SelectInst *SI) {
260*97c3ef5cSSotiris Apostolakis   bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1);
261*97c3ef5cSSotiris Apostolakis   if (VectorCond)
262*97c3ef5cSSotiris Apostolakis     return false;
263*97c3ef5cSSotiris Apostolakis   TargetLowering::SelectSupportKind SelectKind;
264*97c3ef5cSSotiris Apostolakis   if (SI->getType()->isVectorTy())
265*97c3ef5cSSotiris Apostolakis     SelectKind = TargetLowering::ScalarCondVectorVal;
266*97c3ef5cSSotiris Apostolakis   else
267*97c3ef5cSSotiris Apostolakis     SelectKind = TargetLowering::ScalarValSelect;
268*97c3ef5cSSotiris Apostolakis   return TLI->isSelectSupported(SelectKind);
269ca7c307dSSotiris Apostolakis }
270