1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2df49d1bbSPeter Collingbourne //
3df49d1bbSPeter Collingbourne //                     The LLVM Compiler Infrastructure
4df49d1bbSPeter Collingbourne //
5df49d1bbSPeter Collingbourne // This file is distributed under the University of Illinois Open Source
6df49d1bbSPeter Collingbourne // License. See LICENSE.TXT for details.
7df49d1bbSPeter Collingbourne //
8df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
9df49d1bbSPeter Collingbourne //
10df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases
117efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This
12df49d1bbSPeter Collingbourne // includes the following:
13df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single
14df49d1bbSPeter Collingbourne //   possible callee, replace all calls with a direct call to that callee.
15df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an
16df49d1bbSPeter Collingbourne //   integer <=64 bits and all possible callees are readnone, for each class and
17df49d1bbSPeter Collingbourne //   each list of constant arguments: evaluate the function, store the return
18df49d1bbSPeter Collingbourne //   value alongside the virtual table, and rewrite each virtual call as a load
19df49d1bbSPeter Collingbourne //   from the virtual table.
20df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant
21df49d1bbSPeter Collingbourne //   propagation hold and each function returns the same constant value, replace
22df49d1bbSPeter Collingbourne //   each virtual call with that constant.
23df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions
24df49d1bbSPeter Collingbourne //   for virtual constant propagation hold and a single vtable's function
25df49d1bbSPeter Collingbourne //   returns 0, or a single vtable's function returns 1, replace each virtual
26df49d1bbSPeter Collingbourne //   call with a comparison of the vptr against that vtable's address.
27df49d1bbSPeter Collingbourne //
28df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
29df49d1bbSPeter Collingbourne 
30df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
31b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h"
32cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h"
33cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h"
34df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h"
35cdc71612SEugene Zelenko #include "llvm/ADT/iterator_range.h"
36df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h"
37cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h"
3837317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h"
3937317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h"
407efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h"
41df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h"
42df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h"
43df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h"
44b05e06e4SIvan Krasin #include "llvm/IR/DebugInfoMetadata.h"
45cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h"
46cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h"
475474645dSIvan Krasin #include "llvm/IR/DiagnosticInfo.h"
48cdc71612SEugene Zelenko #include "llvm/IR/Function.h"
49cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h"
50cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h"
51df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h"
52cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h"
53cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h"
54df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h"
55df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h"
56cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h"
57cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h"
58df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h"
592b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h"
60df49d1bbSPeter Collingbourne #include "llvm/Pass.h"
61cdc71612SEugene Zelenko #include "llvm/PassRegistry.h"
62cdc71612SEugene Zelenko #include "llvm/PassSupport.h"
63cdc71612SEugene Zelenko #include "llvm/Support/Casting.h"
642b33f653SPeter Collingbourne #include "llvm/Support/Error.h"
652b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h"
66cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h"
67b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h"
6837317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h"
69df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h"
70cdc71612SEugene Zelenko #include <algorithm>
71cdc71612SEugene Zelenko #include <cstddef>
72cdc71612SEugene Zelenko #include <map>
73df49d1bbSPeter Collingbourne #include <set>
74cdc71612SEugene Zelenko #include <string>
75df49d1bbSPeter Collingbourne 
76df49d1bbSPeter Collingbourne using namespace llvm;
77df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt;
78df49d1bbSPeter Collingbourne 
79df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt"
80df49d1bbSPeter Collingbourne 
812b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction(
822b33f653SPeter Collingbourne     "wholeprogramdevirt-summary-action",
832b33f653SPeter Collingbourne     cl::desc("What to do with the summary when running this pass"),
842b33f653SPeter Collingbourne     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
852b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Import, "import",
862b33f653SPeter Collingbourne                           "Import typeid resolutions from summary and globals"),
872b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Export, "export",
882b33f653SPeter Collingbourne                           "Export typeid resolutions to summary and globals")),
892b33f653SPeter Collingbourne     cl::Hidden);
902b33f653SPeter Collingbourne 
912b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary(
922b33f653SPeter Collingbourne     "wholeprogramdevirt-read-summary",
932b33f653SPeter Collingbourne     cl::desc("Read summary from given YAML file before running pass"),
942b33f653SPeter Collingbourne     cl::Hidden);
952b33f653SPeter Collingbourne 
962b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary(
972b33f653SPeter Collingbourne     "wholeprogramdevirt-write-summary",
982b33f653SPeter Collingbourne     cl::desc("Write summary to given YAML file after running pass"),
992b33f653SPeter Collingbourne     cl::Hidden);
1002b33f653SPeter Collingbourne 
101df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If
102df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an
103df49d1bbSPeter Collingbourne // offset after the object.
104df49d1bbSPeter Collingbourne uint64_t
105df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
106df49d1bbSPeter Collingbourne                                      bool IsAfter, uint64_t Size) {
107df49d1bbSPeter Collingbourne   // Find a minimum offset taking into account only vtable sizes.
108df49d1bbSPeter Collingbourne   uint64_t MinByte = 0;
109df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
110df49d1bbSPeter Collingbourne     if (IsAfter)
111df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minAfterBytes());
112df49d1bbSPeter Collingbourne     else
113df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minBeforeBytes());
114df49d1bbSPeter Collingbourne   }
115df49d1bbSPeter Collingbourne 
116df49d1bbSPeter Collingbourne   // Build a vector of arrays of bytes covering, for each target, a slice of the
117df49d1bbSPeter Collingbourne   // used region (see AccumBitVector::BytesUsed in
118df49d1bbSPeter Collingbourne   // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
119df49d1bbSPeter Collingbourne   // this aligns the used regions to start at MinByte.
120df49d1bbSPeter Collingbourne   //
121df49d1bbSPeter Collingbourne   // In this example, A, B and C are vtables, # is a byte already allocated for
122df49d1bbSPeter Collingbourne   // a virtual function pointer, AAAA... (etc.) are the used regions for the
123df49d1bbSPeter Collingbourne   // vtables and Offset(X) is the value computed for the Offset variable below
124df49d1bbSPeter Collingbourne   // for X.
125df49d1bbSPeter Collingbourne   //
126df49d1bbSPeter Collingbourne   //                    Offset(A)
127df49d1bbSPeter Collingbourne   //                    |       |
128df49d1bbSPeter Collingbourne   //                            |MinByte
129df49d1bbSPeter Collingbourne   // A: ################AAAAAAAA|AAAAAAAA
130df49d1bbSPeter Collingbourne   // B: ########BBBBBBBBBBBBBBBB|BBBB
131df49d1bbSPeter Collingbourne   // C: ########################|CCCCCCCCCCCCCCCC
132df49d1bbSPeter Collingbourne   //            |   Offset(B)   |
133df49d1bbSPeter Collingbourne   //
134df49d1bbSPeter Collingbourne   // This code produces the slices of A, B and C that appear after the divider
135df49d1bbSPeter Collingbourne   // at MinByte.
136df49d1bbSPeter Collingbourne   std::vector<ArrayRef<uint8_t>> Used;
137df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
1387efd7506SPeter Collingbourne     ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
1397efd7506SPeter Collingbourne                                        : Target.TM->Bits->Before.BytesUsed;
140df49d1bbSPeter Collingbourne     uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
141df49d1bbSPeter Collingbourne                               : MinByte - Target.minBeforeBytes();
142df49d1bbSPeter Collingbourne 
143df49d1bbSPeter Collingbourne     // Disregard used regions that are smaller than Offset. These are
144df49d1bbSPeter Collingbourne     // effectively all-free regions that do not need to be checked.
145df49d1bbSPeter Collingbourne     if (VTUsed.size() > Offset)
146df49d1bbSPeter Collingbourne       Used.push_back(VTUsed.slice(Offset));
147df49d1bbSPeter Collingbourne   }
148df49d1bbSPeter Collingbourne 
149df49d1bbSPeter Collingbourne   if (Size == 1) {
150df49d1bbSPeter Collingbourne     // Find a free bit in each member of Used.
151df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
152df49d1bbSPeter Collingbourne       uint8_t BitsUsed = 0;
153df49d1bbSPeter Collingbourne       for (auto &&B : Used)
154df49d1bbSPeter Collingbourne         if (I < B.size())
155df49d1bbSPeter Collingbourne           BitsUsed |= B[I];
156df49d1bbSPeter Collingbourne       if (BitsUsed != 0xff)
157df49d1bbSPeter Collingbourne         return (MinByte + I) * 8 +
158df49d1bbSPeter Collingbourne                countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
159df49d1bbSPeter Collingbourne     }
160df49d1bbSPeter Collingbourne   } else {
161df49d1bbSPeter Collingbourne     // Find a free (Size/8) byte region in each member of Used.
162df49d1bbSPeter Collingbourne     // FIXME: see if alignment helps.
163df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
164df49d1bbSPeter Collingbourne       for (auto &&B : Used) {
165df49d1bbSPeter Collingbourne         unsigned Byte = 0;
166df49d1bbSPeter Collingbourne         while ((I + Byte) < B.size() && Byte < (Size / 8)) {
167df49d1bbSPeter Collingbourne           if (B[I + Byte])
168df49d1bbSPeter Collingbourne             goto NextI;
169df49d1bbSPeter Collingbourne           ++Byte;
170df49d1bbSPeter Collingbourne         }
171df49d1bbSPeter Collingbourne       }
172df49d1bbSPeter Collingbourne       return (MinByte + I) * 8;
173df49d1bbSPeter Collingbourne     NextI:;
174df49d1bbSPeter Collingbourne     }
175df49d1bbSPeter Collingbourne   }
176df49d1bbSPeter Collingbourne }
177df49d1bbSPeter Collingbourne 
178df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues(
179df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
180df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
181df49d1bbSPeter Collingbourne   if (BitWidth == 1)
182df49d1bbSPeter Collingbourne     OffsetByte = -(AllocBefore / 8 + 1);
183df49d1bbSPeter Collingbourne   else
184df49d1bbSPeter Collingbourne     OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
185df49d1bbSPeter Collingbourne   OffsetBit = AllocBefore % 8;
186df49d1bbSPeter Collingbourne 
187df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
188df49d1bbSPeter Collingbourne     if (BitWidth == 1)
189df49d1bbSPeter Collingbourne       Target.setBeforeBit(AllocBefore);
190df49d1bbSPeter Collingbourne     else
191df49d1bbSPeter Collingbourne       Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
192df49d1bbSPeter Collingbourne   }
193df49d1bbSPeter Collingbourne }
194df49d1bbSPeter Collingbourne 
195df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues(
196df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
197df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
198df49d1bbSPeter Collingbourne   if (BitWidth == 1)
199df49d1bbSPeter Collingbourne     OffsetByte = AllocAfter / 8;
200df49d1bbSPeter Collingbourne   else
201df49d1bbSPeter Collingbourne     OffsetByte = (AllocAfter + 7) / 8;
202df49d1bbSPeter Collingbourne   OffsetBit = AllocAfter % 8;
203df49d1bbSPeter Collingbourne 
204df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
205df49d1bbSPeter Collingbourne     if (BitWidth == 1)
206df49d1bbSPeter Collingbourne       Target.setAfterBit(AllocAfter);
207df49d1bbSPeter Collingbourne     else
208df49d1bbSPeter Collingbourne       Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
209df49d1bbSPeter Collingbourne   }
210df49d1bbSPeter Collingbourne }
211df49d1bbSPeter Collingbourne 
2127efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
2137efd7506SPeter Collingbourne     : Fn(Fn), TM(TM),
21489439a79SIvan Krasin       IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
215df49d1bbSPeter Collingbourne 
216df49d1bbSPeter Collingbourne namespace {
217df49d1bbSPeter Collingbourne 
2187efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual
219df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to
220df49d1bbSPeter Collingbourne // the virtual function pointer.
221df49d1bbSPeter Collingbourne struct VTableSlot {
2227efd7506SPeter Collingbourne   Metadata *TypeID;
223df49d1bbSPeter Collingbourne   uint64_t ByteOffset;
224df49d1bbSPeter Collingbourne };
225df49d1bbSPeter Collingbourne 
226cdc71612SEugene Zelenko } // end anonymous namespace
227df49d1bbSPeter Collingbourne 
2289b656527SPeter Collingbourne namespace llvm {
2299b656527SPeter Collingbourne 
230df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> {
231df49d1bbSPeter Collingbourne   static VTableSlot getEmptyKey() {
232df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getEmptyKey(),
233df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getEmptyKey()};
234df49d1bbSPeter Collingbourne   }
235df49d1bbSPeter Collingbourne   static VTableSlot getTombstoneKey() {
236df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getTombstoneKey(),
237df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getTombstoneKey()};
238df49d1bbSPeter Collingbourne   }
239df49d1bbSPeter Collingbourne   static unsigned getHashValue(const VTableSlot &I) {
2407efd7506SPeter Collingbourne     return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
241df49d1bbSPeter Collingbourne            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
242df49d1bbSPeter Collingbourne   }
243df49d1bbSPeter Collingbourne   static bool isEqual(const VTableSlot &LHS,
244df49d1bbSPeter Collingbourne                       const VTableSlot &RHS) {
2457efd7506SPeter Collingbourne     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
246df49d1bbSPeter Collingbourne   }
247df49d1bbSPeter Collingbourne };
248df49d1bbSPeter Collingbourne 
249cdc71612SEugene Zelenko } // end namespace llvm
2509b656527SPeter Collingbourne 
251df49d1bbSPeter Collingbourne namespace {
252df49d1bbSPeter Collingbourne 
253df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is
254df49d1bbSPeter Collingbourne // the indirect virtual call.
255df49d1bbSPeter Collingbourne struct VirtualCallSite {
256df49d1bbSPeter Collingbourne   Value *VTable;
257df49d1bbSPeter Collingbourne   CallSite CS;
258df49d1bbSPeter Collingbourne 
2590312f614SPeter Collingbourne   // If non-null, this field points to the associated unsafe use count stored in
2600312f614SPeter Collingbourne   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
2610312f614SPeter Collingbourne   // of that field for details.
2620312f614SPeter Collingbourne   unsigned *NumUnsafeUses;
2630312f614SPeter Collingbourne 
264f3403fd2SIvan Krasin   void emitRemark(const Twine &OptName, const Twine &TargetName) {
2655474645dSIvan Krasin     Function *F = CS.getCaller();
266f3403fd2SIvan Krasin     emitOptimizationRemark(
267f3403fd2SIvan Krasin         F->getContext(), DEBUG_TYPE, *F,
2685474645dSIvan Krasin         CS.getInstruction()->getDebugLoc(),
269f3403fd2SIvan Krasin         OptName + ": devirtualized a call to " + TargetName);
2705474645dSIvan Krasin   }
2715474645dSIvan Krasin 
272f3403fd2SIvan Krasin   void replaceAndErase(const Twine &OptName, const Twine &TargetName,
273f3403fd2SIvan Krasin                        bool RemarksEnabled, Value *New) {
274f3403fd2SIvan Krasin     if (RemarksEnabled)
275f3403fd2SIvan Krasin       emitRemark(OptName, TargetName);
276df49d1bbSPeter Collingbourne     CS->replaceAllUsesWith(New);
277df49d1bbSPeter Collingbourne     if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
278df49d1bbSPeter Collingbourne       BranchInst::Create(II->getNormalDest(), CS.getInstruction());
279df49d1bbSPeter Collingbourne       II->getUnwindDest()->removePredecessor(II->getParent());
280df49d1bbSPeter Collingbourne     }
281df49d1bbSPeter Collingbourne     CS->eraseFromParent();
2820312f614SPeter Collingbourne     // This use is no longer unsafe.
2830312f614SPeter Collingbourne     if (NumUnsafeUses)
2840312f614SPeter Collingbourne       --*NumUnsafeUses;
285df49d1bbSPeter Collingbourne   }
286df49d1bbSPeter Collingbourne };
287df49d1bbSPeter Collingbourne 
28850cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list
28950cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the
29050cbd7ccSPeter Collingbourne // VTableSlotInfo class.
29150cbd7ccSPeter Collingbourne struct CallSiteInfo {
29250cbd7ccSPeter Collingbourne   std::vector<VirtualCallSite> CallSites;
29350cbd7ccSPeter Collingbourne };
29450cbd7ccSPeter Collingbourne 
29550cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot.
29650cbd7ccSPeter Collingbourne struct VTableSlotInfo {
29750cbd7ccSPeter Collingbourne   // The set of call sites which do not have all constant integer arguments
29850cbd7ccSPeter Collingbourne   // (excluding "this").
29950cbd7ccSPeter Collingbourne   CallSiteInfo CSInfo;
30050cbd7ccSPeter Collingbourne 
30150cbd7ccSPeter Collingbourne   // The set of call sites with all constant integer arguments (excluding
30250cbd7ccSPeter Collingbourne   // "this"), grouped by argument list.
30350cbd7ccSPeter Collingbourne   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
30450cbd7ccSPeter Collingbourne 
30550cbd7ccSPeter Collingbourne   void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
30650cbd7ccSPeter Collingbourne 
30750cbd7ccSPeter Collingbourne private:
30850cbd7ccSPeter Collingbourne   CallSiteInfo &findCallSiteInfo(CallSite CS);
30950cbd7ccSPeter Collingbourne };
31050cbd7ccSPeter Collingbourne 
31150cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
31250cbd7ccSPeter Collingbourne   std::vector<uint64_t> Args;
31350cbd7ccSPeter Collingbourne   auto *CI = dyn_cast<IntegerType>(CS.getType());
31450cbd7ccSPeter Collingbourne   if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
31550cbd7ccSPeter Collingbourne     return CSInfo;
31650cbd7ccSPeter Collingbourne   for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
31750cbd7ccSPeter Collingbourne     auto *CI = dyn_cast<ConstantInt>(Arg);
31850cbd7ccSPeter Collingbourne     if (!CI || CI->getBitWidth() > 64)
31950cbd7ccSPeter Collingbourne       return CSInfo;
32050cbd7ccSPeter Collingbourne     Args.push_back(CI->getZExtValue());
32150cbd7ccSPeter Collingbourne   }
32250cbd7ccSPeter Collingbourne   return ConstCSInfo[Args];
32350cbd7ccSPeter Collingbourne }
32450cbd7ccSPeter Collingbourne 
32550cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
32650cbd7ccSPeter Collingbourne                                  unsigned *NumUnsafeUses) {
32750cbd7ccSPeter Collingbourne   findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
32850cbd7ccSPeter Collingbourne }
32950cbd7ccSPeter Collingbourne 
330df49d1bbSPeter Collingbourne struct DevirtModule {
331df49d1bbSPeter Collingbourne   Module &M;
33237317f12SPeter Collingbourne   function_ref<AAResults &(Function &)> AARGetter;
3332b33f653SPeter Collingbourne 
3342b33f653SPeter Collingbourne   PassSummaryAction Action;
3352b33f653SPeter Collingbourne   ModuleSummaryIndex *Summary;
3362b33f653SPeter Collingbourne 
337df49d1bbSPeter Collingbourne   IntegerType *Int8Ty;
338df49d1bbSPeter Collingbourne   PointerType *Int8PtrTy;
339df49d1bbSPeter Collingbourne   IntegerType *Int32Ty;
34050cbd7ccSPeter Collingbourne   IntegerType *Int64Ty;
341df49d1bbSPeter Collingbourne 
342f3403fd2SIvan Krasin   bool RemarksEnabled;
343f3403fd2SIvan Krasin 
34450cbd7ccSPeter Collingbourne   MapVector<VTableSlot, VTableSlotInfo> CallSlots;
345df49d1bbSPeter Collingbourne 
3460312f614SPeter Collingbourne   // This map keeps track of the number of "unsafe" uses of a loaded function
3470312f614SPeter Collingbourne   // pointer. The key is the associated llvm.type.test intrinsic call generated
3480312f614SPeter Collingbourne   // by this pass. An unsafe use is one that calls the loaded function pointer
3490312f614SPeter Collingbourne   // directly. Every time we eliminate an unsafe use (for example, by
3500312f614SPeter Collingbourne   // devirtualizing it or by applying virtual constant propagation), we
3510312f614SPeter Collingbourne   // decrement the value stored in this map. If a value reaches zero, we can
3520312f614SPeter Collingbourne   // eliminate the type check by RAUWing the associated llvm.type.test call with
3530312f614SPeter Collingbourne   // true.
3540312f614SPeter Collingbourne   std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
3550312f614SPeter Collingbourne 
35637317f12SPeter Collingbourne   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
35737317f12SPeter Collingbourne                PassSummaryAction Action, ModuleSummaryIndex *Summary)
35837317f12SPeter Collingbourne       : M(M), AARGetter(AARGetter), Action(Action), Summary(Summary),
3592b33f653SPeter Collingbourne         Int8Ty(Type::getInt8Ty(M.getContext())),
360df49d1bbSPeter Collingbourne         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
361f3403fd2SIvan Krasin         Int32Ty(Type::getInt32Ty(M.getContext())),
36250cbd7ccSPeter Collingbourne         Int64Ty(Type::getInt64Ty(M.getContext())),
363f3403fd2SIvan Krasin         RemarksEnabled(areRemarksEnabled()) {}
364f3403fd2SIvan Krasin 
365f3403fd2SIvan Krasin   bool areRemarksEnabled();
366df49d1bbSPeter Collingbourne 
3670312f614SPeter Collingbourne   void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
3680312f614SPeter Collingbourne   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
3690312f614SPeter Collingbourne 
3707efd7506SPeter Collingbourne   void buildTypeIdentifierMap(
3717efd7506SPeter Collingbourne       std::vector<VTableBits> &Bits,
3727efd7506SPeter Collingbourne       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
3738786754cSPeter Collingbourne   Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
3747efd7506SPeter Collingbourne   bool
3757efd7506SPeter Collingbourne   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
3767efd7506SPeter Collingbourne                             const std::set<TypeMemberInfo> &TypeMemberInfos,
377df49d1bbSPeter Collingbourne                             uint64_t ByteOffset);
37850cbd7ccSPeter Collingbourne 
37950cbd7ccSPeter Collingbourne   void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn);
380f3403fd2SIvan Krasin   bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
38150cbd7ccSPeter Collingbourne                            VTableSlotInfo &SlotInfo);
38250cbd7ccSPeter Collingbourne 
383df49d1bbSPeter Collingbourne   bool tryEvaluateFunctionsWithArgs(
384df49d1bbSPeter Collingbourne       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
38550cbd7ccSPeter Collingbourne       ArrayRef<uint64_t> Args);
38650cbd7ccSPeter Collingbourne 
38750cbd7ccSPeter Collingbourne   void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
38850cbd7ccSPeter Collingbourne                              uint64_t TheRetVal);
38950cbd7ccSPeter Collingbourne   bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
39050cbd7ccSPeter Collingbourne                            CallSiteInfo &CSInfo);
39150cbd7ccSPeter Collingbourne 
39250cbd7ccSPeter Collingbourne   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
39350cbd7ccSPeter Collingbourne                             Constant *UniqueMemberAddr);
394df49d1bbSPeter Collingbourne   bool tryUniqueRetValOpt(unsigned BitWidth,
395f3403fd2SIvan Krasin                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
39650cbd7ccSPeter Collingbourne                           CallSiteInfo &CSInfo);
39750cbd7ccSPeter Collingbourne 
39850cbd7ccSPeter Collingbourne   void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
39950cbd7ccSPeter Collingbourne                              Constant *Byte, Constant *Bit);
400df49d1bbSPeter Collingbourne   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
40150cbd7ccSPeter Collingbourne                            VTableSlotInfo &SlotInfo);
402df49d1bbSPeter Collingbourne 
403df49d1bbSPeter Collingbourne   void rebuildGlobal(VTableBits &B);
404df49d1bbSPeter Collingbourne 
405df49d1bbSPeter Collingbourne   bool run();
4062b33f653SPeter Collingbourne 
4072b33f653SPeter Collingbourne   // Lower the module using the action and summary passed as command line
4082b33f653SPeter Collingbourne   // arguments. For testing purposes only.
40937317f12SPeter Collingbourne   static bool runForTesting(Module &M,
41037317f12SPeter Collingbourne                             function_ref<AAResults &(Function &)> AARGetter);
411df49d1bbSPeter Collingbourne };
412df49d1bbSPeter Collingbourne 
413df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass {
414df49d1bbSPeter Collingbourne   static char ID;
415cdc71612SEugene Zelenko 
4162b33f653SPeter Collingbourne   bool UseCommandLine = false;
4172b33f653SPeter Collingbourne 
4182b33f653SPeter Collingbourne   PassSummaryAction Action;
4192b33f653SPeter Collingbourne   ModuleSummaryIndex *Summary;
4202b33f653SPeter Collingbourne 
4212b33f653SPeter Collingbourne   WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
4222b33f653SPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
4232b33f653SPeter Collingbourne   }
4242b33f653SPeter Collingbourne 
4252b33f653SPeter Collingbourne   WholeProgramDevirt(PassSummaryAction Action, ModuleSummaryIndex *Summary)
4262b33f653SPeter Collingbourne       : ModulePass(ID), Action(Action), Summary(Summary) {
427df49d1bbSPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
428df49d1bbSPeter Collingbourne   }
429cdc71612SEugene Zelenko 
430cdc71612SEugene Zelenko   bool runOnModule(Module &M) override {
431aa641a51SAndrew Kaylor     if (skipModule(M))
432aa641a51SAndrew Kaylor       return false;
4332b33f653SPeter Collingbourne     if (UseCommandLine)
43437317f12SPeter Collingbourne       return DevirtModule::runForTesting(M, LegacyAARGetter(*this));
43537317f12SPeter Collingbourne     return DevirtModule(M, LegacyAARGetter(*this), Action, Summary).run();
43637317f12SPeter Collingbourne   }
43737317f12SPeter Collingbourne 
43837317f12SPeter Collingbourne   void getAnalysisUsage(AnalysisUsage &AU) const override {
43937317f12SPeter Collingbourne     AU.addRequired<AssumptionCacheTracker>();
44037317f12SPeter Collingbourne     AU.addRequired<TargetLibraryInfoWrapperPass>();
441aa641a51SAndrew Kaylor   }
442df49d1bbSPeter Collingbourne };
443df49d1bbSPeter Collingbourne 
444cdc71612SEugene Zelenko } // end anonymous namespace
445df49d1bbSPeter Collingbourne 
44637317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
44737317f12SPeter Collingbourne                       "Whole program devirtualization", false, false)
44837317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
44937317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
45037317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
451df49d1bbSPeter Collingbourne                     "Whole program devirtualization", false, false)
452df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0;
453df49d1bbSPeter Collingbourne 
4542b33f653SPeter Collingbourne ModulePass *llvm::createWholeProgramDevirtPass(PassSummaryAction Action,
4552b33f653SPeter Collingbourne                                                ModuleSummaryIndex *Summary) {
4562b33f653SPeter Collingbourne   return new WholeProgramDevirt(Action, Summary);
457df49d1bbSPeter Collingbourne }
458df49d1bbSPeter Collingbourne 
459164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
46037317f12SPeter Collingbourne                                               ModuleAnalysisManager &AM) {
46137317f12SPeter Collingbourne   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
46237317f12SPeter Collingbourne   auto AARGetter = [&](Function &F) -> AAResults & {
46337317f12SPeter Collingbourne     return FAM.getResult<AAManager>(F);
46437317f12SPeter Collingbourne   };
46537317f12SPeter Collingbourne   if (!DevirtModule(M, AARGetter, PassSummaryAction::None, nullptr).run())
466d737dd2eSDavide Italiano     return PreservedAnalyses::all();
467d737dd2eSDavide Italiano   return PreservedAnalyses::none();
468d737dd2eSDavide Italiano }
469d737dd2eSDavide Italiano 
47037317f12SPeter Collingbourne bool DevirtModule::runForTesting(
47137317f12SPeter Collingbourne     Module &M, function_ref<AAResults &(Function &)> AARGetter) {
4722b33f653SPeter Collingbourne   ModuleSummaryIndex Summary;
4732b33f653SPeter Collingbourne 
4742b33f653SPeter Collingbourne   // Handle the command-line summary arguments. This code is for testing
4752b33f653SPeter Collingbourne   // purposes only, so we handle errors directly.
4762b33f653SPeter Collingbourne   if (!ClReadSummary.empty()) {
4772b33f653SPeter Collingbourne     ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
4782b33f653SPeter Collingbourne                           ": ");
4792b33f653SPeter Collingbourne     auto ReadSummaryFile =
4802b33f653SPeter Collingbourne         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
4812b33f653SPeter Collingbourne 
4822b33f653SPeter Collingbourne     yaml::Input In(ReadSummaryFile->getBuffer());
4832b33f653SPeter Collingbourne     In >> Summary;
4842b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(In.error()));
4852b33f653SPeter Collingbourne   }
4862b33f653SPeter Collingbourne 
48737317f12SPeter Collingbourne   bool Changed = DevirtModule(M, AARGetter, ClSummaryAction, &Summary).run();
4882b33f653SPeter Collingbourne 
4892b33f653SPeter Collingbourne   if (!ClWriteSummary.empty()) {
4902b33f653SPeter Collingbourne     ExitOnError ExitOnErr(
4912b33f653SPeter Collingbourne         "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
4922b33f653SPeter Collingbourne     std::error_code EC;
4932b33f653SPeter Collingbourne     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
4942b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(EC));
4952b33f653SPeter Collingbourne 
4962b33f653SPeter Collingbourne     yaml::Output Out(OS);
4972b33f653SPeter Collingbourne     Out << Summary;
4982b33f653SPeter Collingbourne   }
4992b33f653SPeter Collingbourne 
5002b33f653SPeter Collingbourne   return Changed;
5012b33f653SPeter Collingbourne }
5022b33f653SPeter Collingbourne 
5037efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap(
504df49d1bbSPeter Collingbourne     std::vector<VTableBits> &Bits,
5057efd7506SPeter Collingbourne     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
506df49d1bbSPeter Collingbourne   DenseMap<GlobalVariable *, VTableBits *> GVToBits;
5077efd7506SPeter Collingbourne   Bits.reserve(M.getGlobalList().size());
5087efd7506SPeter Collingbourne   SmallVector<MDNode *, 2> Types;
5097efd7506SPeter Collingbourne   for (GlobalVariable &GV : M.globals()) {
5107efd7506SPeter Collingbourne     Types.clear();
5117efd7506SPeter Collingbourne     GV.getMetadata(LLVMContext::MD_type, Types);
5127efd7506SPeter Collingbourne     if (Types.empty())
513df49d1bbSPeter Collingbourne       continue;
514df49d1bbSPeter Collingbourne 
5157efd7506SPeter Collingbourne     VTableBits *&BitsPtr = GVToBits[&GV];
5167efd7506SPeter Collingbourne     if (!BitsPtr) {
5177efd7506SPeter Collingbourne       Bits.emplace_back();
5187efd7506SPeter Collingbourne       Bits.back().GV = &GV;
5197efd7506SPeter Collingbourne       Bits.back().ObjectSize =
5207efd7506SPeter Collingbourne           M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
5217efd7506SPeter Collingbourne       BitsPtr = &Bits.back();
5227efd7506SPeter Collingbourne     }
5237efd7506SPeter Collingbourne 
5247efd7506SPeter Collingbourne     for (MDNode *Type : Types) {
5257efd7506SPeter Collingbourne       auto TypeID = Type->getOperand(1).get();
526df49d1bbSPeter Collingbourne 
527df49d1bbSPeter Collingbourne       uint64_t Offset =
528df49d1bbSPeter Collingbourne           cast<ConstantInt>(
5297efd7506SPeter Collingbourne               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
530df49d1bbSPeter Collingbourne               ->getZExtValue();
531df49d1bbSPeter Collingbourne 
5327efd7506SPeter Collingbourne       TypeIdMap[TypeID].insert({BitsPtr, Offset});
533df49d1bbSPeter Collingbourne     }
534df49d1bbSPeter Collingbourne   }
535df49d1bbSPeter Collingbourne }
536df49d1bbSPeter Collingbourne 
5378786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
5388786754cSPeter Collingbourne   if (I->getType()->isPointerTy()) {
5398786754cSPeter Collingbourne     if (Offset == 0)
5408786754cSPeter Collingbourne       return I;
5418786754cSPeter Collingbourne     return nullptr;
5428786754cSPeter Collingbourne   }
5438786754cSPeter Collingbourne 
5447a1e5bbeSPeter Collingbourne   const DataLayout &DL = M.getDataLayout();
5457a1e5bbeSPeter Collingbourne 
5467a1e5bbeSPeter Collingbourne   if (auto *C = dyn_cast<ConstantStruct>(I)) {
5477a1e5bbeSPeter Collingbourne     const StructLayout *SL = DL.getStructLayout(C->getType());
5487a1e5bbeSPeter Collingbourne     if (Offset >= SL->getSizeInBytes())
5497a1e5bbeSPeter Collingbourne       return nullptr;
5507a1e5bbeSPeter Collingbourne 
5518786754cSPeter Collingbourne     unsigned Op = SL->getElementContainingOffset(Offset);
5528786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
5538786754cSPeter Collingbourne                               Offset - SL->getElementOffset(Op));
5548786754cSPeter Collingbourne   }
5558786754cSPeter Collingbourne   if (auto *C = dyn_cast<ConstantArray>(I)) {
5567a1e5bbeSPeter Collingbourne     ArrayType *VTableTy = C->getType();
5577a1e5bbeSPeter Collingbourne     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
5587a1e5bbeSPeter Collingbourne 
5598786754cSPeter Collingbourne     unsigned Op = Offset / ElemSize;
5607a1e5bbeSPeter Collingbourne     if (Op >= C->getNumOperands())
5617a1e5bbeSPeter Collingbourne       return nullptr;
5627a1e5bbeSPeter Collingbourne 
5638786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
5648786754cSPeter Collingbourne                               Offset % ElemSize);
5658786754cSPeter Collingbourne   }
5668786754cSPeter Collingbourne   return nullptr;
5677a1e5bbeSPeter Collingbourne }
5687a1e5bbeSPeter Collingbourne 
569df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets(
570df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> &TargetsForSlot,
5717efd7506SPeter Collingbourne     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
5727efd7506SPeter Collingbourne   for (const TypeMemberInfo &TM : TypeMemberInfos) {
5737efd7506SPeter Collingbourne     if (!TM.Bits->GV->isConstant())
574df49d1bbSPeter Collingbourne       return false;
575df49d1bbSPeter Collingbourne 
5768786754cSPeter Collingbourne     Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
5778786754cSPeter Collingbourne                                        TM.Offset + ByteOffset);
5788786754cSPeter Collingbourne     if (!Ptr)
579df49d1bbSPeter Collingbourne       return false;
580df49d1bbSPeter Collingbourne 
5818786754cSPeter Collingbourne     auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
582df49d1bbSPeter Collingbourne     if (!Fn)
583df49d1bbSPeter Collingbourne       return false;
584df49d1bbSPeter Collingbourne 
585df49d1bbSPeter Collingbourne     // We can disregard __cxa_pure_virtual as a possible call target, as
586df49d1bbSPeter Collingbourne     // calls to pure virtuals are UB.
587df49d1bbSPeter Collingbourne     if (Fn->getName() == "__cxa_pure_virtual")
588df49d1bbSPeter Collingbourne       continue;
589df49d1bbSPeter Collingbourne 
5907efd7506SPeter Collingbourne     TargetsForSlot.push_back({Fn, &TM});
591df49d1bbSPeter Collingbourne   }
592df49d1bbSPeter Collingbourne 
593df49d1bbSPeter Collingbourne   // Give up if we couldn't find any targets.
594df49d1bbSPeter Collingbourne   return !TargetsForSlot.empty();
595df49d1bbSPeter Collingbourne }
596df49d1bbSPeter Collingbourne 
59750cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
59850cbd7ccSPeter Collingbourne                                          Constant *TheFn) {
59950cbd7ccSPeter Collingbourne   auto Apply = [&](CallSiteInfo &CSInfo) {
60050cbd7ccSPeter Collingbourne     for (auto &&VCallSite : CSInfo.CallSites) {
601f3403fd2SIvan Krasin       if (RemarksEnabled)
602f3403fd2SIvan Krasin         VCallSite.emitRemark("single-impl", TheFn->getName());
603df49d1bbSPeter Collingbourne       VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
604df49d1bbSPeter Collingbourne           TheFn, VCallSite.CS.getCalledValue()->getType()));
6050312f614SPeter Collingbourne       // This use is no longer unsafe.
6060312f614SPeter Collingbourne       if (VCallSite.NumUnsafeUses)
6070312f614SPeter Collingbourne         --*VCallSite.NumUnsafeUses;
608df49d1bbSPeter Collingbourne     }
60950cbd7ccSPeter Collingbourne   };
61050cbd7ccSPeter Collingbourne   Apply(SlotInfo.CSInfo);
61150cbd7ccSPeter Collingbourne   for (auto &P : SlotInfo.ConstCSInfo)
61250cbd7ccSPeter Collingbourne     Apply(P.second);
61350cbd7ccSPeter Collingbourne }
61450cbd7ccSPeter Collingbourne 
61550cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt(
61650cbd7ccSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
61750cbd7ccSPeter Collingbourne     VTableSlotInfo &SlotInfo) {
61850cbd7ccSPeter Collingbourne   // See if the program contains a single implementation of this virtual
61950cbd7ccSPeter Collingbourne   // function.
62050cbd7ccSPeter Collingbourne   Function *TheFn = TargetsForSlot[0].Fn;
62150cbd7ccSPeter Collingbourne   for (auto &&Target : TargetsForSlot)
62250cbd7ccSPeter Collingbourne     if (TheFn != Target.Fn)
62350cbd7ccSPeter Collingbourne       return false;
62450cbd7ccSPeter Collingbourne 
62550cbd7ccSPeter Collingbourne   // If so, update each call site to call that implementation directly.
62650cbd7ccSPeter Collingbourne   if (RemarksEnabled)
62750cbd7ccSPeter Collingbourne     TargetsForSlot[0].WasDevirt = true;
62850cbd7ccSPeter Collingbourne   applySingleImplDevirt(SlotInfo, TheFn);
629df49d1bbSPeter Collingbourne   return true;
630df49d1bbSPeter Collingbourne }
631df49d1bbSPeter Collingbourne 
632df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs(
633df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
63450cbd7ccSPeter Collingbourne     ArrayRef<uint64_t> Args) {
635df49d1bbSPeter Collingbourne   // Evaluate each function and store the result in each target's RetVal
636df49d1bbSPeter Collingbourne   // field.
637df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
638df49d1bbSPeter Collingbourne     if (Target.Fn->arg_size() != Args.size() + 1)
639df49d1bbSPeter Collingbourne       return false;
640df49d1bbSPeter Collingbourne 
641df49d1bbSPeter Collingbourne     Evaluator Eval(M.getDataLayout(), nullptr);
642df49d1bbSPeter Collingbourne     SmallVector<Constant *, 2> EvalArgs;
643df49d1bbSPeter Collingbourne     EvalArgs.push_back(
644df49d1bbSPeter Collingbourne         Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
64550cbd7ccSPeter Collingbourne     for (unsigned I = 0; I != Args.size(); ++I) {
64650cbd7ccSPeter Collingbourne       auto *ArgTy = dyn_cast<IntegerType>(
64750cbd7ccSPeter Collingbourne           Target.Fn->getFunctionType()->getParamType(I + 1));
64850cbd7ccSPeter Collingbourne       if (!ArgTy)
64950cbd7ccSPeter Collingbourne         return false;
65050cbd7ccSPeter Collingbourne       EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
65150cbd7ccSPeter Collingbourne     }
65250cbd7ccSPeter Collingbourne 
653df49d1bbSPeter Collingbourne     Constant *RetVal;
654df49d1bbSPeter Collingbourne     if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
655df49d1bbSPeter Collingbourne         !isa<ConstantInt>(RetVal))
656df49d1bbSPeter Collingbourne       return false;
657df49d1bbSPeter Collingbourne     Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
658df49d1bbSPeter Collingbourne   }
659df49d1bbSPeter Collingbourne   return true;
660df49d1bbSPeter Collingbourne }
661df49d1bbSPeter Collingbourne 
66250cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
66350cbd7ccSPeter Collingbourne                                          uint64_t TheRetVal) {
66450cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites)
66550cbd7ccSPeter Collingbourne     Call.replaceAndErase(
66650cbd7ccSPeter Collingbourne         "uniform-ret-val", FnName, RemarksEnabled,
66750cbd7ccSPeter Collingbourne         ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
66850cbd7ccSPeter Collingbourne }
66950cbd7ccSPeter Collingbourne 
670df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt(
67150cbd7ccSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo) {
672df49d1bbSPeter Collingbourne   // Uniform return value optimization. If all functions return the same
673df49d1bbSPeter Collingbourne   // constant, replace all calls with that constant.
674df49d1bbSPeter Collingbourne   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
675df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : TargetsForSlot)
676df49d1bbSPeter Collingbourne     if (Target.RetVal != TheRetVal)
677df49d1bbSPeter Collingbourne       return false;
678df49d1bbSPeter Collingbourne 
67950cbd7ccSPeter Collingbourne   applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
680f3403fd2SIvan Krasin   if (RemarksEnabled)
681f3403fd2SIvan Krasin     for (auto &&Target : TargetsForSlot)
682f3403fd2SIvan Krasin       Target.WasDevirt = true;
683df49d1bbSPeter Collingbourne   return true;
684df49d1bbSPeter Collingbourne }
685df49d1bbSPeter Collingbourne 
68650cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
68750cbd7ccSPeter Collingbourne                                         bool IsOne,
68850cbd7ccSPeter Collingbourne                                         Constant *UniqueMemberAddr) {
68950cbd7ccSPeter Collingbourne   for (auto &&Call : CSInfo.CallSites) {
69050cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
69150cbd7ccSPeter Collingbourne     Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
69250cbd7ccSPeter Collingbourne                               Call.VTable, UniqueMemberAddr);
69350cbd7ccSPeter Collingbourne     Cmp = B.CreateZExt(Cmp, Call.CS->getType());
69450cbd7ccSPeter Collingbourne     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp);
69550cbd7ccSPeter Collingbourne   }
69650cbd7ccSPeter Collingbourne }
69750cbd7ccSPeter Collingbourne 
698df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt(
699f3403fd2SIvan Krasin     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
70050cbd7ccSPeter Collingbourne     CallSiteInfo &CSInfo) {
701df49d1bbSPeter Collingbourne   // IsOne controls whether we look for a 0 or a 1.
702df49d1bbSPeter Collingbourne   auto tryUniqueRetValOptFor = [&](bool IsOne) {
703cdc71612SEugene Zelenko     const TypeMemberInfo *UniqueMember = nullptr;
704df49d1bbSPeter Collingbourne     for (const VirtualCallTarget &Target : TargetsForSlot) {
7053866cc5fSPeter Collingbourne       if (Target.RetVal == (IsOne ? 1 : 0)) {
7067efd7506SPeter Collingbourne         if (UniqueMember)
707df49d1bbSPeter Collingbourne           return false;
7087efd7506SPeter Collingbourne         UniqueMember = Target.TM;
709df49d1bbSPeter Collingbourne       }
710df49d1bbSPeter Collingbourne     }
711df49d1bbSPeter Collingbourne 
7127efd7506SPeter Collingbourne     // We should have found a unique member or bailed out by now. We already
713df49d1bbSPeter Collingbourne     // checked for a uniform return value in tryUniformRetValOpt.
7147efd7506SPeter Collingbourne     assert(UniqueMember);
715df49d1bbSPeter Collingbourne 
716df49d1bbSPeter Collingbourne     // Replace each call with the comparison.
71750cbd7ccSPeter Collingbourne     Constant *UniqueMemberAddr =
71850cbd7ccSPeter Collingbourne         ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
71950cbd7ccSPeter Collingbourne     UniqueMemberAddr = ConstantExpr::getGetElementPtr(
72050cbd7ccSPeter Collingbourne         Int8Ty, UniqueMemberAddr,
72150cbd7ccSPeter Collingbourne         ConstantInt::get(Int64Ty, UniqueMember->Offset));
72250cbd7ccSPeter Collingbourne 
72350cbd7ccSPeter Collingbourne     applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
72450cbd7ccSPeter Collingbourne                          UniqueMemberAddr);
72550cbd7ccSPeter Collingbourne 
726f3403fd2SIvan Krasin     // Update devirtualization statistics for targets.
727f3403fd2SIvan Krasin     if (RemarksEnabled)
728f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
729f3403fd2SIvan Krasin         Target.WasDevirt = true;
730f3403fd2SIvan Krasin 
731df49d1bbSPeter Collingbourne     return true;
732df49d1bbSPeter Collingbourne   };
733df49d1bbSPeter Collingbourne 
734df49d1bbSPeter Collingbourne   if (BitWidth == 1) {
735df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(true))
736df49d1bbSPeter Collingbourne       return true;
737df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(false))
738df49d1bbSPeter Collingbourne       return true;
739df49d1bbSPeter Collingbourne   }
740df49d1bbSPeter Collingbourne   return false;
741df49d1bbSPeter Collingbourne }
742df49d1bbSPeter Collingbourne 
74350cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
74450cbd7ccSPeter Collingbourne                                          Constant *Byte, Constant *Bit) {
74550cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites) {
74650cbd7ccSPeter Collingbourne     auto *RetType = cast<IntegerType>(Call.CS.getType());
74750cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
74850cbd7ccSPeter Collingbourne     Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
74950cbd7ccSPeter Collingbourne     if (RetType->getBitWidth() == 1) {
75050cbd7ccSPeter Collingbourne       Value *Bits = B.CreateLoad(Addr);
75150cbd7ccSPeter Collingbourne       Value *BitsAndBit = B.CreateAnd(Bits, Bit);
75250cbd7ccSPeter Collingbourne       auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
75350cbd7ccSPeter Collingbourne       Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
75450cbd7ccSPeter Collingbourne                            IsBitSet);
75550cbd7ccSPeter Collingbourne     } else {
75650cbd7ccSPeter Collingbourne       Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
75750cbd7ccSPeter Collingbourne       Value *Val = B.CreateLoad(RetType, ValAddr);
75850cbd7ccSPeter Collingbourne       Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val);
75950cbd7ccSPeter Collingbourne     }
76050cbd7ccSPeter Collingbourne   }
76150cbd7ccSPeter Collingbourne }
76250cbd7ccSPeter Collingbourne 
763df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp(
764df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
76550cbd7ccSPeter Collingbourne     VTableSlotInfo &SlotInfo) {
766df49d1bbSPeter Collingbourne   // This only works if the function returns an integer.
767df49d1bbSPeter Collingbourne   auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
768df49d1bbSPeter Collingbourne   if (!RetType)
769df49d1bbSPeter Collingbourne     return false;
770df49d1bbSPeter Collingbourne   unsigned BitWidth = RetType->getBitWidth();
771df49d1bbSPeter Collingbourne   if (BitWidth > 64)
772df49d1bbSPeter Collingbourne     return false;
773df49d1bbSPeter Collingbourne 
77417febdbbSPeter Collingbourne   // Make sure that each function is defined, does not access memory, takes at
77517febdbbSPeter Collingbourne   // least one argument, does not use its first argument (which we assume is
77617febdbbSPeter Collingbourne   // 'this'), and has the same return type.
77737317f12SPeter Collingbourne   //
77837317f12SPeter Collingbourne   // Note that we test whether this copy of the function is readnone, rather
77937317f12SPeter Collingbourne   // than testing function attributes, which must hold for any copy of the
78037317f12SPeter Collingbourne   // function, even a less optimized version substituted at link time. This is
78137317f12SPeter Collingbourne   // sound because the virtual constant propagation optimizations effectively
78237317f12SPeter Collingbourne   // inline all implementations of the virtual function into each call site,
78337317f12SPeter Collingbourne   // rather than using function attributes to perform local optimization.
784df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
78537317f12SPeter Collingbourne     if (Target.Fn->isDeclaration() ||
78637317f12SPeter Collingbourne         computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
78737317f12SPeter Collingbourne             MAK_ReadNone ||
78817febdbbSPeter Collingbourne         Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
789df49d1bbSPeter Collingbourne         Target.Fn->getReturnType() != RetType)
790df49d1bbSPeter Collingbourne       return false;
791df49d1bbSPeter Collingbourne   }
792df49d1bbSPeter Collingbourne 
79350cbd7ccSPeter Collingbourne   for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
794df49d1bbSPeter Collingbourne     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
795df49d1bbSPeter Collingbourne       continue;
796df49d1bbSPeter Collingbourne 
79750cbd7ccSPeter Collingbourne     if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second))
798df49d1bbSPeter Collingbourne       continue;
799df49d1bbSPeter Collingbourne 
800df49d1bbSPeter Collingbourne     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second))
801df49d1bbSPeter Collingbourne       continue;
802df49d1bbSPeter Collingbourne 
8037efd7506SPeter Collingbourne     // Find an allocation offset in bits in all vtables associated with the
8047efd7506SPeter Collingbourne     // type.
805df49d1bbSPeter Collingbourne     uint64_t AllocBefore =
806df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
807df49d1bbSPeter Collingbourne     uint64_t AllocAfter =
808df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
809df49d1bbSPeter Collingbourne 
810df49d1bbSPeter Collingbourne     // Calculate the total amount of padding needed to store a value at both
811df49d1bbSPeter Collingbourne     // ends of the object.
812df49d1bbSPeter Collingbourne     uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
813df49d1bbSPeter Collingbourne     for (auto &&Target : TargetsForSlot) {
814df49d1bbSPeter Collingbourne       TotalPaddingBefore += std::max<int64_t>(
815df49d1bbSPeter Collingbourne           (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
816df49d1bbSPeter Collingbourne       TotalPaddingAfter += std::max<int64_t>(
817df49d1bbSPeter Collingbourne           (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
818df49d1bbSPeter Collingbourne     }
819df49d1bbSPeter Collingbourne 
820df49d1bbSPeter Collingbourne     // If the amount of padding is too large, give up.
821df49d1bbSPeter Collingbourne     // FIXME: do something smarter here.
822df49d1bbSPeter Collingbourne     if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
823df49d1bbSPeter Collingbourne       continue;
824df49d1bbSPeter Collingbourne 
825df49d1bbSPeter Collingbourne     // Calculate the offset to the value as a (possibly negative) byte offset
826df49d1bbSPeter Collingbourne     // and (if applicable) a bit offset, and store the values in the targets.
827df49d1bbSPeter Collingbourne     int64_t OffsetByte;
828df49d1bbSPeter Collingbourne     uint64_t OffsetBit;
829df49d1bbSPeter Collingbourne     if (TotalPaddingBefore <= TotalPaddingAfter)
830df49d1bbSPeter Collingbourne       setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
831df49d1bbSPeter Collingbourne                             OffsetBit);
832df49d1bbSPeter Collingbourne     else
833df49d1bbSPeter Collingbourne       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
834df49d1bbSPeter Collingbourne                            OffsetBit);
835df49d1bbSPeter Collingbourne 
836f3403fd2SIvan Krasin     if (RemarksEnabled)
837f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
838f3403fd2SIvan Krasin         Target.WasDevirt = true;
839f3403fd2SIvan Krasin 
840df49d1bbSPeter Collingbourne     // Rewrite each call to a load from OffsetByte/OffsetBit.
841184773d8SPeter Collingbourne     Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
84250cbd7ccSPeter Collingbourne     Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
84350cbd7ccSPeter Collingbourne     applyVirtualConstProp(CSByConstantArg.second,
84450cbd7ccSPeter Collingbourne                           TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
845df49d1bbSPeter Collingbourne   }
846df49d1bbSPeter Collingbourne   return true;
847df49d1bbSPeter Collingbourne }
848df49d1bbSPeter Collingbourne 
849df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) {
850df49d1bbSPeter Collingbourne   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
851df49d1bbSPeter Collingbourne     return;
852df49d1bbSPeter Collingbourne 
853df49d1bbSPeter Collingbourne   // Align each byte array to pointer width.
854df49d1bbSPeter Collingbourne   unsigned PointerSize = M.getDataLayout().getPointerSize();
855df49d1bbSPeter Collingbourne   B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
856df49d1bbSPeter Collingbourne   B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
857df49d1bbSPeter Collingbourne 
858df49d1bbSPeter Collingbourne   // Before was stored in reverse order; flip it now.
859df49d1bbSPeter Collingbourne   for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
860df49d1bbSPeter Collingbourne     std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
861df49d1bbSPeter Collingbourne 
862df49d1bbSPeter Collingbourne   // Build an anonymous global containing the before bytes, followed by the
863df49d1bbSPeter Collingbourne   // original initializer, followed by the after bytes.
864df49d1bbSPeter Collingbourne   auto NewInit = ConstantStruct::getAnon(
865df49d1bbSPeter Collingbourne       {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
866df49d1bbSPeter Collingbourne        B.GV->getInitializer(),
867df49d1bbSPeter Collingbourne        ConstantDataArray::get(M.getContext(), B.After.Bytes)});
868df49d1bbSPeter Collingbourne   auto NewGV =
869df49d1bbSPeter Collingbourne       new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
870df49d1bbSPeter Collingbourne                          GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
871df49d1bbSPeter Collingbourne   NewGV->setSection(B.GV->getSection());
872df49d1bbSPeter Collingbourne   NewGV->setComdat(B.GV->getComdat());
873df49d1bbSPeter Collingbourne 
8740312f614SPeter Collingbourne   // Copy the original vtable's metadata to the anonymous global, adjusting
8750312f614SPeter Collingbourne   // offsets as required.
8760312f614SPeter Collingbourne   NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
8770312f614SPeter Collingbourne 
878df49d1bbSPeter Collingbourne   // Build an alias named after the original global, pointing at the second
879df49d1bbSPeter Collingbourne   // element (the original initializer).
880df49d1bbSPeter Collingbourne   auto Alias = GlobalAlias::create(
881df49d1bbSPeter Collingbourne       B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
882df49d1bbSPeter Collingbourne       ConstantExpr::getGetElementPtr(
883df49d1bbSPeter Collingbourne           NewInit->getType(), NewGV,
884df49d1bbSPeter Collingbourne           ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
885df49d1bbSPeter Collingbourne                                ConstantInt::get(Int32Ty, 1)}),
886df49d1bbSPeter Collingbourne       &M);
887df49d1bbSPeter Collingbourne   Alias->setVisibility(B.GV->getVisibility());
888df49d1bbSPeter Collingbourne   Alias->takeName(B.GV);
889df49d1bbSPeter Collingbourne 
890df49d1bbSPeter Collingbourne   B.GV->replaceAllUsesWith(Alias);
891df49d1bbSPeter Collingbourne   B.GV->eraseFromParent();
892df49d1bbSPeter Collingbourne }
893df49d1bbSPeter Collingbourne 
894f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() {
895f3403fd2SIvan Krasin   const auto &FL = M.getFunctionList();
896f3403fd2SIvan Krasin   if (FL.empty())
897f3403fd2SIvan Krasin     return false;
898f3403fd2SIvan Krasin   const Function &Fn = FL.front();
89904758ba3SAdam Nemet   auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), "");
900f3403fd2SIvan Krasin   return DI.isEnabled();
901f3403fd2SIvan Krasin }
902f3403fd2SIvan Krasin 
9030312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
9040312f614SPeter Collingbourne                                      Function *AssumeFunc) {
905df49d1bbSPeter Collingbourne   // Find all virtual calls via a virtual table pointer %p under an assumption
9067efd7506SPeter Collingbourne   // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
9077efd7506SPeter Collingbourne   // points to a member of the type identifier %md. Group calls by (type ID,
9087efd7506SPeter Collingbourne   // offset) pair (effectively the identity of the virtual function) and store
9097efd7506SPeter Collingbourne   // to CallSlots.
910df49d1bbSPeter Collingbourne   DenseSet<Value *> SeenPtrs;
9117efd7506SPeter Collingbourne   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
912df49d1bbSPeter Collingbourne        I != E;) {
913df49d1bbSPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
914df49d1bbSPeter Collingbourne     ++I;
915df49d1bbSPeter Collingbourne     if (!CI)
916df49d1bbSPeter Collingbourne       continue;
917df49d1bbSPeter Collingbourne 
918ccdc225cSPeter Collingbourne     // Search for virtual calls based on %p and add them to DevirtCalls.
919ccdc225cSPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
920df49d1bbSPeter Collingbourne     SmallVector<CallInst *, 1> Assumes;
9210312f614SPeter Collingbourne     findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
922df49d1bbSPeter Collingbourne 
923ccdc225cSPeter Collingbourne     // If we found any, add them to CallSlots. Only do this if we haven't seen
924ccdc225cSPeter Collingbourne     // the vtable pointer before, as it may have been CSE'd with pointers from
925ccdc225cSPeter Collingbourne     // other call sites, and we don't want to process call sites multiple times.
926df49d1bbSPeter Collingbourne     if (!Assumes.empty()) {
9277efd7506SPeter Collingbourne       Metadata *TypeId =
928df49d1bbSPeter Collingbourne           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
929df49d1bbSPeter Collingbourne       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
930ccdc225cSPeter Collingbourne       if (SeenPtrs.insert(Ptr).second) {
931ccdc225cSPeter Collingbourne         for (DevirtCallSite Call : DevirtCalls) {
93250cbd7ccSPeter Collingbourne           CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
93350cbd7ccSPeter Collingbourne                                                        Call.CS, nullptr);
934ccdc225cSPeter Collingbourne         }
935ccdc225cSPeter Collingbourne       }
936df49d1bbSPeter Collingbourne     }
937df49d1bbSPeter Collingbourne 
9387efd7506SPeter Collingbourne     // We no longer need the assumes or the type test.
939df49d1bbSPeter Collingbourne     for (auto Assume : Assumes)
940df49d1bbSPeter Collingbourne       Assume->eraseFromParent();
941df49d1bbSPeter Collingbourne     // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
942df49d1bbSPeter Collingbourne     // may use the vtable argument later.
943df49d1bbSPeter Collingbourne     if (CI->use_empty())
944df49d1bbSPeter Collingbourne       CI->eraseFromParent();
945df49d1bbSPeter Collingbourne   }
9460312f614SPeter Collingbourne }
9470312f614SPeter Collingbourne 
9480312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
9490312f614SPeter Collingbourne   Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
9500312f614SPeter Collingbourne 
9510312f614SPeter Collingbourne   for (auto I = TypeCheckedLoadFunc->use_begin(),
9520312f614SPeter Collingbourne             E = TypeCheckedLoadFunc->use_end();
9530312f614SPeter Collingbourne        I != E;) {
9540312f614SPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
9550312f614SPeter Collingbourne     ++I;
9560312f614SPeter Collingbourne     if (!CI)
9570312f614SPeter Collingbourne       continue;
9580312f614SPeter Collingbourne 
9590312f614SPeter Collingbourne     Value *Ptr = CI->getArgOperand(0);
9600312f614SPeter Collingbourne     Value *Offset = CI->getArgOperand(1);
9610312f614SPeter Collingbourne     Value *TypeIdValue = CI->getArgOperand(2);
9620312f614SPeter Collingbourne     Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
9630312f614SPeter Collingbourne 
9640312f614SPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
9650312f614SPeter Collingbourne     SmallVector<Instruction *, 1> LoadedPtrs;
9660312f614SPeter Collingbourne     SmallVector<Instruction *, 1> Preds;
9670312f614SPeter Collingbourne     bool HasNonCallUses = false;
9680312f614SPeter Collingbourne     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
9690312f614SPeter Collingbourne                                                HasNonCallUses, CI);
9700312f614SPeter Collingbourne 
9710312f614SPeter Collingbourne     // Start by generating "pessimistic" code that explicitly loads the function
9720312f614SPeter Collingbourne     // pointer from the vtable and performs the type check. If possible, we will
9730312f614SPeter Collingbourne     // eliminate the load and the type check later.
9740312f614SPeter Collingbourne 
9750312f614SPeter Collingbourne     // If possible, only generate the load at the point where it is used.
9760312f614SPeter Collingbourne     // This helps avoid unnecessary spills.
9770312f614SPeter Collingbourne     IRBuilder<> LoadB(
9780312f614SPeter Collingbourne         (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
9790312f614SPeter Collingbourne     Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
9800312f614SPeter Collingbourne     Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
9810312f614SPeter Collingbourne     Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
9820312f614SPeter Collingbourne 
9830312f614SPeter Collingbourne     for (Instruction *LoadedPtr : LoadedPtrs) {
9840312f614SPeter Collingbourne       LoadedPtr->replaceAllUsesWith(LoadedValue);
9850312f614SPeter Collingbourne       LoadedPtr->eraseFromParent();
9860312f614SPeter Collingbourne     }
9870312f614SPeter Collingbourne 
9880312f614SPeter Collingbourne     // Likewise for the type test.
9890312f614SPeter Collingbourne     IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
9900312f614SPeter Collingbourne     CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
9910312f614SPeter Collingbourne 
9920312f614SPeter Collingbourne     for (Instruction *Pred : Preds) {
9930312f614SPeter Collingbourne       Pred->replaceAllUsesWith(TypeTestCall);
9940312f614SPeter Collingbourne       Pred->eraseFromParent();
9950312f614SPeter Collingbourne     }
9960312f614SPeter Collingbourne 
9970312f614SPeter Collingbourne     // We have already erased any extractvalue instructions that refer to the
9980312f614SPeter Collingbourne     // intrinsic call, but the intrinsic may have other non-extractvalue uses
9990312f614SPeter Collingbourne     // (although this is unlikely). In that case, explicitly build a pair and
10000312f614SPeter Collingbourne     // RAUW it.
10010312f614SPeter Collingbourne     if (!CI->use_empty()) {
10020312f614SPeter Collingbourne       Value *Pair = UndefValue::get(CI->getType());
10030312f614SPeter Collingbourne       IRBuilder<> B(CI);
10040312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
10050312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
10060312f614SPeter Collingbourne       CI->replaceAllUsesWith(Pair);
10070312f614SPeter Collingbourne     }
10080312f614SPeter Collingbourne 
10090312f614SPeter Collingbourne     // The number of unsafe uses is initially the number of uses.
10100312f614SPeter Collingbourne     auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
10110312f614SPeter Collingbourne     NumUnsafeUses = DevirtCalls.size();
10120312f614SPeter Collingbourne 
10130312f614SPeter Collingbourne     // If the function pointer has a non-call user, we cannot eliminate the type
10140312f614SPeter Collingbourne     // check, as one of those users may eventually call the pointer. Increment
10150312f614SPeter Collingbourne     // the unsafe use count to make sure it cannot reach zero.
10160312f614SPeter Collingbourne     if (HasNonCallUses)
10170312f614SPeter Collingbourne       ++NumUnsafeUses;
10180312f614SPeter Collingbourne     for (DevirtCallSite Call : DevirtCalls) {
101950cbd7ccSPeter Collingbourne       CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
102050cbd7ccSPeter Collingbourne                                                    &NumUnsafeUses);
10210312f614SPeter Collingbourne     }
10220312f614SPeter Collingbourne 
10230312f614SPeter Collingbourne     CI->eraseFromParent();
10240312f614SPeter Collingbourne   }
10250312f614SPeter Collingbourne }
10260312f614SPeter Collingbourne 
10270312f614SPeter Collingbourne bool DevirtModule::run() {
10280312f614SPeter Collingbourne   Function *TypeTestFunc =
10290312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
10300312f614SPeter Collingbourne   Function *TypeCheckedLoadFunc =
10310312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
10320312f614SPeter Collingbourne   Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
10330312f614SPeter Collingbourne 
10340312f614SPeter Collingbourne   if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
10350312f614SPeter Collingbourne        AssumeFunc->use_empty()) &&
10360312f614SPeter Collingbourne       (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
10370312f614SPeter Collingbourne     return false;
10380312f614SPeter Collingbourne 
10390312f614SPeter Collingbourne   if (TypeTestFunc && AssumeFunc)
10400312f614SPeter Collingbourne     scanTypeTestUsers(TypeTestFunc, AssumeFunc);
10410312f614SPeter Collingbourne 
10420312f614SPeter Collingbourne   if (TypeCheckedLoadFunc)
10430312f614SPeter Collingbourne     scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
1044df49d1bbSPeter Collingbourne 
10457efd7506SPeter Collingbourne   // Rebuild type metadata into a map for easy lookup.
1046df49d1bbSPeter Collingbourne   std::vector<VTableBits> Bits;
10477efd7506SPeter Collingbourne   DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
10487efd7506SPeter Collingbourne   buildTypeIdentifierMap(Bits, TypeIdMap);
10497efd7506SPeter Collingbourne   if (TypeIdMap.empty())
1050df49d1bbSPeter Collingbourne     return true;
1051df49d1bbSPeter Collingbourne 
10527efd7506SPeter Collingbourne   // For each (type, offset) pair:
1053df49d1bbSPeter Collingbourne   bool DidVirtualConstProp = false;
1054f3403fd2SIvan Krasin   std::map<std::string, Function*> DevirtTargets;
1055df49d1bbSPeter Collingbourne   for (auto &S : CallSlots) {
10567efd7506SPeter Collingbourne     // Search each of the members of the type identifier for the virtual
10577efd7506SPeter Collingbourne     // function implementation at offset S.first.ByteOffset, and add to
10587efd7506SPeter Collingbourne     // TargetsForSlot.
1059df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> TargetsForSlot;
10607efd7506SPeter Collingbourne     if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
1061df49d1bbSPeter Collingbourne                                    S.first.ByteOffset))
1062df49d1bbSPeter Collingbourne       continue;
1063df49d1bbSPeter Collingbourne 
1064f3403fd2SIvan Krasin     if (!trySingleImplDevirt(TargetsForSlot, S.second) &&
1065f3403fd2SIvan Krasin         tryVirtualConstProp(TargetsForSlot, S.second))
1066f3403fd2SIvan Krasin         DidVirtualConstProp = true;
1067f3403fd2SIvan Krasin 
1068f3403fd2SIvan Krasin     // Collect functions devirtualized at least for one call site for stats.
1069f3403fd2SIvan Krasin     if (RemarksEnabled)
1070f3403fd2SIvan Krasin       for (const auto &T : TargetsForSlot)
1071f3403fd2SIvan Krasin         if (T.WasDevirt)
1072f3403fd2SIvan Krasin           DevirtTargets[T.Fn->getName()] = T.Fn;
1073b05e06e4SIvan Krasin   }
1074df49d1bbSPeter Collingbourne 
1075f3403fd2SIvan Krasin   if (RemarksEnabled) {
1076f3403fd2SIvan Krasin     // Generate remarks for each devirtualized function.
1077f3403fd2SIvan Krasin     for (const auto &DT : DevirtTargets) {
1078f3403fd2SIvan Krasin       Function *F = DT.second;
1079f3403fd2SIvan Krasin       DISubprogram *SP = F->getSubprogram();
1080*7bc978b5SJustin Bogner       emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP,
1081f3403fd2SIvan Krasin                              Twine("devirtualized ") + F->getName());
1082b05e06e4SIvan Krasin     }
1083df49d1bbSPeter Collingbourne   }
1084df49d1bbSPeter Collingbourne 
10850312f614SPeter Collingbourne   // If we were able to eliminate all unsafe uses for a type checked load,
10860312f614SPeter Collingbourne   // eliminate the type test by replacing it with true.
10870312f614SPeter Collingbourne   if (TypeCheckedLoadFunc) {
10880312f614SPeter Collingbourne     auto True = ConstantInt::getTrue(M.getContext());
10890312f614SPeter Collingbourne     for (auto &&U : NumUnsafeUsesForTypeTest) {
10900312f614SPeter Collingbourne       if (U.second == 0) {
10910312f614SPeter Collingbourne         U.first->replaceAllUsesWith(True);
10920312f614SPeter Collingbourne         U.first->eraseFromParent();
10930312f614SPeter Collingbourne       }
10940312f614SPeter Collingbourne     }
10950312f614SPeter Collingbourne   }
10960312f614SPeter Collingbourne 
1097df49d1bbSPeter Collingbourne   // Rebuild each global we touched as part of virtual constant propagation to
1098df49d1bbSPeter Collingbourne   // include the before and after bytes.
1099df49d1bbSPeter Collingbourne   if (DidVirtualConstProp)
1100df49d1bbSPeter Collingbourne     for (VTableBits &B : Bits)
1101df49d1bbSPeter Collingbourne       rebuildGlobal(B);
1102df49d1bbSPeter Collingbourne 
1103df49d1bbSPeter Collingbourne   return true;
1104df49d1bbSPeter Collingbourne }
1105