1 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The LowerSwitch transformation rewrites switch instructions with a sequence 11 // of branches, which allows targets to get away with not implementing the 12 // switch instruction until it is convenient. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/CFG.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/Function.h" 24 #include "llvm/IR/InstrTypes.h" 25 #include "llvm/IR/Instructions.h" 26 #include "llvm/IR/Value.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Support/Casting.h" 29 #include "llvm/Support/Compiler.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include "llvm/Transforms/Utils.h" 33 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 34 #include <algorithm> 35 #include <cassert> 36 #include <cstdint> 37 #include <iterator> 38 #include <limits> 39 #include <vector> 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "lower-switch" 44 45 namespace { 46 47 struct IntRange { 48 int64_t Low, High; 49 }; 50 51 } // end anonymous namespace 52 53 // Return true iff R is covered by Ranges. 54 static bool IsInRanges(const IntRange &R, 55 const std::vector<IntRange> &Ranges) { 56 // Note: Ranges must be sorted, non-overlapping and non-adjacent. 57 58 // Find the first range whose High field is >= R.High, 59 // then check if the Low field is <= R.Low. If so, we 60 // have a Range that covers R. 61 auto I = std::lower_bound( 62 Ranges.begin(), Ranges.end(), R, 63 [](const IntRange &A, const IntRange &B) { return A.High < B.High; }); 64 return I != Ranges.end() && I->Low <= R.Low; 65 } 66 67 namespace { 68 69 /// Replace all SwitchInst instructions with chained branch instructions. 70 class LowerSwitch : public FunctionPass { 71 public: 72 // Pass identification, replacement for typeid 73 static char ID; 74 75 LowerSwitch() : FunctionPass(ID) { 76 initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 77 } 78 79 bool runOnFunction(Function &F) override; 80 81 struct CaseRange { 82 ConstantInt* Low; 83 ConstantInt* High; 84 BasicBlock* BB; 85 86 CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) 87 : Low(low), High(high), BB(bb) {} 88 }; 89 90 using CaseVector = std::vector<CaseRange>; 91 using CaseItr = std::vector<CaseRange>::iterator; 92 93 private: 94 void processSwitchInst(SwitchInst *SI, SmallPtrSetImpl<BasicBlock*> &DeleteList); 95 96 BasicBlock *switchConvert(CaseItr Begin, CaseItr End, 97 ConstantInt *LowerBound, ConstantInt *UpperBound, 98 Value *Val, BasicBlock *Predecessor, 99 BasicBlock *OrigBlock, BasicBlock *Default, 100 const std::vector<IntRange> &UnreachableRanges); 101 BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, 102 BasicBlock *Default); 103 unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); 104 }; 105 106 /// The comparison function for sorting the switch case values in the vector. 107 /// WARNING: Case ranges should be disjoint! 108 struct CaseCmp { 109 bool operator()(const LowerSwitch::CaseRange& C1, 110 const LowerSwitch::CaseRange& C2) { 111 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 112 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 113 return CI1->getValue().slt(CI2->getValue()); 114 } 115 }; 116 117 } // end anonymous namespace 118 119 char LowerSwitch::ID = 0; 120 121 // Publicly exposed interface to pass... 122 char &llvm::LowerSwitchID = LowerSwitch::ID; 123 124 INITIALIZE_PASS(LowerSwitch, "lowerswitch", 125 "Lower SwitchInst's to branches", false, false) 126 127 // createLowerSwitchPass - Interface to this file... 128 FunctionPass *llvm::createLowerSwitchPass() { 129 return new LowerSwitch(); 130 } 131 132 bool LowerSwitch::runOnFunction(Function &F) { 133 bool Changed = false; 134 SmallPtrSet<BasicBlock*, 8> DeleteList; 135 136 for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 137 BasicBlock *Cur = &*I++; // Advance over block so we don't traverse new blocks 138 139 // If the block is a dead Default block that will be deleted later, don't 140 // waste time processing it. 141 if (DeleteList.count(Cur)) 142 continue; 143 144 if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 145 Changed = true; 146 processSwitchInst(SI, DeleteList); 147 } 148 } 149 150 for (BasicBlock* BB: DeleteList) { 151 DeleteDeadBlock(BB); 152 } 153 154 return Changed; 155 } 156 157 /// Used for debugging purposes. 158 LLVM_ATTRIBUTE_USED 159 static raw_ostream &operator<<(raw_ostream &O, 160 const LowerSwitch::CaseVector &C) { 161 O << "["; 162 163 for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 164 E = C.end(); B != E; ) { 165 O << *B->Low << " -" << *B->High; 166 if (++B != E) O << ", "; 167 } 168 169 return O << "]"; 170 } 171 172 /// Update the first occurrence of the "switch statement" BB in the PHI 173 /// node with the "new" BB. The other occurrences will: 174 /// 175 /// 1) Be updated by subsequent calls to this function. Switch statements may 176 /// have more than one outcoming edge into the same BB if they all have the same 177 /// value. When the switch statement is converted these incoming edges are now 178 /// coming from multiple BBs. 179 /// 2) Removed if subsequent incoming values now share the same case, i.e., 180 /// multiple outcome edges are condensed into one. This is necessary to keep the 181 /// number of phi values equal to the number of branches to SuccBB. 182 static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, 183 unsigned NumMergedCases) { 184 for (BasicBlock::iterator I = SuccBB->begin(), 185 IE = SuccBB->getFirstNonPHI()->getIterator(); 186 I != IE; ++I) { 187 PHINode *PN = cast<PHINode>(I); 188 189 // Only update the first occurrence. 190 unsigned Idx = 0, E = PN->getNumIncomingValues(); 191 unsigned LocalNumMergedCases = NumMergedCases; 192 for (; Idx != E; ++Idx) { 193 if (PN->getIncomingBlock(Idx) == OrigBB) { 194 PN->setIncomingBlock(Idx, NewBB); 195 break; 196 } 197 } 198 199 // Remove additional occurrences coming from condensed cases and keep the 200 // number of incoming values equal to the number of branches to SuccBB. 201 SmallVector<unsigned, 8> Indices; 202 for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) 203 if (PN->getIncomingBlock(Idx) == OrigBB) { 204 Indices.push_back(Idx); 205 LocalNumMergedCases--; 206 } 207 // Remove incoming values in the reverse order to prevent invalidating 208 // *successive* index. 209 for (unsigned III : llvm::reverse(Indices)) 210 PN->removeIncomingValue(III); 211 } 212 } 213 214 /// Convert the switch statement into a binary lookup of the case values. 215 /// The function recursively builds this tree. LowerBound and UpperBound are 216 /// used to keep track of the bounds for Val that have already been checked by 217 /// a block emitted by one of the previous calls to switchConvert in the call 218 /// stack. 219 BasicBlock * 220 LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, 221 ConstantInt *UpperBound, Value *Val, 222 BasicBlock *Predecessor, BasicBlock *OrigBlock, 223 BasicBlock *Default, 224 const std::vector<IntRange> &UnreachableRanges) { 225 unsigned Size = End - Begin; 226 227 if (Size == 1) { 228 // Check if the Case Range is perfectly squeezed in between 229 // already checked Upper and Lower bounds. If it is then we can avoid 230 // emitting the code that checks if the value actually falls in the range 231 // because the bounds already tell us so. 232 if (Begin->Low == LowerBound && Begin->High == UpperBound) { 233 unsigned NumMergedCases = 0; 234 if (LowerBound && UpperBound) 235 NumMergedCases = 236 UpperBound->getSExtValue() - LowerBound->getSExtValue(); 237 fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); 238 return Begin->BB; 239 } 240 return newLeafBlock(*Begin, Val, OrigBlock, Default); 241 } 242 243 unsigned Mid = Size / 2; 244 std::vector<CaseRange> LHS(Begin, Begin + Mid); 245 DEBUG(dbgs() << "LHS: " << LHS << "\n"); 246 std::vector<CaseRange> RHS(Begin + Mid, End); 247 DEBUG(dbgs() << "RHS: " << RHS << "\n"); 248 249 CaseRange &Pivot = *(Begin + Mid); 250 DEBUG(dbgs() << "Pivot ==> " 251 << Pivot.Low->getValue() 252 << " -" << Pivot.High->getValue() << "\n"); 253 254 // NewLowerBound here should never be the integer minimal value. 255 // This is because it is computed from a case range that is never 256 // the smallest, so there is always a case range that has at least 257 // a smaller value. 258 ConstantInt *NewLowerBound = Pivot.Low; 259 260 // Because NewLowerBound is never the smallest representable integer 261 // it is safe here to subtract one. 262 ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), 263 NewLowerBound->getValue() - 1); 264 265 if (!UnreachableRanges.empty()) { 266 // Check if the gap between LHS's highest and NewLowerBound is unreachable. 267 int64_t GapLow = LHS.back().High->getSExtValue() + 1; 268 int64_t GapHigh = NewLowerBound->getSExtValue() - 1; 269 IntRange Gap = { GapLow, GapHigh }; 270 if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) 271 NewUpperBound = LHS.back().High; 272 } 273 274 DEBUG(dbgs() << "LHS Bounds ==> "; 275 if (LowerBound) { 276 dbgs() << LowerBound->getSExtValue(); 277 } else { 278 dbgs() << "NONE"; 279 } 280 dbgs() << " - " << NewUpperBound->getSExtValue() << "\n"; 281 dbgs() << "RHS Bounds ==> "; 282 dbgs() << NewLowerBound->getSExtValue() << " - "; 283 if (UpperBound) { 284 dbgs() << UpperBound->getSExtValue() << "\n"; 285 } else { 286 dbgs() << "NONE\n"; 287 }); 288 289 // Create a new node that checks if the value is < pivot. Go to the 290 // left branch if it is and right branch if not. 291 Function* F = OrigBlock->getParent(); 292 BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 293 294 ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 295 Val, Pivot.Low, "Pivot"); 296 297 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, 298 NewUpperBound, Val, NewNode, OrigBlock, 299 Default, UnreachableRanges); 300 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, 301 UpperBound, Val, NewNode, OrigBlock, 302 Default, UnreachableRanges); 303 304 F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode); 305 NewNode->getInstList().push_back(Comp); 306 307 BranchInst::Create(LBranch, RBranch, Comp, NewNode); 308 return NewNode; 309 } 310 311 /// Create a new leaf block for the binary lookup tree. It checks if the 312 /// switch's value == the case's value. If not, then it jumps to the default 313 /// branch. At this point in the tree, the value can't be another valid case 314 /// value, so the jump to the "default" branch is warranted. 315 BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 316 BasicBlock* OrigBlock, 317 BasicBlock* Default) { 318 Function* F = OrigBlock->getParent(); 319 BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 320 F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); 321 322 // Emit comparison 323 ICmpInst* Comp = nullptr; 324 if (Leaf.Low == Leaf.High) { 325 // Make the seteq instruction... 326 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 327 Leaf.Low, "SwitchLeaf"); 328 } else { 329 // Make range comparison 330 if (Leaf.Low->isMinValue(true /*isSigned*/)) { 331 // Val >= Min && Val <= Hi --> Val <= Hi 332 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 333 "SwitchLeaf"); 334 } else if (Leaf.Low->isZero()) { 335 // Val >= 0 && Val <= Hi --> Val <=u Hi 336 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 337 "SwitchLeaf"); 338 } else { 339 // Emit V-Lo <=u Hi-Lo 340 Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 341 Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 342 Val->getName()+".off", 343 NewLeaf); 344 Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 345 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 346 "SwitchLeaf"); 347 } 348 } 349 350 // Make the conditional branch... 351 BasicBlock* Succ = Leaf.BB; 352 BranchInst::Create(Succ, Default, Comp, NewLeaf); 353 354 // If there were any PHI nodes in this successor, rewrite one entry 355 // from OrigBlock to come from NewLeaf. 356 for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 357 PHINode* PN = cast<PHINode>(I); 358 // Remove all but one incoming entries from the cluster 359 uint64_t Range = Leaf.High->getSExtValue() - 360 Leaf.Low->getSExtValue(); 361 for (uint64_t j = 0; j < Range; ++j) { 362 PN->removeIncomingValue(OrigBlock); 363 } 364 365 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 366 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 367 PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 368 } 369 370 return NewLeaf; 371 } 372 373 /// Transform simple list of Cases into list of CaseRange's. 374 unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 375 unsigned numCmps = 0; 376 377 // Start with "simple" cases 378 for (auto Case : SI->cases()) 379 Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), 380 Case.getCaseSuccessor())); 381 382 llvm::sort(Cases.begin(), Cases.end(), CaseCmp()); 383 384 // Merge case into clusters 385 if (Cases.size() >= 2) { 386 CaseItr I = Cases.begin(); 387 for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { 388 int64_t nextValue = J->Low->getSExtValue(); 389 int64_t currentValue = I->High->getSExtValue(); 390 BasicBlock* nextBB = J->BB; 391 BasicBlock* currentBB = I->BB; 392 393 // If the two neighboring cases go to the same destination, merge them 394 // into a single case. 395 assert(nextValue > currentValue && "Cases should be strictly ascending"); 396 if ((nextValue == currentValue + 1) && (currentBB == nextBB)) { 397 I->High = J->High; 398 // FIXME: Combine branch weights. 399 } else if (++I != J) { 400 *I = *J; 401 } 402 } 403 Cases.erase(std::next(I), Cases.end()); 404 } 405 406 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 407 if (I->Low != I->High) 408 // A range counts double, since it requires two compares. 409 ++numCmps; 410 } 411 412 return numCmps; 413 } 414 415 /// Replace the specified switch instruction with a sequence of chained if-then 416 /// insts in a balanced binary search. 417 void LowerSwitch::processSwitchInst(SwitchInst *SI, 418 SmallPtrSetImpl<BasicBlock*> &DeleteList) { 419 BasicBlock *CurBlock = SI->getParent(); 420 BasicBlock *OrigBlock = CurBlock; 421 Function *F = CurBlock->getParent(); 422 Value *Val = SI->getCondition(); // The value we are switching on... 423 BasicBlock* Default = SI->getDefaultDest(); 424 425 // Don't handle unreachable blocks. If there are successors with phis, this 426 // would leave them behind with missing predecessors. 427 if ((CurBlock != &F->getEntryBlock() && pred_empty(CurBlock)) || 428 CurBlock->getSinglePredecessor() == CurBlock) { 429 DeleteList.insert(CurBlock); 430 return; 431 } 432 433 // If there is only the default destination, just branch. 434 if (!SI->getNumCases()) { 435 BranchInst::Create(Default, CurBlock); 436 SI->eraseFromParent(); 437 return; 438 } 439 440 // Prepare cases vector. 441 CaseVector Cases; 442 unsigned numCmps = Clusterify(Cases, SI); 443 DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 444 << ". Total compares: " << numCmps << "\n"); 445 DEBUG(dbgs() << "Cases: " << Cases << "\n"); 446 (void)numCmps; 447 448 ConstantInt *LowerBound = nullptr; 449 ConstantInt *UpperBound = nullptr; 450 std::vector<IntRange> UnreachableRanges; 451 452 if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) { 453 // Make the bounds tightly fitted around the case value range, because we 454 // know that the value passed to the switch must be exactly one of the case 455 // values. 456 assert(!Cases.empty()); 457 LowerBound = Cases.front().Low; 458 UpperBound = Cases.back().High; 459 460 DenseMap<BasicBlock *, unsigned> Popularity; 461 unsigned MaxPop = 0; 462 BasicBlock *PopSucc = nullptr; 463 464 IntRange R = {std::numeric_limits<int64_t>::min(), 465 std::numeric_limits<int64_t>::max()}; 466 UnreachableRanges.push_back(R); 467 for (const auto &I : Cases) { 468 int64_t Low = I.Low->getSExtValue(); 469 int64_t High = I.High->getSExtValue(); 470 471 IntRange &LastRange = UnreachableRanges.back(); 472 if (LastRange.Low == Low) { 473 // There is nothing left of the previous range. 474 UnreachableRanges.pop_back(); 475 } else { 476 // Terminate the previous range. 477 assert(Low > LastRange.Low); 478 LastRange.High = Low - 1; 479 } 480 if (High != std::numeric_limits<int64_t>::max()) { 481 IntRange R = { High + 1, std::numeric_limits<int64_t>::max() }; 482 UnreachableRanges.push_back(R); 483 } 484 485 // Count popularity. 486 int64_t N = High - Low + 1; 487 unsigned &Pop = Popularity[I.BB]; 488 if ((Pop += N) > MaxPop) { 489 MaxPop = Pop; 490 PopSucc = I.BB; 491 } 492 } 493 #ifndef NDEBUG 494 /* UnreachableRanges should be sorted and the ranges non-adjacent. */ 495 for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); 496 I != E; ++I) { 497 assert(I->Low <= I->High); 498 auto Next = I + 1; 499 if (Next != E) { 500 assert(Next->Low > I->High); 501 } 502 } 503 #endif 504 505 // Use the most popular block as the new default, reducing the number of 506 // cases. 507 assert(MaxPop > 0 && PopSucc); 508 Default = PopSucc; 509 Cases.erase( 510 llvm::remove_if( 511 Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), 512 Cases.end()); 513 514 // If there are no cases left, just branch. 515 if (Cases.empty()) { 516 BranchInst::Create(Default, CurBlock); 517 SI->eraseFromParent(); 518 return; 519 } 520 } 521 522 // Create a new, empty default block so that the new hierarchy of 523 // if-then statements go to this and the PHI nodes are happy. 524 BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 525 F->getBasicBlockList().insert(Default->getIterator(), NewDefault); 526 BranchInst::Create(Default, NewDefault); 527 528 // If there is an entry in any PHI nodes for the default edge, make sure 529 // to update them as well. 530 for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 531 PHINode *PN = cast<PHINode>(I); 532 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 533 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 534 PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 535 } 536 537 BasicBlock *SwitchBlock = 538 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, 539 OrigBlock, OrigBlock, NewDefault, UnreachableRanges); 540 541 // Branch to our shiny new if-then stuff... 542 BranchInst::Create(SwitchBlock, OrigBlock); 543 544 // We are now done with the switch instruction, delete it. 545 BasicBlock *OldDefault = SI->getDefaultDest(); 546 CurBlock->getInstList().erase(SI); 547 548 // If the Default block has no more predecessors just add it to DeleteList. 549 if (pred_begin(OldDefault) == pred_end(OldDefault)) 550 DeleteList.insert(OldDefault); 551 } 552