1 //===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains a pass that keeps track of @llvm.assume intrinsics in
11 // the functions of a module.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Analysis/AssumptionCache.h"
16 #include "llvm/IR/CallSite.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/Function.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/PassManager.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
24 using namespace llvm;
25 using namespace llvm::PatternMatch;
26 
27 SmallVector<WeakVH, 1> &AssumptionCache::getOrInsertAffectedValues(Value *V) {
28   // Try using find_as first to avoid creating extra value handles just for the
29   // purpose of doing the lookup.
30   auto AVI = AffectedValues.find_as(V);
31   if (AVI != AffectedValues.end())
32     return AVI->second;
33 
34   auto AVIP = AffectedValues.insert({
35       AffectedValueCallbackVH(V, this), SmallVector<WeakVH, 1>()});
36   return AVIP.first->second;
37 }
38 
39 void AssumptionCache::updateAffectedValues(CallInst *CI) {
40   // Note: This code must be kept in-sync with the code in
41   // computeKnownBitsFromAssume in ValueTracking.
42 
43   SmallVector<Value *, 16> Affected;
44   auto AddAffected = [&Affected](Value *V) {
45     if (isa<Argument>(V)) {
46       Affected.push_back(V);
47     } else if (auto *I = dyn_cast<Instruction>(V)) {
48       Affected.push_back(I);
49 
50       // Peek through unary operators to find the source of the condition.
51       Value *Op;
52       if (match(I, m_BitCast(m_Value(Op))) ||
53           match(I, m_PtrToInt(m_Value(Op))) ||
54           match(I, m_Not(m_Value(Op)))) {
55         if (isa<Instruction>(Op) || isa<Argument>(Op))
56           Affected.push_back(Op);
57       }
58     }
59   };
60 
61   Value *Cond = CI->getArgOperand(0), *A, *B;
62   AddAffected(Cond);
63 
64   CmpInst::Predicate Pred;
65   if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
66     AddAffected(A);
67     AddAffected(B);
68 
69     if (Pred == ICmpInst::ICMP_EQ) {
70       // For equality comparisons, we handle the case of bit inversion.
71       auto AddAffectedFromEq = [&AddAffected](Value *V) {
72         Value *A;
73         if (match(V, m_Not(m_Value(A)))) {
74           AddAffected(A);
75           V = A;
76         }
77 
78         Value *B;
79         ConstantInt *C;
80         // (A & B) or (A | B) or (A ^ B).
81         if (match(V,
82                   m_CombineOr(m_And(m_Value(A), m_Value(B)),
83                     m_CombineOr(m_Or(m_Value(A), m_Value(B)),
84                                 m_Xor(m_Value(A), m_Value(B)))))) {
85           AddAffected(A);
86           AddAffected(B);
87         // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
88         } else if (match(V,
89                          m_CombineOr(m_Shl(m_Value(A), m_ConstantInt(C)),
90                            m_CombineOr(m_LShr(m_Value(A), m_ConstantInt(C)),
91                                        m_AShr(m_Value(A),
92                                               m_ConstantInt(C)))))) {
93           AddAffected(A);
94         }
95       };
96 
97       AddAffectedFromEq(A);
98       AddAffectedFromEq(B);
99     }
100   }
101 
102   for (auto &AV : Affected) {
103     auto &AVV = getOrInsertAffectedValues(AV);
104     if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
105       AVV.push_back(CI);
106   }
107 }
108 
109 void AssumptionCache::AffectedValueCallbackVH::deleted() {
110   auto AVI = AC->AffectedValues.find(getValPtr());
111   if (AVI != AC->AffectedValues.end())
112     AC->AffectedValues.erase(AVI);
113   // 'this' now dangles!
114 }
115 
116 void AssumptionCache::copyAffectedValuesInCache(Value *OV, Value *NV) {
117   auto &NAVV = getOrInsertAffectedValues(NV);
118   auto AVI = AffectedValues.find(OV);
119   if (AVI == AffectedValues.end())
120     return;
121 
122   for (auto &A : AVI->second)
123     if (std::find(NAVV.begin(), NAVV.end(), A) == NAVV.end())
124       NAVV.push_back(A);
125 }
126 
127 void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
128   if (!isa<Instruction>(NV) && !isa<Argument>(NV))
129     return;
130 
131   // Any assumptions that affected this value now affect the new value.
132 
133   AC->copyAffectedValuesInCache(getValPtr(), NV);
134   // 'this' now might dangle! If the AffectedValues map was resized to add an
135   // entry for NV then this object might have been destroyed in favor of some
136   // copy in the grown map.
137 }
138 
139 void AssumptionCache::scanFunction() {
140   assert(!Scanned && "Tried to scan the function twice!");
141   assert(AssumeHandles.empty() && "Already have assumes when scanning!");
142 
143   // Go through all instructions in all blocks, add all calls to @llvm.assume
144   // to this cache.
145   for (BasicBlock &B : F)
146     for (Instruction &II : B)
147       if (match(&II, m_Intrinsic<Intrinsic::assume>()))
148         AssumeHandles.push_back(&II);
149 
150   // Mark the scan as complete.
151   Scanned = true;
152 
153   // Update affected values.
154   for (auto &A : AssumeHandles)
155     updateAffectedValues(cast<CallInst>(A));
156 }
157 
158 void AssumptionCache::registerAssumption(CallInst *CI) {
159   assert(match(CI, m_Intrinsic<Intrinsic::assume>()) &&
160          "Registered call does not call @llvm.assume");
161 
162   // If we haven't scanned the function yet, just drop this assumption. It will
163   // be found when we scan later.
164   if (!Scanned)
165     return;
166 
167   AssumeHandles.push_back(CI);
168 
169 #ifndef NDEBUG
170   assert(CI->getParent() &&
171          "Cannot register @llvm.assume call not in a basic block");
172   assert(&F == CI->getParent()->getParent() &&
173          "Cannot register @llvm.assume call not in this function");
174 
175   // We expect the number of assumptions to be small, so in an asserts build
176   // check that we don't accumulate duplicates and that all assumptions point
177   // to the same function.
178   SmallPtrSet<Value *, 16> AssumptionSet;
179   for (auto &VH : AssumeHandles) {
180     if (!VH)
181       continue;
182 
183     assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
184            "Cached assumption not inside this function!");
185     assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
186            "Cached something other than a call to @llvm.assume!");
187     assert(AssumptionSet.insert(VH).second &&
188            "Cache contains multiple copies of a call!");
189   }
190 #endif
191 
192   updateAffectedValues(CI);
193 }
194 
195 AnalysisKey AssumptionAnalysis::Key;
196 
197 PreservedAnalyses AssumptionPrinterPass::run(Function &F,
198                                              FunctionAnalysisManager &AM) {
199   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
200 
201   OS << "Cached assumptions for function: " << F.getName() << "\n";
202   for (auto &VH : AC.assumptions())
203     if (VH)
204       OS << "  " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
205 
206   return PreservedAnalyses::all();
207 }
208 
209 void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
210   auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr()));
211   if (I != ACT->AssumptionCaches.end())
212     ACT->AssumptionCaches.erase(I);
213   // 'this' now dangles!
214 }
215 
216 AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
217   // We probe the function map twice to try and avoid creating a value handle
218   // around the function in common cases. This makes insertion a bit slower,
219   // but if we have to insert we're going to scan the whole function so that
220   // shouldn't matter.
221   auto I = AssumptionCaches.find_as(&F);
222   if (I != AssumptionCaches.end())
223     return *I->second;
224 
225   // Ok, build a new cache by scanning the function, insert it and the value
226   // handle into our map, and return the newly populated cache.
227   auto IP = AssumptionCaches.insert(std::make_pair(
228       FunctionCallbackVH(&F, this), llvm::make_unique<AssumptionCache>(F)));
229   assert(IP.second && "Scanning function already in the map?");
230   return *IP.first->second;
231 }
232 
233 void AssumptionCacheTracker::verifyAnalysis() const {
234 #ifndef NDEBUG
235   SmallPtrSet<const CallInst *, 4> AssumptionSet;
236   for (const auto &I : AssumptionCaches) {
237     for (auto &VH : I.second->assumptions())
238       if (VH)
239         AssumptionSet.insert(cast<CallInst>(VH));
240 
241     for (const BasicBlock &B : cast<Function>(*I.first))
242       for (const Instruction &II : B)
243         if (match(&II, m_Intrinsic<Intrinsic::assume>()))
244           assert(AssumptionSet.count(cast<CallInst>(&II)) &&
245                  "Assumption in scanned function not in cache");
246   }
247 #endif
248 }
249 
250 AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {
251   initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry());
252 }
253 
254 AssumptionCacheTracker::~AssumptionCacheTracker() {}
255 
256 INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
257                 "Assumption Cache Tracker", false, true)
258 char AssumptionCacheTracker::ID = 0;
259