1 //===- CallSiteSplitting.cpp ----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements a transformation that tries to split a call-site to pass 11 // more constrained arguments if its argument is predicated in the control flow 12 // so that we can expose better context to the later passes (e.g, inliner, jump 13 // threading, or IPA-CP based function cloning, etc.). 14 // As of now we support two cases : 15 // 16 // 1) Try to a split call-site with constrained arguments, if any constraints 17 // on any argument can be found by following the single predecessors of the 18 // all site's predecessors. Currently this pass only handles call-sites with 2 19 // predecessors. For example, in the code below, we try to split the call-site 20 // since we can predicate the argument(ptr) based on the OR condition. 21 // 22 // Split from : 23 // if (!ptr || c) 24 // callee(ptr); 25 // to : 26 // if (!ptr) 27 // callee(null) // set the known constant value 28 // else if (c) 29 // callee(nonnull ptr) // set non-null attribute in the argument 30 // 31 // 2) We can also split a call-site based on constant incoming values of a PHI 32 // For example, 33 // from : 34 // Header: 35 // %c = icmp eq i32 %i1, %i2 36 // br i1 %c, label %Tail, label %TBB 37 // TBB: 38 // br label Tail% 39 // Tail: 40 // %p = phi i32 [ 0, %Header], [ 1, %TBB] 41 // call void @bar(i32 %p) 42 // to 43 // Header: 44 // %c = icmp eq i32 %i1, %i2 45 // br i1 %c, label %Tail-split0, label %TBB 46 // TBB: 47 // br label %Tail-split1 48 // Tail-split0: 49 // call void @bar(i32 0) 50 // br label %Tail 51 // Tail-split1: 52 // call void @bar(i32 1) 53 // br label %Tail 54 // Tail: 55 // %p = phi i32 [ 0, %Tail-split0 ], [ 1, %Tail-split1 ] 56 // 57 //===----------------------------------------------------------------------===// 58 59 #include "llvm/Transforms/Scalar/CallSiteSplitting.h" 60 #include "llvm/ADT/Statistic.h" 61 #include "llvm/Analysis/TargetLibraryInfo.h" 62 #include "llvm/IR/IntrinsicInst.h" 63 #include "llvm/IR/PatternMatch.h" 64 #include "llvm/Support/Debug.h" 65 #include "llvm/Transforms/Scalar.h" 66 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 67 #include "llvm/Transforms/Utils/Local.h" 68 69 using namespace llvm; 70 using namespace PatternMatch; 71 72 #define DEBUG_TYPE "callsite-splitting" 73 74 STATISTIC(NumCallSiteSplit, "Number of call-site split"); 75 76 static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI, 77 Value *Op) { 78 CallSite CS(NewCallI); 79 unsigned ArgNo = 0; 80 for (auto &I : CS.args()) { 81 if (&*I == Op) 82 CS.addParamAttr(ArgNo, Attribute::NonNull); 83 ++ArgNo; 84 } 85 } 86 87 static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI, 88 Value *Op, Constant *ConstValue) { 89 CallSite CS(NewCallI); 90 unsigned ArgNo = 0; 91 for (auto &I : CS.args()) { 92 if (&*I == Op) 93 CS.setArgument(ArgNo, ConstValue); 94 ++ArgNo; 95 } 96 } 97 98 static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { 99 assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand."); 100 Value *Op0 = Cmp->getOperand(0); 101 unsigned ArgNo = 0; 102 for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; 103 ++I, ++ArgNo) { 104 // Don't consider constant or arguments that are already known non-null. 105 if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull)) 106 continue; 107 108 if (*I == Op0) 109 return true; 110 } 111 return false; 112 } 113 114 /// If From has a conditional jump to To, add the condition to Conditions, 115 /// if it is relevant to any argument at CS. 116 static void 117 recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To, 118 SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { 119 auto *BI = dyn_cast<BranchInst>(From->getTerminator()); 120 if (!BI || !BI->isConditional()) 121 return; 122 123 CmpInst::Predicate Pred; 124 Value *Cond = BI->getCondition(); 125 if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) 126 return; 127 128 ICmpInst *Cmp = cast<ICmpInst>(Cond); 129 if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) 130 if (isCondRelevantToAnyCallArgument(Cmp, CS)) 131 Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To 132 ? Pred 133 : Cmp->getInversePredicate()}); 134 } 135 136 /// Record ICmp conditions relevant to any argument in CS following Pred's 137 /// single successors. If there are conflicting conditions along a path, like 138 /// x == 1 and x == 0, the first condition will be used. 139 static void 140 recordConditions(const CallSite &CS, BasicBlock *Pred, 141 SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { 142 recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); 143 BasicBlock *From = Pred; 144 BasicBlock *To = Pred; 145 SmallPtrSet<BasicBlock *, 4> Visited; 146 while (!Visited.count(From->getSinglePredecessor()) && 147 (From = From->getSinglePredecessor())) { 148 recordCondition(CS, From, To, Conditions); 149 Visited.insert(From); 150 To = From; 151 } 152 } 153 154 static Instruction * 155 addConditions(CallSite &CS, 156 SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { 157 if (Conditions.empty()) 158 return nullptr; 159 160 Instruction *NewCI = CS.getInstruction()->clone(); 161 for (auto &Cond : Conditions) { 162 Value *Arg = Cond.first->getOperand(0); 163 Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1)); 164 if (Cond.second == ICmpInst::ICMP_EQ) 165 setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal); 166 else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { 167 assert(Cond.second == ICmpInst::ICMP_NE); 168 addNonNullAttribute(CS.getInstruction(), NewCI, Arg); 169 } 170 } 171 return NewCI; 172 } 173 174 static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { 175 SmallVector<BasicBlock *, 2> Preds(predecessors((BB))); 176 assert(Preds.size() == 2 && "Expected exactly 2 predecessors!"); 177 return Preds; 178 } 179 180 static bool canSplitCallSite(CallSite CS) { 181 // FIXME: As of now we handle only CallInst. InvokeInst could be handled 182 // without too much effort. 183 Instruction *Instr = CS.getInstruction(); 184 if (!isa<CallInst>(Instr)) 185 return false; 186 187 // Allow splitting a call-site only when there is no instruction before the 188 // call-site in the basic block. Based on this constraint, we only clone the 189 // call instruction, and we do not move a call-site across any other 190 // instruction. 191 BasicBlock *CallSiteBB = Instr->getParent(); 192 if (Instr != CallSiteBB->getFirstNonPHIOrDbg()) 193 return false; 194 195 // Need 2 predecessors and cannot split an edge from an IndirectBrInst. 196 SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB)); 197 if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) || 198 isa<IndirectBrInst>(Preds[1]->getTerminator())) 199 return false; 200 201 return CallSiteBB->canSplitPredecessors(); 202 } 203 204 static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before, 205 Value *V) { 206 Instruction *Copy = I->clone(); 207 Copy->setName(I->getName()); 208 Copy->insertBefore(Before); 209 if (V) 210 Copy->setOperand(0, V); 211 return Copy; 212 } 213 214 /// Copy mandatory `musttail` return sequence that follows original `CI`, and 215 /// link it up to `NewCI` value instead: 216 /// 217 /// * (optional) `bitcast NewCI to ...` 218 /// * `ret bitcast or NewCI` 219 /// 220 /// Insert this sequence right before `SplitBB`'s terminator, which will be 221 /// cleaned up later in `splitCallSite` below. 222 static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, 223 Instruction *NewCI) { 224 bool IsVoid = SplitBB->getParent()->getReturnType()->isVoidTy(); 225 auto II = std::next(CI->getIterator()); 226 227 BitCastInst *BCI = dyn_cast<BitCastInst>(&*II); 228 if (BCI) 229 ++II; 230 231 ReturnInst *RI = dyn_cast<ReturnInst>(&*II); 232 assert(RI && "`musttail` call must be followed by `ret` instruction"); 233 234 TerminatorInst *TI = SplitBB->getTerminator(); 235 Value *V = NewCI; 236 if (BCI) 237 V = cloneInstForMustTail(BCI, TI, V); 238 cloneInstForMustTail(RI, TI, IsVoid ? nullptr : V); 239 240 // FIXME: remove TI here, `DuplicateInstructionsInSplitBetween` has a bug 241 // that prevents doing this now. 242 } 243 244 /// Return true if the CS is split into its new predecessors which are directly 245 /// hooked to each of its original predecessors pointed by PredBB1 and PredBB2. 246 /// CallInst1 and CallInst2 will be the new call-sites placed in the new 247 /// predecessors split for PredBB1 and PredBB2, respectively. 248 /// For example, in the IR below with an OR condition, the call-site can 249 /// be split. Assuming PredBB1=Header and PredBB2=TBB, CallInst1 will be the 250 /// call-site placed between Header and Tail, and CallInst2 will be the 251 /// call-site between TBB and Tail. 252 /// 253 /// From : 254 /// 255 /// Header: 256 /// %c = icmp eq i32* %a, null 257 /// br i1 %c %Tail, %TBB 258 /// TBB: 259 /// %c2 = icmp eq i32* %b, null 260 /// br i1 %c %Tail, %End 261 /// Tail: 262 /// %ca = call i1 @callee (i32* %a, i32* %b) 263 /// 264 /// to : 265 /// 266 /// Header: // PredBB1 is Header 267 /// %c = icmp eq i32* %a, null 268 /// br i1 %c %Tail-split1, %TBB 269 /// TBB: // PredBB2 is TBB 270 /// %c2 = icmp eq i32* %b, null 271 /// br i1 %c %Tail-split2, %End 272 /// Tail-split1: 273 /// %ca1 = call @callee (i32* null, i32* %b) // CallInst1 274 /// br %Tail 275 /// Tail-split2: 276 /// %ca2 = call @callee (i32* nonnull %a, i32* null) // CallInst2 277 /// br %Tail 278 /// Tail: 279 /// %p = phi i1 [%ca1, %Tail-split1],[%ca2, %Tail-split2] 280 /// 281 /// Note that in case any arguments at the call-site are constrained by its 282 /// predecessors, new call-sites with more constrained arguments will be 283 /// created in createCallSitesOnPredicatedArgument(). 284 static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2, 285 Instruction *CallInst1, Instruction *CallInst2) { 286 Instruction *Instr = CS.getInstruction(); 287 BasicBlock *TailBB = Instr->getParent(); 288 bool IsMustTailCall = CS.isMustTailCall(); 289 assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site"); 290 291 BasicBlock *SplitBlock1 = 292 SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split"); 293 BasicBlock *SplitBlock2 = 294 SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split"); 295 296 assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split."); 297 298 if (!CallInst1) 299 CallInst1 = Instr->clone(); 300 if (!CallInst2) 301 CallInst2 = Instr->clone(); 302 303 CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt()); 304 CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt()); 305 306 CallSite CS1(CallInst1); 307 CallSite CS2(CallInst2); 308 309 // Handle PHIs used as arguments in the call-site. 310 for (PHINode &PN : TailBB->phis()) { 311 unsigned ArgNo = 0; 312 for (auto &CI : CS.args()) { 313 if (&*CI == &PN) { 314 CS1.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock1)); 315 CS2.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock2)); 316 } 317 ++ArgNo; 318 } 319 } 320 // Clone and place bitcast and return instructions before `TI` 321 if (IsMustTailCall) { 322 copyMustTailReturn(SplitBlock1, CS.getInstruction(), CallInst1); 323 copyMustTailReturn(SplitBlock2, CS.getInstruction(), CallInst2); 324 } 325 326 // Replace users of the original call with a PHI mering call-sites split. 327 if (!IsMustTailCall && Instr->getNumUses()) { 328 PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call", 329 TailBB->getFirstNonPHI()); 330 PN->addIncoming(CallInst1, SplitBlock1); 331 PN->addIncoming(CallInst2, SplitBlock2); 332 Instr->replaceAllUsesWith(PN); 333 } 334 DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); 335 DEBUG(dbgs() << " " << *CallInst1 << " in " << SplitBlock1->getName() 336 << "\n"); 337 DEBUG(dbgs() << " " << *CallInst2 << " in " << SplitBlock2->getName() 338 << "\n"); 339 340 NumCallSiteSplit++; 341 342 // FIXME: remove TI in `copyMustTailReturn` 343 if (IsMustTailCall) { 344 // Remove superfluous `br` terminators from the end of the Split blocks 345 // NOTE: Removing terminator removes the SplitBlock from the TailBB's 346 // predecessors. Therefore we must get complete list of Splits before 347 // attempting removal. 348 SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB))); 349 assert(Splits.size() == 2 && "Expected exactly 2 splits!"); 350 for (unsigned i = 0; i < Splits.size(); i++) 351 Splits[i]->getTerminator()->eraseFromParent(); 352 353 // Erase the tail block once done with musttail patching 354 TailBB->eraseFromParent(); 355 return; 356 } 357 Instr->eraseFromParent(); 358 } 359 360 // Return true if the call-site has an argument which is a PHI with only 361 // constant incoming values. 362 static bool isPredicatedOnPHI(CallSite CS) { 363 Instruction *Instr = CS.getInstruction(); 364 BasicBlock *Parent = Instr->getParent(); 365 if (Instr != Parent->getFirstNonPHIOrDbg()) 366 return false; 367 368 for (auto &BI : *Parent) { 369 if (PHINode *PN = dyn_cast<PHINode>(&BI)) { 370 for (auto &I : CS.args()) 371 if (&*I == PN) { 372 assert(PN->getNumIncomingValues() == 2 && 373 "Unexpected number of incoming values"); 374 if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1)) 375 return false; 376 if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) 377 continue; 378 if (isa<Constant>(PN->getIncomingValue(0)) && 379 isa<Constant>(PN->getIncomingValue(1))) 380 return true; 381 } 382 } 383 break; 384 } 385 return false; 386 } 387 388 static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) { 389 if (!isPredicatedOnPHI(CS)) 390 return false; 391 392 auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); 393 splitCallSite(CS, Preds[0], Preds[1], nullptr, nullptr); 394 return true; 395 } 396 397 static bool tryToSplitOnPredicatedArgument(CallSite CS) { 398 auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); 399 if (Preds[0] == Preds[1]) 400 return false; 401 402 SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2; 403 recordConditions(CS, Preds[0], C1); 404 recordConditions(CS, Preds[1], C2); 405 406 Instruction *CallInst1 = addConditions(CS, C1); 407 Instruction *CallInst2 = addConditions(CS, C2); 408 if (!CallInst1 && !CallInst2) 409 return false; 410 411 splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1); 412 return true; 413 } 414 415 static bool tryToSplitCallSite(CallSite CS) { 416 if (!CS.arg_size() || !canSplitCallSite(CS)) 417 return false; 418 return tryToSplitOnPredicatedArgument(CS) || 419 tryToSplitOnPHIPredicatedArgument(CS); 420 } 421 422 static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) { 423 bool Changed = false; 424 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { 425 BasicBlock &BB = *BI++; 426 for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { 427 Instruction *I = &*II++; 428 CallSite CS(cast<Value>(I)); 429 if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) 430 continue; 431 432 Function *Callee = CS.getCalledFunction(); 433 if (!Callee || Callee->isDeclaration()) 434 continue; 435 436 // Successful musttail call-site splits result in erased CI and erased BB. 437 // Check if such path is possible before attempting the splitting. 438 bool IsMustTail = CS.isMustTailCall(); 439 440 Changed |= tryToSplitCallSite(CS); 441 442 // There're no interesting instructions after this. The call site 443 // itself might have been erased on splitting. 444 if (IsMustTail) 445 break; 446 } 447 } 448 return Changed; 449 } 450 451 namespace { 452 struct CallSiteSplittingLegacyPass : public FunctionPass { 453 static char ID; 454 CallSiteSplittingLegacyPass() : FunctionPass(ID) { 455 initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); 456 } 457 458 void getAnalysisUsage(AnalysisUsage &AU) const override { 459 AU.addRequired<TargetLibraryInfoWrapperPass>(); 460 FunctionPass::getAnalysisUsage(AU); 461 } 462 463 bool runOnFunction(Function &F) override { 464 if (skipFunction(F)) 465 return false; 466 467 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 468 return doCallSiteSplitting(F, TLI); 469 } 470 }; 471 } // namespace 472 473 char CallSiteSplittingLegacyPass::ID = 0; 474 INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", 475 "Call-site splitting", false, false) 476 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 477 INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", 478 "Call-site splitting", false, false) 479 FunctionPass *llvm::createCallSiteSplittingPass() { 480 return new CallSiteSplittingLegacyPass(); 481 } 482 483 PreservedAnalyses CallSiteSplittingPass::run(Function &F, 484 FunctionAnalysisManager &AM) { 485 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 486 487 if (!doCallSiteSplitting(F, TLI)) 488 return PreservedAnalyses::all(); 489 PreservedAnalyses PA; 490 return PA; 491 } 492