1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a class for representing known zeros and ones used by 10 // computeKnownBits. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Support/KnownBits.h" 15 #include <cassert> 16 17 using namespace llvm; 18 19 static KnownBits computeForAddCarry( 20 const KnownBits &LHS, const KnownBits &RHS, 21 bool CarryZero, bool CarryOne) { 22 assert(!(CarryZero && CarryOne) && 23 "Carry can't be zero and one at the same time"); 24 25 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero; 26 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne; 27 28 // Compute known bits of the carry. 29 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); 30 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; 31 32 // Compute set of known bits (where all three relevant bits are known). 33 APInt LHSKnownUnion = LHS.Zero | LHS.One; 34 APInt RHSKnownUnion = RHS.Zero | RHS.One; 35 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne; 36 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion; 37 38 assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && 39 "known bits of sum differ"); 40 41 // Compute known bits of the result. 42 KnownBits KnownOut; 43 KnownOut.Zero = ~std::move(PossibleSumZero) & Known; 44 KnownOut.One = std::move(PossibleSumOne) & Known; 45 return KnownOut; 46 } 47 48 KnownBits KnownBits::computeForAddCarry( 49 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { 50 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit"); 51 return ::computeForAddCarry( 52 LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); 53 } 54 55 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, 56 const KnownBits &LHS, KnownBits RHS) { 57 KnownBits KnownOut; 58 if (Add) { 59 // Sum = LHS + RHS + 0 60 KnownOut = ::computeForAddCarry( 61 LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); 62 } else { 63 // Sum = LHS + ~RHS + 1 64 std::swap(RHS.Zero, RHS.One); 65 KnownOut = ::computeForAddCarry( 66 LHS, RHS, /*CarryZero*/false, /*CarryOne*/true); 67 } 68 69 // Are we still trying to solve for the sign bit? 70 if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) { 71 if (NSW) { 72 // Adding two non-negative numbers, or subtracting a negative number from 73 // a non-negative one, can't wrap into negative. 74 if (LHS.isNonNegative() && RHS.isNonNegative()) 75 KnownOut.makeNonNegative(); 76 // Adding two negative numbers, or subtracting a non-negative number from 77 // a negative one, can't wrap into non-negative. 78 else if (LHS.isNegative() && RHS.isNegative()) 79 KnownOut.makeNegative(); 80 } 81 } 82 83 return KnownOut; 84 } 85 86 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const { 87 unsigned BitWidth = getBitWidth(); 88 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth && 89 "Illegal sext-in-register"); 90 91 if (SrcBitWidth == BitWidth) 92 return *this; 93 94 unsigned ExtBits = BitWidth - SrcBitWidth; 95 KnownBits Result; 96 Result.One = One << ExtBits; 97 Result.Zero = Zero << ExtBits; 98 Result.One.ashrInPlace(ExtBits); 99 Result.Zero.ashrInPlace(ExtBits); 100 return Result; 101 } 102 103 KnownBits KnownBits::makeGE(const APInt &Val) const { 104 // Count the number of leading bit positions where our underlying value is 105 // known to be less than or equal to Val. 106 unsigned N = (Zero | Val).countLeadingOnes(); 107 108 // For each of those bit positions, if Val has a 1 in that bit then our 109 // underlying value must also have a 1. 110 APInt MaskedVal(Val); 111 MaskedVal.clearLowBits(getBitWidth() - N); 112 return KnownBits(Zero, One | MaskedVal); 113 } 114 115 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { 116 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for 117 // RHS. Ideally our caller would already have spotted these cases and 118 // optimized away the umax operation, but we handle them here for 119 // completeness. 120 if (LHS.getMinValue().uge(RHS.getMaxValue())) 121 return LHS; 122 if (RHS.getMinValue().uge(LHS.getMaxValue())) 123 return RHS; 124 125 // If the result of the umax is LHS then it must be greater than or equal to 126 // the minimum possible value of RHS. Likewise for RHS. Any known bits that 127 // are common to these two values are also known in the result. 128 KnownBits L = LHS.makeGE(RHS.getMinValue()); 129 KnownBits R = RHS.makeGE(LHS.getMinValue()); 130 return KnownBits::commonBits(L, R); 131 } 132 133 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { 134 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] 135 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); }; 136 return Flip(umax(Flip(LHS), Flip(RHS))); 137 } 138 139 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { 140 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] 141 auto Flip = [](const KnownBits &Val) { 142 unsigned SignBitPosition = Val.getBitWidth() - 1; 143 APInt Zero = Val.Zero; 144 APInt One = Val.One; 145 Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 146 One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 147 return KnownBits(Zero, One); 148 }; 149 return Flip(umax(Flip(LHS), Flip(RHS))); 150 } 151 152 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { 153 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] 154 auto Flip = [](const KnownBits &Val) { 155 unsigned SignBitPosition = Val.getBitWidth() - 1; 156 APInt Zero = Val.One; 157 APInt One = Val.Zero; 158 Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 159 One.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 160 return KnownBits(Zero, One); 161 }; 162 return Flip(umax(Flip(LHS), Flip(RHS))); 163 } 164 165 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) { 166 unsigned BitWidth = LHS.getBitWidth(); 167 KnownBits Known(BitWidth); 168 169 // If the shift amount is a valid constant then transform LHS directly. 170 if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { 171 unsigned Shift = RHS.getConstant().getZExtValue(); 172 Known = LHS; 173 Known.Zero <<= Shift; 174 Known.One <<= Shift; 175 // Low bits are known zero. 176 Known.Zero.setLowBits(Shift); 177 return Known; 178 } 179 180 // No matter the shift amount, the trailing zeros will stay zero. 181 unsigned MinTrailingZeros = LHS.countMinTrailingZeros(); 182 183 // Minimum shift amount low bits are known zero. 184 APInt MinShiftAmount = RHS.getMinValue(); 185 if (MinShiftAmount.ult(BitWidth)) { 186 MinTrailingZeros += MinShiftAmount.getZExtValue(); 187 MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); 188 } 189 190 Known.Zero.setLowBits(MinTrailingZeros); 191 return Known; 192 } 193 194 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) { 195 unsigned BitWidth = LHS.getBitWidth(); 196 KnownBits Known(BitWidth); 197 198 if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { 199 unsigned Shift = RHS.getConstant().getZExtValue(); 200 Known = LHS; 201 Known.Zero.lshrInPlace(Shift); 202 Known.One.lshrInPlace(Shift); 203 // High bits are known zero. 204 Known.Zero.setHighBits(Shift); 205 return Known; 206 } 207 208 // No matter the shift amount, the leading zeros will stay zero. 209 unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); 210 211 // Minimum shift amount high bits are known zero. 212 APInt MinShiftAmount = RHS.getMinValue(); 213 if (MinShiftAmount.ult(BitWidth)) { 214 MinLeadingZeros += MinShiftAmount.getZExtValue(); 215 MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); 216 } 217 218 Known.Zero.setHighBits(MinLeadingZeros); 219 return Known; 220 } 221 222 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) { 223 unsigned BitWidth = LHS.getBitWidth(); 224 KnownBits Known(BitWidth); 225 226 if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { 227 unsigned Shift = RHS.getConstant().getZExtValue(); 228 Known = LHS; 229 Known.Zero.ashrInPlace(Shift); 230 Known.One.ashrInPlace(Shift); 231 return Known; 232 } 233 234 // No matter the shift amount, the leading sign bits will stay. 235 unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); 236 unsigned MinLeadingOnes = LHS.countMinLeadingOnes(); 237 238 // Minimum shift amount high bits are known sign bits. 239 APInt MinShiftAmount = RHS.getMinValue(); 240 if (MinShiftAmount.ult(BitWidth)) { 241 if (MinLeadingZeros) { 242 MinLeadingZeros += MinShiftAmount.getZExtValue(); 243 MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); 244 } 245 if (MinLeadingOnes) { 246 MinLeadingOnes += MinShiftAmount.getZExtValue(); 247 MinLeadingOnes = std::min(MinLeadingOnes, BitWidth); 248 } 249 } 250 251 Known.Zero.setHighBits(MinLeadingZeros); 252 Known.One.setHighBits(MinLeadingOnes); 253 return Known; 254 } 255 256 Optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) { 257 if (LHS.isConstant() && RHS.isConstant()) 258 return Optional<bool>(LHS.getConstant() == RHS.getConstant()); 259 if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero)) 260 return Optional<bool>(false); 261 return None; 262 } 263 264 Optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) { 265 if (Optional<bool> KnownEQ = eq(LHS, RHS)) 266 return Optional<bool>(!KnownEQ.getValue()); 267 return None; 268 } 269 270 Optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) { 271 // LHS >u RHS -> false if umax(LHS) <= umax(RHS) 272 if (LHS.getMaxValue().ule(RHS.getMinValue())) 273 return Optional<bool>(false); 274 // LHS >u RHS -> true if umin(LHS) > umax(RHS) 275 if (LHS.getMinValue().ugt(RHS.getMaxValue())) 276 return Optional<bool>(true); 277 return None; 278 } 279 280 Optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) { 281 if (Optional<bool> IsUGT = ugt(RHS, LHS)) 282 return Optional<bool>(!IsUGT.getValue()); 283 return None; 284 } 285 286 Optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) { 287 return ugt(RHS, LHS); 288 } 289 290 Optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) { 291 return uge(RHS, LHS); 292 } 293 294 Optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) { 295 // LHS >s RHS -> false if smax(LHS) <= smax(RHS) 296 if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue())) 297 return Optional<bool>(false); 298 // LHS >s RHS -> true if smin(LHS) > smax(RHS) 299 if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue())) 300 return Optional<bool>(true); 301 return None; 302 } 303 304 Optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) { 305 if (Optional<bool> KnownSGT = sgt(RHS, LHS)) 306 return Optional<bool>(!KnownSGT.getValue()); 307 return None; 308 } 309 310 Optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) { 311 return sgt(RHS, LHS); 312 } 313 314 Optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) { 315 return sge(RHS, LHS); 316 } 317 318 KnownBits KnownBits::abs(bool IntMinIsPoison) const { 319 // If the source's MSB is zero then we know the rest of the bits already. 320 if (isNonNegative()) 321 return *this; 322 323 // Absolute value preserves trailing zero count. 324 KnownBits KnownAbs(getBitWidth()); 325 KnownAbs.Zero.setLowBits(countMinTrailingZeros()); 326 327 // We only know that the absolute values's MSB will be zero if INT_MIN is 328 // poison, or there is a set bit that isn't the sign bit (otherwise it could 329 // be INT_MIN). 330 if (IntMinIsPoison || (!One.isNullValue() && !One.isMinSignedValue())) 331 KnownAbs.Zero.setSignBit(); 332 333 // FIXME: Handle known negative input? 334 // FIXME: Calculate the negated Known bits and combine them? 335 return KnownAbs; 336 } 337 338 KnownBits KnownBits::computeForMul(const KnownBits &LHS, const KnownBits &RHS) { 339 unsigned BitWidth = LHS.getBitWidth(); 340 341 assert(!LHS.hasConflict() && !RHS.hasConflict()); 342 // Compute a conservative estimate for high known-0 bits. 343 unsigned LeadZ = 344 std::max(LHS.countMinLeadingZeros() + RHS.countMinLeadingZeros(), 345 BitWidth) - 346 BitWidth; 347 LeadZ = std::min(LeadZ, BitWidth); 348 349 // The result of the bottom bits of an integer multiply can be 350 // inferred by looking at the bottom bits of both operands and 351 // multiplying them together. 352 // We can infer at least the minimum number of known trailing bits 353 // of both operands. Depending on number of trailing zeros, we can 354 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming 355 // a and b are divisible by m and n respectively. 356 // We then calculate how many of those bits are inferrable and set 357 // the output. For example, the i8 mul: 358 // a = XXXX1100 (12) 359 // b = XXXX1110 (14) 360 // We know the bottom 3 bits are zero since the first can be divided by 361 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). 362 // Applying the multiplication to the trimmed arguments gets: 363 // XX11 (3) 364 // X111 (7) 365 // ------- 366 // XX11 367 // XX11 368 // XX11 369 // XX11 370 // ------- 371 // XXXXX01 372 // Which allows us to infer the 2 LSBs. Since we're multiplying the result 373 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. 374 // The proof for this can be described as: 375 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && 376 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + 377 // umin(countTrailingZeros(C2), C6) + 378 // umin(C5 - umin(countTrailingZeros(C1), C5), 379 // C6 - umin(countTrailingZeros(C2), C6)))) - 1) 380 // %aa = shl i8 %a, C5 381 // %bb = shl i8 %b, C6 382 // %aaa = or i8 %aa, C1 383 // %bbb = or i8 %bb, C2 384 // %mul = mul i8 %aaa, %bbb 385 // %mask = and i8 %mul, C7 386 // => 387 // %mask = i8 ((C1*C2)&C7) 388 // Where C5, C6 describe the known bits of %a, %b 389 // C1, C2 describe the known bottom bits of %a, %b. 390 // C7 describes the mask of the known bits of the result. 391 const APInt &Bottom0 = LHS.One; 392 const APInt &Bottom1 = RHS.One; 393 394 // How many times we'd be able to divide each argument by 2 (shr by 1). 395 // This gives us the number of trailing zeros on the multiplication result. 396 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes(); 397 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes(); 398 unsigned TrailZero0 = LHS.countMinTrailingZeros(); 399 unsigned TrailZero1 = RHS.countMinTrailingZeros(); 400 unsigned TrailZ = TrailZero0 + TrailZero1; 401 402 // Figure out the fewest known-bits operand. 403 unsigned SmallestOperand = 404 std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1); 405 unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); 406 407 APInt BottomKnown = 408 Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1); 409 410 KnownBits Res(BitWidth); 411 Res.Zero.setHighBits(LeadZ); 412 Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); 413 Res.One = BottomKnown.getLoBits(ResultBitsKnown); 414 return Res; 415 } 416 417 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) { 418 unsigned BitWidth = LHS.getBitWidth(); 419 assert(!LHS.hasConflict() && !RHS.hasConflict()); 420 KnownBits Known(BitWidth); 421 422 // For the purposes of computing leading zeros we can conservatively 423 // treat a udiv as a logical right shift by the power of 2 known to 424 // be less than the denominator. 425 unsigned LeadZ = LHS.countMinLeadingZeros(); 426 unsigned RHSMaxLeadingZeros = RHS.countMaxLeadingZeros(); 427 428 if (RHSMaxLeadingZeros != BitWidth) 429 LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1); 430 431 Known.Zero.setHighBits(LeadZ); 432 return Known; 433 } 434 435 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) { 436 unsigned BitWidth = LHS.getBitWidth(); 437 assert(!LHS.hasConflict() && !RHS.hasConflict()); 438 KnownBits Known(BitWidth); 439 440 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 441 // The upper bits are all zero, the lower ones are unchanged. 442 APInt LowBits = RHS.getConstant() - 1; 443 Known.Zero = LHS.Zero | ~LowBits; 444 Known.One = LHS.One & LowBits; 445 return Known; 446 } 447 448 // Since the result is less than or equal to either operand, any leading 449 // zero bits in either operand must also exist in the result. 450 uint32_t Leaders = 451 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros()); 452 Known.Zero.setHighBits(Leaders); 453 return Known; 454 } 455 456 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) { 457 unsigned BitWidth = LHS.getBitWidth(); 458 assert(!LHS.hasConflict() && !RHS.hasConflict()); 459 KnownBits Known(BitWidth); 460 461 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 462 // The low bits of the first operand are unchanged by the srem. 463 APInt LowBits = RHS.getConstant() - 1; 464 Known.Zero = LHS.Zero & LowBits; 465 Known.One = LHS.One & LowBits; 466 467 // If the first operand is non-negative or has all low bits zero, then 468 // the upper bits are all zero. 469 if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero)) 470 Known.Zero |= ~LowBits; 471 472 // If the first operand is negative and not all low bits are zero, then 473 // the upper bits are all one. 474 if (LHS.isNegative() && LowBits.intersects(LHS.One)) 475 Known.One |= ~LowBits; 476 return Known; 477 } 478 479 // The sign bit is the LHS's sign bit, except when the result of the 480 // remainder is zero. The magnitude of the result should be less than or 481 // equal to the magnitude of the LHS. Therefore any leading zeros that exist 482 // in the left hand side must also exist in the result. 483 Known.Zero.setHighBits(LHS.countMinLeadingZeros()); 484 return Known; 485 } 486 487 KnownBits &KnownBits::operator&=(const KnownBits &RHS) { 488 // Result bit is 0 if either operand bit is 0. 489 Zero |= RHS.Zero; 490 // Result bit is 1 if both operand bits are 1. 491 One &= RHS.One; 492 return *this; 493 } 494 495 KnownBits &KnownBits::operator|=(const KnownBits &RHS) { 496 // Result bit is 0 if both operand bits are 0. 497 Zero &= RHS.Zero; 498 // Result bit is 1 if either operand bit is 1. 499 One |= RHS.One; 500 return *this; 501 } 502 503 KnownBits &KnownBits::operator^=(const KnownBits &RHS) { 504 // Result bit is 0 if both operand bits are 0 or both are 1. 505 APInt Z = (Zero & RHS.Zero) | (One & RHS.One); 506 // Result bit is 1 if one operand bit is 0 and the other is 1. 507 One = (Zero & RHS.One) | (One & RHS.Zero); 508 Zero = std::move(Z); 509 return *this; 510 } 511