1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass implements a demanded bits analysis. A demanded bit is one that
11 // contributes to a result; bits that are not demanded can be either zero or
12 // one without affecting control or data flow. For example in this sequence:
13 //
14 //   %1 = add i32 %x, %y
15 //   %2 = trunc i32 %1 to i16
16 //
17 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
18 // trunc.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "llvm/Analysis/DemandedBits.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/Analysis/AssumptionCache.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/DataLayout.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/Dominators.h"
34 #include "llvm/IR/InstIterator.h"
35 #include "llvm/IR/InstrTypes.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/IntrinsicInst.h"
38 #include "llvm/IR/Intrinsics.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Operator.h"
41 #include "llvm/IR/PassManager.h"
42 #include "llvm/IR/PatternMatch.h"
43 #include "llvm/IR/Type.h"
44 #include "llvm/IR/Use.h"
45 #include "llvm/Pass.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/KnownBits.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include <algorithm>
51 #include <cstdint>
52 
53 using namespace llvm;
54 using namespace llvm::PatternMatch;
55 
56 #define DEBUG_TYPE "demanded-bits"
57 
58 char DemandedBitsWrapperPass::ID = 0;
59 
60 INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits",
61                       "Demanded bits analysis", false, false)
62 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
63 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
64 INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits",
65                     "Demanded bits analysis", false, false)
66 
67 DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) {
68   initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry());
69 }
70 
71 void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
72   AU.setPreservesCFG();
73   AU.addRequired<AssumptionCacheTracker>();
74   AU.addRequired<DominatorTreeWrapperPass>();
75   AU.setPreservesAll();
76 }
77 
78 void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const {
79   DB->print(OS);
80 }
81 
82 static bool isAlwaysLive(Instruction *I) {
83   return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
84          I->mayHaveSideEffects();
85 }
86 
87 void DemandedBits::determineLiveOperandBits(
88     const Instruction *UserI, const Instruction *I, unsigned OperandNo,
89     const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2) {
90   unsigned BitWidth = AB.getBitWidth();
91 
92   // We're called once per operand, but for some instructions, we need to
93   // compute known bits of both operands in order to determine the live bits of
94   // either (when both operands are instructions themselves). We don't,
95   // however, want to do this twice, so we cache the result in APInts that live
96   // in the caller. For the two-relevant-operands case, both operand values are
97   // provided here.
98   auto ComputeKnownBits =
99       [&](unsigned BitWidth, const Value *V1, const Value *V2) {
100         const DataLayout &DL = I->getModule()->getDataLayout();
101         Known = KnownBits(BitWidth);
102         computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
103 
104         if (V2) {
105           Known2 = KnownBits(BitWidth);
106           computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
107         }
108       };
109 
110   switch (UserI->getOpcode()) {
111   default: break;
112   case Instruction::Call:
113   case Instruction::Invoke:
114     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI))
115       switch (II->getIntrinsicID()) {
116       default: break;
117       case Intrinsic::bswap:
118         // The alive bits of the input are the swapped alive bits of
119         // the output.
120         AB = AOut.byteSwap();
121         break;
122       case Intrinsic::bitreverse:
123         // The alive bits of the input are the reversed alive bits of
124         // the output.
125         AB = AOut.reverseBits();
126         break;
127       case Intrinsic::ctlz:
128         if (OperandNo == 0) {
129           // We need some output bits, so we need all bits of the
130           // input to the left of, and including, the leftmost bit
131           // known to be one.
132           ComputeKnownBits(BitWidth, I, nullptr);
133           AB = APInt::getHighBitsSet(BitWidth,
134                  std::min(BitWidth, Known.countMaxLeadingZeros()+1));
135         }
136         break;
137       case Intrinsic::cttz:
138         if (OperandNo == 0) {
139           // We need some output bits, so we need all bits of the
140           // input to the right of, and including, the rightmost bit
141           // known to be one.
142           ComputeKnownBits(BitWidth, I, nullptr);
143           AB = APInt::getLowBitsSet(BitWidth,
144                  std::min(BitWidth, Known.countMaxTrailingZeros()+1));
145         }
146         break;
147       case Intrinsic::fshl:
148       case Intrinsic::fshr: {
149         const APInt *SA;
150         if (OperandNo == 2) {
151           // Shift amount is modulo the bitwidth. For powers of two we have
152           // SA % BW == SA & (BW - 1).
153           if (isPowerOf2_32(BitWidth))
154             AB = BitWidth - 1;
155         } else if (match(II->getOperand(2), m_APInt(SA))) {
156           // Normalize to funnel shift left. APInt shifts of BitWidth are well-
157           // defined, so no need to special-case zero shifts here.
158           uint64_t ShiftAmt = SA->urem(BitWidth);
159           if (II->getIntrinsicID() == Intrinsic::fshr)
160             ShiftAmt = BitWidth - ShiftAmt;
161 
162           if (OperandNo == 0)
163             AB = AOut.lshr(ShiftAmt);
164           else if (OperandNo == 1)
165             AB = AOut.shl(BitWidth - ShiftAmt);
166         }
167         break;
168       }
169       }
170     break;
171   case Instruction::Add:
172   case Instruction::Sub:
173   case Instruction::Mul:
174     // Find the highest live output bit. We don't need any more input
175     // bits than that (adds, and thus subtracts, ripple only to the
176     // left).
177     AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
178     break;
179   case Instruction::Shl:
180     if (OperandNo == 0) {
181       const APInt *ShiftAmtC;
182       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
183         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
184         AB = AOut.lshr(ShiftAmt);
185 
186         // If the shift is nuw/nsw, then the high bits are not dead
187         // (because we've promised that they *must* be zero).
188         const ShlOperator *S = cast<ShlOperator>(UserI);
189         if (S->hasNoSignedWrap())
190           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
191         else if (S->hasNoUnsignedWrap())
192           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
193       }
194     }
195     break;
196   case Instruction::LShr:
197     if (OperandNo == 0) {
198       const APInt *ShiftAmtC;
199       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
200         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
201         AB = AOut.shl(ShiftAmt);
202 
203         // If the shift is exact, then the low bits are not dead
204         // (they must be zero).
205         if (cast<LShrOperator>(UserI)->isExact())
206           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
207       }
208     }
209     break;
210   case Instruction::AShr:
211     if (OperandNo == 0) {
212       const APInt *ShiftAmtC;
213       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
214         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
215         AB = AOut.shl(ShiftAmt);
216         // Because the high input bit is replicated into the
217         // high-order bits of the result, if we need any of those
218         // bits, then we must keep the highest input bit.
219         if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
220             .getBoolValue())
221           AB.setSignBit();
222 
223         // If the shift is exact, then the low bits are not dead
224         // (they must be zero).
225         if (cast<AShrOperator>(UserI)->isExact())
226           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
227       }
228     }
229     break;
230   case Instruction::And:
231     AB = AOut;
232 
233     // For bits that are known zero, the corresponding bits in the
234     // other operand are dead (unless they're both zero, in which
235     // case they can't both be dead, so just mark the LHS bits as
236     // dead).
237     if (OperandNo == 0) {
238       ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
239       AB &= ~Known2.Zero;
240     } else {
241       if (!isa<Instruction>(UserI->getOperand(0)))
242         ComputeKnownBits(BitWidth, UserI->getOperand(0), I);
243       AB &= ~(Known.Zero & ~Known2.Zero);
244     }
245     break;
246   case Instruction::Or:
247     AB = AOut;
248 
249     // For bits that are known one, the corresponding bits in the
250     // other operand are dead (unless they're both one, in which
251     // case they can't both be dead, so just mark the LHS bits as
252     // dead).
253     if (OperandNo == 0) {
254       ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
255       AB &= ~Known2.One;
256     } else {
257       if (!isa<Instruction>(UserI->getOperand(0)))
258         ComputeKnownBits(BitWidth, UserI->getOperand(0), I);
259       AB &= ~(Known.One & ~Known2.One);
260     }
261     break;
262   case Instruction::Xor:
263   case Instruction::PHI:
264     AB = AOut;
265     break;
266   case Instruction::Trunc:
267     AB = AOut.zext(BitWidth);
268     break;
269   case Instruction::ZExt:
270     AB = AOut.trunc(BitWidth);
271     break;
272   case Instruction::SExt:
273     AB = AOut.trunc(BitWidth);
274     // Because the high input bit is replicated into the
275     // high-order bits of the result, if we need any of those
276     // bits, then we must keep the highest input bit.
277     if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
278                                       AOut.getBitWidth() - BitWidth))
279         .getBoolValue())
280       AB.setSignBit();
281     break;
282   case Instruction::Select:
283     if (OperandNo != 0)
284       AB = AOut;
285     break;
286   case Instruction::ExtractElement:
287     if (OperandNo == 0)
288       AB = AOut;
289     break;
290   case Instruction::InsertElement:
291   case Instruction::ShuffleVector:
292     if (OperandNo == 0 || OperandNo == 1)
293       AB = AOut;
294     break;
295   }
296 }
297 
298 bool DemandedBitsWrapperPass::runOnFunction(Function &F) {
299   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
300   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
301   DB.emplace(F, AC, DT);
302   return false;
303 }
304 
305 void DemandedBitsWrapperPass::releaseMemory() {
306   DB.reset();
307 }
308 
309 void DemandedBits::performAnalysis() {
310   if (Analyzed)
311     // Analysis already completed for this function.
312     return;
313   Analyzed = true;
314 
315   Visited.clear();
316   AliveBits.clear();
317   DeadUses.clear();
318 
319   SmallVector<Instruction*, 128> Worklist;
320 
321   // Collect the set of "root" instructions that are known live.
322   for (Instruction &I : instructions(F)) {
323     if (!isAlwaysLive(&I))
324       continue;
325 
326     LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
327     // For integer-valued instructions, set up an initial empty set of alive
328     // bits and add the instruction to the work list. For other instructions
329     // add their operands to the work list (for integer values operands, mark
330     // all bits as live).
331     Type *T = I.getType();
332     if (T->isIntOrIntVectorTy()) {
333       if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
334         Worklist.push_back(&I);
335 
336       continue;
337     }
338 
339     // Non-integer-typed instructions...
340     for (Use &OI : I.operands()) {
341       if (Instruction *J = dyn_cast<Instruction>(OI)) {
342         Type *T = J->getType();
343         if (T->isIntOrIntVectorTy())
344           AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits());
345         Worklist.push_back(J);
346       }
347     }
348     // To save memory, we don't add I to the Visited set here. Instead, we
349     // check isAlwaysLive on every instruction when searching for dead
350     // instructions later (we need to check isAlwaysLive for the
351     // integer-typed instructions anyway).
352   }
353 
354   // Propagate liveness backwards to operands.
355   while (!Worklist.empty()) {
356     Instruction *UserI = Worklist.pop_back_val();
357 
358     LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
359     APInt AOut;
360     if (UserI->getType()->isIntOrIntVectorTy()) {
361       AOut = AliveBits[UserI];
362       LLVM_DEBUG(dbgs() << " Alive Out: 0x"
363                         << Twine::utohexstr(AOut.getLimitedValue()));
364     }
365     LLVM_DEBUG(dbgs() << "\n");
366 
367     if (!UserI->getType()->isIntOrIntVectorTy())
368       Visited.insert(UserI);
369 
370     KnownBits Known, Known2;
371     // Compute the set of alive bits for each operand. These are anded into the
372     // existing set, if any, and if that changes the set of alive bits, the
373     // operand is added to the work-list.
374     for (Use &OI : UserI->operands()) {
375       if (Instruction *I = dyn_cast<Instruction>(OI)) {
376         Type *T = I->getType();
377         if (T->isIntOrIntVectorTy()) {
378           unsigned BitWidth = T->getScalarSizeInBits();
379           APInt AB = APInt::getAllOnesValue(BitWidth);
380           if (UserI->getType()->isIntOrIntVectorTy() && !AOut &&
381               !isAlwaysLive(UserI)) {
382             // If all bits of the output are dead, then all bits of the input
383             // are also dead.
384             AB = APInt(BitWidth, 0);
385           } else {
386             // Bits of each operand that are used to compute alive bits of the
387             // output are alive, all others are dead.
388             determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB,
389                                      Known, Known2);
390 
391             // Keep track of uses which have no demanded bits.
392             if (AB.isNullValue())
393               DeadUses.insert(&OI);
394             else
395               DeadUses.erase(&OI);
396           }
397 
398           // If we've added to the set of alive bits (or the operand has not
399           // been previously visited), then re-queue the operand to be visited
400           // again.
401           APInt ABPrev(BitWidth, 0);
402           auto ABI = AliveBits.find(I);
403           if (ABI != AliveBits.end())
404             ABPrev = ABI->second;
405 
406           APInt ABNew = AB | ABPrev;
407           if (ABNew != ABPrev || ABI == AliveBits.end()) {
408             AliveBits[I] = std::move(ABNew);
409             Worklist.push_back(I);
410           }
411         } else if (!Visited.count(I)) {
412           Worklist.push_back(I);
413         }
414       }
415     }
416   }
417 }
418 
419 APInt DemandedBits::getDemandedBits(Instruction *I) {
420   performAnalysis();
421 
422   auto Found = AliveBits.find(I);
423   if (Found != AliveBits.end())
424     return Found->second;
425 
426   const DataLayout &DL = I->getModule()->getDataLayout();
427   return APInt::getAllOnesValue(
428       DL.getTypeSizeInBits(I->getType()->getScalarType()));
429 }
430 
431 bool DemandedBits::isInstructionDead(Instruction *I) {
432   performAnalysis();
433 
434   return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
435     !isAlwaysLive(I);
436 }
437 
438 bool DemandedBits::isUseDead(Use *U) {
439   // We only track integer uses, everything else is assumed live.
440   if (!(*U)->getType()->isIntOrIntVectorTy())
441     return false;
442 
443   // Uses by always-live instructions are never dead.
444   Instruction *UserI = cast<Instruction>(U->getUser());
445   if (isAlwaysLive(UserI))
446     return false;
447 
448   performAnalysis();
449   if (DeadUses.count(U))
450     return true;
451 
452   // If no output bits are demanded, no input bits are demanded and the use
453   // is dead. These uses might not be explicitly present in the DeadUses map.
454   if (UserI->getType()->isIntOrIntVectorTy()) {
455     auto Found = AliveBits.find(UserI);
456     if (Found != AliveBits.end() && Found->second.isNullValue())
457       return true;
458   }
459 
460   return false;
461 }
462 
463 void DemandedBits::print(raw_ostream &OS) {
464   performAnalysis();
465   for (auto &KV : AliveBits) {
466     OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue())
467        << " for " << *KV.first << '\n';
468   }
469 }
470 
471 FunctionPass *llvm::createDemandedBitsWrapperPass() {
472   return new DemandedBitsWrapperPass();
473 }
474 
475 AnalysisKey DemandedBitsAnalysis::Key;
476 
477 DemandedBits DemandedBitsAnalysis::run(Function &F,
478                                              FunctionAnalysisManager &AM) {
479   auto &AC = AM.getResult<AssumptionAnalysis>(F);
480   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
481   return DemandedBits(F, AC, DT);
482 }
483 
484 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
485                                                FunctionAnalysisManager &AM) {
486   AM.getResult<DemandedBitsAnalysis>(F).print(OS);
487   return PreservedAnalyses::all();
488 }
489