1 //===--- PartiallyInlineLibCalls.cpp - Partially inline libcalls ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass tries to partially inline the fast path of well-known library
11 // functions, such as using square-root instructions for cases where sqrt()
12 // does not need to set errno.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/TargetLibraryInfo.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/Transforms/Scalar.h"
20 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
21 
22 using namespace llvm;
23 
24 #define DEBUG_TYPE "partially-inline-libcalls"
25 
26 namespace {
27   class PartiallyInlineLibCalls : public FunctionPass {
28   public:
29     static char ID;
30 
31     PartiallyInlineLibCalls() :
32       FunctionPass(ID) {
33       initializePartiallyInlineLibCallsPass(*PassRegistry::getPassRegistry());
34     }
35 
36     void getAnalysisUsage(AnalysisUsage &AU) const override;
37     bool runOnFunction(Function &F) override;
38   };
39 
40   char PartiallyInlineLibCalls::ID = 0;
41 }
42 
43 INITIALIZE_PASS(PartiallyInlineLibCalls, "partially-inline-libcalls",
44                 "Partially inline calls to library functions", false, false)
45 
46 void PartiallyInlineLibCalls::getAnalysisUsage(AnalysisUsage &AU) const {
47   AU.addRequired<TargetLibraryInfoWrapperPass>();
48   AU.addRequired<TargetTransformInfoWrapperPass>();
49   FunctionPass::getAnalysisUsage(AU);
50 }
51 
52 
53 static bool optimizeSQRT(CallInst *Call, Function *CalledFunc,
54                          BasicBlock &CurrBB, Function::iterator &BB) {
55   // There is no need to change the IR, since backend will emit sqrt
56   // instruction if the call has already been marked read-only.
57   if (Call->onlyReadsMemory())
58     return false;
59 
60   // The call must have the expected result type.
61   if (!Call->getType()->isFloatingPointTy())
62     return false;
63 
64   // Do the following transformation:
65   //
66   // (before)
67   // dst = sqrt(src)
68   //
69   // (after)
70   // v0 = sqrt_noreadmem(src) # native sqrt instruction.
71   // if (v0 is a NaN)
72   //   v1 = sqrt(src)         # library call.
73   // dst = phi(v0, v1)
74   //
75 
76   // Move all instructions following Call to newly created block JoinBB.
77   // Create phi and replace all uses.
78   BasicBlock *JoinBB = llvm::SplitBlock(&CurrBB, Call->getNextNode());
79   IRBuilder<> Builder(JoinBB, JoinBB->begin());
80   PHINode *Phi = Builder.CreatePHI(Call->getType(), 2);
81   Call->replaceAllUsesWith(Phi);
82 
83   // Create basic block LibCallBB and insert a call to library function sqrt.
84   BasicBlock *LibCallBB = BasicBlock::Create(CurrBB.getContext(), "call.sqrt",
85                                              CurrBB.getParent(), JoinBB);
86   Builder.SetInsertPoint(LibCallBB);
87   Instruction *LibCall = Call->clone();
88   Builder.Insert(LibCall);
89   Builder.CreateBr(JoinBB);
90 
91   // Add attribute "readnone" so that backend can use a native sqrt instruction
92   // for this call. Insert a FP compare instruction and a conditional branch
93   // at the end of CurrBB.
94   Call->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
95   CurrBB.getTerminator()->eraseFromParent();
96   Builder.SetInsertPoint(&CurrBB);
97   Value *FCmp = Builder.CreateFCmpOEQ(Call, Call);
98   Builder.CreateCondBr(FCmp, JoinBB, LibCallBB);
99 
100   // Add phi operands.
101   Phi->addIncoming(Call, &CurrBB);
102   Phi->addIncoming(LibCall, LibCallBB);
103 
104   BB = JoinBB->getIterator();
105   return true;
106 }
107 
108 bool PartiallyInlineLibCalls::runOnFunction(Function &F) {
109   if (skipFunction(F))
110     return false;
111 
112   bool Changed = false;
113   Function::iterator CurrBB;
114   TargetLibraryInfo *TLI =
115       &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
116   const TargetTransformInfo *TTI =
117       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
118   for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) {
119     CurrBB = BB++;
120 
121     for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end();
122          II != IE; ++II) {
123       CallInst *Call = dyn_cast<CallInst>(&*II);
124       Function *CalledFunc;
125 
126       if (!Call || !(CalledFunc = Call->getCalledFunction()))
127         continue;
128 
129       // Skip if function either has local linkage or is not a known library
130       // function.
131       LibFunc::Func LibFunc;
132       if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() ||
133           !TLI->getLibFunc(CalledFunc->getName(), LibFunc))
134         continue;
135 
136       switch (LibFunc) {
137       case LibFunc::sqrtf:
138       case LibFunc::sqrt:
139         if (TTI->haveFastSqrt(Call->getType()) &&
140             optimizeSQRT(Call, CalledFunc, *CurrBB, BB))
141           break;
142         continue;
143       default:
144         continue;
145       }
146 
147       Changed = true;
148       break;
149     }
150   }
151 
152   return Changed;
153 }
154 
155 FunctionPass *llvm::createPartiallyInlineLibCallsPass() {
156   return new PartiallyInlineLibCalls();
157 }
158