1 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The LowerSwitch transformation rewrites switch instructions with a sequence 11 // of branches, which allows targets to get away with not implementing the 12 // switch instruction until it is convenient. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar.h" 17 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/LLVMContext.h" 23 #include "llvm/IR/CFG.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/Compiler.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 29 #include <algorithm> 30 using namespace llvm; 31 32 #define DEBUG_TYPE "lower-switch" 33 34 namespace { 35 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 36 /// instructions. 37 class LowerSwitch : public FunctionPass { 38 public: 39 static char ID; // Pass identification, replacement for typeid 40 LowerSwitch() : FunctionPass(ID) { 41 initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 42 } 43 44 bool runOnFunction(Function &F) override; 45 46 void getAnalysisUsage(AnalysisUsage &AU) const override { 47 // This is a cluster of orthogonal Transforms 48 AU.addPreserved<UnifyFunctionExitNodes>(); 49 AU.addPreserved("mem2reg"); 50 AU.addPreservedID(LowerInvokePassID); 51 } 52 53 struct CaseRange { 54 Constant* Low; 55 Constant* High; 56 BasicBlock* BB; 57 58 CaseRange(Constant *low = nullptr, Constant *high = nullptr, 59 BasicBlock *bb = nullptr) : 60 Low(low), High(high), BB(bb) { } 61 }; 62 63 typedef std::vector<CaseRange> CaseVector; 64 typedef std::vector<CaseRange>::iterator CaseItr; 65 private: 66 void processSwitchInst(SwitchInst *SI); 67 68 BasicBlock *switchConvert(CaseItr Begin, CaseItr End, 69 ConstantInt *LowerBound, ConstantInt *UpperBound, 70 Value *Val, BasicBlock *Predecessor, 71 BasicBlock *OrigBlock, BasicBlock *Default); 72 BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, 73 BasicBlock *Default); 74 unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); 75 }; 76 77 /// The comparison function for sorting the switch case values in the vector. 78 /// WARNING: Case ranges should be disjoint! 79 struct CaseCmp { 80 bool operator () (const LowerSwitch::CaseRange& C1, 81 const LowerSwitch::CaseRange& C2) { 82 83 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 84 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 85 return CI1->getValue().slt(CI2->getValue()); 86 } 87 }; 88 } 89 90 char LowerSwitch::ID = 0; 91 INITIALIZE_PASS(LowerSwitch, "lowerswitch", 92 "Lower SwitchInst's to branches", false, false) 93 94 // Publicly exposed interface to pass... 95 char &llvm::LowerSwitchID = LowerSwitch::ID; 96 // createLowerSwitchPass - Interface to this file... 97 FunctionPass *llvm::createLowerSwitchPass() { 98 return new LowerSwitch(); 99 } 100 101 bool LowerSwitch::runOnFunction(Function &F) { 102 bool Changed = false; 103 104 for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 105 BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 106 107 if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 108 Changed = true; 109 processSwitchInst(SI); 110 } 111 } 112 113 return Changed; 114 } 115 116 // operator<< - Used for debugging purposes. 117 // 118 static raw_ostream& operator<<(raw_ostream &O, 119 const LowerSwitch::CaseVector &C) 120 LLVM_ATTRIBUTE_USED; 121 static raw_ostream& operator<<(raw_ostream &O, 122 const LowerSwitch::CaseVector &C) { 123 O << "["; 124 125 for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 126 E = C.end(); B != E; ) { 127 O << *B->Low << " -" << *B->High; 128 if (++B != E) O << ", "; 129 } 130 131 return O << "]"; 132 } 133 134 // \brief Update the first occurrence of the "switch statement" BB in the PHI 135 // node with the "new" BB. The other occurrences will: 136 // 137 // 1) Be updated by subsequent calls to this function. Switch statements may 138 // have more than one outcoming edge into the same BB if they all have the same 139 // value. When the switch statement is converted these incoming edges are now 140 // coming from multiple BBs. 141 // 2) Removed if subsequent incoming values now share the same case, i.e., 142 // multiple outcome edges are condensed into one. This is necessary to keep the 143 // number of phi values equal to the number of branches to SuccBB. 144 static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, 145 unsigned NumMergedCases) { 146 for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI(); 147 I != IE; ++I) { 148 PHINode *PN = cast<PHINode>(I); 149 150 // Only update the first occurence. 151 unsigned Idx = 0, E = PN->getNumIncomingValues(); 152 unsigned LocalNumMergedCases = NumMergedCases; 153 for (; Idx != E; ++Idx) { 154 if (PN->getIncomingBlock(Idx) == OrigBB) { 155 PN->setIncomingBlock(Idx, NewBB); 156 break; 157 } 158 } 159 160 // Remove additional occurences coming from condensed cases and keep the 161 // number of incoming values equal to the number of branches to SuccBB. 162 for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) 163 if (PN->getIncomingBlock(Idx) == OrigBB) { 164 PN->removeIncomingValue(Idx); 165 LocalNumMergedCases--; 166 } 167 } 168 } 169 170 // switchConvert - Convert the switch statement into a binary lookup of 171 // the case values. The function recursively builds this tree. 172 // LowerBound and UpperBound are used to keep track of the bounds for Val 173 // that have already been checked by a block emitted by one of the previous 174 // calls to switchConvert in the call stack. 175 BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 176 ConstantInt *LowerBound, 177 ConstantInt *UpperBound, Value *Val, 178 BasicBlock *Predecessor, 179 BasicBlock *OrigBlock, 180 BasicBlock *Default) { 181 unsigned Size = End - Begin; 182 183 if (Size == 1) { 184 // Check if the Case Range is perfectly squeezed in between 185 // already checked Upper and Lower bounds. If it is then we can avoid 186 // emitting the code that checks if the value actually falls in the range 187 // because the bounds already tell us so. 188 if (Begin->Low == LowerBound && Begin->High == UpperBound) { 189 unsigned NumMergedCases = 0; 190 if (LowerBound && UpperBound) 191 NumMergedCases = 192 UpperBound->getSExtValue() - LowerBound->getSExtValue(); 193 fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); 194 return Begin->BB; 195 } 196 return newLeafBlock(*Begin, Val, OrigBlock, Default); 197 } 198 199 unsigned Mid = Size / 2; 200 std::vector<CaseRange> LHS(Begin, Begin + Mid); 201 DEBUG(dbgs() << "LHS: " << LHS << "\n"); 202 std::vector<CaseRange> RHS(Begin + Mid, End); 203 DEBUG(dbgs() << "RHS: " << RHS << "\n"); 204 205 CaseRange &Pivot = *(Begin + Mid); 206 DEBUG(dbgs() << "Pivot ==> " 207 << cast<ConstantInt>(Pivot.Low)->getValue() 208 << " -" << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 209 210 // NewLowerBound here should never be the integer minimal value. 211 // This is because it is computed from a case range that is never 212 // the smallest, so there is always a case range that has at least 213 // a smaller value. 214 ConstantInt *NewLowerBound = cast<ConstantInt>(Pivot.Low); 215 ConstantInt *NewUpperBound; 216 217 // If we don't have a Default block then it means that we can never 218 // have a value outside of a case range, so set the UpperBound to the highest 219 // value in the LHS part of the case ranges. 220 if (Default != nullptr) { 221 // Because NewLowerBound is never the smallest representable integer 222 // it is safe here to subtract one. 223 NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), 224 NewLowerBound->getValue() - 1); 225 } else { 226 CaseItr LastLHS = LHS.begin() + LHS.size() - 1; 227 NewUpperBound = cast<ConstantInt>(LastLHS->High); 228 } 229 230 DEBUG(dbgs() << "LHS Bounds ==> "; 231 if (LowerBound) { 232 dbgs() << cast<ConstantInt>(LowerBound)->getSExtValue(); 233 } else { 234 dbgs() << "NONE"; 235 } 236 dbgs() << " - " << NewUpperBound->getSExtValue() << "\n"; 237 dbgs() << "RHS Bounds ==> "; 238 dbgs() << NewLowerBound->getSExtValue() << " - "; 239 if (UpperBound) { 240 dbgs() << cast<ConstantInt>(UpperBound)->getSExtValue() << "\n"; 241 } else { 242 dbgs() << "NONE\n"; 243 }); 244 245 // Create a new node that checks if the value is < pivot. Go to the 246 // left branch if it is and right branch if not. 247 Function* F = OrigBlock->getParent(); 248 BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 249 250 ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 251 Val, Pivot.Low, "Pivot"); 252 253 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, 254 NewUpperBound, Val, NewNode, OrigBlock, 255 Default); 256 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, 257 UpperBound, Val, NewNode, OrigBlock, 258 Default); 259 260 Function::iterator FI = OrigBlock; 261 F->getBasicBlockList().insert(++FI, NewNode); 262 NewNode->getInstList().push_back(Comp); 263 264 BranchInst::Create(LBranch, RBranch, Comp, NewNode); 265 return NewNode; 266 } 267 268 // newLeafBlock - Create a new leaf block for the binary lookup tree. It 269 // checks if the switch's value == the case's value. If not, then it 270 // jumps to the default branch. At this point in the tree, the value 271 // can't be another valid case value, so the jump to the "default" branch 272 // is warranted. 273 // 274 BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 275 BasicBlock* OrigBlock, 276 BasicBlock* Default) 277 { 278 Function* F = OrigBlock->getParent(); 279 BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 280 Function::iterator FI = OrigBlock; 281 F->getBasicBlockList().insert(++FI, NewLeaf); 282 283 // Emit comparison 284 ICmpInst* Comp = nullptr; 285 if (Leaf.Low == Leaf.High) { 286 // Make the seteq instruction... 287 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 288 Leaf.Low, "SwitchLeaf"); 289 } else { 290 // Make range comparison 291 if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 292 // Val >= Min && Val <= Hi --> Val <= Hi 293 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 294 "SwitchLeaf"); 295 } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 296 // Val >= 0 && Val <= Hi --> Val <=u Hi 297 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 298 "SwitchLeaf"); 299 } else { 300 // Emit V-Lo <=u Hi-Lo 301 Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 302 Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 303 Val->getName()+".off", 304 NewLeaf); 305 Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 306 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 307 "SwitchLeaf"); 308 } 309 } 310 311 // Make the conditional branch... 312 BasicBlock* Succ = Leaf.BB; 313 BranchInst::Create(Succ, Default, Comp, NewLeaf); 314 315 // If there were any PHI nodes in this successor, rewrite one entry 316 // from OrigBlock to come from NewLeaf. 317 for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 318 PHINode* PN = cast<PHINode>(I); 319 // Remove all but one incoming entries from the cluster 320 uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 321 cast<ConstantInt>(Leaf.Low)->getSExtValue(); 322 for (uint64_t j = 0; j < Range; ++j) { 323 PN->removeIncomingValue(OrigBlock); 324 } 325 326 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 327 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 328 PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 329 } 330 331 return NewLeaf; 332 } 333 334 // Clusterify - Transform simple list of Cases into list of CaseRange's 335 unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 336 unsigned numCmps = 0; 337 338 // Start with "simple" cases 339 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) 340 Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), 341 i.getCaseSuccessor())); 342 343 std::sort(Cases.begin(), Cases.end(), CaseCmp()); 344 345 // Merge case into clusters 346 if (Cases.size()>=2) 347 for (CaseItr I = Cases.begin(), J = std::next(Cases.begin()); 348 J != Cases.end();) { 349 int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 350 int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 351 BasicBlock* nextBB = J->BB; 352 BasicBlock* currentBB = I->BB; 353 354 // If the two neighboring cases go to the same destination, merge them 355 // into a single case. 356 if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 357 I->High = J->High; 358 J = Cases.erase(J); 359 } else { 360 I = J++; 361 } 362 } 363 364 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 365 if (I->Low != I->High) 366 // A range counts double, since it requires two compares. 367 ++numCmps; 368 } 369 370 return numCmps; 371 } 372 373 // processSwitchInst - Replace the specified switch instruction with a sequence 374 // of chained if-then insts in a balanced binary search. 375 // 376 void LowerSwitch::processSwitchInst(SwitchInst *SI) { 377 BasicBlock *CurBlock = SI->getParent(); 378 BasicBlock *OrigBlock = CurBlock; 379 Function *F = CurBlock->getParent(); 380 Value *Val = SI->getCondition(); // The value we are switching on... 381 BasicBlock* Default = SI->getDefaultDest(); 382 383 // If there is only the default destination, don't bother with the code below. 384 if (!SI->getNumCases()) { 385 BranchInst::Create(SI->getDefaultDest(), CurBlock); 386 CurBlock->getInstList().erase(SI); 387 return; 388 } 389 390 const bool DefaultIsUnreachable = 391 Default->size() == 1 && isa<UnreachableInst>(Default->getTerminator()); 392 // Create a new, empty default block so that the new hierarchy of 393 // if-then statements go to this and the PHI nodes are happy. 394 // if the default block is set as an unreachable we avoid creating one 395 // because will never be a valid target. 396 BasicBlock *NewDefault = nullptr; 397 if (!DefaultIsUnreachable) { 398 NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 399 F->getBasicBlockList().insert(Default, NewDefault); 400 401 BranchInst::Create(Default, NewDefault); 402 } 403 // If there is an entry in any PHI nodes for the default edge, make sure 404 // to update them as well. 405 for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 406 PHINode *PN = cast<PHINode>(I); 407 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 408 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 409 PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 410 } 411 412 // Prepare cases vector. 413 CaseVector Cases; 414 unsigned numCmps = Clusterify(Cases, SI); 415 416 DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 417 << ". Total compares: " << numCmps << "\n"); 418 DEBUG(dbgs() << "Cases: " << Cases << "\n"); 419 (void)numCmps; 420 421 ConstantInt *UpperBound = nullptr; 422 ConstantInt *LowerBound = nullptr; 423 424 // Optimize the condition where Default is an unreachable block. In this case 425 // we can make the bounds tightly fitted around the case value ranges, 426 // because we know that the value passed to the switch should always be 427 // exactly one of the case values. 428 if (DefaultIsUnreachable) { 429 CaseItr LastCase = Cases.begin() + Cases.size() - 1; 430 UpperBound = cast<ConstantInt>(LastCase->High); 431 LowerBound = cast<ConstantInt>(Cases.begin()->Low); 432 } 433 BasicBlock *SwitchBlock = 434 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, 435 OrigBlock, OrigBlock, NewDefault); 436 437 // Branch to our shiny new if-then stuff... 438 BranchInst::Create(SwitchBlock, OrigBlock); 439 440 // We are now done with the switch instruction, delete it. 441 CurBlock->getInstList().erase(SI); 442 443 pred_iterator PI = pred_begin(Default), E = pred_end(Default); 444 // If the Default block has no more predecessors just remove it 445 if (PI == E) { 446 DeleteDeadBlock(Default); 447 } 448 } 449