1 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The LowerSwitch transformation rewrites switch instructions with a sequence 11 // of branches, which allows targets to get away with not implementing the 12 // switch instruction until it is convenient. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar.h" 17 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 18 #include "llvm/Constants.h" 19 #include "llvm/Function.h" 20 #include "llvm/Instructions.h" 21 #include "llvm/LLVMContext.h" 22 #include "llvm/Pass.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/Support/Compiler.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/raw_ostream.h" 27 #include <algorithm> 28 using namespace llvm; 29 30 namespace { 31 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 32 /// instructions. 33 class LowerSwitch : public FunctionPass { 34 public: 35 static char ID; // Pass identification, replacement for typeid 36 LowerSwitch() : FunctionPass(ID) {} 37 38 virtual bool runOnFunction(Function &F); 39 40 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 41 // This is a cluster of orthogonal Transforms 42 AU.addPreserved<UnifyFunctionExitNodes>(); 43 AU.addPreserved("mem2reg"); 44 AU.addPreservedID(LowerInvokePassID); 45 } 46 47 struct CaseRange { 48 Constant* Low; 49 Constant* High; 50 BasicBlock* BB; 51 52 CaseRange(Constant *low = 0, Constant *high = 0, BasicBlock *bb = 0) : 53 Low(low), High(high), BB(bb) { } 54 }; 55 56 typedef std::vector<CaseRange> CaseVector; 57 typedef std::vector<CaseRange>::iterator CaseItr; 58 private: 59 void processSwitchInst(SwitchInst *SI); 60 61 BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val, 62 BasicBlock* OrigBlock, BasicBlock* Default); 63 BasicBlock* newLeafBlock(CaseRange& Leaf, Value* Val, 64 BasicBlock* OrigBlock, BasicBlock* Default); 65 unsigned Clusterify(CaseVector& Cases, SwitchInst *SI); 66 }; 67 68 /// The comparison function for sorting the switch case values in the vector. 69 /// WARNING: Case ranges should be disjoint! 70 struct CaseCmp { 71 bool operator () (const LowerSwitch::CaseRange& C1, 72 const LowerSwitch::CaseRange& C2) { 73 74 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 75 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 76 return CI1->getValue().slt(CI2->getValue()); 77 } 78 }; 79 } 80 81 char LowerSwitch::ID = 0; 82 INITIALIZE_PASS(LowerSwitch, "lowerswitch", 83 "Lower SwitchInst's to branches", false, false); 84 85 // Publically exposed interface to pass... 86 char &llvm::LowerSwitchID = LowerSwitch::ID; 87 // createLowerSwitchPass - Interface to this file... 88 FunctionPass *llvm::createLowerSwitchPass() { 89 return new LowerSwitch(); 90 } 91 92 bool LowerSwitch::runOnFunction(Function &F) { 93 bool Changed = false; 94 95 for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 96 BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 97 98 if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 99 Changed = true; 100 processSwitchInst(SI); 101 } 102 } 103 104 return Changed; 105 } 106 107 // operator<< - Used for debugging purposes. 108 // 109 static raw_ostream& operator<<(raw_ostream &O, 110 const LowerSwitch::CaseVector &C) ATTRIBUTE_USED; 111 static raw_ostream& operator<<(raw_ostream &O, 112 const LowerSwitch::CaseVector &C) { 113 O << "["; 114 115 for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 116 E = C.end(); B != E; ) { 117 O << *B->Low << " -" << *B->High; 118 if (++B != E) O << ", "; 119 } 120 121 return O << "]"; 122 } 123 124 // switchConvert - Convert the switch statement into a binary lookup of 125 // the case values. The function recursively builds this tree. 126 // 127 BasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 128 Value* Val, BasicBlock* OrigBlock, 129 BasicBlock* Default) 130 { 131 unsigned Size = End - Begin; 132 133 if (Size == 1) 134 return newLeafBlock(*Begin, Val, OrigBlock, Default); 135 136 unsigned Mid = Size / 2; 137 std::vector<CaseRange> LHS(Begin, Begin + Mid); 138 DEBUG(dbgs() << "LHS: " << LHS << "\n"); 139 std::vector<CaseRange> RHS(Begin + Mid, End); 140 DEBUG(dbgs() << "RHS: " << RHS << "\n"); 141 142 CaseRange& Pivot = *(Begin + Mid); 143 DEBUG(dbgs() << "Pivot ==> " 144 << cast<ConstantInt>(Pivot.Low)->getValue() << " -" 145 << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 146 147 BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val, 148 OrigBlock, Default); 149 BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val, 150 OrigBlock, Default); 151 152 // Create a new node that checks if the value is < pivot. Go to the 153 // left branch if it is and right branch if not. 154 Function* F = OrigBlock->getParent(); 155 BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 156 Function::iterator FI = OrigBlock; 157 F->getBasicBlockList().insert(++FI, NewNode); 158 159 ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 160 Val, Pivot.Low, "Pivot"); 161 NewNode->getInstList().push_back(Comp); 162 BranchInst::Create(LBranch, RBranch, Comp, NewNode); 163 return NewNode; 164 } 165 166 // newLeafBlock - Create a new leaf block for the binary lookup tree. It 167 // checks if the switch's value == the case's value. If not, then it 168 // jumps to the default branch. At this point in the tree, the value 169 // can't be another valid case value, so the jump to the "default" branch 170 // is warranted. 171 // 172 BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 173 BasicBlock* OrigBlock, 174 BasicBlock* Default) 175 { 176 Function* F = OrigBlock->getParent(); 177 BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 178 Function::iterator FI = OrigBlock; 179 F->getBasicBlockList().insert(++FI, NewLeaf); 180 181 // Emit comparison 182 ICmpInst* Comp = NULL; 183 if (Leaf.Low == Leaf.High) { 184 // Make the seteq instruction... 185 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 186 Leaf.Low, "SwitchLeaf"); 187 } else { 188 // Make range comparison 189 if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 190 // Val >= Min && Val <= Hi --> Val <= Hi 191 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 192 "SwitchLeaf"); 193 } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 194 // Val >= 0 && Val <= Hi --> Val <=u Hi 195 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 196 "SwitchLeaf"); 197 } else { 198 // Emit V-Lo <=u Hi-Lo 199 Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 200 Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 201 Val->getName()+".off", 202 NewLeaf); 203 Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 204 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 205 "SwitchLeaf"); 206 } 207 } 208 209 // Make the conditional branch... 210 BasicBlock* Succ = Leaf.BB; 211 BranchInst::Create(Succ, Default, Comp, NewLeaf); 212 213 // If there were any PHI nodes in this successor, rewrite one entry 214 // from OrigBlock to come from NewLeaf. 215 for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 216 PHINode* PN = cast<PHINode>(I); 217 // Remove all but one incoming entries from the cluster 218 uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 219 cast<ConstantInt>(Leaf.Low)->getSExtValue(); 220 for (uint64_t j = 0; j < Range; ++j) { 221 PN->removeIncomingValue(OrigBlock); 222 } 223 224 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 225 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 226 PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 227 } 228 229 return NewLeaf; 230 } 231 232 // Clusterify - Transform simple list of Cases into list of CaseRange's 233 unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 234 unsigned numCmps = 0; 235 236 // Start with "simple" cases 237 for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) 238 Cases.push_back(CaseRange(SI->getSuccessorValue(i), 239 SI->getSuccessorValue(i), 240 SI->getSuccessor(i))); 241 std::sort(Cases.begin(), Cases.end(), CaseCmp()); 242 243 // Merge case into clusters 244 if (Cases.size()>=2) 245 for (CaseItr I=Cases.begin(), J=llvm::next(Cases.begin()); J!=Cases.end(); ) { 246 int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 247 int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 248 BasicBlock* nextBB = J->BB; 249 BasicBlock* currentBB = I->BB; 250 251 // If the two neighboring cases go to the same destination, merge them 252 // into a single case. 253 if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 254 I->High = J->High; 255 J = Cases.erase(J); 256 } else { 257 I = J++; 258 } 259 } 260 261 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 262 if (I->Low != I->High) 263 // A range counts double, since it requires two compares. 264 ++numCmps; 265 } 266 267 return numCmps; 268 } 269 270 // processSwitchInst - Replace the specified switch instruction with a sequence 271 // of chained if-then insts in a balanced binary search. 272 // 273 void LowerSwitch::processSwitchInst(SwitchInst *SI) { 274 BasicBlock *CurBlock = SI->getParent(); 275 BasicBlock *OrigBlock = CurBlock; 276 Function *F = CurBlock->getParent(); 277 Value *Val = SI->getOperand(0); // The value we are switching on... 278 BasicBlock* Default = SI->getDefaultDest(); 279 280 // If there is only the default destination, don't bother with the code below. 281 if (SI->getNumOperands() == 2) { 282 BranchInst::Create(SI->getDefaultDest(), CurBlock); 283 CurBlock->getInstList().erase(SI); 284 return; 285 } 286 287 // Create a new, empty default block so that the new hierarchy of 288 // if-then statements go to this and the PHI nodes are happy. 289 BasicBlock* NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 290 F->getBasicBlockList().insert(Default, NewDefault); 291 292 BranchInst::Create(Default, NewDefault); 293 294 // If there is an entry in any PHI nodes for the default edge, make sure 295 // to update them as well. 296 for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 297 PHINode *PN = cast<PHINode>(I); 298 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 299 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 300 PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 301 } 302 303 // Prepare cases vector. 304 CaseVector Cases; 305 unsigned numCmps = Clusterify(Cases, SI); 306 307 DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 308 << ". Total compares: " << numCmps << "\n"); 309 DEBUG(dbgs() << "Cases: " << Cases << "\n"); 310 (void)numCmps; 311 312 BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val, 313 OrigBlock, NewDefault); 314 315 // Branch to our shiny new if-then stuff... 316 BranchInst::Create(SwitchBlock, OrigBlock); 317 318 // We are now done with the switch instruction, delete it. 319 CurBlock->getInstList().erase(SI); 320 } 321