1 //===- Float2Int.cpp - Demote floating point ops to work on integers ------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements the Float2Int pass, which aims to demote floating 11 // point operations to work on integers, where that is losslessly possible. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #define DEBUG_TYPE "float2int" 16 #include "llvm/ADT/APInt.h" 17 #include "llvm/ADT/APSInt.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/ADT/EquivalenceClasses.h" 20 #include "llvm/ADT/MapVector.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/IR/ConstantRange.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/InstIterator.h" 26 #include "llvm/IR/Instructions.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 #include "llvm/Transforms/Scalar.h" 32 #include <deque> 33 #include <functional> // For std::function 34 using namespace llvm; 35 36 // The algorithm is simple. Start at instructions that convert from the 37 // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use 38 // graph, using an equivalence datastructure to unify graphs that interfere. 39 // 40 // Mappable instructions are those with an integer corrollary that, given 41 // integer domain inputs, produce an integer output; fadd, for example. 42 // 43 // If a non-mappable instruction is seen, this entire def-use graph is marked 44 // as non-transformable. If we see an instruction that converts from the 45 // integer domain to FP domain (uitofp,sitofp), we terminate our walk. 46 47 /// The largest integer type worth dealing with. 48 static cl::opt<unsigned> 49 MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden, 50 cl::desc("Max integer bitwidth to consider in float2int" 51 "(default=64)")); 52 53 namespace { 54 struct Float2Int : public FunctionPass { 55 static char ID; // Pass identification, replacement for typeid 56 Float2Int() : FunctionPass(ID) { 57 initializeFloat2IntPass(*PassRegistry::getPassRegistry()); 58 } 59 60 bool runOnFunction(Function &F) override; 61 void getAnalysisUsage(AnalysisUsage &AU) const override { 62 AU.setPreservesCFG(); 63 } 64 65 void findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots); 66 ConstantRange seen(Instruction *I, ConstantRange R); 67 ConstantRange badRange(); 68 ConstantRange unknownRange(); 69 ConstantRange validateRange(ConstantRange R); 70 void walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots); 71 void walkForwards(); 72 bool validateAndTransform(); 73 Value *convert(Instruction *I, Type *ToTy); 74 void cleanup(); 75 76 MapVector<Instruction*, ConstantRange > SeenInsts; 77 SmallPtrSet<Instruction*,8> Roots; 78 EquivalenceClasses<Instruction*> ECs; 79 MapVector<Instruction*, Value*> ConvertedInsts; 80 LLVMContext *Ctx; 81 }; 82 } 83 84 char Float2Int::ID = 0; 85 INITIALIZE_PASS(Float2Int, "float2int", "Float to int", false, false) 86 87 // Given a FCmp predicate, return a matching ICmp predicate if one 88 // exists, otherwise return BAD_ICMP_PREDICATE. 89 static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) { 90 switch (P) { 91 case CmpInst::FCMP_OEQ: 92 case CmpInst::FCMP_UEQ: 93 return CmpInst::ICMP_EQ; 94 case CmpInst::FCMP_OGT: 95 case CmpInst::FCMP_UGT: 96 return CmpInst::ICMP_SGT; 97 case CmpInst::FCMP_OGE: 98 case CmpInst::FCMP_UGE: 99 return CmpInst::ICMP_SGE; 100 case CmpInst::FCMP_OLT: 101 case CmpInst::FCMP_ULT: 102 return CmpInst::ICMP_SLT; 103 case CmpInst::FCMP_OLE: 104 case CmpInst::FCMP_ULE: 105 return CmpInst::ICMP_SLE; 106 case CmpInst::FCMP_ONE: 107 case CmpInst::FCMP_UNE: 108 return CmpInst::ICMP_NE; 109 default: 110 return CmpInst::BAD_ICMP_PREDICATE; 111 } 112 } 113 114 // Given a floating point binary operator, return the matching 115 // integer version. 116 static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) { 117 switch (Opcode) { 118 default: llvm_unreachable("Unhandled opcode!"); 119 case Instruction::FAdd: return Instruction::Add; 120 case Instruction::FSub: return Instruction::Sub; 121 case Instruction::FMul: return Instruction::Mul; 122 } 123 } 124 125 // Find the roots - instructions that convert from the FP domain to 126 // integer domain. 127 void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { 128 for (auto &I : inst_range(F)) { 129 switch (I.getOpcode()) { 130 default: break; 131 case Instruction::FPToUI: 132 case Instruction::FPToSI: 133 Roots.insert(&I); 134 break; 135 case Instruction::FCmp: 136 if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) != 137 CmpInst::BAD_ICMP_PREDICATE) 138 Roots.insert(&I); 139 break; 140 } 141 } 142 } 143 144 // Helper - mark I as having been traversed, having range R. 145 ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) { 146 DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n"); 147 if (SeenInsts.find(I) != SeenInsts.end()) 148 SeenInsts.find(I)->second = R; 149 else 150 SeenInsts.insert(std::make_pair(I, R)); 151 return R; 152 } 153 154 // Helper - get a range representing a poison value. 155 ConstantRange Float2Int::badRange() { 156 return ConstantRange(MaxIntegerBW + 1, true); 157 } 158 ConstantRange Float2Int::unknownRange() { 159 return ConstantRange(MaxIntegerBW + 1, false); 160 } 161 ConstantRange Float2Int::validateRange(ConstantRange R) { 162 if (R.getBitWidth() > MaxIntegerBW + 1) 163 return badRange(); 164 return R; 165 } 166 167 // The most obvious way to structure the search is a depth-first, eager 168 // search from each root. However, that require direct recursion and so 169 // can only handle small instruction sequences. Instead, we split the search 170 // up into two phases: 171 // - walkBackwards: A breadth-first walk of the use-def graph starting from 172 // the roots. Populate "SeenInsts" with interesting 173 // instructions and poison values if they're obvious and 174 // cheap to compute. Calculate the equivalance set structure 175 // while we're here too. 176 // - walkForwards: Iterate over SeenInsts in reverse order, so we visit 177 // defs before their uses. Calculate the real range info. 178 179 // Breadth-first walk of the use-def graph; determine the set of nodes 180 // we care about and eagerly determine if some of them are poisonous. 181 void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { 182 std::deque<Instruction*> Worklist(Roots.begin(), Roots.end()); 183 while (!Worklist.empty()) { 184 Instruction *I = Worklist.back(); 185 Worklist.pop_back(); 186 187 if (SeenInsts.find(I) != SeenInsts.end()) 188 // Seen already. 189 continue; 190 191 switch (I->getOpcode()) { 192 // FIXME: Handle select and phi nodes. 193 default: 194 // Path terminated uncleanly. 195 seen(I, badRange()); 196 break; 197 198 case Instruction::UIToFP: { 199 // Path terminated cleanly. 200 unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); 201 APInt Min = APInt::getMinValue(BW).zextOrSelf(MaxIntegerBW+1); 202 APInt Max = APInt::getMaxValue(BW).zextOrSelf(MaxIntegerBW+1); 203 seen(I, validateRange(ConstantRange(Min, Max))); 204 continue; 205 } 206 207 case Instruction::SIToFP: { 208 // Path terminated cleanly. 209 unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); 210 APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(MaxIntegerBW+1); 211 APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(MaxIntegerBW+1); 212 seen(I, validateRange(ConstantRange(SMin, SMax))); 213 continue; 214 } 215 216 case Instruction::FAdd: 217 case Instruction::FSub: 218 case Instruction::FMul: 219 case Instruction::FPToUI: 220 case Instruction::FPToSI: 221 case Instruction::FCmp: 222 seen(I, unknownRange()); 223 break; 224 } 225 226 for (Value *O : I->operands()) { 227 if (Instruction *OI = dyn_cast<Instruction>(O)) { 228 // Unify def-use chains if they interfere. 229 ECs.unionSets(I, OI); 230 if (SeenInsts.find(I)->second != badRange()) 231 Worklist.push_back(OI); 232 } else if (!isa<ConstantFP>(O)) { 233 // Not an instruction or ConstantFP? we can't do anything. 234 seen(I, badRange()); 235 } 236 } 237 } 238 } 239 240 // Walk forwards down the list of seen instructions, so we visit defs before 241 // uses. 242 void Float2Int::walkForwards() { 243 for (auto It = SeenInsts.rbegin(), E = SeenInsts.rend(); It != E; ++It) { 244 if (It->second != unknownRange()) 245 continue; 246 247 Instruction *I = It->first; 248 std::function<ConstantRange(ArrayRef<ConstantRange>)> Op; 249 switch (I->getOpcode()) { 250 // FIXME: Handle select and phi nodes. 251 default: 252 case Instruction::UIToFP: 253 case Instruction::SIToFP: 254 llvm_unreachable("Should have been handled in walkForwards!"); 255 256 case Instruction::FAdd: 257 Op = [](ArrayRef<ConstantRange> Ops) { 258 assert(Ops.size() == 2 && "FAdd is a binary operator!"); 259 return Ops[0].add(Ops[1]); 260 }; 261 break; 262 263 case Instruction::FSub: 264 Op = [](ArrayRef<ConstantRange> Ops) { 265 assert(Ops.size() == 2 && "FSub is a binary operator!"); 266 return Ops[0].sub(Ops[1]); 267 }; 268 break; 269 270 case Instruction::FMul: 271 Op = [](ArrayRef<ConstantRange> Ops) { 272 assert(Ops.size() == 2 && "FMul is a binary operator!"); 273 return Ops[0].multiply(Ops[1]); 274 }; 275 break; 276 277 // 278 // Root-only instructions - we'll only see these if they're the 279 // first node in a walk. 280 // 281 case Instruction::FPToUI: 282 case Instruction::FPToSI: 283 Op = [](ArrayRef<ConstantRange> Ops) { 284 assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!"); 285 return Ops[0]; 286 }; 287 break; 288 289 case Instruction::FCmp: 290 Op = [](ArrayRef<ConstantRange> Ops) { 291 assert(Ops.size() == 2 && "FCmp is a binary operator!"); 292 return Ops[0].unionWith(Ops[1]); 293 }; 294 break; 295 } 296 297 bool Abort = false; 298 SmallVector<ConstantRange,4> OpRanges; 299 for (Value *O : I->operands()) { 300 if (Instruction *OI = dyn_cast<Instruction>(O)) { 301 assert(SeenInsts.find(OI) != SeenInsts.end() && 302 "def not seen before use!"); 303 OpRanges.push_back(SeenInsts.find(OI)->second); 304 } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) { 305 // Work out if the floating point number can be losslessly represented 306 // as an integer. 307 // APFloat::convertToInteger(&Exact) purports to do what we want, but 308 // the exactness can be too precise. For example, negative zero can 309 // never be exactly converted to an integer. 310 // 311 // Instead, we ask APFloat to round itself to an integral value - this 312 // preserves sign-of-zero - then compare the result with the original. 313 // 314 APFloat F = CF->getValueAPF(); 315 316 // First, weed out obviously incorrect values. Non-finite numbers 317 // can't be represented and neither can negative zero, unless 318 // we're in fast math mode. 319 if (!F.isFinite() || 320 (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) && 321 !I->hasNoSignedZeros())) { 322 seen(I, badRange()); 323 Abort = true; 324 break; 325 } 326 327 APFloat NewF = F; 328 auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven); 329 if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) { 330 seen(I, badRange()); 331 Abort = true; 332 break; 333 } 334 // OK, it's representable. Now get it. 335 APSInt Int(MaxIntegerBW+1, false); 336 bool Exact; 337 CF->getValueAPF().convertToInteger(Int, 338 APFloat::rmNearestTiesToEven, 339 &Exact); 340 OpRanges.push_back(ConstantRange(Int)); 341 } else { 342 llvm_unreachable("Should have already marked this as badRange!"); 343 } 344 } 345 346 // Reduce the operands' ranges to a single range and return. 347 if (!Abort) 348 seen(I, Op(OpRanges)); 349 } 350 } 351 352 // If there is a valid transform to be done, do it. 353 bool Float2Int::validateAndTransform() { 354 bool MadeChange = false; 355 356 // Iterate over every disjoint partition of the def-use graph. 357 for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) { 358 ConstantRange R(MaxIntegerBW + 1, false); 359 bool Fail = false; 360 Type *ConvertedToTy = nullptr; 361 362 // For every member of the partition, union all the ranges together. 363 for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); 364 MI != ME; ++MI) { 365 Instruction *I = *MI; 366 auto SeenI = SeenInsts.find(I); 367 if (SeenI == SeenInsts.end()) 368 continue; 369 370 R = R.unionWith(SeenI->second); 371 // We need to ensure I has no users that have not been seen. 372 // If it does, transformation would be illegal. 373 // 374 // Don't count the roots, as they terminate the graphs. 375 if (Roots.count(I) == 0) { 376 // Set the type of the conversion while we're here. 377 if (!ConvertedToTy) 378 ConvertedToTy = I->getType(); 379 for (User *U : I->users()) { 380 Instruction *UI = dyn_cast<Instruction>(U); 381 if (!UI || SeenInsts.find(UI) == SeenInsts.end()) { 382 DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n"); 383 Fail = true; 384 break; 385 } 386 } 387 } 388 if (Fail) 389 break; 390 } 391 392 // If the set was empty, or we failed, or the range is poisonous, 393 // bail out. 394 if (ECs.member_begin(It) == ECs.member_end() || Fail || 395 R.isFullSet() || R.isSignWrappedSet()) 396 continue; 397 assert(ConvertedToTy && "Must have set the convertedtoty by this point!"); 398 399 // The number of bits required is the maximum of the upper and 400 // lower limits, plus one so it can be signed. 401 unsigned MinBW = std::max(R.getLower().getMinSignedBits(), 402 R.getUpper().getMinSignedBits()) + 1; 403 DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n"); 404 405 // If we've run off the realms of the exactly representable integers, 406 // the floating point result will differ from an integer approximation. 407 408 // Do we need more bits than are in the mantissa of the type we converted 409 // to? semanticsPrecision returns the number of mantissa bits plus one 410 // for the sign bit. 411 unsigned MaxRepresentableBits 412 = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1; 413 if (MinBW > MaxRepresentableBits) { 414 DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n"); 415 continue; 416 } 417 if (MinBW > 64) { 418 DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n"); 419 continue; 420 } 421 422 // OK, R is known to be representable. Now pick a type for it. 423 // FIXME: Pick the smallest legal type that will fit. 424 Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx); 425 426 for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); 427 MI != ME; ++MI) 428 convert(*MI, Ty); 429 MadeChange = true; 430 } 431 432 return MadeChange; 433 } 434 435 Value *Float2Int::convert(Instruction *I, Type *ToTy) { 436 if (ConvertedInsts.find(I) != ConvertedInsts.end()) 437 // Already converted this instruction. 438 return ConvertedInsts[I]; 439 440 SmallVector<Value*,4> NewOperands; 441 for (Value *V : I->operands()) { 442 // Don't recurse if we're an instruction that terminates the path. 443 if (I->getOpcode() == Instruction::UIToFP || 444 I->getOpcode() == Instruction::SIToFP) { 445 NewOperands.push_back(V); 446 } else if (Instruction *VI = dyn_cast<Instruction>(V)) { 447 NewOperands.push_back(convert(VI, ToTy)); 448 } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) { 449 APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false); 450 bool Exact; 451 CF->getValueAPF().convertToInteger(Val, 452 APFloat::rmNearestTiesToEven, 453 &Exact); 454 NewOperands.push_back(ConstantInt::get(ToTy, Val)); 455 } else { 456 llvm_unreachable("Unhandled operand type?"); 457 } 458 } 459 460 // Now create a new instruction. 461 IRBuilder<> IRB(I); 462 Value *NewV = nullptr; 463 switch (I->getOpcode()) { 464 default: llvm_unreachable("Unhandled instruction!"); 465 466 case Instruction::FPToUI: 467 NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType()); 468 break; 469 470 case Instruction::FPToSI: 471 NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType()); 472 break; 473 474 case Instruction::FCmp: { 475 CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate()); 476 assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!"); 477 NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName()); 478 break; 479 } 480 481 case Instruction::UIToFP: 482 NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy); 483 break; 484 485 case Instruction::SIToFP: 486 NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy); 487 break; 488 489 case Instruction::FAdd: 490 case Instruction::FSub: 491 case Instruction::FMul: 492 NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()), 493 NewOperands[0], NewOperands[1], 494 I->getName()); 495 break; 496 } 497 498 // If we're a root instruction, RAUW. 499 if (Roots.count(I)) 500 I->replaceAllUsesWith(NewV); 501 502 ConvertedInsts[I] = NewV; 503 return NewV; 504 } 505 506 // Perform dead code elimination on the instructions we just modified. 507 void Float2Int::cleanup() { 508 for (auto I = ConvertedInsts.rbegin(), E = ConvertedInsts.rend(); 509 I != E; ++I) 510 I->first->eraseFromParent(); 511 } 512 513 bool Float2Int::runOnFunction(Function &F) { 514 if (skipOptnoneFunction(F)) 515 return false; 516 517 DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n"); 518 // Clear out all state. 519 ECs = EquivalenceClasses<Instruction*>(); 520 SeenInsts.clear(); 521 ConvertedInsts.clear(); 522 Roots.clear(); 523 524 Ctx = &F.getParent()->getContext(); 525 526 findRoots(F, Roots); 527 528 walkBackwards(Roots); 529 walkForwards(); 530 531 bool Modified = validateAndTransform(); 532 if (Modified) 533 cleanup(); 534 return Modified; 535 } 536 537 FunctionPass *llvm::createFloat2IntPass() { 538 return new Float2Int(); 539 } 540 541