1 //===--- SelectOptimize.cpp - Convert select to branches if profitable ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass converts selects to conditional jumps when profitable. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/ADT/SmallVector.h" 14 #include "llvm/ADT/Statistic.h" 15 #include "llvm/Analysis/BlockFrequencyInfo.h" 16 #include "llvm/Analysis/BranchProbabilityInfo.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/CodeGen/Passes.h" 19 #include "llvm/CodeGen/TargetLowering.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/CodeGen/TargetSchedule.h" 22 #include "llvm/CodeGen/TargetSubtargetInfo.h" 23 #include "llvm/IR/BasicBlock.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Instruction.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Target/TargetMachine.h" 30 31 using namespace llvm; 32 33 #define DEBUG_TYPE "select-optimize" 34 35 STATISTIC(NumSelectsConverted, "Number of selects converted"); 36 37 namespace { 38 39 class SelectOptimize : public FunctionPass { 40 const TargetMachine *TM = nullptr; 41 const TargetSubtargetInfo *TSI; 42 const TargetLowering *TLI = nullptr; 43 const LoopInfo *LI; 44 std::unique_ptr<BlockFrequencyInfo> BFI; 45 std::unique_ptr<BranchProbabilityInfo> BPI; 46 47 public: 48 static char ID; 49 SelectOptimize() : FunctionPass(ID) { 50 initializeSelectOptimizePass(*PassRegistry::getPassRegistry()); 51 } 52 53 bool runOnFunction(Function &F) override; 54 55 void getAnalysisUsage(AnalysisUsage &AU) const override { 56 AU.addRequired<TargetPassConfig>(); 57 AU.addRequired<LoopInfoWrapperPass>(); 58 } 59 60 private: 61 // Select groups consist of consecutive select instructions with the same 62 // condition. 63 using SelectGroup = SmallVector<SelectInst *, 2>; 64 using SelectGroups = SmallVector<SelectGroup, 2>; 65 66 bool optimizeSelects(Function &F); 67 void convertProfitableSIGroups(SelectGroups &ProfSIGroups); 68 void collectSelectGroups(BasicBlock &BB, SelectGroups &SIGroups); 69 bool isSelectKindSupported(SelectInst *SI); 70 }; 71 } // namespace 72 73 char SelectOptimize::ID = 0; 74 75 INITIALIZE_PASS_BEGIN(SelectOptimize, DEBUG_TYPE, "Optimize selects", false, 76 false) 77 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 78 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 79 INITIALIZE_PASS_END(SelectOptimize, DEBUG_TYPE, "Optimize selects", false, 80 false) 81 82 FunctionPass *llvm::createSelectOptimizePass() { return new SelectOptimize(); } 83 84 bool SelectOptimize::runOnFunction(Function &F) { 85 TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 86 TSI = TM->getSubtargetImpl(F); 87 TLI = TSI->getTargetLowering(); 88 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 89 BPI.reset(new BranchProbabilityInfo(F, *LI)); 90 BFI.reset(new BlockFrequencyInfo(F, *BPI, *LI)); 91 92 return optimizeSelects(F); 93 } 94 95 bool SelectOptimize::optimizeSelects(Function &F) { 96 // Collect all the select groups. 97 SelectGroups SIGroups; 98 for (BasicBlock &BB : F) { 99 collectSelectGroups(BB, SIGroups); 100 } 101 102 // Determine for which select groups it is profitable converting to branches. 103 SelectGroups ProfSIGroups; 104 // For now assume that all select groups can be profitably converted to 105 // branches. 106 for (SelectGroup &ASI : SIGroups) { 107 ProfSIGroups.push_back(ASI); 108 } 109 110 // Convert to branches the select groups that were deemed 111 // profitable-to-convert. 112 convertProfitableSIGroups(ProfSIGroups); 113 114 // Code modified if at least one select group was converted. 115 return !ProfSIGroups.empty(); 116 } 117 118 /// If \p isTrue is true, return the true value of \p SI, otherwise return 119 /// false value of \p SI. If the true/false value of \p SI is defined by any 120 /// select instructions in \p Selects, look through the defining select 121 /// instruction until the true/false value is not defined in \p Selects. 122 static Value * 123 getTrueOrFalseValue(SelectInst *SI, bool isTrue, 124 const SmallPtrSet<const Instruction *, 2> &Selects) { 125 Value *V = nullptr; 126 for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI); 127 DefSI = dyn_cast<SelectInst>(V)) { 128 assert(DefSI->getCondition() == SI->getCondition() && 129 "The condition of DefSI does not match with SI"); 130 V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue()); 131 } 132 assert(V && "Failed to get select true/false value"); 133 return V; 134 } 135 136 void SelectOptimize::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { 137 for (SelectGroup &ASI : ProfSIGroups) { 138 // TODO: eliminate the redundancy of logic transforming selects to branches 139 // by removing CodeGenPrepare::optimizeSelectInst and optimizing here 140 // selects for all cases (with and without profile information). 141 142 // Transform a sequence like this: 143 // start: 144 // %cmp = cmp uge i32 %a, %b 145 // %sel = select i1 %cmp, i32 %c, i32 %d 146 // 147 // Into: 148 // start: 149 // %cmp = cmp uge i32 %a, %b 150 // %cmp.frozen = freeze %cmp 151 // br i1 %cmp.frozen, label %select.end, label %select.false 152 // select.false: 153 // br label %select.end 154 // select.end: 155 // %sel = phi i32 [ %c, %start ], [ %d, %select.false ] 156 // 157 // %cmp should be frozen, otherwise it may introduce undefined behavior. 158 159 // We split the block containing the select(s) into two blocks. 160 SelectInst *SI = ASI.front(); 161 SelectInst *LastSI = ASI.back(); 162 BasicBlock *StartBlock = SI->getParent(); 163 BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI)); 164 BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end"); 165 BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency()); 166 // Delete the unconditional branch that was just created by the split. 167 StartBlock->getTerminator()->eraseFromParent(); 168 169 // Move any debug/pseudo instructions that were in-between the select 170 // group to the newly-created end block. 171 SmallVector<Instruction *, 2> DebugPseudoINS; 172 auto DIt = SI->getIterator(); 173 while (&*DIt != LastSI) { 174 if (DIt->isDebugOrPseudoInst()) 175 DebugPseudoINS.push_back(&*DIt); 176 DIt++; 177 } 178 for (auto DI : DebugPseudoINS) { 179 DI->moveBefore(&*EndBlock->getFirstInsertionPt()); 180 } 181 182 // These are the new basic blocks for the conditional branch. 183 // For now, no instruction sinking to the true/false blocks. 184 // Thus both True and False blocks will be empty. 185 BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr; 186 187 // Use the 'false' side for a new input value to the PHI. 188 FalseBlock = BasicBlock::Create(SI->getContext(), "select.false", 189 EndBlock->getParent(), EndBlock); 190 auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock); 191 FalseBranch->setDebugLoc(SI->getDebugLoc()); 192 193 // For the 'true' side the path originates from the start block from the 194 // point view of the new PHI. 195 TrueBlock = StartBlock; 196 197 // Insert the real conditional branch based on the original condition. 198 BasicBlock *TT, *FT; 199 TT = EndBlock; 200 FT = FalseBlock; 201 IRBuilder<> IB(SI); 202 auto *CondFr = 203 IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen"); 204 IB.CreateCondBr(CondFr, TT, FT, SI); 205 206 SmallPtrSet<const Instruction *, 2> INS; 207 INS.insert(ASI.begin(), ASI.end()); 208 // Use reverse iterator because later select may use the value of the 209 // earlier select, and we need to propagate value through earlier select 210 // to get the PHI operand. 211 for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) { 212 SelectInst *SI = *It; 213 // The select itself is replaced with a PHI Node. 214 PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front()); 215 PN->takeName(SI); 216 PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock); 217 PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock); 218 PN->setDebugLoc(SI->getDebugLoc()); 219 220 SI->replaceAllUsesWith(PN); 221 SI->eraseFromParent(); 222 INS.erase(SI); 223 ++NumSelectsConverted; 224 } 225 } 226 } 227 228 void SelectOptimize::collectSelectGroups(BasicBlock &BB, 229 SelectGroups &SIGroups) { 230 BasicBlock::iterator BBIt = BB.begin(); 231 while (BBIt != BB.end()) { 232 Instruction *I = &*BBIt++; 233 if (SelectInst *SI = dyn_cast<SelectInst>(I)) { 234 SelectGroup SIGroup; 235 SIGroup.push_back(SI); 236 while (BBIt != BB.end()) { 237 Instruction *NI = &*BBIt; 238 SelectInst *NSI = dyn_cast<SelectInst>(NI); 239 if (NSI && SI->getCondition() == NSI->getCondition()) { 240 SIGroup.push_back(NSI); 241 } else if (!NI->isDebugOrPseudoInst()) { 242 // Debug/pseudo instructions should be skipped and not prevent the 243 // formation of a select group. 244 break; 245 } 246 ++BBIt; 247 } 248 249 // If the select type is not supported, no point optimizing it. 250 // Instruction selection will take care of it. 251 if (!isSelectKindSupported(SI)) 252 continue; 253 254 SIGroups.push_back(SIGroup); 255 } 256 } 257 } 258 259 bool SelectOptimize::isSelectKindSupported(SelectInst *SI) { 260 bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1); 261 if (VectorCond) 262 return false; 263 TargetLowering::SelectSupportKind SelectKind; 264 if (SI->getType()->isVectorTy()) 265 SelectKind = TargetLowering::ScalarCondVectorVal; 266 else 267 SelectKind = TargetLowering::ScalarValSelect; 268 return TLI->isSelectSupported(SelectKind); 269 } 270