1 //===- AMDGPULibCalls.cpp -------------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// \brief This file does AMD library function optimizations.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #define DEBUG_TYPE "amdgpu-simplifylib"
16 
17 #include "AMDGPU.h"
18 #include "AMDGPULibFunc.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/Loads.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Instructions.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/LLVMContext.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/ValueSymbolTable.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <vector>
34 #include <cmath>
35 
36 using namespace llvm;
37 
38 static cl::opt<bool> EnablePreLink("amdgpu-prelink",
39   cl::desc("Enable pre-link mode optimizations"),
40   cl::init(false),
41   cl::Hidden);
42 
43 static cl::list<std::string> UseNative("amdgpu-use-native",
44   cl::desc("Comma separated list of functions to replace with native, or all"),
45   cl::CommaSeparated, cl::ValueOptional,
46   cl::Hidden);
47 
48 #define MATH_PI     3.14159265358979323846264338327950288419716939937511
49 #define MATH_E      2.71828182845904523536028747135266249775724709369996
50 #define MATH_SQRT2  1.41421356237309504880168872420969807856967187537695
51 
52 #define MATH_LOG2E     1.4426950408889634073599246810018921374266459541529859
53 #define MATH_LOG10E    0.4342944819032518276511289189166050822943970058036665
54 // Value of log2(10)
55 #define MATH_LOG2_10   3.3219280948873623478703194294893901758648313930245806
56 // Value of 1 / log2(10)
57 #define MATH_RLOG2_10  0.3010299956639811952137388947244930267681898814621085
58 // Value of 1 / M_LOG2E_F = 1 / log2(e)
59 #define MATH_RLOG2_E   0.6931471805599453094172321214581765680755001343602552
60 
61 namespace llvm {
62 
63 class AMDGPULibCalls {
64 private:
65 
66   typedef llvm::AMDGPULibFunc FuncInfo;
67 
68   // -fuse-native.
69   bool AllNative = false;
70 
71   bool useNativeFunc(const StringRef F) const;
72 
73   // Return a pointer (pointer expr) to the function if function defintion with
74   // "FuncName" exists. It may create a new function prototype in pre-link mode.
75   Constant *getFunction(Module *M, const FuncInfo& fInfo);
76 
77   // Replace a normal function with its native version.
78   bool replaceWithNative(CallInst *CI, const FuncInfo &FInfo);
79 
80   bool parseFunctionName(const StringRef& FMangledName,
81                          FuncInfo *FInfo=nullptr /*out*/);
82 
83   bool TDOFold(CallInst *CI, const FuncInfo &FInfo);
84 
85   /* Specialized optimizations */
86 
87   // recip (half or native)
88   bool fold_recip(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
89 
90   // divide (half or native)
91   bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
92 
93   // pow/powr/pown
94   bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
95 
96   // rootn
97   bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
98 
99   // fma/mad
100   bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
101 
102   // -fuse-native for sincos
103   bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo);
104 
105   // evaluate calls if calls' arguments are constants.
106   bool evaluateScalarMathFunc(FuncInfo &FInfo, double& Res0,
107     double& Res1, Constant *copr0, Constant *copr1, Constant *copr2);
108   bool evaluateCall(CallInst *aCI, FuncInfo &FInfo);
109 
110   // exp
111   bool fold_exp(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
112 
113   // exp2
114   bool fold_exp2(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
115 
116   // exp10
117   bool fold_exp10(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
118 
119   // log
120   bool fold_log(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
121 
122   // log2
123   bool fold_log2(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
124 
125   // log10
126   bool fold_log10(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
127 
128   // sqrt
129   bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
130 
131   // sin/cos
132   bool fold_sincos(CallInst * CI, IRBuilder<> &B, AliasAnalysis * AA);
133 
134   // Get insertion point at entry.
135   BasicBlock::iterator getEntryIns(CallInst * UI);
136   // Insert an Alloc instruction.
137   AllocaInst* insertAlloca(CallInst * UI, IRBuilder<> &B, const char *prefix);
138   // Get a scalar native builtin signle argument FP function
139   Constant* getNativeFunction(Module* M, const FuncInfo &FInfo);
140 
141 protected:
142   CallInst *CI;
143 
144   bool isUnsafeMath(const CallInst *CI) const;
145 
146   void replaceCall(Value *With) {
147     CI->replaceAllUsesWith(With);
148     CI->eraseFromParent();
149   }
150 
151 public:
152   bool fold(CallInst *CI, AliasAnalysis *AA = nullptr);
153 
154   void initNativeFuncs();
155 
156   // Replace a normal math function call with that native version
157   bool useNative(CallInst *CI);
158 };
159 
160 } // end llvm namespace
161 
162 namespace {
163 
164   class AMDGPUSimplifyLibCalls : public FunctionPass {
165 
166   AMDGPULibCalls Simplifier;
167 
168   public:
169     static char ID; // Pass identification
170 
171     AMDGPUSimplifyLibCalls() : FunctionPass(ID) {
172       initializeAMDGPUSimplifyLibCallsPass(*PassRegistry::getPassRegistry());
173     }
174 
175     void getAnalysisUsage(AnalysisUsage &AU) const override {
176       AU.addRequired<AAResultsWrapperPass>();
177     }
178 
179     bool runOnFunction(Function &M) override;
180   };
181 
182   class AMDGPUUseNativeCalls : public FunctionPass {
183 
184   AMDGPULibCalls Simplifier;
185 
186   public:
187     static char ID; // Pass identification
188 
189     AMDGPUUseNativeCalls() : FunctionPass(ID) {
190       initializeAMDGPUUseNativeCallsPass(*PassRegistry::getPassRegistry());
191       Simplifier.initNativeFuncs();
192     }
193 
194     bool runOnFunction(Function &F) override;
195   };
196 
197 } // end anonymous namespace.
198 
199 char AMDGPUSimplifyLibCalls::ID = 0;
200 char AMDGPUUseNativeCalls::ID = 0;
201 
202 INITIALIZE_PASS_BEGIN(AMDGPUSimplifyLibCalls, "amdgpu-simplifylib",
203                       "Simplify well-known AMD library calls", false, false)
204 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
205 INITIALIZE_PASS_END(AMDGPUSimplifyLibCalls, "amdgpu-simplifylib",
206                     "Simplify well-known AMD library calls", false, false)
207 
208 INITIALIZE_PASS(AMDGPUUseNativeCalls, "amdgpu-usenative",
209                 "Replace builtin math calls with that native versions.",
210                 false, false)
211 
212 template <typename IRB>
213 CallInst *CreateCallEx(IRB &B, Value *Callee, Value *Arg, const Twine &Name="")
214 {
215   CallInst *R = B.CreateCall(Callee, Arg, Name);
216   if (Function* F = dyn_cast<Function>(Callee))
217     R->setCallingConv(F->getCallingConv());
218   return R;
219 }
220 
221 template <typename IRB>
222 CallInst *CreateCallEx2(IRB &B, Value *Callee, Value *Arg1, Value *Arg2,
223                         const Twine &Name="") {
224   CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name);
225   if (Function* F = dyn_cast<Function>(Callee))
226     R->setCallingConv(F->getCallingConv());
227   return R;
228 }
229 
230 //  Data structures for table-driven optimizations.
231 //  FuncTbl works for both f32 and f64 functions with 1 input argument
232 
233 struct TableEntry {
234   double   result;
235   double   input;
236 };
237 
238 /* a list of {result, input} */
239 static const TableEntry tbl_acos[] = {
240   {MATH_PI/2.0, 0.0},
241   {MATH_PI/2.0, -0.0},
242   {0.0, 1.0},
243   {MATH_PI, -1.0}
244 };
245 static const TableEntry tbl_acosh[] = {
246   {0.0, 1.0}
247 };
248 static const TableEntry tbl_acospi[] = {
249   {0.5, 0.0},
250   {0.5, -0.0},
251   {0.0, 1.0},
252   {1.0, -1.0}
253 };
254 static const TableEntry tbl_asin[] = {
255   {0.0, 0.0},
256   {-0.0, -0.0},
257   {MATH_PI/2.0, 1.0},
258   {-MATH_PI/2.0, -1.0}
259 };
260 static const TableEntry tbl_asinh[] = {
261   {0.0, 0.0},
262   {-0.0, -0.0}
263 };
264 static const TableEntry tbl_asinpi[] = {
265   {0.0, 0.0},
266   {-0.0, -0.0},
267   {0.5, 1.0},
268   {-0.5, -1.0}
269 };
270 static const TableEntry tbl_atan[] = {
271   {0.0, 0.0},
272   {-0.0, -0.0},
273   {MATH_PI/4.0, 1.0},
274   {-MATH_PI/4.0, -1.0}
275 };
276 static const TableEntry tbl_atanh[] = {
277   {0.0, 0.0},
278   {-0.0, -0.0}
279 };
280 static const TableEntry tbl_atanpi[] = {
281   {0.0, 0.0},
282   {-0.0, -0.0},
283   {0.25, 1.0},
284   {-0.25, -1.0}
285 };
286 static const TableEntry tbl_cbrt[] = {
287   {0.0, 0.0},
288   {-0.0, -0.0},
289   {1.0, 1.0},
290   {-1.0, -1.0},
291 };
292 static const TableEntry tbl_cos[] = {
293   {1.0, 0.0},
294   {1.0, -0.0}
295 };
296 static const TableEntry tbl_cosh[] = {
297   {1.0, 0.0},
298   {1.0, -0.0}
299 };
300 static const TableEntry tbl_cospi[] = {
301   {1.0, 0.0},
302   {1.0, -0.0}
303 };
304 static const TableEntry tbl_erfc[] = {
305   {1.0, 0.0},
306   {1.0, -0.0}
307 };
308 static const TableEntry tbl_erf[] = {
309   {0.0, 0.0},
310   {-0.0, -0.0}
311 };
312 static const TableEntry tbl_exp[] = {
313   {1.0, 0.0},
314   {1.0, -0.0},
315   {MATH_E, 1.0}
316 };
317 static const TableEntry tbl_exp2[] = {
318   {1.0, 0.0},
319   {1.0, -0.0},
320   {2.0, 1.0}
321 };
322 static const TableEntry tbl_exp10[] = {
323   {1.0, 0.0},
324   {1.0, -0.0},
325   {10.0, 1.0}
326 };
327 static const TableEntry tbl_expm1[] = {
328   {0.0, 0.0},
329   {-0.0, -0.0}
330 };
331 static const TableEntry tbl_log[] = {
332   {0.0, 1.0},
333   {1.0, MATH_E}
334 };
335 static const TableEntry tbl_log2[] = {
336   {0.0, 1.0},
337   {1.0, 2.0}
338 };
339 static const TableEntry tbl_log10[] = {
340   {0.0, 1.0},
341   {1.0, 10.0}
342 };
343 static const TableEntry tbl_rsqrt[] = {
344   {1.0, 1.0},
345   {1.0/MATH_SQRT2, 2.0}
346 };
347 static const TableEntry tbl_sin[] = {
348   {0.0, 0.0},
349   {-0.0, -0.0}
350 };
351 static const TableEntry tbl_sinh[] = {
352   {0.0, 0.0},
353   {-0.0, -0.0}
354 };
355 static const TableEntry tbl_sinpi[] = {
356   {0.0, 0.0},
357   {-0.0, -0.0}
358 };
359 static const TableEntry tbl_sqrt[] = {
360   {0.0, 0.0},
361   {1.0, 1.0},
362   {MATH_SQRT2, 2.0}
363 };
364 static const TableEntry tbl_tan[] = {
365   {0.0, 0.0},
366   {-0.0, -0.0}
367 };
368 static const TableEntry tbl_tanh[] = {
369   {0.0, 0.0},
370   {-0.0, -0.0}
371 };
372 static const TableEntry tbl_tanpi[] = {
373   {0.0, 0.0},
374   {-0.0, -0.0}
375 };
376 static const TableEntry tbl_tgamma[] = {
377   {1.0, 1.0},
378   {1.0, 2.0},
379   {2.0, 3.0},
380   {6.0, 4.0}
381 };
382 
383 static bool HasNative(AMDGPULibFunc::EFuncId id) {
384   switch(id) {
385   case AMDGPULibFunc::EI_DIVIDE:
386   case AMDGPULibFunc::EI_COS:
387   case AMDGPULibFunc::EI_EXP:
388   case AMDGPULibFunc::EI_EXP2:
389   case AMDGPULibFunc::EI_EXP10:
390   case AMDGPULibFunc::EI_LOG:
391   case AMDGPULibFunc::EI_LOG2:
392   case AMDGPULibFunc::EI_LOG10:
393   case AMDGPULibFunc::EI_POWR:
394   case AMDGPULibFunc::EI_RECIP:
395   case AMDGPULibFunc::EI_RSQRT:
396   case AMDGPULibFunc::EI_SIN:
397   case AMDGPULibFunc::EI_SINCOS:
398   case AMDGPULibFunc::EI_SQRT:
399   case AMDGPULibFunc::EI_TAN:
400     return true;
401   default:;
402   }
403   return false;
404 }
405 
406 struct TableRef {
407   size_t size;
408   const TableEntry *table; // variable size: from 0 to (size - 1)
409 
410   TableRef() : size(0), table(nullptr) {}
411 
412   template <size_t N>
413   TableRef(const TableEntry (&tbl)[N]) : size(N), table(&tbl[0]) {}
414 };
415 
416 static TableRef getOptTable(AMDGPULibFunc::EFuncId id) {
417   switch(id) {
418   case AMDGPULibFunc::EI_ACOS:    return TableRef(tbl_acos);
419   case AMDGPULibFunc::EI_ACOSH:   return TableRef(tbl_acosh);
420   case AMDGPULibFunc::EI_ACOSPI:  return TableRef(tbl_acospi);
421   case AMDGPULibFunc::EI_ASIN:    return TableRef(tbl_asin);
422   case AMDGPULibFunc::EI_ASINH:   return TableRef(tbl_asinh);
423   case AMDGPULibFunc::EI_ASINPI:  return TableRef(tbl_asinpi);
424   case AMDGPULibFunc::EI_ATAN:    return TableRef(tbl_atan);
425   case AMDGPULibFunc::EI_ATANH:   return TableRef(tbl_atanh);
426   case AMDGPULibFunc::EI_ATANPI:  return TableRef(tbl_atanpi);
427   case AMDGPULibFunc::EI_CBRT:    return TableRef(tbl_cbrt);
428   case AMDGPULibFunc::EI_NCOS:
429   case AMDGPULibFunc::EI_COS:     return TableRef(tbl_cos);
430   case AMDGPULibFunc::EI_COSH:    return TableRef(tbl_cosh);
431   case AMDGPULibFunc::EI_COSPI:   return TableRef(tbl_cospi);
432   case AMDGPULibFunc::EI_ERFC:    return TableRef(tbl_erfc);
433   case AMDGPULibFunc::EI_ERF:     return TableRef(tbl_erf);
434   case AMDGPULibFunc::EI_EXP:     return TableRef(tbl_exp);
435   case AMDGPULibFunc::EI_NEXP2:
436   case AMDGPULibFunc::EI_EXP2:    return TableRef(tbl_exp2);
437   case AMDGPULibFunc::EI_EXP10:   return TableRef(tbl_exp10);
438   case AMDGPULibFunc::EI_EXPM1:   return TableRef(tbl_expm1);
439   case AMDGPULibFunc::EI_LOG:     return TableRef(tbl_log);
440   case AMDGPULibFunc::EI_NLOG2:
441   case AMDGPULibFunc::EI_LOG2:    return TableRef(tbl_log2);
442   case AMDGPULibFunc::EI_LOG10:   return TableRef(tbl_log10);
443   case AMDGPULibFunc::EI_NRSQRT:
444   case AMDGPULibFunc::EI_RSQRT:   return TableRef(tbl_rsqrt);
445   case AMDGPULibFunc::EI_NSIN:
446   case AMDGPULibFunc::EI_SIN:     return TableRef(tbl_sin);
447   case AMDGPULibFunc::EI_SINH:    return TableRef(tbl_sinh);
448   case AMDGPULibFunc::EI_SINPI:   return TableRef(tbl_sinpi);
449   case AMDGPULibFunc::EI_NSQRT:
450   case AMDGPULibFunc::EI_SQRT:    return TableRef(tbl_sqrt);
451   case AMDGPULibFunc::EI_TAN:     return TableRef(tbl_tan);
452   case AMDGPULibFunc::EI_TANH:    return TableRef(tbl_tanh);
453   case AMDGPULibFunc::EI_TANPI:   return TableRef(tbl_tanpi);
454   case AMDGPULibFunc::EI_TGAMMA:  return TableRef(tbl_tgamma);
455   default:;
456   }
457   return TableRef();
458 }
459 
460 static inline int getVecSize(const AMDGPULibFunc& FInfo) {
461   return FInfo.Leads[0].VectorSize;
462 }
463 
464 static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) {
465   return (AMDGPULibFunc::EType)FInfo.Leads[0].ArgType;
466 }
467 
468 Constant *AMDGPULibCalls::getFunction(Module *M, const FuncInfo& fInfo) {
469   // If we are doing PreLinkOpt, the function is external. So it is safe to
470   // use getOrInsertFunction() at this stage.
471 
472   return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo)
473                        : AMDGPULibFunc::getFunction(M, fInfo);
474 }
475 
476 bool AMDGPULibCalls::parseFunctionName(const StringRef& FMangledName,
477                                     FuncInfo *FInfo) {
478   return AMDGPULibFunc::parse(FMangledName, *FInfo);
479 }
480 
481 bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const {
482   if (auto Op = dyn_cast<FPMathOperator>(CI))
483     if (Op->hasUnsafeAlgebra())
484       return true;
485   const Function *F = CI->getParent()->getParent();
486   Attribute Attr = F->getFnAttribute("unsafe-fp-math");
487   return Attr.getValueAsString() == "true";
488 }
489 
490 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
491   return AllNative ||
492          std::find(UseNative.begin(), UseNative.end(), F) != UseNative.end();
493 }
494 
495 void AMDGPULibCalls::initNativeFuncs() {
496   AllNative = useNativeFunc("all") ||
497               (UseNative.getNumOccurrences() && UseNative.size() == 1 &&
498                UseNative.begin()->empty());
499 }
500 
501 bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
502   bool native_sin = useNativeFunc("sin");
503   bool native_cos = useNativeFunc("cos");
504 
505   if (native_sin && native_cos) {
506     Module *M = aCI->getModule();
507     Value *opr0 = aCI->getArgOperand(0);
508 
509     AMDGPULibFunc nf;
510     nf.Leads[0].ArgType = FInfo.Leads[0].ArgType;
511     nf.Leads[0].VectorSize = FInfo.Leads[0].VectorSize;
512 
513     nf.setPrefix(AMDGPULibFunc::NATIVE);
514     nf.setId(AMDGPULibFunc::EI_SIN);
515     Constant *sinExpr = getFunction(M, nf);
516 
517     nf.setPrefix(AMDGPULibFunc::NATIVE);
518     nf.setId(AMDGPULibFunc::EI_COS);
519     Constant *cosExpr = getFunction(M, nf);
520     if (sinExpr && cosExpr) {
521       Value *sinval = CallInst::Create(sinExpr, opr0, "splitsin", aCI);
522       Value *cosval = CallInst::Create(cosExpr, opr0, "splitcos", aCI);
523       new StoreInst(cosval, aCI->getArgOperand(1), aCI);
524 
525       DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
526                                           << " with native version of sin/cos");
527 
528       replaceCall(sinval);
529       return true;
530     }
531   }
532   return false;
533 }
534 
535 bool AMDGPULibCalls::useNative(CallInst *aCI) {
536   CI = aCI;
537   Function *Callee = aCI->getCalledFunction();
538 
539   FuncInfo FInfo;
540   if (!parseFunctionName(Callee->getName(), &FInfo) ||
541       FInfo.getPrefix() != AMDGPULibFunc::NOPFX ||
542       getArgType(FInfo) == AMDGPULibFunc::F64 ||
543       !HasNative(FInfo.getId()) ||
544       !(AllNative || useNativeFunc(FInfo.getName())) ) {
545     return false;
546   }
547 
548   if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS)
549     return sincosUseNative(aCI, FInfo);
550 
551   FInfo.setPrefix(AMDGPULibFunc::NATIVE);
552   Constant *F = getFunction(aCI->getModule(), FInfo);
553   if (!F)
554     return false;
555 
556   aCI->setCalledFunction(F);
557   DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
558                                       << " with native version");
559   return true;
560 }
561 
562 // This function returns false if no change; return true otherwise.
563 bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
564   this->CI = CI;
565   Function *Callee = CI->getCalledFunction();
566 
567   // Ignore indirect calls.
568   if (Callee == 0) return false;
569 
570   FuncInfo FInfo;
571   if (!parseFunctionName(Callee->getName(), &FInfo))
572     return false;
573 
574   // Further check the number of arguments to see if they match.
575   if (CI->getNumArgOperands() != FInfo.getNumArgs())
576     return false;
577 
578   BasicBlock *BB = CI->getParent();
579   LLVMContext &Context = CI->getParent()->getContext();
580   IRBuilder<> B(Context);
581 
582   // Set the builder to the instruction after the call.
583   B.SetInsertPoint(BB, CI->getIterator());
584 
585   // Copy fast flags from the original call.
586   if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
587     B.setFastMathFlags(FPOp->getFastMathFlags());
588 
589   if (TDOFold(CI, FInfo))
590     return true;
591 
592   // Under unsafe-math, evaluate calls if possible.
593   // According to Brian Sumner, we can do this for all f32 function calls
594   // using host's double function calls.
595   if (isUnsafeMath(CI) && evaluateCall(CI, FInfo))
596     return true;
597 
598   // Specilized optimizations for each function call
599   switch (FInfo.getId()) {
600   case AMDGPULibFunc::EI_RECIP:
601     // skip vector function
602     assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
603              FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
604             "recip must be an either native or half function");
605     return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
606 
607   case AMDGPULibFunc::EI_DIVIDE:
608     // skip vector function
609     assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
610              FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
611             "divide must be an either native or half function");
612     return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
613 
614   case AMDGPULibFunc::EI_POW:
615   case AMDGPULibFunc::EI_POWR:
616   case AMDGPULibFunc::EI_POWN:
617     return fold_pow(CI, B, FInfo);
618 
619   case AMDGPULibFunc::EI_ROOTN:
620     // skip vector function
621     return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo);
622 
623   case AMDGPULibFunc::EI_FMA:
624   case AMDGPULibFunc::EI_MAD:
625   case AMDGPULibFunc::EI_NFMA:
626     // skip vector function
627     return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
628 
629   case AMDGPULibFunc::EI_SQRT:
630     return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo);
631   case AMDGPULibFunc::EI_COS:
632   case AMDGPULibFunc::EI_SIN:
633     if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
634          getArgType(FInfo) == AMDGPULibFunc::F64)
635         && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX))
636       return fold_sincos(CI, B, AA);
637 
638     break;
639 
640   default:
641     break;
642   }
643 
644   return false;
645 }
646 
647 bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
648   // Table-Driven optimization
649   const TableRef tr = getOptTable(FInfo.getId());
650   if (tr.size==0)
651     return false;
652 
653   int const sz = (int)tr.size;
654   const TableEntry * const ftbl = tr.table;
655   Value *opr0 = CI->getArgOperand(0);
656 
657   if (getVecSize(FInfo) > 1) {
658     if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(opr0)) {
659       SmallVector<double, 0> DVal;
660       for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) {
661         ConstantFP *eltval = dyn_cast<ConstantFP>(
662                                CV->getElementAsConstant((unsigned)eltNo));
663         assert(eltval && "Non-FP arguments in math function!");
664         bool found = false;
665         for (int i=0; i < sz; ++i) {
666           if (eltval->isExactlyValue(ftbl[i].input)) {
667             DVal.push_back(ftbl[i].result);
668             found = true;
669             break;
670           }
671         }
672         if (!found) {
673           // This vector constants not handled yet.
674           return false;
675         }
676       }
677       LLVMContext &context = CI->getParent()->getParent()->getContext();
678       Constant *nval;
679       if (getArgType(FInfo) == AMDGPULibFunc::F32) {
680         SmallVector<float, 0> FVal;
681         for (unsigned i = 0; i < DVal.size(); ++i) {
682           FVal.push_back((float)DVal[i]);
683         }
684         ArrayRef<float> tmp(FVal);
685         nval = ConstantDataVector::get(context, tmp);
686       } else { // F64
687         ArrayRef<double> tmp(DVal);
688         nval = ConstantDataVector::get(context, tmp);
689       }
690       DEBUG(errs() << "AMDIC: " << *CI
691                    << " ---> " << *nval << "\n");
692       replaceCall(nval);
693       return true;
694     }
695   } else {
696     // Scalar version
697     if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) {
698       for (int i = 0; i < sz; ++i) {
699         if (CF->isExactlyValue(ftbl[i].input)) {
700           Value *nval = ConstantFP::get(CF->getType(), ftbl[i].result);
701           DEBUG(errs() << "AMDIC: " << *CI
702                        << " ---> " << *nval << "\n");
703           replaceCall(nval);
704           return true;
705         }
706       }
707     }
708   }
709 
710   return false;
711 }
712 
713 bool AMDGPULibCalls::replaceWithNative(CallInst *CI, const FuncInfo &FInfo) {
714   Module *M = CI->getModule();
715   if (getArgType(FInfo) != AMDGPULibFunc::F32 ||
716       FInfo.getPrefix() != AMDGPULibFunc::NOPFX ||
717       !HasNative(FInfo.getId()))
718     return false;
719 
720   AMDGPULibFunc nf = FInfo;
721   nf.setPrefix(AMDGPULibFunc::NATIVE);
722   if (Constant *FPExpr = getFunction(M, nf)) {
723     DEBUG(dbgs() << "AMDIC: " << *CI << " ---> ");
724 
725     CI->setCalledFunction(FPExpr);
726 
727     DEBUG(dbgs() << *CI << '\n');
728 
729     return true;
730   }
731   return false;
732 }
733 
734 //  [native_]half_recip(c) ==> 1.0/c
735 bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B,
736                                 const FuncInfo &FInfo) {
737   Value *opr0 = CI->getArgOperand(0);
738   if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) {
739     // Just create a normal div. Later, InstCombine will be able
740     // to compute the divide into a constant (avoid check float infinity
741     // or subnormal at this point).
742     Value *nval = B.CreateFDiv(ConstantFP::get(CF->getType(), 1.0),
743                                opr0,
744                                "recip2div");
745     DEBUG(errs() << "AMDIC: " << *CI
746                  << " ---> " << *nval << "\n");
747     replaceCall(nval);
748     return true;
749   }
750   return false;
751 }
752 
753 //  [native_]half_divide(x, c) ==> x/c
754 bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B,
755                                  const FuncInfo &FInfo) {
756   Value *opr0 = CI->getArgOperand(0);
757   Value *opr1 = CI->getArgOperand(1);
758   ConstantFP *CF0 = dyn_cast<ConstantFP>(opr0);
759   ConstantFP *CF1 = dyn_cast<ConstantFP>(opr1);
760 
761   if ((CF0 && CF1) ||  // both are constants
762       (CF1 && (getArgType(FInfo) == AMDGPULibFunc::F32)))
763       // CF1 is constant && f32 divide
764   {
765     Value *nval1 = B.CreateFDiv(ConstantFP::get(opr1->getType(), 1.0),
766                                 opr1, "__div2recip");
767     Value *nval  = B.CreateFMul(opr0, nval1, "__div2mul");
768     replaceCall(nval);
769     return true;
770   }
771   return false;
772 }
773 
774 namespace llvm {
775 static double log2(double V) {
776 #if _XOPEN_SOURCE >= 600 || _ISOC99_SOURCE || _POSIX_C_SOURCE >= 200112L
777   return ::log2(V);
778 #else
779   return log(V) / 0.693147180559945309417;
780 #endif
781 }
782 }
783 
784 bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
785                               const FuncInfo &FInfo) {
786   assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
787           FInfo.getId() == AMDGPULibFunc::EI_POWR ||
788           FInfo.getId() == AMDGPULibFunc::EI_POWN) &&
789          "fold_pow: encounter a wrong function call");
790 
791   Value *opr0, *opr1;
792   ConstantFP *CF;
793   ConstantInt *CINT;
794   ConstantAggregateZero *CZero;
795   Type *eltType;
796 
797   opr0 = CI->getArgOperand(0);
798   opr1 = CI->getArgOperand(1);
799   CZero = dyn_cast<ConstantAggregateZero>(opr1);
800   if (getVecSize(FInfo) == 1) {
801     eltType = opr0->getType();
802     CF = dyn_cast<ConstantFP>(opr1);
803     CINT = dyn_cast<ConstantInt>(opr1);
804   } else {
805     VectorType *VTy = dyn_cast<VectorType>(opr0->getType());
806     assert(VTy && "Oprand of vector function should be of vectortype");
807     eltType = VTy->getElementType();
808     ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1);
809 
810     // Now, only Handle vector const whose elements have the same value.
811     CF = CDV ? dyn_cast_or_null<ConstantFP>(CDV->getSplatValue()) : nullptr;
812     CINT = CDV ? dyn_cast_or_null<ConstantInt>(CDV->getSplatValue()) : nullptr;
813   }
814 
815   // No unsafe math , no constant argument, do nothing
816   if (!isUnsafeMath(CI) && !CF && !CINT && !CZero)
817     return false;
818 
819   // 0x1111111 means that we don't do anything for this call.
820   int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
821 
822   if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
823     //  pow/powr/pown(x, 0) == 1
824     DEBUG(errs() << "AMDIC: " << *CI << " ---> 1\n");
825     Constant *cnval = ConstantFP::get(eltType, 1.0);
826     if (getVecSize(FInfo) > 1) {
827       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
828     }
829     replaceCall(cnval);
830     return true;
831   }
832   if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
833     // pow/powr/pown(x, 1.0) = x
834     DEBUG(errs() << "AMDIC: " << *CI
835                  << " ---> " << *opr0 << "\n");
836     replaceCall(opr0);
837     return true;
838   }
839   if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
840     // pow/powr/pown(x, 2.0) = x*x
841     DEBUG(errs() << "AMDIC: " << *CI
842                  << " ---> " << *opr0 << " * " << *opr0 << "\n");
843     Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
844     replaceCall(nval);
845     return true;
846   }
847   if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
848     // pow/powr/pown(x, -1.0) = 1.0/x
849     DEBUG(errs() << "AMDIC: " << *CI
850                  << " ---> 1 / " << *opr0 << "\n");
851     Constant *cnval = ConstantFP::get(eltType, 1.0);
852     if (getVecSize(FInfo) > 1) {
853       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
854     }
855     Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
856     replaceCall(nval);
857     return true;
858   }
859 
860   Module *M = CI->getModule();
861   if (CF && (CF->isExactlyValue(0.5) || CF->isExactlyValue(-0.5))) {
862     // pow[r](x, [-]0.5) = sqrt(x)
863     bool issqrt = CF->isExactlyValue(0.5);
864     if (Constant *FPExpr = getFunction(M,
865         AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
866                              : AMDGPULibFunc::EI_RSQRT, FInfo))) {
867       DEBUG(errs() << "AMDIC: " << *CI << " ---> "
868                    << FInfo.getName().c_str() << "(" << *opr0 << ")\n");
869       Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
870                                                         : "__pow2rsqrt");
871       replaceCall(nval);
872       return true;
873     }
874   }
875 
876   if (!isUnsafeMath(CI))
877     return false;
878 
879   // Unsafe Math optimization
880 
881   // Remember that ci_opr1 is set if opr1 is integral
882   if (CF) {
883     double dval = (getArgType(FInfo) == AMDGPULibFunc::F32)
884                     ? (double)CF->getValueAPF().convertToFloat()
885                     : CF->getValueAPF().convertToDouble();
886     int ival = (int)dval;
887     if ((double)ival == dval) {
888       ci_opr1 = ival;
889     } else
890       ci_opr1 = 0x11111111;
891   }
892 
893   // pow/powr/pown(x, c) = [1/](x*x*..x); where
894   //   trunc(c) == c && the number of x == c && |c| <= 12
895   unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1;
896   if (abs_opr1 <= 12) {
897     Constant *cnval;
898     Value *nval;
899     if (abs_opr1 == 0) {
900       cnval = ConstantFP::get(eltType, 1.0);
901       if (getVecSize(FInfo) > 1) {
902         cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
903       }
904       nval = cnval;
905     } else {
906       Value *valx2 = nullptr;
907       nval = nullptr;
908       while (abs_opr1 > 0) {
909         valx2 = valx2 ? B.CreateFMul(valx2, valx2, "__powx2") : opr0;
910         if (abs_opr1 & 1) {
911           nval = nval ? B.CreateFMul(nval, valx2, "__powprod") : valx2;
912         }
913         abs_opr1 >>= 1;
914       }
915     }
916 
917     if (ci_opr1 < 0) {
918       cnval = ConstantFP::get(eltType, 1.0);
919       if (getVecSize(FInfo) > 1) {
920         cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
921       }
922       nval = B.CreateFDiv(cnval, nval, "__1powprod");
923     }
924     DEBUG(errs() << "AMDIC: " << *CI << " ---> "
925                  <<  ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0 << ")\n");
926     replaceCall(nval);
927     return true;
928   }
929 
930   // powr ---> exp2(y * log2(x))
931   // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31))
932   Constant *ExpExpr = getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_EXP2,
933                                                    FInfo));
934   if (!ExpExpr)
935     return false;
936 
937   bool needlog = false;
938   bool needabs = false;
939   bool needcopysign = false;
940   Constant *cnval = nullptr;
941   if (getVecSize(FInfo) == 1) {
942     CF = dyn_cast<ConstantFP>(opr0);
943 
944     if (CF) {
945       double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
946                    ? (double)CF->getValueAPF().convertToFloat()
947                    : CF->getValueAPF().convertToDouble();
948 
949       V = log2(std::abs(V));
950       cnval = ConstantFP::get(eltType, V);
951       needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) &&
952                      CF->isNegative();
953     } else {
954       needlog = true;
955       needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR &&
956                                (!CF || CF->isNegative());
957     }
958   } else {
959     ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr0);
960 
961     if (!CDV) {
962       needlog = true;
963       needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR;
964     } else {
965       assert ((int)CDV->getNumElements() == getVecSize(FInfo) &&
966               "Wrong vector size detected");
967 
968       SmallVector<double, 0> DVal;
969       for (int i=0; i < getVecSize(FInfo); ++i) {
970         double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
971                      ? (double)CDV->getElementAsFloat(i)
972                      : CDV->getElementAsDouble(i);
973         if (V < 0.0) needcopysign = true;
974         V = log2(std::abs(V));
975         DVal.push_back(V);
976       }
977       if (getArgType(FInfo) == AMDGPULibFunc::F32) {
978         SmallVector<float, 0> FVal;
979         for (unsigned i=0; i < DVal.size(); ++i) {
980           FVal.push_back((float)DVal[i]);
981         }
982         ArrayRef<float> tmp(FVal);
983         cnval = ConstantDataVector::get(M->getContext(), tmp);
984       } else {
985         ArrayRef<double> tmp(DVal);
986         cnval = ConstantDataVector::get(M->getContext(), tmp);
987       }
988     }
989   }
990 
991   if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) {
992     // We cannot handle corner cases for a general pow() function, give up
993     // unless y is a constant integral value. Then proceed as if it were pown.
994     if (getVecSize(FInfo) == 1) {
995       if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) {
996         double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
997                    ? (double)CF->getValueAPF().convertToFloat()
998                    : CF->getValueAPF().convertToDouble();
999         if (y != (double)(int64_t)y)
1000           return false;
1001       } else
1002         return false;
1003     } else {
1004       if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) {
1005         for (int i=0; i < getVecSize(FInfo); ++i) {
1006           double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
1007                      ? (double)CDV->getElementAsFloat(i)
1008                      : CDV->getElementAsDouble(i);
1009           if (y != (double)(int64_t)y)
1010             return false;
1011         }
1012       } else
1013         return false;
1014     }
1015   }
1016 
1017   Value *nval;
1018   if (needabs) {
1019     Constant *AbsExpr = getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_FABS,
1020                                                      FInfo));
1021     if (!AbsExpr)
1022       return false;
1023     nval = CreateCallEx(B, AbsExpr, opr0, "__fabs");
1024   } else {
1025     nval = cnval ? cnval : opr0;
1026   }
1027   if (needlog) {
1028     Constant *LogExpr = getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_LOG2,
1029                                                      FInfo));
1030     if (!LogExpr)
1031       return false;
1032     nval = CreateCallEx(B,LogExpr, nval, "__log2");
1033   }
1034 
1035   if (FInfo.getId() == AMDGPULibFunc::EI_POWN) {
1036     // convert int(32) to fp(f32 or f64)
1037     opr1 = B.CreateSIToFP(opr1, nval->getType(), "pownI2F");
1038   }
1039   nval = B.CreateFMul(opr1, nval, "__ylogx");
1040   nval = CreateCallEx(B,ExpExpr, nval, "__exp2");
1041 
1042   if (needcopysign) {
1043     Value *opr_n;
1044     Type* rTy = opr0->getType();
1045     Type* nTyS = eltType->isDoubleTy() ? B.getInt64Ty() : B.getInt32Ty();
1046     Type *nTy = nTyS;
1047     if (const VectorType *vTy = dyn_cast<VectorType>(rTy))
1048       nTy = VectorType::get(nTyS, vTy->getNumElements());
1049     unsigned size = nTy->getScalarSizeInBits();
1050     opr_n = CI->getArgOperand(1);
1051     if (opr_n->getType()->isIntegerTy())
1052       opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
1053     else
1054       opr_n = B.CreateFPToSI(opr1, nTy, "__ytou");
1055 
1056     Value *sign = B.CreateShl(opr_n, size-1, "__yeven");
1057     sign = B.CreateAnd(B.CreateBitCast(opr0, nTy), sign, "__pow_sign");
1058     nval = B.CreateOr(B.CreateBitCast(nval, nTy), sign);
1059     nval = B.CreateBitCast(nval, opr0->getType());
1060   }
1061 
1062   DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1063                << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
1064   replaceCall(nval);
1065 
1066   return true;
1067 }
1068 
1069 bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B,
1070                                 const FuncInfo &FInfo) {
1071   Value *opr0 = CI->getArgOperand(0);
1072   Value *opr1 = CI->getArgOperand(1);
1073 
1074   ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
1075   if (!CINT) {
1076     return false;
1077   }
1078   int ci_opr1 = (int)CINT->getSExtValue();
1079   if (ci_opr1 == 1) {  // rootn(x, 1) = x
1080     DEBUG(errs() << "AMDIC: " << *CI
1081                  << " ---> " << *opr0 << "\n");
1082     replaceCall(opr0);
1083     return true;
1084   }
1085   if (ci_opr1 == 2) {  // rootn(x, 2) = sqrt(x)
1086     std::vector<const Type*> ParamsTys;
1087     ParamsTys.push_back(opr0->getType());
1088     Module *M = CI->getModule();
1089     if (Constant *FPExpr = getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT,
1090                                                         FInfo))) {
1091       DEBUG(errs() << "AMDIC: " << *CI << " ---> sqrt(" << *opr0 << ")\n");
1092       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
1093       replaceCall(nval);
1094       return true;
1095     }
1096   } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
1097     Module *M = CI->getModule();
1098     if (Constant *FPExpr = getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT,
1099                                                         FInfo))) {
1100       DEBUG(errs() << "AMDIC: " << *CI << " ---> cbrt(" << *opr0 << ")\n");
1101       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
1102       replaceCall(nval);
1103       return true;
1104     }
1105   } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
1106     DEBUG(errs() << "AMDIC: " << *CI << " ---> 1.0 / " << *opr0 << "\n");
1107     Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
1108                                opr0,
1109                                "__rootn2div");
1110     replaceCall(nval);
1111     return true;
1112   } else if (ci_opr1 == -2) {  // rootn(x, -2) = rsqrt(x)
1113     std::vector<const Type*> ParamsTys;
1114     ParamsTys.push_back(opr0->getType());
1115     Module *M = CI->getModule();
1116     if (Constant *FPExpr = getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT,
1117                                                         FInfo))) {
1118       DEBUG(errs() << "AMDIC: " << *CI << " ---> rsqrt(" << *opr0 << ")\n");
1119       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
1120       replaceCall(nval);
1121       return true;
1122     }
1123   }
1124   return false;
1125 }
1126 
1127 bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
1128                                   const FuncInfo &FInfo) {
1129   Value *opr0 = CI->getArgOperand(0);
1130   Value *opr1 = CI->getArgOperand(1);
1131   Value *opr2 = CI->getArgOperand(2);
1132 
1133   ConstantFP *CF0 = dyn_cast<ConstantFP>(opr0);
1134   ConstantFP *CF1 = dyn_cast<ConstantFP>(opr1);
1135   if ((CF0 && CF0->isZero()) || (CF1 && CF1->isZero())) {
1136     // fma/mad(a, b, c) = c if a=0 || b=0
1137     DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr2 << "\n");
1138     replaceCall(opr2);
1139     return true;
1140   }
1141   if (CF0 && CF0->isExactlyValue(1.0f)) {
1142     // fma/mad(a, b, c) = b+c if a=1
1143     DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1144                  << *opr1 << " + " << *opr2 << "\n");
1145     Value *nval = B.CreateFAdd(opr1, opr2, "fmaadd");
1146     replaceCall(nval);
1147     return true;
1148   }
1149   if (CF1 && CF1->isExactlyValue(1.0f)) {
1150     // fma/mad(a, b, c) = a+c if b=1
1151     DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1152                  << *opr0 << " + " << *opr2 << "\n");
1153     Value *nval = B.CreateFAdd(opr0, opr2, "fmaadd");
1154     replaceCall(nval);
1155     return true;
1156   }
1157   if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) {
1158     if (CF->isZero()) {
1159       // fma/mad(a, b, c) = a*b if c=0
1160       DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1161                    << *opr0 << " * " << *opr1 << "\n");
1162       Value *nval = B.CreateFMul(opr0, opr1, "fmamul");
1163       replaceCall(nval);
1164       return true;
1165     }
1166   }
1167 
1168   return false;
1169 }
1170 
1171 // Get a scalar native builtin signle argument FP function
1172 Constant* AMDGPULibCalls::getNativeFunction(Module* M, const FuncInfo& FInfo) {
1173   FuncInfo nf = FInfo;
1174   nf.setPrefix(AMDGPULibFunc::NATIVE);
1175   return getFunction(M, nf);
1176 }
1177 
1178 // fold sqrt -> native_sqrt (x)
1179 bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
1180                                const FuncInfo &FInfo) {
1181   if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
1182        getArgType(FInfo) == AMDGPULibFunc::F64) &&
1183       (getVecSize(FInfo) == 1) &&
1184       (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
1185     if (Constant *FPExpr = getNativeFunction(
1186         CI->getModule(), AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1187       Value *opr0 = CI->getArgOperand(0);
1188       DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1189                    << "sqrt(" << *opr0 << ")\n");
1190       Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
1191       replaceCall(nval);
1192       return true;
1193     }
1194   }
1195   return false;
1196 }
1197 
1198 // fold sin, cos -> sincos.
1199 bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B,
1200                                  AliasAnalysis *AA) {
1201   AMDGPULibFunc fInfo;
1202   if (!AMDGPULibFunc::parse(CI->getCalledFunction()->getName(), fInfo))
1203     return false;
1204 
1205   assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
1206          fInfo.getId() == AMDGPULibFunc::EI_COS);
1207   bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
1208 
1209   Value *CArgVal = CI->getArgOperand(0);
1210   BasicBlock * const CBB = CI->getParent();
1211 
1212   int const MaxScan = 30;
1213 
1214   { // fold in load value.
1215     LoadInst *LI = dyn_cast<LoadInst>(CArgVal);
1216     if (LI && LI->getParent() == CBB) {
1217       BasicBlock::iterator BBI = LI->getIterator();
1218       Value *AvailableVal = FindAvailableLoadedValue(LI, CBB, BBI, MaxScan, AA);
1219       if (AvailableVal) {
1220         CArgVal->replaceAllUsesWith(AvailableVal);
1221         if (CArgVal->getNumUses() == 0)
1222           LI->eraseFromParent();
1223         CArgVal = CI->getArgOperand(0);
1224       }
1225     }
1226   }
1227 
1228   Module *M = CI->getModule();
1229   fInfo.setId(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN);
1230   std::string const PairName = fInfo.mangle();
1231 
1232   CallInst *UI = nullptr;
1233   for (User* U : CArgVal->users()) {
1234     CallInst *XI = dyn_cast_or_null<CallInst>(U);
1235     if (!XI || XI == CI || XI->getParent() != CBB)
1236       continue;
1237 
1238     Function *UCallee = XI->getCalledFunction();
1239     if (!UCallee || !UCallee->getName().equals(PairName))
1240       continue;
1241 
1242     BasicBlock::iterator BBI = CI->getIterator();
1243     if (BBI == CI->getParent()->begin())
1244       break;
1245     --BBI;
1246     for (int I = MaxScan; I > 0 && BBI != CBB->begin(); --BBI, --I) {
1247       if (cast<Instruction>(BBI) == XI) {
1248         UI = XI;
1249         break;
1250       }
1251     }
1252     if (UI) break;
1253   }
1254 
1255   if (!UI) return false;
1256 
1257   // Merge the sin and cos.
1258 
1259   // for OpenCL 2.0 we have only generic implementation of sincos
1260   // function.
1261   AMDGPULibFunc nf(AMDGPULibFunc::EI_SINCOS, fInfo);
1262   nf.Leads[0].PtrKind = AMDGPULibFunc::GENERIC;
1263   Function *Fsincos = dyn_cast_or_null<Function>(getFunction(M, nf));
1264   if (!Fsincos) return false;
1265 
1266   BasicBlock::iterator ItOld = B.GetInsertPoint();
1267   AllocaInst *Alloc = insertAlloca(UI, B, "__sincos_");
1268   B.SetInsertPoint(UI);
1269 
1270   Value *P = Alloc;
1271   Type *PTy = Fsincos->getFunctionType()->getParamType(1);
1272   // The allocaInst allocates the memory in private address space. This need
1273   // to be bitcasted to point to the address space of cos pointer type.
1274   // In OpenCL 2.0 this is generic, while in 1.2 that is private.
1275   const AMDGPUAS AS = AMDGPU::getAMDGPUAS(*M);
1276   if (PTy->getPointerAddressSpace() != AS.PRIVATE_ADDRESS)
1277     P = B.CreateAddrSpaceCast(Alloc, PTy);
1278   CallInst *Call = CreateCallEx2(B, Fsincos, UI->getArgOperand(0), P);
1279 
1280   DEBUG(errs() << "AMDIC: fold_sincos (" << *CI << ", " << *UI
1281                << ") with " << *Call << "\n");
1282 
1283   if (!isSin) { // CI->cos, UI->sin
1284     B.SetInsertPoint(&*ItOld);
1285     UI->replaceAllUsesWith(&*Call);
1286     Instruction *Reload = B.CreateLoad(Alloc);
1287     CI->replaceAllUsesWith(Reload);
1288     UI->eraseFromParent();
1289     CI->eraseFromParent();
1290   } else { // CI->sin, UI->cos
1291     Instruction *Reload = B.CreateLoad(Alloc);
1292     UI->replaceAllUsesWith(Reload);
1293     CI->replaceAllUsesWith(Call);
1294     UI->eraseFromParent();
1295     CI->eraseFromParent();
1296   }
1297   return true;
1298 }
1299 
1300 // Get insertion point at entry.
1301 BasicBlock::iterator AMDGPULibCalls::getEntryIns(CallInst * UI) {
1302   Function * Func = UI->getParent()->getParent();
1303   BasicBlock * BB = &Func->getEntryBlock();
1304   assert(BB && "Entry block not found!");
1305   BasicBlock::iterator ItNew = BB->begin();
1306   return ItNew;
1307 }
1308 
1309 // Insert a AllocsInst at the beginning of function entry block.
1310 AllocaInst* AMDGPULibCalls::insertAlloca(CallInst *UI, IRBuilder<> &B,
1311                                          const char *prefix) {
1312   BasicBlock::iterator ItNew = getEntryIns(UI);
1313   Function *UCallee = UI->getCalledFunction();
1314   Type *RetType = UCallee->getReturnType();
1315   B.SetInsertPoint(&*ItNew);
1316   AllocaInst *Alloc = B.CreateAlloca(RetType, 0,
1317     std::string(prefix) + UI->getName());
1318   Alloc->setAlignment(UCallee->getParent()->getDataLayout()
1319                        .getTypeAllocSize(RetType));
1320   return Alloc;
1321 }
1322 
1323 bool AMDGPULibCalls::evaluateScalarMathFunc(FuncInfo &FInfo,
1324                                             double& Res0, double& Res1,
1325                                             Constant *copr0, Constant *copr1,
1326                                             Constant *copr2) {
1327   // By default, opr0/opr1/opr3 holds values of float/double type.
1328   // If they are not float/double, each function has to its
1329   // operand separately.
1330   double opr0=0.0, opr1=0.0, opr2=0.0;
1331   ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(copr0);
1332   ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(copr1);
1333   ConstantFP *fpopr2 = dyn_cast_or_null<ConstantFP>(copr2);
1334   if (fpopr0) {
1335     opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1336              ? fpopr0->getValueAPF().convertToDouble()
1337              : (double)fpopr0->getValueAPF().convertToFloat();
1338   }
1339 
1340   if (fpopr1) {
1341     opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1342              ? fpopr1->getValueAPF().convertToDouble()
1343              : (double)fpopr1->getValueAPF().convertToFloat();
1344   }
1345 
1346   if (fpopr2) {
1347     opr2 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1348              ? fpopr2->getValueAPF().convertToDouble()
1349              : (double)fpopr2->getValueAPF().convertToFloat();
1350   }
1351 
1352   switch (FInfo.getId()) {
1353   default : return false;
1354 
1355   case AMDGPULibFunc::EI_ACOS:
1356     Res0 = acos(opr0);
1357     return true;
1358 
1359   case AMDGPULibFunc::EI_ACOSH:
1360     // acosh(x) == log(x + sqrt(x*x - 1))
1361     Res0 = log(opr0 + sqrt(opr0*opr0 - 1.0));
1362     return true;
1363 
1364   case AMDGPULibFunc::EI_ACOSPI:
1365     Res0 = acos(opr0) / MATH_PI;
1366     return true;
1367 
1368   case AMDGPULibFunc::EI_ASIN:
1369     Res0 = asin(opr0);
1370     return true;
1371 
1372   case AMDGPULibFunc::EI_ASINH:
1373     // asinh(x) == log(x + sqrt(x*x + 1))
1374     Res0 = log(opr0 + sqrt(opr0*opr0 + 1.0));
1375     return true;
1376 
1377   case AMDGPULibFunc::EI_ASINPI:
1378     Res0 = asin(opr0) / MATH_PI;
1379     return true;
1380 
1381   case AMDGPULibFunc::EI_ATAN:
1382     Res0 = atan(opr0);
1383     return true;
1384 
1385   case AMDGPULibFunc::EI_ATANH:
1386     // atanh(x) == (log(x+1) - log(x-1))/2;
1387     Res0 = (log(opr0 + 1.0) - log(opr0 - 1.0))/2.0;
1388     return true;
1389 
1390   case AMDGPULibFunc::EI_ATANPI:
1391     Res0 = atan(opr0) / MATH_PI;
1392     return true;
1393 
1394   case AMDGPULibFunc::EI_CBRT:
1395     Res0 = (opr0 < 0.0) ? -pow(-opr0, 1.0/3.0) : pow(opr0, 1.0/3.0);
1396     return true;
1397 
1398   case AMDGPULibFunc::EI_COS:
1399     Res0 = cos(opr0);
1400     return true;
1401 
1402   case AMDGPULibFunc::EI_COSH:
1403     Res0 = cosh(opr0);
1404     return true;
1405 
1406   case AMDGPULibFunc::EI_COSPI:
1407     Res0 = cos(MATH_PI * opr0);
1408     return true;
1409 
1410   case AMDGPULibFunc::EI_EXP:
1411     Res0 = exp(opr0);
1412     return true;
1413 
1414   case AMDGPULibFunc::EI_EXP2:
1415     Res0 = pow(2.0, opr0);
1416     return true;
1417 
1418   case AMDGPULibFunc::EI_EXP10:
1419     Res0 = pow(10.0, opr0);
1420     return true;
1421 
1422   case AMDGPULibFunc::EI_EXPM1:
1423     Res0 = exp(opr0) - 1.0;
1424     return true;
1425 
1426   case AMDGPULibFunc::EI_LOG:
1427     Res0 = log(opr0);
1428     return true;
1429 
1430   case AMDGPULibFunc::EI_LOG2:
1431     Res0 = log(opr0) / log(2.0);
1432     return true;
1433 
1434   case AMDGPULibFunc::EI_LOG10:
1435     Res0 = log(opr0) / log(10.0);
1436     return true;
1437 
1438   case AMDGPULibFunc::EI_RSQRT:
1439     Res0 = 1.0 / sqrt(opr0);
1440     return true;
1441 
1442   case AMDGPULibFunc::EI_SIN:
1443     Res0 = sin(opr0);
1444     return true;
1445 
1446   case AMDGPULibFunc::EI_SINH:
1447     Res0 = sinh(opr0);
1448     return true;
1449 
1450   case AMDGPULibFunc::EI_SINPI:
1451     Res0 = sin(MATH_PI * opr0);
1452     return true;
1453 
1454   case AMDGPULibFunc::EI_SQRT:
1455     Res0 = sqrt(opr0);
1456     return true;
1457 
1458   case AMDGPULibFunc::EI_TAN:
1459     Res0 = tan(opr0);
1460     return true;
1461 
1462   case AMDGPULibFunc::EI_TANH:
1463     Res0 = tanh(opr0);
1464     return true;
1465 
1466   case AMDGPULibFunc::EI_TANPI:
1467     Res0 = tan(MATH_PI * opr0);
1468     return true;
1469 
1470   case AMDGPULibFunc::EI_RECIP:
1471     Res0 = 1.0 / opr0;
1472     return true;
1473 
1474   // two-arg functions
1475   case AMDGPULibFunc::EI_DIVIDE:
1476     Res0 = opr0 / opr1;
1477     return true;
1478 
1479   case AMDGPULibFunc::EI_POW:
1480   case AMDGPULibFunc::EI_POWR:
1481     Res0 = pow(opr0, opr1);
1482     return true;
1483 
1484   case AMDGPULibFunc::EI_POWN: {
1485     if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1486       double val = (double)iopr1->getSExtValue();
1487       Res0 = pow(opr0, val);
1488       return true;
1489     }
1490     return false;
1491   }
1492 
1493   case AMDGPULibFunc::EI_ROOTN: {
1494     if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1495       double val = (double)iopr1->getSExtValue();
1496       Res0 = pow(opr0, 1.0 / val);
1497       return true;
1498     }
1499     return false;
1500   }
1501 
1502   // with ptr arg
1503   case AMDGPULibFunc::EI_SINCOS:
1504     Res0 = sin(opr0);
1505     Res1 = cos(opr0);
1506     return true;
1507 
1508   // three-arg functions
1509   case AMDGPULibFunc::EI_FMA:
1510   case AMDGPULibFunc::EI_MAD:
1511     Res0 = opr0 * opr1 + opr2;
1512     return true;
1513   }
1514 
1515   return false;
1516 }
1517 
1518 bool AMDGPULibCalls::evaluateCall(CallInst *aCI, FuncInfo &FInfo) {
1519   int numArgs = (int)aCI->getNumArgOperands();
1520   if (numArgs > 3)
1521     return false;
1522 
1523   Constant *copr0 = nullptr;
1524   Constant *copr1 = nullptr;
1525   Constant *copr2 = nullptr;
1526   if (numArgs > 0) {
1527     if ((copr0 = dyn_cast<Constant>(aCI->getArgOperand(0))) == nullptr)
1528       return false;
1529   }
1530 
1531   if (numArgs > 1) {
1532     if ((copr1 = dyn_cast<Constant>(aCI->getArgOperand(1))) == nullptr) {
1533       if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS)
1534         return false;
1535     }
1536   }
1537 
1538   if (numArgs > 2) {
1539     if ((copr2 = dyn_cast<Constant>(aCI->getArgOperand(2))) == nullptr)
1540       return false;
1541   }
1542 
1543   // At this point, all arguments to aCI are constants.
1544 
1545   // max vector size is 16, and sincos will generate two results.
1546   double DVal0[16], DVal1[16];
1547   bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS);
1548   if (getVecSize(FInfo) == 1) {
1549     if (!evaluateScalarMathFunc(FInfo, DVal0[0],
1550                                 DVal1[0], copr0, copr1, copr2)) {
1551       return false;
1552     }
1553   } else {
1554     ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(copr0);
1555     ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(copr1);
1556     ConstantDataVector *CDV2 = dyn_cast_or_null<ConstantDataVector>(copr2);
1557     for (int i=0; i < getVecSize(FInfo); ++i) {
1558       Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr;
1559       Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr;
1560       Constant *celt2 = CDV2 ? CDV2->getElementAsConstant(i) : nullptr;
1561       if (!evaluateScalarMathFunc(FInfo, DVal0[i],
1562                                   DVal1[i], celt0, celt1, celt2)) {
1563         return false;
1564       }
1565     }
1566   }
1567 
1568   LLVMContext &context = CI->getParent()->getParent()->getContext();
1569   Constant *nval0, *nval1;
1570   if (getVecSize(FInfo) == 1) {
1571     nval0 = ConstantFP::get(CI->getType(), DVal0[0]);
1572     if (hasTwoResults)
1573       nval1 = ConstantFP::get(CI->getType(), DVal1[0]);
1574   } else {
1575     if (getArgType(FInfo) == AMDGPULibFunc::F32) {
1576       SmallVector <float, 0> FVal0, FVal1;
1577       for (int i=0; i < getVecSize(FInfo); ++i)
1578         FVal0.push_back((float)DVal0[i]);
1579       ArrayRef<float> tmp0(FVal0);
1580       nval0 = ConstantDataVector::get(context, tmp0);
1581       if (hasTwoResults) {
1582         for (int i=0; i < getVecSize(FInfo); ++i)
1583           FVal1.push_back((float)DVal1[i]);
1584         ArrayRef<float> tmp1(FVal1);
1585         nval1 = ConstantDataVector::get(context, tmp1);
1586       }
1587     } else {
1588       ArrayRef<double> tmp0(DVal0);
1589       nval0 = ConstantDataVector::get(context, tmp0);
1590       if (hasTwoResults) {
1591         ArrayRef<double> tmp1(DVal1);
1592         nval1 = ConstantDataVector::get(context, tmp1);
1593       }
1594     }
1595   }
1596 
1597   if (hasTwoResults) {
1598     // sincos
1599     assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS &&
1600            "math function with ptr arg not supported yet");
1601     new StoreInst(nval1, aCI->getArgOperand(1), aCI);
1602   }
1603 
1604   replaceCall(nval0);
1605   return true;
1606 }
1607 
1608 // Public interface to the Simplify LibCalls pass.
1609 FunctionPass *llvm::createAMDGPUSimplifyLibCallsPass() {
1610   return new AMDGPUSimplifyLibCalls();
1611 }
1612 
1613 FunctionPass *llvm::createAMDGPUUseNativeCallsPass() {
1614   return new AMDGPUUseNativeCalls();
1615 }
1616 
1617 bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) {
1618   if (skipFunction(F))
1619     return false;
1620 
1621   bool Changed = false;
1622   auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
1623 
1624   DEBUG(dbgs() << "AMDIC: process function ";
1625         F.printAsOperand(dbgs(), false, F.getParent());
1626         dbgs() << '\n';);
1627 
1628   for (auto &BB : F) {
1629     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
1630       // Ignore non-calls.
1631       CallInst *CI = dyn_cast<CallInst>(I);
1632       ++I;
1633       if (!CI) continue;
1634 
1635       // Ignore indirect calls.
1636       Function *Callee = CI->getCalledFunction();
1637       if (Callee == 0) continue;
1638 
1639       DEBUG(dbgs() << "AMDIC: try folding " << *CI << "\n";
1640             dbgs().flush());
1641       if(Simplifier.fold(CI, AA))
1642         Changed = true;
1643     }
1644   }
1645   return Changed;
1646 }
1647 
1648 bool AMDGPUUseNativeCalls::runOnFunction(Function &F) {
1649   if (skipFunction(F) || UseNative.empty())
1650     return false;
1651 
1652   bool Changed = false;
1653   for (auto &BB : F) {
1654     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
1655       // Ignore non-calls.
1656       CallInst *CI = dyn_cast<CallInst>(I);
1657       ++I;
1658       if (!CI) continue;
1659 
1660       // Ignore indirect calls.
1661       Function *Callee = CI->getCalledFunction();
1662       if (Callee == 0) continue;
1663 
1664       if(Simplifier.useNative(CI))
1665         Changed = true;
1666     }
1667   }
1668   return Changed;
1669 }
1670