1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a class for representing known zeros and ones used by 10 // computeKnownBits. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Support/KnownBits.h" 15 #include <cassert> 16 17 using namespace llvm; 18 19 static KnownBits computeForAddCarry( 20 const KnownBits &LHS, const KnownBits &RHS, 21 bool CarryZero, bool CarryOne) { 22 assert(!(CarryZero && CarryOne) && 23 "Carry can't be zero and one at the same time"); 24 25 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero; 26 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne; 27 28 // Compute known bits of the carry. 29 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); 30 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; 31 32 // Compute set of known bits (where all three relevant bits are known). 33 APInt LHSKnownUnion = LHS.Zero | LHS.One; 34 APInt RHSKnownUnion = RHS.Zero | RHS.One; 35 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne; 36 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion; 37 38 assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && 39 "known bits of sum differ"); 40 41 // Compute known bits of the result. 42 KnownBits KnownOut; 43 KnownOut.Zero = ~std::move(PossibleSumZero) & Known; 44 KnownOut.One = std::move(PossibleSumOne) & Known; 45 return KnownOut; 46 } 47 48 KnownBits KnownBits::computeForAddCarry( 49 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { 50 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit"); 51 return ::computeForAddCarry( 52 LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); 53 } 54 55 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, 56 const KnownBits &LHS, KnownBits RHS) { 57 KnownBits KnownOut; 58 if (Add) { 59 // Sum = LHS + RHS + 0 60 KnownOut = ::computeForAddCarry( 61 LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); 62 } else { 63 // Sum = LHS + ~RHS + 1 64 std::swap(RHS.Zero, RHS.One); 65 KnownOut = ::computeForAddCarry( 66 LHS, RHS, /*CarryZero*/false, /*CarryOne*/true); 67 } 68 69 // Are we still trying to solve for the sign bit? 70 if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) { 71 if (NSW) { 72 // Adding two non-negative numbers, or subtracting a negative number from 73 // a non-negative one, can't wrap into negative. 74 if (LHS.isNonNegative() && RHS.isNonNegative()) 75 KnownOut.makeNonNegative(); 76 // Adding two negative numbers, or subtracting a non-negative number from 77 // a negative one, can't wrap into non-negative. 78 else if (LHS.isNegative() && RHS.isNegative()) 79 KnownOut.makeNegative(); 80 } 81 } 82 83 return KnownOut; 84 } 85 86 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const { 87 unsigned BitWidth = getBitWidth(); 88 assert(BitWidth >= SrcBitWidth && "Illegal sext-in-register"); 89 90 // Sign extension. Compute the demanded bits in the result that are not 91 // present in the input. 92 APInt NewBits = APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); 93 94 // If the sign extended bits are demanded, we know that the sign 95 // bit is demanded. 96 APInt InSignMask = APInt::getSignMask(SrcBitWidth).zext(BitWidth); 97 APInt InDemandedBits = APInt::getLowBitsSet(BitWidth, SrcBitWidth); 98 if (NewBits.getBoolValue()) 99 InDemandedBits |= InSignMask; 100 101 KnownBits Result; 102 Result.One = One & InDemandedBits; 103 Result.Zero = Zero & InDemandedBits; 104 105 // If the sign bit of the input is known set or clear, then we know the 106 // top bits of the result. 107 if (Result.Zero.intersects(InSignMask)) { // Input sign bit known clear 108 Result.Zero |= NewBits; 109 Result.One &= ~NewBits; 110 } else if (Result.One.intersects(InSignMask)) { // Input sign bit known set 111 Result.One |= NewBits; 112 Result.Zero &= ~NewBits; 113 } else { // Input sign bit unknown 114 Result.Zero &= ~NewBits; 115 Result.One &= ~NewBits; 116 } 117 118 return Result; 119 } 120 121 KnownBits KnownBits::makeGE(const APInt &Val) const { 122 // Count the number of leading bit positions where our underlying value is 123 // known to be less than or equal to Val. 124 unsigned N = (Zero | Val).countLeadingOnes(); 125 126 // For each of those bit positions, if Val has a 1 in that bit then our 127 // underlying value must also have a 1. 128 APInt MaskedVal(Val); 129 MaskedVal.clearLowBits(getBitWidth() - N); 130 return KnownBits(Zero, One | MaskedVal); 131 } 132 133 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { 134 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for 135 // RHS. Ideally our caller would already have spotted these cases and 136 // optimized away the umax operation, but we handle them here for 137 // completeness. 138 if (LHS.getMinValue().uge(RHS.getMaxValue())) 139 return LHS; 140 if (RHS.getMinValue().uge(LHS.getMaxValue())) 141 return RHS; 142 143 // If the result of the umax is LHS then it must be greater than or equal to 144 // the minimum possible value of RHS. Likewise for RHS. Any known bits that 145 // are common to these two values are also known in the result. 146 KnownBits L = LHS.makeGE(RHS.getMinValue()); 147 KnownBits R = RHS.makeGE(LHS.getMinValue()); 148 return KnownBits::commonBits(L, R); 149 } 150 151 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { 152 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] 153 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); }; 154 return Flip(umax(Flip(LHS), Flip(RHS))); 155 } 156 157 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { 158 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] 159 auto Flip = [](const KnownBits &Val) { 160 unsigned SignBitPosition = Val.getBitWidth() - 1; 161 APInt Zero = Val.Zero; 162 APInt One = Val.One; 163 Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 164 One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 165 return KnownBits(Zero, One); 166 }; 167 return Flip(umax(Flip(LHS), Flip(RHS))); 168 } 169 170 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { 171 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] 172 auto Flip = [](const KnownBits &Val) { 173 unsigned SignBitPosition = Val.getBitWidth() - 1; 174 APInt Zero = Val.One; 175 APInt One = Val.Zero; 176 Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 177 One.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 178 return KnownBits(Zero, One); 179 }; 180 return Flip(umax(Flip(LHS), Flip(RHS))); 181 } 182 183 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) { 184 unsigned BitWidth = LHS.getBitWidth(); 185 KnownBits Known(BitWidth); 186 187 // If the shift amount is a valid constant then transform LHS directly. 188 if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { 189 unsigned Shift = RHS.getConstant().getZExtValue(); 190 Known = LHS; 191 Known.Zero <<= Shift; 192 Known.One <<= Shift; 193 // Low bits are known zero. 194 Known.Zero.setLowBits(Shift); 195 return Known; 196 } 197 198 // No matter the shift amount, the trailing zeros will stay zero. 199 unsigned MinTrailingZeros = LHS.countMinTrailingZeros(); 200 201 // Minimum shift amount low bits are known zero. 202 if (RHS.getMinValue().ult(BitWidth)) { 203 MinTrailingZeros += RHS.getMinValue().getZExtValue(); 204 MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); 205 } 206 207 Known.Zero.setLowBits(MinTrailingZeros); 208 return Known; 209 } 210 211 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) { 212 unsigned BitWidth = LHS.getBitWidth(); 213 KnownBits Known(BitWidth); 214 215 if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { 216 unsigned Shift = RHS.getConstant().getZExtValue(); 217 Known = LHS; 218 Known.Zero.lshrInPlace(Shift); 219 Known.One.lshrInPlace(Shift); 220 // High bits are known zero. 221 Known.Zero.setHighBits(Shift); 222 return Known; 223 } 224 225 // No matter the shift amount, the leading zeros will stay zero. 226 unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); 227 228 // Minimum shift amount high bits are known zero. 229 if (RHS.getMinValue().ult(BitWidth)) { 230 MinLeadingZeros += RHS.getMinValue().getZExtValue(); 231 MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); 232 } 233 234 Known.Zero.setHighBits(MinLeadingZeros); 235 return Known; 236 } 237 238 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) { 239 unsigned BitWidth = LHS.getBitWidth(); 240 KnownBits Known(BitWidth); 241 242 if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { 243 unsigned Shift = RHS.getConstant().getZExtValue(); 244 Known = LHS; 245 Known.Zero.ashrInPlace(Shift); 246 Known.One.ashrInPlace(Shift); 247 return Known; 248 } 249 250 // No matter the shift amount, the leading sign bits will stay. 251 unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); 252 unsigned MinLeadingOnes = LHS.countMinLeadingOnes(); 253 254 // Minimum shift amount high bits are known sign bits. 255 if (RHS.getMinValue().ult(BitWidth)) { 256 if (MinLeadingZeros) { 257 MinLeadingZeros += RHS.getMinValue().getZExtValue(); 258 MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); 259 } 260 if (MinLeadingOnes) { 261 MinLeadingOnes += RHS.getMinValue().getZExtValue(); 262 MinLeadingOnes = std::min(MinLeadingOnes, BitWidth); 263 } 264 } 265 266 Known.Zero.setHighBits(MinLeadingZeros); 267 Known.One.setHighBits(MinLeadingOnes); 268 return Known; 269 } 270 271 KnownBits KnownBits::abs(bool IntMinIsPoison) const { 272 // If the source's MSB is zero then we know the rest of the bits already. 273 if (isNonNegative()) 274 return *this; 275 276 // Absolute value preserves trailing zero count. 277 KnownBits KnownAbs(getBitWidth()); 278 KnownAbs.Zero.setLowBits(countMinTrailingZeros()); 279 280 // We only know that the absolute values's MSB will be zero if INT_MIN is 281 // poison, or there is a set bit that isn't the sign bit (otherwise it could 282 // be INT_MIN). 283 if (IntMinIsPoison || (!One.isNullValue() && !One.isMinSignedValue())) 284 KnownAbs.Zero.setSignBit(); 285 286 // FIXME: Handle known negative input? 287 // FIXME: Calculate the negated Known bits and combine them? 288 return KnownAbs; 289 } 290 291 KnownBits KnownBits::computeForMul(const KnownBits &LHS, const KnownBits &RHS) { 292 unsigned BitWidth = LHS.getBitWidth(); 293 294 assert(!LHS.hasConflict() && !RHS.hasConflict()); 295 // Compute a conservative estimate for high known-0 bits. 296 unsigned LeadZ = 297 std::max(LHS.countMinLeadingZeros() + RHS.countMinLeadingZeros(), 298 BitWidth) - 299 BitWidth; 300 LeadZ = std::min(LeadZ, BitWidth); 301 302 // The result of the bottom bits of an integer multiply can be 303 // inferred by looking at the bottom bits of both operands and 304 // multiplying them together. 305 // We can infer at least the minimum number of known trailing bits 306 // of both operands. Depending on number of trailing zeros, we can 307 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming 308 // a and b are divisible by m and n respectively. 309 // We then calculate how many of those bits are inferrable and set 310 // the output. For example, the i8 mul: 311 // a = XXXX1100 (12) 312 // b = XXXX1110 (14) 313 // We know the bottom 3 bits are zero since the first can be divided by 314 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). 315 // Applying the multiplication to the trimmed arguments gets: 316 // XX11 (3) 317 // X111 (7) 318 // ------- 319 // XX11 320 // XX11 321 // XX11 322 // XX11 323 // ------- 324 // XXXXX01 325 // Which allows us to infer the 2 LSBs. Since we're multiplying the result 326 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. 327 // The proof for this can be described as: 328 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && 329 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + 330 // umin(countTrailingZeros(C2), C6) + 331 // umin(C5 - umin(countTrailingZeros(C1), C5), 332 // C6 - umin(countTrailingZeros(C2), C6)))) - 1) 333 // %aa = shl i8 %a, C5 334 // %bb = shl i8 %b, C6 335 // %aaa = or i8 %aa, C1 336 // %bbb = or i8 %bb, C2 337 // %mul = mul i8 %aaa, %bbb 338 // %mask = and i8 %mul, C7 339 // => 340 // %mask = i8 ((C1*C2)&C7) 341 // Where C5, C6 describe the known bits of %a, %b 342 // C1, C2 describe the known bottom bits of %a, %b. 343 // C7 describes the mask of the known bits of the result. 344 const APInt &Bottom0 = LHS.One; 345 const APInt &Bottom1 = RHS.One; 346 347 // How many times we'd be able to divide each argument by 2 (shr by 1). 348 // This gives us the number of trailing zeros on the multiplication result. 349 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes(); 350 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes(); 351 unsigned TrailZero0 = LHS.countMinTrailingZeros(); 352 unsigned TrailZero1 = RHS.countMinTrailingZeros(); 353 unsigned TrailZ = TrailZero0 + TrailZero1; 354 355 // Figure out the fewest known-bits operand. 356 unsigned SmallestOperand = 357 std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1); 358 unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); 359 360 APInt BottomKnown = 361 Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1); 362 363 KnownBits Res(BitWidth); 364 Res.Zero.setHighBits(LeadZ); 365 Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); 366 Res.One = BottomKnown.getLoBits(ResultBitsKnown); 367 return Res; 368 } 369 370 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) { 371 unsigned BitWidth = LHS.getBitWidth(); 372 assert(!LHS.hasConflict() && !RHS.hasConflict()); 373 KnownBits Known(BitWidth); 374 375 // For the purposes of computing leading zeros we can conservatively 376 // treat a udiv as a logical right shift by the power of 2 known to 377 // be less than the denominator. 378 unsigned LeadZ = LHS.countMinLeadingZeros(); 379 unsigned RHSMaxLeadingZeros = RHS.countMaxLeadingZeros(); 380 381 if (RHSMaxLeadingZeros != BitWidth) 382 LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1); 383 384 Known.Zero.setHighBits(LeadZ); 385 return Known; 386 } 387 388 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) { 389 unsigned BitWidth = LHS.getBitWidth(); 390 assert(!LHS.hasConflict() && !RHS.hasConflict()); 391 KnownBits Known(BitWidth); 392 393 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 394 // The upper bits are all zero, the lower ones are unchanged. 395 APInt LowBits = RHS.getConstant() - 1; 396 Known.Zero = LHS.Zero | ~LowBits; 397 Known.One = LHS.One & LowBits; 398 return Known; 399 } 400 401 // Since the result is less than or equal to either operand, any leading 402 // zero bits in either operand must also exist in the result. 403 uint32_t Leaders = 404 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros()); 405 Known.Zero.setHighBits(Leaders); 406 return Known; 407 } 408 409 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) { 410 unsigned BitWidth = LHS.getBitWidth(); 411 assert(!LHS.hasConflict() && !RHS.hasConflict()); 412 KnownBits Known(BitWidth); 413 414 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 415 // The low bits of the first operand are unchanged by the srem. 416 APInt LowBits = RHS.getConstant() - 1; 417 Known.Zero = LHS.Zero & LowBits; 418 Known.One = LHS.One & LowBits; 419 420 // If the first operand is non-negative or has all low bits zero, then 421 // the upper bits are all zero. 422 if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero)) 423 Known.Zero |= ~LowBits; 424 425 // If the first operand is negative and not all low bits are zero, then 426 // the upper bits are all one. 427 if (LHS.isNegative() && LowBits.intersects(LHS.One)) 428 Known.One |= ~LowBits; 429 return Known; 430 } 431 432 // The sign bit is the LHS's sign bit, except when the result of the 433 // remainder is zero. If it's known zero, our sign bit is also zero. 434 if (LHS.isNonNegative()) 435 Known.makeNonNegative(); 436 return Known; 437 } 438 439 KnownBits &KnownBits::operator&=(const KnownBits &RHS) { 440 // Result bit is 0 if either operand bit is 0. 441 Zero |= RHS.Zero; 442 // Result bit is 1 if both operand bits are 1. 443 One &= RHS.One; 444 return *this; 445 } 446 447 KnownBits &KnownBits::operator|=(const KnownBits &RHS) { 448 // Result bit is 0 if both operand bits are 0. 449 Zero &= RHS.Zero; 450 // Result bit is 1 if either operand bit is 1. 451 One |= RHS.One; 452 return *this; 453 } 454 455 KnownBits &KnownBits::operator^=(const KnownBits &RHS) { 456 // Result bit is 0 if both operand bits are 0 or both are 1. 457 APInt Z = (Zero & RHS.Zero) | (One & RHS.One); 458 // Result bit is 1 if one operand bit is 0 and the other is 1. 459 One = (Zero & RHS.One) | (One & RHS.Zero); 460 Zero = std::move(Z); 461 return *this; 462 } 463