1 //===- CalledValuePropagation.cpp - Propagate called values -----*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a transformation that attaches !callees metadata to
11 // indirect call sites. For a given call site, the metadata, if present,
12 // indicates the set of functions the call site could possibly target at
13 // run-time. This metadata is added to indirect call sites when the set of
14 // possible targets can be determined by analysis and is known to be small. The
15 // analysis driving the transformation is similar to constant propagation and
16 // makes uses of the generic sparse propagation solver.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/CalledValuePropagation.h"
21 #include "llvm/Analysis/SparsePropagation.h"
22 #include "llvm/Analysis/ValueLatticeUtils.h"
23 #include "llvm/IR/InstVisitor.h"
24 #include "llvm/IR/MDBuilder.h"
25 #include "llvm/Transforms/IPO.h"
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "called-value-propagation"
29 
30 /// The maximum number of functions to track per lattice value. Once the number
31 /// of functions a call site can possibly target exceeds this threshold, it's
32 /// lattice value becomes overdefined. The number of possible lattice values is
33 /// bounded by Ch(F, M), where F is the number of functions in the module and M
34 /// is MaxFunctionsPerValue. As such, this value should be kept very small. We
35 /// likely can't do anything useful for call sites with a large number of
36 /// possible targets, anyway.
37 static cl::opt<unsigned> MaxFunctionsPerValue(
38     "cvp-max-functions-per-value", cl::Hidden, cl::init(4),
39     cl::desc("The maximum number of functions to track per lattice value"));
40 
41 namespace {
42 /// To enable interprocedural analysis, we assign LLVM values to the following
43 /// groups. The register group represents SSA registers, the return group
44 /// represents the return values of functions, and the memory group represents
45 /// in-memory values. An LLVM Value can technically be in more than one group.
46 /// It's necessary to distinguish these groups so we can, for example, track a
47 /// global variable separately from the value stored at its location.
48 enum class IPOGrouping { Register, Return, Memory };
49 
50 /// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
51 using CVPLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
52 
53 /// The lattice value type used by our custom lattice function. It holds the
54 /// lattice state, and a set of functions.
55 class CVPLatticeVal {
56 public:
57   /// The states of the lattice values. Only the FunctionSet state is
58   /// interesting. It indicates the set of functions to which an LLVM value may
59   /// refer.
60   enum CVPLatticeStateTy { Undefined, FunctionSet, Overdefined, Untracked };
61 
62   /// Comparator for sorting the functions set. We want to keep the order
63   /// deterministic for testing, etc.
64   struct Compare {
65     bool operator()(const Function *LHS, const Function *RHS) const {
66       return LHS->getName() < RHS->getName();
67     }
68   };
69 
70   CVPLatticeVal() : LatticeState(Undefined) {}
71   CVPLatticeVal(CVPLatticeStateTy LatticeState) : LatticeState(LatticeState) {}
72   CVPLatticeVal(std::set<Function *, Compare> &&Functions)
73       : LatticeState(FunctionSet), Functions(Functions) {}
74 
75   /// Get a reference to the functions held by this lattice value. The number
76   /// of functions will be zero for states other than FunctionSet.
77   const std::set<Function *, Compare> &getFunctions() const {
78     return Functions;
79   }
80 
81   /// Returns true if the lattice value is in the FunctionSet state.
82   bool isFunctionSet() const { return LatticeState == FunctionSet; }
83 
84   bool operator==(const CVPLatticeVal &RHS) const {
85     return LatticeState == RHS.LatticeState && Functions == RHS.Functions;
86   }
87 
88   bool operator!=(const CVPLatticeVal &RHS) const {
89     return LatticeState != RHS.LatticeState || Functions != RHS.Functions;
90   }
91 
92 private:
93   /// Holds the state this lattice value is in.
94   CVPLatticeStateTy LatticeState;
95 
96   /// Holds functions indicating the possible targets of call sites. This set
97   /// is empty for lattice values in the undefined, overdefined, and untracked
98   /// states. The maximum size of the set is controlled by
99   /// MaxFunctionsPerValue. Since most LLVM values are expected to be in
100   /// uninteresting states (i.e., overdefined), CVPLatticeVal objects should be
101   /// small and efficiently copyable.
102   std::set<Function *, Compare> Functions;
103 };
104 
105 /// The custom lattice function used by the generic sparse propagation solver.
106 /// It handles merging lattice values and computing new lattice values for
107 /// constants, arguments, values returned from trackable functions, and values
108 /// located in trackable global variables. It also computes the lattice values
109 /// that change as a result of executing instructions.
110 class CVPLatticeFunc
111     : public AbstractLatticeFunction<CVPLatticeKey, CVPLatticeVal> {
112 public:
113   CVPLatticeFunc()
114       : AbstractLatticeFunction(CVPLatticeVal(CVPLatticeVal::Undefined),
115                                 CVPLatticeVal(CVPLatticeVal::Overdefined),
116                                 CVPLatticeVal(CVPLatticeVal::Untracked)) {}
117 
118   /// Compute and return a CVPLatticeVal for the given CVPLatticeKey.
119   CVPLatticeVal ComputeLatticeVal(CVPLatticeKey Key) override {
120     switch (Key.getInt()) {
121     case IPOGrouping::Register:
122       if (isa<Instruction>(Key.getPointer())) {
123         return getUndefVal();
124       } else if (auto *A = dyn_cast<Argument>(Key.getPointer())) {
125         if (canTrackArgumentsInterprocedurally(A->getParent()))
126           return getUndefVal();
127       } else if (auto *C = dyn_cast<Constant>(Key.getPointer())) {
128         return computeConstant(C);
129       }
130       return getOverdefinedVal();
131     case IPOGrouping::Memory:
132     case IPOGrouping::Return:
133       if (auto *GV = dyn_cast<GlobalVariable>(Key.getPointer())) {
134         if (canTrackGlobalVariableInterprocedurally(GV))
135           return computeConstant(GV->getInitializer());
136       } else if (auto *F = cast<Function>(Key.getPointer()))
137         if (canTrackReturnsInterprocedurally(F))
138           return getUndefVal();
139     }
140     return getOverdefinedVal();
141   }
142 
143   /// Merge the two given lattice values. The interesting cases are merging two
144   /// FunctionSet values and a FunctionSet value with an Undefined value. For
145   /// these cases, we simply union the function sets. If the size of the union
146   /// is greater than the maximum functions we track, the merged value is
147   /// overdefined.
148   CVPLatticeVal MergeValues(CVPLatticeVal X, CVPLatticeVal Y) override {
149     if (X == getOverdefinedVal() || Y == getOverdefinedVal())
150       return getOverdefinedVal();
151     if (X == getUndefVal() && Y == getUndefVal())
152       return getUndefVal();
153     std::set<Function *, CVPLatticeVal::Compare> Union;
154     std::set_union(X.getFunctions().begin(), X.getFunctions().end(),
155                    Y.getFunctions().begin(), Y.getFunctions().end(),
156                    std::inserter(Union, Union.begin()),
157                    CVPLatticeVal::Compare{});
158     if (Union.size() > MaxFunctionsPerValue)
159       return getOverdefinedVal();
160     return CVPLatticeVal(std::move(Union));
161   }
162 
163   /// Compute the lattice values that change as a result of executing the given
164   /// instruction. The changed values are stored in \p ChangedValues. We handle
165   /// just a few kinds of instructions since we're only propagating values that
166   /// can be called.
167   void ComputeInstructionState(
168       Instruction &I, DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
169       SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) override {
170     switch (I.getOpcode()) {
171     case Instruction::Call:
172       return visitCallSite(cast<CallInst>(&I), ChangedValues, SS);
173     case Instruction::Invoke:
174       return visitCallSite(cast<InvokeInst>(&I), ChangedValues, SS);
175     case Instruction::Load:
176       return visitLoad(*cast<LoadInst>(&I), ChangedValues, SS);
177     case Instruction::Ret:
178       return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
179     case Instruction::Select:
180       return visitSelect(*cast<SelectInst>(&I), ChangedValues, SS);
181     case Instruction::Store:
182       return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
183     default:
184       return visitInst(I, ChangedValues, SS);
185     }
186   }
187 
188   /// Print the given CVPLatticeVal to the specified stream.
189   void PrintLatticeVal(CVPLatticeVal LV, raw_ostream &OS) override {
190     if (LV == getUndefVal())
191       OS << "Undefined  ";
192     else if (LV == getOverdefinedVal())
193       OS << "Overdefined";
194     else if (LV == getUntrackedVal())
195       OS << "Untracked  ";
196     else
197       OS << "FunctionSet";
198   }
199 
200   /// Print the given CVPLatticeKey to the specified stream.
201   void PrintLatticeKey(CVPLatticeKey Key, raw_ostream &OS) override {
202     if (Key.getInt() == IPOGrouping::Register)
203       OS << "<reg> ";
204     else if (Key.getInt() == IPOGrouping::Memory)
205       OS << "<mem> ";
206     else if (Key.getInt() == IPOGrouping::Return)
207       OS << "<ret> ";
208     if (isa<Function>(Key.getPointer()))
209       OS << Key.getPointer()->getName();
210     else
211       OS << *Key.getPointer();
212   }
213 
214   /// We collect a set of indirect calls when visiting call sites. This method
215   /// returns a reference to that set.
216   SmallPtrSetImpl<Instruction *> &getIndirectCalls() { return IndirectCalls; }
217 
218 private:
219   /// Holds the indirect calls we encounter during the analysis. We will attach
220   /// metadata to these calls after the analysis indicating the functions the
221   /// calls can possibly target.
222   SmallPtrSet<Instruction *, 32> IndirectCalls;
223 
224   /// Compute a new lattice value for the given constant. The constant, after
225   /// stripping any pointer casts, should be a Function. We ignore null
226   /// pointers as an optimization, since calling these values is undefined
227   /// behavior.
228   CVPLatticeVal computeConstant(Constant *C) {
229     if (isa<ConstantPointerNull>(C))
230       return CVPLatticeVal(CVPLatticeVal::FunctionSet);
231     if (auto *F = dyn_cast<Function>(C->stripPointerCasts()))
232       return CVPLatticeVal({F});
233     return getOverdefinedVal();
234   }
235 
236   /// Handle return instructions. The function's return state is the merge of
237   /// the returned value state and the function's return state.
238   void visitReturn(ReturnInst &I,
239                    DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
240                    SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
241     Function *F = I.getParent()->getParent();
242     if (F->getReturnType()->isVoidTy())
243       return;
244     auto RegI = CVPLatticeKey(I.getReturnValue(), IPOGrouping::Register);
245     auto RetF = CVPLatticeKey(F, IPOGrouping::Return);
246     ChangedValues[RetF] =
247         MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
248   }
249 
250   /// Handle call sites. The state of a called function's formal arguments is
251   /// the merge of the argument state with the call sites corresponding actual
252   /// argument state. The call site state is the merge of the call site state
253   /// with the returned value state of the called function.
254   void visitCallSite(CallSite CS,
255                      DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
256                      SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
257     Function *F = CS.getCalledFunction();
258     Instruction *I = CS.getInstruction();
259     auto RegI = CVPLatticeKey(I, IPOGrouping::Register);
260 
261     // If this is an indirect call, save it so we can quickly revisit it when
262     // attaching metadata.
263     if (!F)
264       IndirectCalls.insert(I);
265 
266     // If we can't track the function's return values, there's nothing to do.
267     if (!F || !canTrackReturnsInterprocedurally(F)) {
268       ChangedValues[RegI] = getOverdefinedVal();
269       return;
270     }
271 
272     // Inform the solver that the called function is executable, and perform
273     // the merges for the arguments and return value.
274     SS.MarkBlockExecutable(&F->front());
275     auto RetF = CVPLatticeKey(F, IPOGrouping::Return);
276     for (Argument &A : F->args()) {
277       auto RegFormal = CVPLatticeKey(&A, IPOGrouping::Register);
278       auto RegActual =
279           CVPLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register);
280       ChangedValues[RegFormal] =
281           MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
282     }
283     ChangedValues[RegI] =
284         MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
285   }
286 
287   /// Handle select instructions. The select instruction state is the merge the
288   /// true and false value states.
289   void visitSelect(SelectInst &I,
290                    DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
291                    SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
292     auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
293     auto RegT = CVPLatticeKey(I.getTrueValue(), IPOGrouping::Register);
294     auto RegF = CVPLatticeKey(I.getFalseValue(), IPOGrouping::Register);
295     ChangedValues[RegI] =
296         MergeValues(SS.getValueState(RegT), SS.getValueState(RegF));
297   }
298 
299   /// Handle load instructions. If the pointer operand of the load is a global
300   /// variable, we attempt to track the value. The loaded value state is the
301   /// merge of the loaded value state with the global variable state.
302   void visitLoad(LoadInst &I,
303                  DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
304                  SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
305     auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
306     if (auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand())) {
307       auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory);
308       ChangedValues[RegI] =
309           MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV));
310     } else {
311       ChangedValues[RegI] = getOverdefinedVal();
312     }
313   }
314 
315   /// Handle store instructions. If the pointer operand of the store is a
316   /// global variable, we attempt to track the value. The global variable state
317   /// is the merge of the stored value state with the global variable state.
318   void visitStore(StoreInst &I,
319                   DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
320                   SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
321     auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
322     if (!GV)
323       return;
324     auto RegI = CVPLatticeKey(I.getValueOperand(), IPOGrouping::Register);
325     auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory);
326     ChangedValues[MemGV] =
327         MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV));
328   }
329 
330   /// Handle all other instructions. All other instructions are marked
331   /// overdefined.
332   void visitInst(Instruction &I,
333                  DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues,
334                  SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) {
335     auto RegI = CVPLatticeKey(&I, IPOGrouping::Register);
336     ChangedValues[RegI] = getOverdefinedVal();
337   }
338 };
339 } // namespace
340 
341 namespace llvm {
342 /// A specialization of LatticeKeyInfo for CVPLatticeKeys. The generic solver
343 /// must translate between LatticeKeys and LLVM Values when adding Values to
344 /// its work list and inspecting the state of control-flow related values.
345 template <> struct LatticeKeyInfo<CVPLatticeKey> {
346   static inline Value *getValueFromLatticeKey(CVPLatticeKey Key) {
347     return Key.getPointer();
348   }
349   static inline CVPLatticeKey getLatticeKeyFromValue(Value *V) {
350     return CVPLatticeKey(V, IPOGrouping::Register);
351   }
352 };
353 } // namespace llvm
354 
355 static bool runCVP(Module &M) {
356   // Our custom lattice function and generic sparse propagation solver.
357   CVPLatticeFunc Lattice;
358   SparseSolver<CVPLatticeKey, CVPLatticeVal> Solver(&Lattice);
359 
360   // For each function in the module, if we can't track its arguments, let the
361   // generic solver assume it is executable.
362   for (Function &F : M)
363     if (!F.isDeclaration() && !canTrackArgumentsInterprocedurally(&F))
364       Solver.MarkBlockExecutable(&F.front());
365 
366   // Solver our custom lattice. In doing so, we will also build a set of
367   // indirect call sites.
368   Solver.Solve();
369 
370   // Attach metadata to the indirect call sites that were collected indicating
371   // the set of functions they can possibly target.
372   bool Changed = false;
373   MDBuilder MDB(M.getContext());
374   for (Instruction *C : Lattice.getIndirectCalls()) {
375     CallSite CS(C);
376     auto RegI = CVPLatticeKey(CS.getCalledValue(), IPOGrouping::Register);
377     CVPLatticeVal LV = Solver.getExistingValueState(RegI);
378     if (!LV.isFunctionSet() || LV.getFunctions().empty())
379       continue;
380     MDNode *Callees = MDB.createCallees(SmallVector<Function *, 4>(
381         LV.getFunctions().begin(), LV.getFunctions().end()));
382     C->setMetadata(LLVMContext::MD_callees, Callees);
383     Changed = true;
384   }
385 
386   return Changed;
387 }
388 
389 PreservedAnalyses CalledValuePropagationPass::run(Module &M,
390                                                   ModuleAnalysisManager &) {
391   runCVP(M);
392   return PreservedAnalyses::all();
393 }
394 
395 namespace {
396 class CalledValuePropagationLegacyPass : public ModulePass {
397 public:
398   static char ID;
399 
400   void getAnalysisUsage(AnalysisUsage &AU) const override {
401     AU.setPreservesAll();
402   }
403 
404   CalledValuePropagationLegacyPass() : ModulePass(ID) {
405     initializeCalledValuePropagationLegacyPassPass(
406         *PassRegistry::getPassRegistry());
407   }
408 
409   bool runOnModule(Module &M) override {
410     if (skipModule(M))
411       return false;
412     return runCVP(M);
413   }
414 };
415 } // namespace
416 
417 char CalledValuePropagationLegacyPass::ID = 0;
418 INITIALIZE_PASS(CalledValuePropagationLegacyPass, "called-value-propagation",
419                 "Called Value Propagation", false, false)
420 
421 ModulePass *llvm::createCalledValuePropagationPass() {
422   return new CalledValuePropagationLegacyPass();
423 }
424