1ca7c307dSSotiris Apostolakis //===--- SelectOptimize.cpp - Convert select to branches if profitable ---===// 2ca7c307dSSotiris Apostolakis // 3ca7c307dSSotiris Apostolakis // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ca7c307dSSotiris Apostolakis // See https://llvm.org/LICENSE.txt for license information. 5ca7c307dSSotiris Apostolakis // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ca7c307dSSotiris Apostolakis // 7ca7c307dSSotiris Apostolakis //===----------------------------------------------------------------------===// 8ca7c307dSSotiris Apostolakis // 9ca7c307dSSotiris Apostolakis // This pass converts selects to conditional jumps when profitable. 10ca7c307dSSotiris Apostolakis // 11ca7c307dSSotiris Apostolakis //===----------------------------------------------------------------------===// 12ca7c307dSSotiris Apostolakis 13*97c3ef5cSSotiris Apostolakis #include "llvm/ADT/SmallVector.h" 14*97c3ef5cSSotiris Apostolakis #include "llvm/ADT/Statistic.h" 15*97c3ef5cSSotiris Apostolakis #include "llvm/Analysis/BlockFrequencyInfo.h" 16*97c3ef5cSSotiris Apostolakis #include "llvm/Analysis/BranchProbabilityInfo.h" 17*97c3ef5cSSotiris Apostolakis #include "llvm/Analysis/LoopInfo.h" 18ca7c307dSSotiris Apostolakis #include "llvm/CodeGen/Passes.h" 19*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetLowering.h" 20*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetPassConfig.h" 21*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetSchedule.h" 22*97c3ef5cSSotiris Apostolakis #include "llvm/CodeGen/TargetSubtargetInfo.h" 23*97c3ef5cSSotiris Apostolakis #include "llvm/IR/BasicBlock.h" 24ca7c307dSSotiris Apostolakis #include "llvm/IR/Function.h" 25*97c3ef5cSSotiris Apostolakis #include "llvm/IR/IRBuilder.h" 26*97c3ef5cSSotiris Apostolakis #include "llvm/IR/Instruction.h" 27ca7c307dSSotiris Apostolakis #include "llvm/InitializePasses.h" 28ca7c307dSSotiris Apostolakis #include "llvm/Pass.h" 29*97c3ef5cSSotiris Apostolakis #include "llvm/Target/TargetMachine.h" 30ca7c307dSSotiris Apostolakis 31ca7c307dSSotiris Apostolakis using namespace llvm; 32ca7c307dSSotiris Apostolakis 33*97c3ef5cSSotiris Apostolakis #define DEBUG_TYPE "select-optimize" 34*97c3ef5cSSotiris Apostolakis 35*97c3ef5cSSotiris Apostolakis STATISTIC(NumSelectsConverted, "Number of selects converted"); 36*97c3ef5cSSotiris Apostolakis 37ca7c307dSSotiris Apostolakis namespace { 38ca7c307dSSotiris Apostolakis 39ca7c307dSSotiris Apostolakis class SelectOptimize : public FunctionPass { 40*97c3ef5cSSotiris Apostolakis const TargetMachine *TM = nullptr; 41*97c3ef5cSSotiris Apostolakis const TargetSubtargetInfo *TSI; 42*97c3ef5cSSotiris Apostolakis const TargetLowering *TLI = nullptr; 43*97c3ef5cSSotiris Apostolakis const LoopInfo *LI; 44*97c3ef5cSSotiris Apostolakis std::unique_ptr<BlockFrequencyInfo> BFI; 45*97c3ef5cSSotiris Apostolakis std::unique_ptr<BranchProbabilityInfo> BPI; 46*97c3ef5cSSotiris Apostolakis 47ca7c307dSSotiris Apostolakis public: 48ca7c307dSSotiris Apostolakis static char ID; 49ca7c307dSSotiris Apostolakis SelectOptimize() : FunctionPass(ID) { 50ca7c307dSSotiris Apostolakis initializeSelectOptimizePass(*PassRegistry::getPassRegistry()); 51ca7c307dSSotiris Apostolakis } 52ca7c307dSSotiris Apostolakis 53ca7c307dSSotiris Apostolakis bool runOnFunction(Function &F) override; 54ca7c307dSSotiris Apostolakis 55*97c3ef5cSSotiris Apostolakis void getAnalysisUsage(AnalysisUsage &AU) const override { 56*97c3ef5cSSotiris Apostolakis AU.addRequired<TargetPassConfig>(); 57*97c3ef5cSSotiris Apostolakis AU.addRequired<LoopInfoWrapperPass>(); 58*97c3ef5cSSotiris Apostolakis } 59*97c3ef5cSSotiris Apostolakis 60*97c3ef5cSSotiris Apostolakis private: 61*97c3ef5cSSotiris Apostolakis // Select groups consist of consecutive select instructions with the same 62*97c3ef5cSSotiris Apostolakis // condition. 63*97c3ef5cSSotiris Apostolakis using SelectGroup = SmallVector<SelectInst *, 2>; 64*97c3ef5cSSotiris Apostolakis using SelectGroups = SmallVector<SelectGroup, 2>; 65*97c3ef5cSSotiris Apostolakis 66*97c3ef5cSSotiris Apostolakis bool optimizeSelects(Function &F); 67*97c3ef5cSSotiris Apostolakis void convertProfitableSIGroups(SelectGroups &ProfSIGroups); 68*97c3ef5cSSotiris Apostolakis void collectSelectGroups(BasicBlock &BB, SelectGroups &SIGroups); 69*97c3ef5cSSotiris Apostolakis bool isSelectKindSupported(SelectInst *SI); 70ca7c307dSSotiris Apostolakis }; 71ca7c307dSSotiris Apostolakis } // namespace 72ca7c307dSSotiris Apostolakis 73ca7c307dSSotiris Apostolakis char SelectOptimize::ID = 0; 74*97c3ef5cSSotiris Apostolakis 75*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_BEGIN(SelectOptimize, DEBUG_TYPE, "Optimize selects", false, 76*97c3ef5cSSotiris Apostolakis false) 77*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 78*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 79*97c3ef5cSSotiris Apostolakis INITIALIZE_PASS_END(SelectOptimize, DEBUG_TYPE, "Optimize selects", false, 80ca7c307dSSotiris Apostolakis false) 81ca7c307dSSotiris Apostolakis 82ca7c307dSSotiris Apostolakis FunctionPass *llvm::createSelectOptimizePass() { return new SelectOptimize(); } 83ca7c307dSSotiris Apostolakis 84ca7c307dSSotiris Apostolakis bool SelectOptimize::runOnFunction(Function &F) { 85*97c3ef5cSSotiris Apostolakis TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 86*97c3ef5cSSotiris Apostolakis TSI = TM->getSubtargetImpl(F); 87*97c3ef5cSSotiris Apostolakis TLI = TSI->getTargetLowering(); 88*97c3ef5cSSotiris Apostolakis LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 89*97c3ef5cSSotiris Apostolakis BPI.reset(new BranchProbabilityInfo(F, *LI)); 90*97c3ef5cSSotiris Apostolakis BFI.reset(new BlockFrequencyInfo(F, *BPI, *LI)); 91*97c3ef5cSSotiris Apostolakis 92*97c3ef5cSSotiris Apostolakis return optimizeSelects(F); 93*97c3ef5cSSotiris Apostolakis } 94*97c3ef5cSSotiris Apostolakis 95*97c3ef5cSSotiris Apostolakis bool SelectOptimize::optimizeSelects(Function &F) { 96*97c3ef5cSSotiris Apostolakis // Collect all the select groups. 97*97c3ef5cSSotiris Apostolakis SelectGroups SIGroups; 98*97c3ef5cSSotiris Apostolakis for (BasicBlock &BB : F) { 99*97c3ef5cSSotiris Apostolakis collectSelectGroups(BB, SIGroups); 100*97c3ef5cSSotiris Apostolakis } 101*97c3ef5cSSotiris Apostolakis 102*97c3ef5cSSotiris Apostolakis // Determine for which select groups it is profitable converting to branches. 103*97c3ef5cSSotiris Apostolakis SelectGroups ProfSIGroups; 104*97c3ef5cSSotiris Apostolakis // For now assume that all select groups can be profitably converted to 105*97c3ef5cSSotiris Apostolakis // branches. 106*97c3ef5cSSotiris Apostolakis for (SelectGroup &ASI : SIGroups) { 107*97c3ef5cSSotiris Apostolakis ProfSIGroups.push_back(ASI); 108*97c3ef5cSSotiris Apostolakis } 109*97c3ef5cSSotiris Apostolakis 110*97c3ef5cSSotiris Apostolakis // Convert to branches the select groups that were deemed 111*97c3ef5cSSotiris Apostolakis // profitable-to-convert. 112*97c3ef5cSSotiris Apostolakis convertProfitableSIGroups(ProfSIGroups); 113*97c3ef5cSSotiris Apostolakis 114*97c3ef5cSSotiris Apostolakis // Code modified if at least one select group was converted. 115*97c3ef5cSSotiris Apostolakis return !ProfSIGroups.empty(); 116*97c3ef5cSSotiris Apostolakis } 117*97c3ef5cSSotiris Apostolakis 118*97c3ef5cSSotiris Apostolakis /// If \p isTrue is true, return the true value of \p SI, otherwise return 119*97c3ef5cSSotiris Apostolakis /// false value of \p SI. If the true/false value of \p SI is defined by any 120*97c3ef5cSSotiris Apostolakis /// select instructions in \p Selects, look through the defining select 121*97c3ef5cSSotiris Apostolakis /// instruction until the true/false value is not defined in \p Selects. 122*97c3ef5cSSotiris Apostolakis static Value * 123*97c3ef5cSSotiris Apostolakis getTrueOrFalseValue(SelectInst *SI, bool isTrue, 124*97c3ef5cSSotiris Apostolakis const SmallPtrSet<const Instruction *, 2> &Selects) { 125*97c3ef5cSSotiris Apostolakis Value *V = nullptr; 126*97c3ef5cSSotiris Apostolakis for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI); 127*97c3ef5cSSotiris Apostolakis DefSI = dyn_cast<SelectInst>(V)) { 128*97c3ef5cSSotiris Apostolakis assert(DefSI->getCondition() == SI->getCondition() && 129*97c3ef5cSSotiris Apostolakis "The condition of DefSI does not match with SI"); 130*97c3ef5cSSotiris Apostolakis V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue()); 131*97c3ef5cSSotiris Apostolakis } 132*97c3ef5cSSotiris Apostolakis assert(V && "Failed to get select true/false value"); 133*97c3ef5cSSotiris Apostolakis return V; 134*97c3ef5cSSotiris Apostolakis } 135*97c3ef5cSSotiris Apostolakis 136*97c3ef5cSSotiris Apostolakis void SelectOptimize::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { 137*97c3ef5cSSotiris Apostolakis for (SelectGroup &ASI : ProfSIGroups) { 138*97c3ef5cSSotiris Apostolakis // TODO: eliminate the redundancy of logic transforming selects to branches 139*97c3ef5cSSotiris Apostolakis // by removing CodeGenPrepare::optimizeSelectInst and optimizing here 140*97c3ef5cSSotiris Apostolakis // selects for all cases (with and without profile information). 141*97c3ef5cSSotiris Apostolakis 142*97c3ef5cSSotiris Apostolakis // Transform a sequence like this: 143*97c3ef5cSSotiris Apostolakis // start: 144*97c3ef5cSSotiris Apostolakis // %cmp = cmp uge i32 %a, %b 145*97c3ef5cSSotiris Apostolakis // %sel = select i1 %cmp, i32 %c, i32 %d 146*97c3ef5cSSotiris Apostolakis // 147*97c3ef5cSSotiris Apostolakis // Into: 148*97c3ef5cSSotiris Apostolakis // start: 149*97c3ef5cSSotiris Apostolakis // %cmp = cmp uge i32 %a, %b 150*97c3ef5cSSotiris Apostolakis // %cmp.frozen = freeze %cmp 151*97c3ef5cSSotiris Apostolakis // br i1 %cmp.frozen, label %select.end, label %select.false 152*97c3ef5cSSotiris Apostolakis // select.false: 153*97c3ef5cSSotiris Apostolakis // br label %select.end 154*97c3ef5cSSotiris Apostolakis // select.end: 155*97c3ef5cSSotiris Apostolakis // %sel = phi i32 [ %c, %start ], [ %d, %select.false ] 156*97c3ef5cSSotiris Apostolakis // 157*97c3ef5cSSotiris Apostolakis // %cmp should be frozen, otherwise it may introduce undefined behavior. 158*97c3ef5cSSotiris Apostolakis 159*97c3ef5cSSotiris Apostolakis // We split the block containing the select(s) into two blocks. 160*97c3ef5cSSotiris Apostolakis SelectInst *SI = ASI.front(); 161*97c3ef5cSSotiris Apostolakis SelectInst *LastSI = ASI.back(); 162*97c3ef5cSSotiris Apostolakis BasicBlock *StartBlock = SI->getParent(); 163*97c3ef5cSSotiris Apostolakis BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI)); 164*97c3ef5cSSotiris Apostolakis BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end"); 165*97c3ef5cSSotiris Apostolakis BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency()); 166*97c3ef5cSSotiris Apostolakis // Delete the unconditional branch that was just created by the split. 167*97c3ef5cSSotiris Apostolakis StartBlock->getTerminator()->eraseFromParent(); 168*97c3ef5cSSotiris Apostolakis 169*97c3ef5cSSotiris Apostolakis // Move any debug/pseudo instructions that were in-between the select 170*97c3ef5cSSotiris Apostolakis // group to the newly-created end block. 171*97c3ef5cSSotiris Apostolakis SmallVector<Instruction *, 2> DebugPseudoINS; 172*97c3ef5cSSotiris Apostolakis auto DIt = SI->getIterator(); 173*97c3ef5cSSotiris Apostolakis while (&*DIt != LastSI) { 174*97c3ef5cSSotiris Apostolakis if (DIt->isDebugOrPseudoInst()) 175*97c3ef5cSSotiris Apostolakis DebugPseudoINS.push_back(&*DIt); 176*97c3ef5cSSotiris Apostolakis DIt++; 177*97c3ef5cSSotiris Apostolakis } 178*97c3ef5cSSotiris Apostolakis for (auto DI : DebugPseudoINS) { 179*97c3ef5cSSotiris Apostolakis DI->moveBefore(&*EndBlock->getFirstInsertionPt()); 180*97c3ef5cSSotiris Apostolakis } 181*97c3ef5cSSotiris Apostolakis 182*97c3ef5cSSotiris Apostolakis // These are the new basic blocks for the conditional branch. 183*97c3ef5cSSotiris Apostolakis // For now, no instruction sinking to the true/false blocks. 184*97c3ef5cSSotiris Apostolakis // Thus both True and False blocks will be empty. 185*97c3ef5cSSotiris Apostolakis BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr; 186*97c3ef5cSSotiris Apostolakis 187*97c3ef5cSSotiris Apostolakis // Use the 'false' side for a new input value to the PHI. 188*97c3ef5cSSotiris Apostolakis FalseBlock = BasicBlock::Create(SI->getContext(), "select.false", 189*97c3ef5cSSotiris Apostolakis EndBlock->getParent(), EndBlock); 190*97c3ef5cSSotiris Apostolakis auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock); 191*97c3ef5cSSotiris Apostolakis FalseBranch->setDebugLoc(SI->getDebugLoc()); 192*97c3ef5cSSotiris Apostolakis 193*97c3ef5cSSotiris Apostolakis // For the 'true' side the path originates from the start block from the 194*97c3ef5cSSotiris Apostolakis // point view of the new PHI. 195*97c3ef5cSSotiris Apostolakis TrueBlock = StartBlock; 196*97c3ef5cSSotiris Apostolakis 197*97c3ef5cSSotiris Apostolakis // Insert the real conditional branch based on the original condition. 198*97c3ef5cSSotiris Apostolakis BasicBlock *TT, *FT; 199*97c3ef5cSSotiris Apostolakis TT = EndBlock; 200*97c3ef5cSSotiris Apostolakis FT = FalseBlock; 201*97c3ef5cSSotiris Apostolakis IRBuilder<> IB(SI); 202*97c3ef5cSSotiris Apostolakis auto *CondFr = 203*97c3ef5cSSotiris Apostolakis IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen"); 204*97c3ef5cSSotiris Apostolakis IB.CreateCondBr(CondFr, TT, FT, SI); 205*97c3ef5cSSotiris Apostolakis 206*97c3ef5cSSotiris Apostolakis SmallPtrSet<const Instruction *, 2> INS; 207*97c3ef5cSSotiris Apostolakis INS.insert(ASI.begin(), ASI.end()); 208*97c3ef5cSSotiris Apostolakis // Use reverse iterator because later select may use the value of the 209*97c3ef5cSSotiris Apostolakis // earlier select, and we need to propagate value through earlier select 210*97c3ef5cSSotiris Apostolakis // to get the PHI operand. 211*97c3ef5cSSotiris Apostolakis for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) { 212*97c3ef5cSSotiris Apostolakis SelectInst *SI = *It; 213*97c3ef5cSSotiris Apostolakis // The select itself is replaced with a PHI Node. 214*97c3ef5cSSotiris Apostolakis PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front()); 215*97c3ef5cSSotiris Apostolakis PN->takeName(SI); 216*97c3ef5cSSotiris Apostolakis PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock); 217*97c3ef5cSSotiris Apostolakis PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock); 218*97c3ef5cSSotiris Apostolakis PN->setDebugLoc(SI->getDebugLoc()); 219*97c3ef5cSSotiris Apostolakis 220*97c3ef5cSSotiris Apostolakis SI->replaceAllUsesWith(PN); 221*97c3ef5cSSotiris Apostolakis SI->eraseFromParent(); 222*97c3ef5cSSotiris Apostolakis INS.erase(SI); 223*97c3ef5cSSotiris Apostolakis ++NumSelectsConverted; 224*97c3ef5cSSotiris Apostolakis } 225*97c3ef5cSSotiris Apostolakis } 226*97c3ef5cSSotiris Apostolakis } 227*97c3ef5cSSotiris Apostolakis 228*97c3ef5cSSotiris Apostolakis void SelectOptimize::collectSelectGroups(BasicBlock &BB, 229*97c3ef5cSSotiris Apostolakis SelectGroups &SIGroups) { 230*97c3ef5cSSotiris Apostolakis BasicBlock::iterator BBIt = BB.begin(); 231*97c3ef5cSSotiris Apostolakis while (BBIt != BB.end()) { 232*97c3ef5cSSotiris Apostolakis Instruction *I = &*BBIt++; 233*97c3ef5cSSotiris Apostolakis if (SelectInst *SI = dyn_cast<SelectInst>(I)) { 234*97c3ef5cSSotiris Apostolakis SelectGroup SIGroup; 235*97c3ef5cSSotiris Apostolakis SIGroup.push_back(SI); 236*97c3ef5cSSotiris Apostolakis while (BBIt != BB.end()) { 237*97c3ef5cSSotiris Apostolakis Instruction *NI = &*BBIt; 238*97c3ef5cSSotiris Apostolakis SelectInst *NSI = dyn_cast<SelectInst>(NI); 239*97c3ef5cSSotiris Apostolakis if (NSI && SI->getCondition() == NSI->getCondition()) { 240*97c3ef5cSSotiris Apostolakis SIGroup.push_back(NSI); 241*97c3ef5cSSotiris Apostolakis } else if (!NI->isDebugOrPseudoInst()) { 242*97c3ef5cSSotiris Apostolakis // Debug/pseudo instructions should be skipped and not prevent the 243*97c3ef5cSSotiris Apostolakis // formation of a select group. 244*97c3ef5cSSotiris Apostolakis break; 245*97c3ef5cSSotiris Apostolakis } 246*97c3ef5cSSotiris Apostolakis ++BBIt; 247*97c3ef5cSSotiris Apostolakis } 248*97c3ef5cSSotiris Apostolakis 249*97c3ef5cSSotiris Apostolakis // If the select type is not supported, no point optimizing it. 250*97c3ef5cSSotiris Apostolakis // Instruction selection will take care of it. 251*97c3ef5cSSotiris Apostolakis if (!isSelectKindSupported(SI)) 252*97c3ef5cSSotiris Apostolakis continue; 253*97c3ef5cSSotiris Apostolakis 254*97c3ef5cSSotiris Apostolakis SIGroups.push_back(SIGroup); 255*97c3ef5cSSotiris Apostolakis } 256*97c3ef5cSSotiris Apostolakis } 257*97c3ef5cSSotiris Apostolakis } 258*97c3ef5cSSotiris Apostolakis 259*97c3ef5cSSotiris Apostolakis bool SelectOptimize::isSelectKindSupported(SelectInst *SI) { 260*97c3ef5cSSotiris Apostolakis bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1); 261*97c3ef5cSSotiris Apostolakis if (VectorCond) 262*97c3ef5cSSotiris Apostolakis return false; 263*97c3ef5cSSotiris Apostolakis TargetLowering::SelectSupportKind SelectKind; 264*97c3ef5cSSotiris Apostolakis if (SI->getType()->isVectorTy()) 265*97c3ef5cSSotiris Apostolakis SelectKind = TargetLowering::ScalarCondVectorVal; 266*97c3ef5cSSotiris Apostolakis else 267*97c3ef5cSSotiris Apostolakis SelectKind = TargetLowering::ScalarValSelect; 268*97c3ef5cSSotiris Apostolakis return TLI->isSelectSupported(SelectKind); 269ca7c307dSSotiris Apostolakis } 270