1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass implements a demanded bits analysis. A demanded bit is one that
11 // contributes to a result; bits that are not demanded can be either zero or
12 // one without affecting control or data flow. For example in this sequence:
13 //
14 //   %1 = add i32 %x, %y
15 //   %2 = trunc i32 %1 to i16
16 //
17 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
18 // trunc.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "llvm/Analysis/DemandedBits.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/Analysis/AssumptionCache.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/DataLayout.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/Dominators.h"
34 #include "llvm/IR/InstIterator.h"
35 #include "llvm/IR/InstrTypes.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/IntrinsicInst.h"
38 #include "llvm/IR/Intrinsics.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Operator.h"
41 #include "llvm/IR/PassManager.h"
42 #include "llvm/IR/PatternMatch.h"
43 #include "llvm/IR/Type.h"
44 #include "llvm/IR/Use.h"
45 #include "llvm/Pass.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/KnownBits.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include <algorithm>
51 #include <cstdint>
52 
53 using namespace llvm;
54 using namespace llvm::PatternMatch;
55 
56 #define DEBUG_TYPE "demanded-bits"
57 
58 char DemandedBitsWrapperPass::ID = 0;
59 
60 INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits",
61                       "Demanded bits analysis", false, false)
62 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
63 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
64 INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits",
65                     "Demanded bits analysis", false, false)
66 
67 DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) {
68   initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry());
69 }
70 
71 void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
72   AU.setPreservesCFG();
73   AU.addRequired<AssumptionCacheTracker>();
74   AU.addRequired<DominatorTreeWrapperPass>();
75   AU.setPreservesAll();
76 }
77 
78 void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const {
79   DB->print(OS);
80 }
81 
82 static bool isAlwaysLive(Instruction *I) {
83   return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
84          I->mayHaveSideEffects();
85 }
86 
87 void DemandedBits::determineLiveOperandBits(
88     const Instruction *UserI, const Instruction *I, unsigned OperandNo,
89     const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2) {
90   unsigned BitWidth = AB.getBitWidth();
91 
92   // We're called once per operand, but for some instructions, we need to
93   // compute known bits of both operands in order to determine the live bits of
94   // either (when both operands are instructions themselves). We don't,
95   // however, want to do this twice, so we cache the result in APInts that live
96   // in the caller. For the two-relevant-operands case, both operand values are
97   // provided here.
98   auto ComputeKnownBits =
99       [&](unsigned BitWidth, const Value *V1, const Value *V2) {
100         const DataLayout &DL = I->getModule()->getDataLayout();
101         Known = KnownBits(BitWidth);
102         computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
103 
104         if (V2) {
105           Known2 = KnownBits(BitWidth);
106           computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
107         }
108       };
109 
110   switch (UserI->getOpcode()) {
111   default: break;
112   case Instruction::Call:
113   case Instruction::Invoke:
114     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI))
115       switch (II->getIntrinsicID()) {
116       default: break;
117       case Intrinsic::bswap:
118         // The alive bits of the input are the swapped alive bits of
119         // the output.
120         AB = AOut.byteSwap();
121         break;
122       case Intrinsic::bitreverse:
123         // The alive bits of the input are the reversed alive bits of
124         // the output.
125         AB = AOut.reverseBits();
126         break;
127       case Intrinsic::ctlz:
128         if (OperandNo == 0) {
129           // We need some output bits, so we need all bits of the
130           // input to the left of, and including, the leftmost bit
131           // known to be one.
132           ComputeKnownBits(BitWidth, I, nullptr);
133           AB = APInt::getHighBitsSet(BitWidth,
134                  std::min(BitWidth, Known.countMaxLeadingZeros()+1));
135         }
136         break;
137       case Intrinsic::cttz:
138         if (OperandNo == 0) {
139           // We need some output bits, so we need all bits of the
140           // input to the right of, and including, the rightmost bit
141           // known to be one.
142           ComputeKnownBits(BitWidth, I, nullptr);
143           AB = APInt::getLowBitsSet(BitWidth,
144                  std::min(BitWidth, Known.countMaxTrailingZeros()+1));
145         }
146         break;
147       case Intrinsic::fshl:
148       case Intrinsic::fshr: {
149         const APInt *SA;
150         if (OperandNo == 2) {
151           // Shift amount is modulo the bitwidth. For powers of two we have
152           // SA % BW == SA & (BW - 1).
153           if (isPowerOf2_32(BitWidth))
154             AB = BitWidth - 1;
155         } else if (match(II->getOperand(2), m_APInt(SA))) {
156           // Normalize to funnel shift left. APInt shifts of BitWidth are well-
157           // defined, so no need to special-case zero shifts here.
158           uint64_t ShiftAmt = SA->urem(BitWidth);
159           if (II->getIntrinsicID() == Intrinsic::fshr)
160             ShiftAmt = BitWidth - ShiftAmt;
161 
162           if (OperandNo == 0)
163             AB = AOut.lshr(ShiftAmt);
164           else if (OperandNo == 1)
165             AB = AOut.shl(BitWidth - ShiftAmt);
166         }
167         break;
168       }
169       }
170     break;
171   case Instruction::Add:
172   case Instruction::Sub:
173   case Instruction::Mul:
174     // Find the highest live output bit. We don't need any more input
175     // bits than that (adds, and thus subtracts, ripple only to the
176     // left).
177     AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
178     break;
179   case Instruction::Shl:
180     if (OperandNo == 0) {
181       const APInt *ShiftAmtC;
182       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
183         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
184         AB = AOut.lshr(ShiftAmt);
185 
186         // If the shift is nuw/nsw, then the high bits are not dead
187         // (because we've promised that they *must* be zero).
188         const ShlOperator *S = cast<ShlOperator>(UserI);
189         if (S->hasNoSignedWrap())
190           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
191         else if (S->hasNoUnsignedWrap())
192           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
193       }
194     }
195     break;
196   case Instruction::LShr:
197     if (OperandNo == 0) {
198       const APInt *ShiftAmtC;
199       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
200         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
201         AB = AOut.shl(ShiftAmt);
202 
203         // If the shift is exact, then the low bits are not dead
204         // (they must be zero).
205         if (cast<LShrOperator>(UserI)->isExact())
206           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
207       }
208     }
209     break;
210   case Instruction::AShr:
211     if (OperandNo == 0) {
212       const APInt *ShiftAmtC;
213       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
214         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
215         AB = AOut.shl(ShiftAmt);
216         // Because the high input bit is replicated into the
217         // high-order bits of the result, if we need any of those
218         // bits, then we must keep the highest input bit.
219         if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
220             .getBoolValue())
221           AB.setSignBit();
222 
223         // If the shift is exact, then the low bits are not dead
224         // (they must be zero).
225         if (cast<AShrOperator>(UserI)->isExact())
226           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
227       }
228     }
229     break;
230   case Instruction::And:
231     AB = AOut;
232 
233     // For bits that are known zero, the corresponding bits in the
234     // other operand are dead (unless they're both zero, in which
235     // case they can't both be dead, so just mark the LHS bits as
236     // dead).
237     if (OperandNo == 0) {
238       ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
239       AB &= ~Known2.Zero;
240     } else {
241       if (!isa<Instruction>(UserI->getOperand(0)))
242         ComputeKnownBits(BitWidth, UserI->getOperand(0), I);
243       AB &= ~(Known.Zero & ~Known2.Zero);
244     }
245     break;
246   case Instruction::Or:
247     AB = AOut;
248 
249     // For bits that are known one, the corresponding bits in the
250     // other operand are dead (unless they're both one, in which
251     // case they can't both be dead, so just mark the LHS bits as
252     // dead).
253     if (OperandNo == 0) {
254       ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
255       AB &= ~Known2.One;
256     } else {
257       if (!isa<Instruction>(UserI->getOperand(0)))
258         ComputeKnownBits(BitWidth, UserI->getOperand(0), I);
259       AB &= ~(Known.One & ~Known2.One);
260     }
261     break;
262   case Instruction::Xor:
263   case Instruction::PHI:
264     AB = AOut;
265     break;
266   case Instruction::Trunc:
267     AB = AOut.zext(BitWidth);
268     break;
269   case Instruction::ZExt:
270     AB = AOut.trunc(BitWidth);
271     break;
272   case Instruction::SExt:
273     AB = AOut.trunc(BitWidth);
274     // Because the high input bit is replicated into the
275     // high-order bits of the result, if we need any of those
276     // bits, then we must keep the highest input bit.
277     if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
278                                       AOut.getBitWidth() - BitWidth))
279         .getBoolValue())
280       AB.setSignBit();
281     break;
282   case Instruction::Select:
283     if (OperandNo != 0)
284       AB = AOut;
285     break;
286   case Instruction::ExtractElement:
287     if (OperandNo == 0)
288       AB = AOut;
289     break;
290   case Instruction::InsertElement:
291   case Instruction::ShuffleVector:
292     if (OperandNo == 0 || OperandNo == 1)
293       AB = AOut;
294     break;
295   }
296 }
297 
298 bool DemandedBitsWrapperPass::runOnFunction(Function &F) {
299   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
300   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
301   DB.emplace(F, AC, DT);
302   return false;
303 }
304 
305 void DemandedBitsWrapperPass::releaseMemory() {
306   DB.reset();
307 }
308 
309 void DemandedBits::performAnalysis() {
310   if (Analyzed)
311     // Analysis already completed for this function.
312     return;
313   Analyzed = true;
314 
315   Visited.clear();
316   AliveBits.clear();
317 
318   SmallVector<Instruction*, 128> Worklist;
319 
320   // Collect the set of "root" instructions that are known live.
321   for (Instruction &I : instructions(F)) {
322     if (!isAlwaysLive(&I))
323       continue;
324 
325     LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
326     // For integer-valued instructions, set up an initial empty set of alive
327     // bits and add the instruction to the work list. For other instructions
328     // add their operands to the work list (for integer values operands, mark
329     // all bits as live).
330     Type *T = I.getType();
331     if (T->isIntOrIntVectorTy()) {
332       if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
333         Worklist.push_back(&I);
334 
335       continue;
336     }
337 
338     // Non-integer-typed instructions...
339     for (Use &OI : I.operands()) {
340       if (Instruction *J = dyn_cast<Instruction>(OI)) {
341         Type *T = J->getType();
342         if (T->isIntOrIntVectorTy())
343           AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits());
344         Worklist.push_back(J);
345       }
346     }
347     // To save memory, we don't add I to the Visited set here. Instead, we
348     // check isAlwaysLive on every instruction when searching for dead
349     // instructions later (we need to check isAlwaysLive for the
350     // integer-typed instructions anyway).
351   }
352 
353   // Propagate liveness backwards to operands.
354   while (!Worklist.empty()) {
355     Instruction *UserI = Worklist.pop_back_val();
356 
357     LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
358     APInt AOut;
359     if (UserI->getType()->isIntOrIntVectorTy()) {
360       AOut = AliveBits[UserI];
361       LLVM_DEBUG(dbgs() << " Alive Out: " << AOut);
362     }
363     LLVM_DEBUG(dbgs() << "\n");
364 
365     if (!UserI->getType()->isIntOrIntVectorTy())
366       Visited.insert(UserI);
367 
368     KnownBits Known, Known2;
369     // Compute the set of alive bits for each operand. These are anded into the
370     // existing set, if any, and if that changes the set of alive bits, the
371     // operand is added to the work-list.
372     for (Use &OI : UserI->operands()) {
373       if (Instruction *I = dyn_cast<Instruction>(OI)) {
374         Type *T = I->getType();
375         if (T->isIntOrIntVectorTy()) {
376           unsigned BitWidth = T->getScalarSizeInBits();
377           APInt AB = APInt::getAllOnesValue(BitWidth);
378           if (UserI->getType()->isIntOrIntVectorTy() && !AOut &&
379               !isAlwaysLive(UserI)) {
380             AB = APInt(BitWidth, 0);
381           } else {
382             // If all bits of the output are dead, then all bits of the input
383             // Bits of each operand that are used to compute alive bits of the
384             // output are alive, all others are dead.
385             determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB,
386                                      Known, Known2);
387           }
388 
389           // If we've added to the set of alive bits (or the operand has not
390           // been previously visited), then re-queue the operand to be visited
391           // again.
392           APInt ABPrev(BitWidth, 0);
393           auto ABI = AliveBits.find(I);
394           if (ABI != AliveBits.end())
395             ABPrev = ABI->second;
396 
397           APInt ABNew = AB | ABPrev;
398           if (ABNew != ABPrev || ABI == AliveBits.end()) {
399             AliveBits[I] = std::move(ABNew);
400             Worklist.push_back(I);
401           }
402         } else if (!Visited.count(I)) {
403           Worklist.push_back(I);
404         }
405       }
406     }
407   }
408 }
409 
410 APInt DemandedBits::getDemandedBits(Instruction *I) {
411   performAnalysis();
412 
413   auto Found = AliveBits.find(I);
414   if (Found != AliveBits.end())
415     return Found->second;
416 
417   const DataLayout &DL = I->getModule()->getDataLayout();
418   return APInt::getAllOnesValue(
419       DL.getTypeSizeInBits(I->getType()->getScalarType()));
420 }
421 
422 bool DemandedBits::isInstructionDead(Instruction *I) {
423   performAnalysis();
424 
425   return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
426     !isAlwaysLive(I);
427 }
428 
429 void DemandedBits::print(raw_ostream &OS) {
430   performAnalysis();
431   for (auto &KV : AliveBits) {
432     OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue())
433        << " for " << *KV.first << '\n';
434   }
435 }
436 
437 FunctionPass *llvm::createDemandedBitsWrapperPass() {
438   return new DemandedBitsWrapperPass();
439 }
440 
441 AnalysisKey DemandedBitsAnalysis::Key;
442 
443 DemandedBits DemandedBitsAnalysis::run(Function &F,
444                                              FunctionAnalysisManager &AM) {
445   auto &AC = AM.getResult<AssumptionAnalysis>(F);
446   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
447   return DemandedBits(F, AC, DT);
448 }
449 
450 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
451                                                FunctionAnalysisManager &AM) {
452   AM.getResult<DemandedBitsAnalysis>(F).print(OS);
453   return PreservedAnalyses::all();
454 }
455