1 //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the Float2Int pass, which aims to demote floating
11 // point operations to work on integers, where that is losslessly possible.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #define DEBUG_TYPE "float2int"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/APSInt.h"
18 #include "llvm/ADT/EquivalenceClasses.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Analysis/AliasAnalysis.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/IR/ConstantRange.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstIterator.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "llvm/Transforms/Scalar.h"
33 #include <deque>
34 #include <functional> // For std::function
35 using namespace llvm;
36 
37 // The algorithm is simple. Start at instructions that convert from the
38 // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
39 // graph, using an equivalence datastructure to unify graphs that interfere.
40 //
41 // Mappable instructions are those with an integer corrollary that, given
42 // integer domain inputs, produce an integer output; fadd, for example.
43 //
44 // If a non-mappable instruction is seen, this entire def-use graph is marked
45 // as non-transformable. If we see an instruction that converts from the
46 // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
47 
48 /// The largest integer type worth dealing with.
49 static cl::opt<unsigned>
50 MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
51              cl::desc("Max integer bitwidth to consider in float2int"
52                       "(default=64)"));
53 
54 namespace {
55   struct Float2Int : public FunctionPass {
56     static char ID; // Pass identification, replacement for typeid
57     Float2Int() : FunctionPass(ID) {
58       initializeFloat2IntPass(*PassRegistry::getPassRegistry());
59     }
60 
61     bool runOnFunction(Function &F) override;
62     void getAnalysisUsage(AnalysisUsage &AU) const override {
63       AU.setPreservesCFG();
64       AU.addPreserved<GlobalsAAWrapperPass>();
65     }
66 
67     void findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots);
68     ConstantRange seen(Instruction *I, ConstantRange R);
69     ConstantRange badRange();
70     ConstantRange unknownRange();
71     ConstantRange validateRange(ConstantRange R);
72     void walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots);
73     void walkForwards();
74     bool validateAndTransform();
75     Value *convert(Instruction *I, Type *ToTy);
76     void cleanup();
77 
78     MapVector<Instruction*, ConstantRange > SeenInsts;
79     SmallPtrSet<Instruction*,8> Roots;
80     EquivalenceClasses<Instruction*> ECs;
81     MapVector<Instruction*, Value*> ConvertedInsts;
82     LLVMContext *Ctx;
83   };
84 }
85 
86 char Float2Int::ID = 0;
87 INITIALIZE_PASS_BEGIN(Float2Int, "float2int", "Float to int", false, false)
88 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
89 INITIALIZE_PASS_END(Float2Int, "float2int", "Float to int", false, false)
90 
91 // Given a FCmp predicate, return a matching ICmp predicate if one
92 // exists, otherwise return BAD_ICMP_PREDICATE.
93 static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
94   switch (P) {
95   case CmpInst::FCMP_OEQ:
96   case CmpInst::FCMP_UEQ:
97     return CmpInst::ICMP_EQ;
98   case CmpInst::FCMP_OGT:
99   case CmpInst::FCMP_UGT:
100     return CmpInst::ICMP_SGT;
101   case CmpInst::FCMP_OGE:
102   case CmpInst::FCMP_UGE:
103     return CmpInst::ICMP_SGE;
104   case CmpInst::FCMP_OLT:
105   case CmpInst::FCMP_ULT:
106     return CmpInst::ICMP_SLT;
107   case CmpInst::FCMP_OLE:
108   case CmpInst::FCMP_ULE:
109     return CmpInst::ICMP_SLE;
110   case CmpInst::FCMP_ONE:
111   case CmpInst::FCMP_UNE:
112     return CmpInst::ICMP_NE;
113   default:
114     return CmpInst::BAD_ICMP_PREDICATE;
115   }
116 }
117 
118 // Given a floating point binary operator, return the matching
119 // integer version.
120 static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
121   switch (Opcode) {
122   default: llvm_unreachable("Unhandled opcode!");
123   case Instruction::FAdd: return Instruction::Add;
124   case Instruction::FSub: return Instruction::Sub;
125   case Instruction::FMul: return Instruction::Mul;
126   }
127 }
128 
129 // Find the roots - instructions that convert from the FP domain to
130 // integer domain.
131 void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) {
132   for (auto &I : instructions(F)) {
133     if (isa<VectorType>(I.getType()))
134       continue;
135     switch (I.getOpcode()) {
136     default: break;
137     case Instruction::FPToUI:
138     case Instruction::FPToSI:
139       Roots.insert(&I);
140       break;
141     case Instruction::FCmp:
142       if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
143           CmpInst::BAD_ICMP_PREDICATE)
144         Roots.insert(&I);
145       break;
146     }
147   }
148 }
149 
150 // Helper - mark I as having been traversed, having range R.
151 ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) {
152   DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
153   if (SeenInsts.find(I) != SeenInsts.end())
154     SeenInsts.find(I)->second = R;
155   else
156     SeenInsts.insert(std::make_pair(I, R));
157   return R;
158 }
159 
160 // Helper - get a range representing a poison value.
161 ConstantRange Float2Int::badRange() {
162   return ConstantRange(MaxIntegerBW + 1, true);
163 }
164 ConstantRange Float2Int::unknownRange() {
165   return ConstantRange(MaxIntegerBW + 1, false);
166 }
167 ConstantRange Float2Int::validateRange(ConstantRange R) {
168   if (R.getBitWidth() > MaxIntegerBW + 1)
169     return badRange();
170   return R;
171 }
172 
173 // The most obvious way to structure the search is a depth-first, eager
174 // search from each root. However, that require direct recursion and so
175 // can only handle small instruction sequences. Instead, we split the search
176 // up into two phases:
177 //   - walkBackwards:  A breadth-first walk of the use-def graph starting from
178 //                     the roots. Populate "SeenInsts" with interesting
179 //                     instructions and poison values if they're obvious and
180 //                     cheap to compute. Calculate the equivalance set structure
181 //                     while we're here too.
182 //   - walkForwards:  Iterate over SeenInsts in reverse order, so we visit
183 //                     defs before their uses. Calculate the real range info.
184 
185 // Breadth-first walk of the use-def graph; determine the set of nodes
186 // we care about and eagerly determine if some of them are poisonous.
187 void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) {
188   std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
189   while (!Worklist.empty()) {
190     Instruction *I = Worklist.back();
191     Worklist.pop_back();
192 
193     if (SeenInsts.find(I) != SeenInsts.end())
194       // Seen already.
195       continue;
196 
197     switch (I->getOpcode()) {
198       // FIXME: Handle select and phi nodes.
199     default:
200       // Path terminated uncleanly.
201       seen(I, badRange());
202       break;
203 
204     case Instruction::UIToFP: {
205       // Path terminated cleanly.
206       unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
207       APInt Min = APInt::getMinValue(BW).zextOrSelf(MaxIntegerBW+1);
208       APInt Max = APInt::getMaxValue(BW).zextOrSelf(MaxIntegerBW+1);
209       seen(I, validateRange(ConstantRange(Min, Max)));
210       continue;
211     }
212 
213     case Instruction::SIToFP: {
214       // Path terminated cleanly.
215       unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
216       APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(MaxIntegerBW+1);
217       APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(MaxIntegerBW+1);
218       seen(I, validateRange(ConstantRange(SMin, SMax)));
219       continue;
220     }
221 
222     case Instruction::FAdd:
223     case Instruction::FSub:
224     case Instruction::FMul:
225     case Instruction::FPToUI:
226     case Instruction::FPToSI:
227     case Instruction::FCmp:
228       seen(I, unknownRange());
229       break;
230     }
231 
232     for (Value *O : I->operands()) {
233       if (Instruction *OI = dyn_cast<Instruction>(O)) {
234         // Unify def-use chains if they interfere.
235         ECs.unionSets(I, OI);
236         if (SeenInsts.find(I)->second != badRange())
237           Worklist.push_back(OI);
238       } else if (!isa<ConstantFP>(O)) {
239         // Not an instruction or ConstantFP? we can't do anything.
240         seen(I, badRange());
241       }
242     }
243   }
244 }
245 
246 // Walk forwards down the list of seen instructions, so we visit defs before
247 // uses.
248 void Float2Int::walkForwards() {
249   for (auto &It : make_range(SeenInsts.rbegin(), SeenInsts.rend())) {
250     if (It.second != unknownRange())
251       continue;
252 
253     Instruction *I = It.first;
254     std::function<ConstantRange(ArrayRef<ConstantRange>)> Op;
255     switch (I->getOpcode()) {
256       // FIXME: Handle select and phi nodes.
257     default:
258     case Instruction::UIToFP:
259     case Instruction::SIToFP:
260       llvm_unreachable("Should have been handled in walkForwards!");
261 
262     case Instruction::FAdd:
263       Op = [](ArrayRef<ConstantRange> Ops) {
264         assert(Ops.size() == 2 && "FAdd is a binary operator!");
265         return Ops[0].add(Ops[1]);
266       };
267       break;
268 
269     case Instruction::FSub:
270       Op = [](ArrayRef<ConstantRange> Ops) {
271         assert(Ops.size() == 2 && "FSub is a binary operator!");
272         return Ops[0].sub(Ops[1]);
273       };
274       break;
275 
276     case Instruction::FMul:
277       Op = [](ArrayRef<ConstantRange> Ops) {
278         assert(Ops.size() == 2 && "FMul is a binary operator!");
279         return Ops[0].multiply(Ops[1]);
280       };
281       break;
282 
283     //
284     // Root-only instructions - we'll only see these if they're the
285     //                          first node in a walk.
286     //
287     case Instruction::FPToUI:
288     case Instruction::FPToSI:
289       Op = [](ArrayRef<ConstantRange> Ops) {
290         assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!");
291         return Ops[0];
292       };
293       break;
294 
295     case Instruction::FCmp:
296       Op = [](ArrayRef<ConstantRange> Ops) {
297         assert(Ops.size() == 2 && "FCmp is a binary operator!");
298         return Ops[0].unionWith(Ops[1]);
299       };
300       break;
301     }
302 
303     bool Abort = false;
304     SmallVector<ConstantRange,4> OpRanges;
305     for (Value *O : I->operands()) {
306       if (Instruction *OI = dyn_cast<Instruction>(O)) {
307         assert(SeenInsts.find(OI) != SeenInsts.end() &&
308                "def not seen before use!");
309         OpRanges.push_back(SeenInsts.find(OI)->second);
310       } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
311         // Work out if the floating point number can be losslessly represented
312         // as an integer.
313         // APFloat::convertToInteger(&Exact) purports to do what we want, but
314         // the exactness can be too precise. For example, negative zero can
315         // never be exactly converted to an integer.
316         //
317         // Instead, we ask APFloat to round itself to an integral value - this
318         // preserves sign-of-zero - then compare the result with the original.
319         //
320         APFloat F = CF->getValueAPF();
321 
322         // First, weed out obviously incorrect values. Non-finite numbers
323         // can't be represented and neither can negative zero, unless
324         // we're in fast math mode.
325         if (!F.isFinite() ||
326             (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
327              !I->hasNoSignedZeros())) {
328           seen(I, badRange());
329           Abort = true;
330           break;
331         }
332 
333         APFloat NewF = F;
334         auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
335         if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) {
336           seen(I, badRange());
337           Abort = true;
338           break;
339         }
340         // OK, it's representable. Now get it.
341         APSInt Int(MaxIntegerBW+1, false);
342         bool Exact;
343         CF->getValueAPF().convertToInteger(Int,
344                                            APFloat::rmNearestTiesToEven,
345                                            &Exact);
346         OpRanges.push_back(ConstantRange(Int));
347       } else {
348         llvm_unreachable("Should have already marked this as badRange!");
349       }
350     }
351 
352     // Reduce the operands' ranges to a single range and return.
353     if (!Abort)
354       seen(I, Op(OpRanges));
355   }
356 }
357 
358 // If there is a valid transform to be done, do it.
359 bool Float2Int::validateAndTransform() {
360   bool MadeChange = false;
361 
362   // Iterate over every disjoint partition of the def-use graph.
363   for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
364     ConstantRange R(MaxIntegerBW + 1, false);
365     bool Fail = false;
366     Type *ConvertedToTy = nullptr;
367 
368     // For every member of the partition, union all the ranges together.
369     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
370          MI != ME; ++MI) {
371       Instruction *I = *MI;
372       auto SeenI = SeenInsts.find(I);
373       if (SeenI == SeenInsts.end())
374         continue;
375 
376       R = R.unionWith(SeenI->second);
377       // We need to ensure I has no users that have not been seen.
378       // If it does, transformation would be illegal.
379       //
380       // Don't count the roots, as they terminate the graphs.
381       if (Roots.count(I) == 0) {
382         // Set the type of the conversion while we're here.
383         if (!ConvertedToTy)
384           ConvertedToTy = I->getType();
385         for (User *U : I->users()) {
386           Instruction *UI = dyn_cast<Instruction>(U);
387           if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
388             DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
389             Fail = true;
390             break;
391           }
392         }
393       }
394       if (Fail)
395         break;
396     }
397 
398     // If the set was empty, or we failed, or the range is poisonous,
399     // bail out.
400     if (ECs.member_begin(It) == ECs.member_end() || Fail ||
401         R.isFullSet() || R.isSignWrappedSet())
402       continue;
403     assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
404 
405     // The number of bits required is the maximum of the upper and
406     // lower limits, plus one so it can be signed.
407     unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
408                               R.getUpper().getMinSignedBits()) + 1;
409     DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
410 
411     // If we've run off the realms of the exactly representable integers,
412     // the floating point result will differ from an integer approximation.
413 
414     // Do we need more bits than are in the mantissa of the type we converted
415     // to? semanticsPrecision returns the number of mantissa bits plus one
416     // for the sign bit.
417     unsigned MaxRepresentableBits
418       = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
419     if (MinBW > MaxRepresentableBits) {
420       DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
421       continue;
422     }
423     if (MinBW > 64) {
424       DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
425       continue;
426     }
427 
428     // OK, R is known to be representable. Now pick a type for it.
429     // FIXME: Pick the smallest legal type that will fit.
430     Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
431 
432     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
433          MI != ME; ++MI)
434       convert(*MI, Ty);
435     MadeChange = true;
436   }
437 
438   return MadeChange;
439 }
440 
441 Value *Float2Int::convert(Instruction *I, Type *ToTy) {
442   if (ConvertedInsts.find(I) != ConvertedInsts.end())
443     // Already converted this instruction.
444     return ConvertedInsts[I];
445 
446   SmallVector<Value*,4> NewOperands;
447   for (Value *V : I->operands()) {
448     // Don't recurse if we're an instruction that terminates the path.
449     if (I->getOpcode() == Instruction::UIToFP ||
450         I->getOpcode() == Instruction::SIToFP) {
451       NewOperands.push_back(V);
452     } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
453       NewOperands.push_back(convert(VI, ToTy));
454     } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
455       APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false);
456       bool Exact;
457       CF->getValueAPF().convertToInteger(Val,
458                                          APFloat::rmNearestTiesToEven,
459                                          &Exact);
460       NewOperands.push_back(ConstantInt::get(ToTy, Val));
461     } else {
462       llvm_unreachable("Unhandled operand type?");
463     }
464   }
465 
466   // Now create a new instruction.
467   IRBuilder<> IRB(I);
468   Value *NewV = nullptr;
469   switch (I->getOpcode()) {
470   default: llvm_unreachable("Unhandled instruction!");
471 
472   case Instruction::FPToUI:
473     NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
474     break;
475 
476   case Instruction::FPToSI:
477     NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
478     break;
479 
480   case Instruction::FCmp: {
481     CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
482     assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
483     NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
484     break;
485   }
486 
487   case Instruction::UIToFP:
488     NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
489     break;
490 
491   case Instruction::SIToFP:
492     NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
493     break;
494 
495   case Instruction::FAdd:
496   case Instruction::FSub:
497   case Instruction::FMul:
498     NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
499                            NewOperands[0], NewOperands[1],
500                            I->getName());
501     break;
502   }
503 
504   // If we're a root instruction, RAUW.
505   if (Roots.count(I))
506     I->replaceAllUsesWith(NewV);
507 
508   ConvertedInsts[I] = NewV;
509   return NewV;
510 }
511 
512 // Perform dead code elimination on the instructions we just modified.
513 void Float2Int::cleanup() {
514   for (auto &I : make_range(ConvertedInsts.rbegin(), ConvertedInsts.rend()))
515     I.first->eraseFromParent();
516 }
517 
518 bool Float2Int::runOnFunction(Function &F) {
519   if (skipFunction(F))
520     return false;
521 
522   DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
523   // Clear out all state.
524   ECs = EquivalenceClasses<Instruction*>();
525   SeenInsts.clear();
526   ConvertedInsts.clear();
527   Roots.clear();
528 
529   Ctx = &F.getParent()->getContext();
530 
531   findRoots(F, Roots);
532 
533   walkBackwards(Roots);
534   walkForwards();
535 
536   bool Modified = validateAndTransform();
537   if (Modified)
538     cleanup();
539   return Modified;
540 }
541 
542 FunctionPass *llvm::createFloat2IntPass() { return new Float2Int(); }
543