1 //===- CallSiteSplitting.cpp ----------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a transformation that tries to split a call-site to pass
11 // more constrained arguments if its argument is predicated in the control flow
12 // so that we can expose better context to the later passes (e.g, inliner, jump
13 // threading, or IPA-CP based function cloning, etc.).
14 // As of now we support two cases :
15 //
16 // 1) If a call site is dominated by an OR condition and if any of its arguments
17 // are predicated on this OR condition, try to split the condition with more
18 // constrained arguments. For example, in the code below, we try to split the
19 // call site since we can predicate the argument(ptr) based on the OR condition.
20 //
21 // Split from :
22 //   if (!ptr || c)
23 //     callee(ptr);
24 // to :
25 //   if (!ptr)
26 //     callee(null)         // set the known constant value
27 //   else if (c)
28 //     callee(nonnull ptr)  // set non-null attribute in the argument
29 //
30 // 2) We can also split a call-site based on constant incoming values of a PHI
31 // For example,
32 // from :
33 //   Header:
34 //    %c = icmp eq i32 %i1, %i2
35 //    br i1 %c, label %Tail, label %TBB
36 //   TBB:
37 //    br label Tail%
38 //   Tail:
39 //    %p = phi i32 [ 0, %Header], [ 1, %TBB]
40 //    call void @bar(i32 %p)
41 // to
42 //   Header:
43 //    %c = icmp eq i32 %i1, %i2
44 //    br i1 %c, label %Tail-split0, label %TBB
45 //   TBB:
46 //    br label %Tail-split1
47 //   Tail-split0:
48 //    call void @bar(i32 0)
49 //    br label %Tail
50 //   Tail-split1:
51 //    call void @bar(i32 1)
52 //    br label %Tail
53 //   Tail:
54 //    %p = phi i32 [ 0, %Tail-split0 ], [ 1, %Tail-split1 ]
55 //
56 //===----------------------------------------------------------------------===//
57 
58 #include "llvm/Transforms/Scalar/CallSiteSplitting.h"
59 #include "llvm/ADT/Statistic.h"
60 #include "llvm/Analysis/TargetLibraryInfo.h"
61 #include "llvm/IR/IntrinsicInst.h"
62 #include "llvm/IR/PatternMatch.h"
63 #include "llvm/Support/Debug.h"
64 #include "llvm/Transforms/Scalar.h"
65 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
66 #include "llvm/Transforms/Utils/Local.h"
67 
68 using namespace llvm;
69 using namespace PatternMatch;
70 
71 #define DEBUG_TYPE "callsite-splitting"
72 
73 STATISTIC(NumCallSiteSplit, "Number of call-site split");
74 
75 static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI,
76                                 Value *Op) {
77   if (!NewCallI)
78     NewCallI = CallI->clone();
79   CallSite CS(NewCallI);
80   unsigned ArgNo = 0;
81   for (auto &I : CS.args()) {
82     if (&*I == Op)
83       CS.addParamAttr(ArgNo, Attribute::NonNull);
84     ++ArgNo;
85   }
86 }
87 
88 static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI,
89                                   Value *Op, Constant *ConstValue) {
90   if (!NewCallI)
91     NewCallI = CallI->clone();
92   CallSite CS(NewCallI);
93   unsigned ArgNo = 0;
94   for (auto &I : CS.args()) {
95     if (&*I == Op)
96       CS.setArgument(ArgNo, ConstValue);
97     ++ArgNo;
98   }
99 }
100 
101 static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) {
102   assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand.");
103   Value *Op0 = Cmp->getOperand(0);
104   unsigned ArgNo = 0;
105   for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E;
106        ++I, ++ArgNo) {
107     // Don't consider constant or arguments that are already known non-null.
108     if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull))
109       continue;
110 
111     if (*I == Op0)
112       return true;
113   }
114   return false;
115 }
116 
117 static SmallVector<BranchInst *, 2>
118 findOrCondRelevantToCallArgument(CallSite CS) {
119   SmallVector<BranchInst *, 2> BranchInsts;
120   for (auto PredBB : predecessors(CS.getInstruction()->getParent())) {
121     auto *PBI = dyn_cast<BranchInst>(PredBB->getTerminator());
122     if (!PBI || !PBI->isConditional())
123       continue;
124 
125     CmpInst::Predicate Pred;
126     Value *Cond = PBI->getCondition();
127     if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
128       continue;
129     ICmpInst *Cmp = cast<ICmpInst>(Cond);
130     if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
131       if (isCondRelevantToAnyCallArgument(Cmp, CS))
132         BranchInsts.push_back(PBI);
133   }
134   return BranchInsts;
135 }
136 
137 static bool tryCreateCallSitesOnOrPredicatedArgument(
138     CallSite CS, Instruction *&NewCSTakenFromHeader,
139     Instruction *&NewCSTakenFromNextCond, BasicBlock *HeaderBB) {
140   auto BranchInsts = findOrCondRelevantToCallArgument(CS);
141   assert(BranchInsts.size() <= 2 &&
142          "Unexpected number of blocks in the OR predicated condition");
143   Instruction *Instr = CS.getInstruction();
144   BasicBlock *CallSiteBB = Instr->getParent();
145   TerminatorInst *HeaderTI = HeaderBB->getTerminator();
146   bool IsCSInTakenPath = CallSiteBB == HeaderTI->getSuccessor(0);
147 
148   for (auto *PBI : BranchInsts) {
149     assert(isa<ICmpInst>(PBI->getCondition()) &&
150            "Unexpected condition in a conditional branch.");
151     ICmpInst *Cmp = cast<ICmpInst>(PBI->getCondition());
152     Value *Arg = Cmp->getOperand(0);
153     assert(isa<Constant>(Cmp->getOperand(1)) &&
154            "Expected op1 to be a constant.");
155     Constant *ConstVal = cast<Constant>(Cmp->getOperand(1));
156     CmpInst::Predicate Pred = Cmp->getPredicate();
157 
158     if (PBI->getParent() == HeaderBB) {
159       Instruction *&CallTakenFromHeader =
160           IsCSInTakenPath ? NewCSTakenFromHeader : NewCSTakenFromNextCond;
161       Instruction *&CallUntakenFromHeader =
162           IsCSInTakenPath ? NewCSTakenFromNextCond : NewCSTakenFromHeader;
163 
164       assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
165              "Unexpected predicate in an OR condition");
166 
167       // Set the constant value for agruments in the call predicated based on
168       // the OR condition.
169       Instruction *&CallToSetConst = Pred == ICmpInst::ICMP_EQ
170                                          ? CallTakenFromHeader
171                                          : CallUntakenFromHeader;
172       setConstantInArgument(Instr, CallToSetConst, Arg, ConstVal);
173 
174       // Add the NonNull attribute if compared with the null pointer.
175       if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {
176         Instruction *&CallToSetAttr = Pred == ICmpInst::ICMP_EQ
177                                           ? CallUntakenFromHeader
178                                           : CallTakenFromHeader;
179         addNonNullAttribute(Instr, CallToSetAttr, Arg);
180       }
181       continue;
182     }
183 
184     if (Pred == ICmpInst::ICMP_EQ) {
185       if (PBI->getSuccessor(0) == Instr->getParent()) {
186         // Set the constant value for the call taken from the second block in
187         // the OR condition.
188         setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
189       } else {
190         // Add the NonNull attribute if compared with the null pointer for the
191         // call taken from the second block in the OR condition.
192         if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
193           addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
194       }
195     } else {
196       if (PBI->getSuccessor(0) == Instr->getParent()) {
197         // Add the NonNull attribute if compared with the null pointer for the
198         // call taken from the second block in the OR condition.
199         if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue())
200           addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg);
201       } else if (Pred == ICmpInst::ICMP_NE) {
202         // Set the constant value for the call in the untaken path from the
203         // header block.
204         setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal);
205       } else
206         llvm_unreachable("Unexpected condition");
207     }
208   }
209   return NewCSTakenFromHeader || NewCSTakenFromNextCond;
210 }
211 
212 static bool canSplitCallSite(CallSite CS) {
213   // FIXME: As of now we handle only CallInst. InvokeInst could be handled
214   // without too much effort.
215   Instruction *Instr = CS.getInstruction();
216   if (!isa<CallInst>(Instr))
217     return false;
218 
219   // Allow splitting a call-site only when there is no instruction before the
220   // call-site in the basic block. Based on this constraint, we only clone the
221   // call instruction, and we do not move a call-site across any other
222   // instruction.
223   BasicBlock *CallSiteBB = Instr->getParent();
224   if (Instr != CallSiteBB->getFirstNonPHIOrDbg())
225     return false;
226 
227   // Need 2 predecessors and cannot split an edge from an IndirectBrInst.
228   SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB));
229   if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) ||
230       isa<IndirectBrInst>(Preds[1]->getTerminator()))
231     return false;
232 
233   return CallSiteBB->canSplitPredecessors();
234 }
235 
236 /// Return true if the CS is split into its new predecessors which are directly
237 /// hooked to each of its orignial predecessors pointed by PredBB1 and PredBB2.
238 /// In OR predicated case, PredBB1 will point the header, and PredBB2 will point
239 /// to the second compare block. CallInst1 and CallInst2 will be the new
240 /// call-sites placed in the new predecessors split for PredBB1 and PredBB2,
241 /// repectively. Therefore, CallInst1 will be the call-site placed
242 /// between Header and Tail, and CallInst2 will be the call-site between TBB and
243 /// Tail. For example, in the IR below with an OR condition, the call-site can
244 /// be split
245 ///
246 /// from :
247 ///
248 ///   Header:
249 ///     %c = icmp eq i32* %a, null
250 ///     br i1 %c %Tail, %TBB
251 ///   TBB:
252 ///     %c2 = icmp eq i32* %b, null
253 ///     br i1 %c %Tail, %End
254 ///   Tail:
255 ///     %ca = call i1  @callee (i32* %a, i32* %b)
256 ///
257 ///  to :
258 ///
259 ///   Header:                          // PredBB1 is Header
260 ///     %c = icmp eq i32* %a, null
261 ///     br i1 %c %Tail-split1, %TBB
262 ///   TBB:                             // PredBB2 is TBB
263 ///     %c2 = icmp eq i32* %b, null
264 ///     br i1 %c %Tail-split2, %End
265 ///   Tail-split1:
266 ///     %ca1 = call @callee (i32* null, i32* %b)         // CallInst1
267 ///    br %Tail
268 ///   Tail-split2:
269 ///     %ca2 = call @callee (i32* nonnull %a, i32* null) // CallInst2
270 ///    br %Tail
271 ///   Tail:
272 ///    %p = phi i1 [%ca1, %Tail-split1],[%ca2, %Tail-split2]
273 ///
274 /// Note that for an OR predicated case, CallInst1 and CallInst2 should be
275 /// created with more constrained arguments in
276 /// createCallSitesOnOrPredicatedArgument().
277 static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2,
278                           Instruction *CallInst1, Instruction *CallInst2) {
279   Instruction *Instr = CS.getInstruction();
280   BasicBlock *TailBB = Instr->getParent();
281   assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site");
282 
283   BasicBlock *SplitBlock1 =
284       SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split");
285   BasicBlock *SplitBlock2 =
286       SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split");
287 
288   assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split.");
289 
290   if (!CallInst1)
291     CallInst1 = Instr->clone();
292   if (!CallInst2)
293     CallInst2 = Instr->clone();
294 
295   CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt());
296   CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt());
297 
298   CallSite CS1(CallInst1);
299   CallSite CS2(CallInst2);
300 
301   // Handle PHIs used as arguments in the call-site.
302   for (auto &PI : *TailBB) {
303     PHINode *PN = dyn_cast<PHINode>(&PI);
304     if (!PN)
305       break;
306     unsigned ArgNo = 0;
307     for (auto &CI : CS.args()) {
308       if (&*CI == PN) {
309         CS1.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock1));
310         CS2.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock2));
311       }
312       ++ArgNo;
313     }
314   }
315 
316   // Replace users of the original call with a PHI mering call-sites split.
317   if (Instr->getNumUses()) {
318     PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call",
319                                   TailBB->getFirstNonPHI());
320     PN->addIncoming(CallInst1, SplitBlock1);
321     PN->addIncoming(CallInst2, SplitBlock2);
322     Instr->replaceAllUsesWith(PN);
323   }
324   DEBUG(dbgs() << "split call-site : " << *Instr << " into \n");
325   DEBUG(dbgs() << "    " << *CallInst1 << " in " << SplitBlock1->getName()
326                << "\n");
327   DEBUG(dbgs() << "    " << *CallInst2 << " in " << SplitBlock2->getName()
328                << "\n");
329   Instr->eraseFromParent();
330   NumCallSiteSplit++;
331 }
332 
333 // Return true if the call-site has an argument which is a PHI with only
334 // constant incoming values.
335 static bool isPredicatedOnPHI(CallSite CS) {
336   Instruction *Instr = CS.getInstruction();
337   BasicBlock *Parent = Instr->getParent();
338   if (Instr != Parent->getFirstNonPHIOrDbg())
339     return false;
340 
341   for (auto &BI : *Parent) {
342     if (PHINode *PN = dyn_cast<PHINode>(&BI)) {
343       for (auto &I : CS.args())
344         if (&*I == PN) {
345           assert(PN->getNumIncomingValues() == 2 &&
346                  "Unexpected number of incoming values");
347           if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1))
348             return false;
349           if (PN->getIncomingValue(0) == PN->getIncomingValue(1))
350             continue;
351           if (isa<Constant>(PN->getIncomingValue(0)) &&
352               isa<Constant>(PN->getIncomingValue(1)))
353             return true;
354         }
355     }
356     break;
357   }
358   return false;
359 }
360 
361 static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {
362   SmallVector<BasicBlock *, 2> Preds(predecessors((BB)));
363   assert(Preds.size() == 2 && "Expected exactly 2 predecessors!");
364   return Preds;
365 }
366 
367 static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) {
368   if (!isPredicatedOnPHI(CS))
369     return false;
370 
371   auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
372   splitCallSite(CS, Preds[0], Preds[1], nullptr, nullptr);
373   return true;
374 }
375 // Check if one of the predecessors is a single predecessors of the other.
376 // This is a requirement for control flow modeling an OR. HeaderBB points to
377 // the single predecessor and OrBB points to other node. HeaderBB potentially
378 // contains the first compare of the OR and OrBB the second.
379 static bool isOrHeader(BasicBlock *HeaderBB, BasicBlock *OrBB) {
380   return OrBB->getSinglePredecessor() == HeaderBB &&
381          HeaderBB->getTerminator()->getNumSuccessors() == 2;
382 }
383 
384 static bool tryToSplitOnOrPredicatedArgument(CallSite CS) {
385   auto Preds = getTwoPredecessors(CS.getInstruction()->getParent());
386   BasicBlock *HeaderBB = nullptr;
387   BasicBlock *OrBB = nullptr;
388   if (isOrHeader(Preds[0], Preds[1])) {
389     HeaderBB = Preds[0];
390     OrBB = Preds[1];
391   } else if (isOrHeader(Preds[1], Preds[0])) {
392     HeaderBB = Preds[1];
393     OrBB = Preds[0];
394   } else
395     return false;
396 
397   Instruction *CallInst1 = nullptr;
398   Instruction *CallInst2 = nullptr;
399   if (!tryCreateCallSitesOnOrPredicatedArgument(CS, CallInst1, CallInst2,
400                                                 HeaderBB)) {
401     assert(!CallInst1 && !CallInst2 && "Unexpected new call-sites cloned.");
402     return false;
403   }
404 
405   splitCallSite(CS, HeaderBB, OrBB, CallInst1, CallInst2);
406   return true;
407 }
408 
409 static bool tryToSplitCallSite(CallSite CS) {
410   if (!CS.arg_size() || !canSplitCallSite(CS))
411     return false;
412   return tryToSplitOnOrPredicatedArgument(CS) ||
413          tryToSplitOnPHIPredicatedArgument(CS);
414 }
415 
416 static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) {
417   bool Changed = false;
418   for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) {
419     BasicBlock &BB = *BI++;
420     for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) {
421       Instruction *I = &*II++;
422       CallSite CS(cast<Value>(I));
423       if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI))
424         continue;
425 
426       Function *Callee = CS.getCalledFunction();
427       if (!Callee || Callee->isDeclaration())
428         continue;
429       Changed |= tryToSplitCallSite(CS);
430     }
431   }
432   return Changed;
433 }
434 
435 namespace {
436 struct CallSiteSplittingLegacyPass : public FunctionPass {
437   static char ID;
438   CallSiteSplittingLegacyPass() : FunctionPass(ID) {
439     initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry());
440   }
441 
442   void getAnalysisUsage(AnalysisUsage &AU) const override {
443     AU.addRequired<TargetLibraryInfoWrapperPass>();
444     FunctionPass::getAnalysisUsage(AU);
445   }
446 
447   bool runOnFunction(Function &F) override {
448     if (skipFunction(F))
449       return false;
450 
451     auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
452     return doCallSiteSplitting(F, TLI);
453   }
454 };
455 } // namespace
456 
457 char CallSiteSplittingLegacyPass::ID = 0;
458 INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting",
459                       "Call-site splitting", false, false)
460 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
461 INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting",
462                     "Call-site splitting", false, false)
463 FunctionPass *llvm::createCallSiteSplittingPass() {
464   return new CallSiteSplittingLegacyPass();
465 }
466 
467 PreservedAnalyses CallSiteSplittingPass::run(Function &F,
468                                              FunctionAnalysisManager &AM) {
469   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
470 
471   if (!doCallSiteSplitting(F, TLI))
472     return PreservedAnalyses::all();
473   PreservedAnalyses PA;
474   return PA;
475 }
476