1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2df49d1bbSPeter Collingbourne //
3df49d1bbSPeter Collingbourne //                     The LLVM Compiler Infrastructure
4df49d1bbSPeter Collingbourne //
5df49d1bbSPeter Collingbourne // This file is distributed under the University of Illinois Open Source
6df49d1bbSPeter Collingbourne // License. See LICENSE.TXT for details.
7df49d1bbSPeter Collingbourne //
8df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
9df49d1bbSPeter Collingbourne //
10df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases
117efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This
12df49d1bbSPeter Collingbourne // includes the following:
13df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single
14df49d1bbSPeter Collingbourne //   possible callee, replace all calls with a direct call to that callee.
15df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an
16df49d1bbSPeter Collingbourne //   integer <=64 bits and all possible callees are readnone, for each class and
17df49d1bbSPeter Collingbourne //   each list of constant arguments: evaluate the function, store the return
18df49d1bbSPeter Collingbourne //   value alongside the virtual table, and rewrite each virtual call as a load
19df49d1bbSPeter Collingbourne //   from the virtual table.
20df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant
21df49d1bbSPeter Collingbourne //   propagation hold and each function returns the same constant value, replace
22df49d1bbSPeter Collingbourne //   each virtual call with that constant.
23df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions
24df49d1bbSPeter Collingbourne //   for virtual constant propagation hold and a single vtable's function
25df49d1bbSPeter Collingbourne //   returns 0, or a single vtable's function returns 1, replace each virtual
26df49d1bbSPeter Collingbourne //   call with a comparison of the vptr against that vtable's address.
27df49d1bbSPeter Collingbourne //
28b406baaeSPeter Collingbourne // This pass is intended to be used during the regular and thin LTO pipelines.
29b406baaeSPeter Collingbourne // During regular LTO, the pass determines the best optimization for each
30b406baaeSPeter Collingbourne // virtual call and applies the resolutions directly to virtual calls that are
31b406baaeSPeter Collingbourne // eligible for virtual call optimization (i.e. calls that use either of the
32b406baaeSPeter Collingbourne // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During
33b406baaeSPeter Collingbourne // ThinLTO, the pass operates in two phases:
34b406baaeSPeter Collingbourne // - Export phase: this is run during the thin link over a single merged module
35b406baaeSPeter Collingbourne //   that contains all vtables with !type metadata that participate in the link.
36b406baaeSPeter Collingbourne //   The pass computes a resolution for each virtual call and stores it in the
37b406baaeSPeter Collingbourne //   type identifier summary.
38b406baaeSPeter Collingbourne // - Import phase: this is run during the thin backends over the individual
39b406baaeSPeter Collingbourne //   modules. The pass applies the resolutions previously computed during the
40b406baaeSPeter Collingbourne //   import phase to each eligible virtual call.
41b406baaeSPeter Collingbourne //
42df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
43df49d1bbSPeter Collingbourne 
44df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
45b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h"
46cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h"
47cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h"
48df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h"
49cdc71612SEugene Zelenko #include "llvm/ADT/iterator_range.h"
50df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h"
51cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h"
5237317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h"
5337317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h"
547efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h"
55df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h"
56df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h"
57df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h"
58b05e06e4SIvan Krasin #include "llvm/IR/DebugInfoMetadata.h"
59cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h"
60cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h"
615474645dSIvan Krasin #include "llvm/IR/DiagnosticInfo.h"
62cdc71612SEugene Zelenko #include "llvm/IR/Function.h"
63cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h"
64cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h"
65df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h"
66cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h"
67cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h"
68df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h"
69df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h"
70cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h"
71cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h"
72df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h"
732b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h"
74df49d1bbSPeter Collingbourne #include "llvm/Pass.h"
75cdc71612SEugene Zelenko #include "llvm/PassRegistry.h"
76cdc71612SEugene Zelenko #include "llvm/PassSupport.h"
77cdc71612SEugene Zelenko #include "llvm/Support/Casting.h"
782b33f653SPeter Collingbourne #include "llvm/Support/Error.h"
792b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h"
80cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h"
81b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h"
8237317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h"
83df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h"
84cdc71612SEugene Zelenko #include <algorithm>
85cdc71612SEugene Zelenko #include <cstddef>
86cdc71612SEugene Zelenko #include <map>
87df49d1bbSPeter Collingbourne #include <set>
88cdc71612SEugene Zelenko #include <string>
89df49d1bbSPeter Collingbourne 
90df49d1bbSPeter Collingbourne using namespace llvm;
91df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt;
92df49d1bbSPeter Collingbourne 
93df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt"
94df49d1bbSPeter Collingbourne 
952b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction(
962b33f653SPeter Collingbourne     "wholeprogramdevirt-summary-action",
972b33f653SPeter Collingbourne     cl::desc("What to do with the summary when running this pass"),
982b33f653SPeter Collingbourne     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
992b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Import, "import",
1002b33f653SPeter Collingbourne                           "Import typeid resolutions from summary and globals"),
1012b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Export, "export",
1022b33f653SPeter Collingbourne                           "Export typeid resolutions to summary and globals")),
1032b33f653SPeter Collingbourne     cl::Hidden);
1042b33f653SPeter Collingbourne 
1052b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary(
1062b33f653SPeter Collingbourne     "wholeprogramdevirt-read-summary",
1072b33f653SPeter Collingbourne     cl::desc("Read summary from given YAML file before running pass"),
1082b33f653SPeter Collingbourne     cl::Hidden);
1092b33f653SPeter Collingbourne 
1102b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary(
1112b33f653SPeter Collingbourne     "wholeprogramdevirt-write-summary",
1122b33f653SPeter Collingbourne     cl::desc("Write summary to given YAML file after running pass"),
1132b33f653SPeter Collingbourne     cl::Hidden);
1142b33f653SPeter Collingbourne 
115df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If
116df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an
117df49d1bbSPeter Collingbourne // offset after the object.
118df49d1bbSPeter Collingbourne uint64_t
119df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
120df49d1bbSPeter Collingbourne                                      bool IsAfter, uint64_t Size) {
121df49d1bbSPeter Collingbourne   // Find a minimum offset taking into account only vtable sizes.
122df49d1bbSPeter Collingbourne   uint64_t MinByte = 0;
123df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
124df49d1bbSPeter Collingbourne     if (IsAfter)
125df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minAfterBytes());
126df49d1bbSPeter Collingbourne     else
127df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minBeforeBytes());
128df49d1bbSPeter Collingbourne   }
129df49d1bbSPeter Collingbourne 
130df49d1bbSPeter Collingbourne   // Build a vector of arrays of bytes covering, for each target, a slice of the
131df49d1bbSPeter Collingbourne   // used region (see AccumBitVector::BytesUsed in
132df49d1bbSPeter Collingbourne   // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
133df49d1bbSPeter Collingbourne   // this aligns the used regions to start at MinByte.
134df49d1bbSPeter Collingbourne   //
135df49d1bbSPeter Collingbourne   // In this example, A, B and C are vtables, # is a byte already allocated for
136df49d1bbSPeter Collingbourne   // a virtual function pointer, AAAA... (etc.) are the used regions for the
137df49d1bbSPeter Collingbourne   // vtables and Offset(X) is the value computed for the Offset variable below
138df49d1bbSPeter Collingbourne   // for X.
139df49d1bbSPeter Collingbourne   //
140df49d1bbSPeter Collingbourne   //                    Offset(A)
141df49d1bbSPeter Collingbourne   //                    |       |
142df49d1bbSPeter Collingbourne   //                            |MinByte
143df49d1bbSPeter Collingbourne   // A: ################AAAAAAAA|AAAAAAAA
144df49d1bbSPeter Collingbourne   // B: ########BBBBBBBBBBBBBBBB|BBBB
145df49d1bbSPeter Collingbourne   // C: ########################|CCCCCCCCCCCCCCCC
146df49d1bbSPeter Collingbourne   //            |   Offset(B)   |
147df49d1bbSPeter Collingbourne   //
148df49d1bbSPeter Collingbourne   // This code produces the slices of A, B and C that appear after the divider
149df49d1bbSPeter Collingbourne   // at MinByte.
150df49d1bbSPeter Collingbourne   std::vector<ArrayRef<uint8_t>> Used;
151df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
1527efd7506SPeter Collingbourne     ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
1537efd7506SPeter Collingbourne                                        : Target.TM->Bits->Before.BytesUsed;
154df49d1bbSPeter Collingbourne     uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
155df49d1bbSPeter Collingbourne                               : MinByte - Target.minBeforeBytes();
156df49d1bbSPeter Collingbourne 
157df49d1bbSPeter Collingbourne     // Disregard used regions that are smaller than Offset. These are
158df49d1bbSPeter Collingbourne     // effectively all-free regions that do not need to be checked.
159df49d1bbSPeter Collingbourne     if (VTUsed.size() > Offset)
160df49d1bbSPeter Collingbourne       Used.push_back(VTUsed.slice(Offset));
161df49d1bbSPeter Collingbourne   }
162df49d1bbSPeter Collingbourne 
163df49d1bbSPeter Collingbourne   if (Size == 1) {
164df49d1bbSPeter Collingbourne     // Find a free bit in each member of Used.
165df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
166df49d1bbSPeter Collingbourne       uint8_t BitsUsed = 0;
167df49d1bbSPeter Collingbourne       for (auto &&B : Used)
168df49d1bbSPeter Collingbourne         if (I < B.size())
169df49d1bbSPeter Collingbourne           BitsUsed |= B[I];
170df49d1bbSPeter Collingbourne       if (BitsUsed != 0xff)
171df49d1bbSPeter Collingbourne         return (MinByte + I) * 8 +
172df49d1bbSPeter Collingbourne                countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
173df49d1bbSPeter Collingbourne     }
174df49d1bbSPeter Collingbourne   } else {
175df49d1bbSPeter Collingbourne     // Find a free (Size/8) byte region in each member of Used.
176df49d1bbSPeter Collingbourne     // FIXME: see if alignment helps.
177df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
178df49d1bbSPeter Collingbourne       for (auto &&B : Used) {
179df49d1bbSPeter Collingbourne         unsigned Byte = 0;
180df49d1bbSPeter Collingbourne         while ((I + Byte) < B.size() && Byte < (Size / 8)) {
181df49d1bbSPeter Collingbourne           if (B[I + Byte])
182df49d1bbSPeter Collingbourne             goto NextI;
183df49d1bbSPeter Collingbourne           ++Byte;
184df49d1bbSPeter Collingbourne         }
185df49d1bbSPeter Collingbourne       }
186df49d1bbSPeter Collingbourne       return (MinByte + I) * 8;
187df49d1bbSPeter Collingbourne     NextI:;
188df49d1bbSPeter Collingbourne     }
189df49d1bbSPeter Collingbourne   }
190df49d1bbSPeter Collingbourne }
191df49d1bbSPeter Collingbourne 
192df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues(
193df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
194df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
195df49d1bbSPeter Collingbourne   if (BitWidth == 1)
196df49d1bbSPeter Collingbourne     OffsetByte = -(AllocBefore / 8 + 1);
197df49d1bbSPeter Collingbourne   else
198df49d1bbSPeter Collingbourne     OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
199df49d1bbSPeter Collingbourne   OffsetBit = AllocBefore % 8;
200df49d1bbSPeter Collingbourne 
201df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
202df49d1bbSPeter Collingbourne     if (BitWidth == 1)
203df49d1bbSPeter Collingbourne       Target.setBeforeBit(AllocBefore);
204df49d1bbSPeter Collingbourne     else
205df49d1bbSPeter Collingbourne       Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
206df49d1bbSPeter Collingbourne   }
207df49d1bbSPeter Collingbourne }
208df49d1bbSPeter Collingbourne 
209df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues(
210df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
211df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
212df49d1bbSPeter Collingbourne   if (BitWidth == 1)
213df49d1bbSPeter Collingbourne     OffsetByte = AllocAfter / 8;
214df49d1bbSPeter Collingbourne   else
215df49d1bbSPeter Collingbourne     OffsetByte = (AllocAfter + 7) / 8;
216df49d1bbSPeter Collingbourne   OffsetBit = AllocAfter % 8;
217df49d1bbSPeter Collingbourne 
218df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
219df49d1bbSPeter Collingbourne     if (BitWidth == 1)
220df49d1bbSPeter Collingbourne       Target.setAfterBit(AllocAfter);
221df49d1bbSPeter Collingbourne     else
222df49d1bbSPeter Collingbourne       Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
223df49d1bbSPeter Collingbourne   }
224df49d1bbSPeter Collingbourne }
225df49d1bbSPeter Collingbourne 
2267efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
2277efd7506SPeter Collingbourne     : Fn(Fn), TM(TM),
22889439a79SIvan Krasin       IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
229df49d1bbSPeter Collingbourne 
230df49d1bbSPeter Collingbourne namespace {
231df49d1bbSPeter Collingbourne 
2327efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual
233df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to
234df49d1bbSPeter Collingbourne // the virtual function pointer.
235df49d1bbSPeter Collingbourne struct VTableSlot {
2367efd7506SPeter Collingbourne   Metadata *TypeID;
237df49d1bbSPeter Collingbourne   uint64_t ByteOffset;
238df49d1bbSPeter Collingbourne };
239df49d1bbSPeter Collingbourne 
240cdc71612SEugene Zelenko } // end anonymous namespace
241df49d1bbSPeter Collingbourne 
2429b656527SPeter Collingbourne namespace llvm {
2439b656527SPeter Collingbourne 
244df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> {
245df49d1bbSPeter Collingbourne   static VTableSlot getEmptyKey() {
246df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getEmptyKey(),
247df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getEmptyKey()};
248df49d1bbSPeter Collingbourne   }
249df49d1bbSPeter Collingbourne   static VTableSlot getTombstoneKey() {
250df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getTombstoneKey(),
251df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getTombstoneKey()};
252df49d1bbSPeter Collingbourne   }
253df49d1bbSPeter Collingbourne   static unsigned getHashValue(const VTableSlot &I) {
2547efd7506SPeter Collingbourne     return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
255df49d1bbSPeter Collingbourne            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
256df49d1bbSPeter Collingbourne   }
257df49d1bbSPeter Collingbourne   static bool isEqual(const VTableSlot &LHS,
258df49d1bbSPeter Collingbourne                       const VTableSlot &RHS) {
2597efd7506SPeter Collingbourne     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
260df49d1bbSPeter Collingbourne   }
261df49d1bbSPeter Collingbourne };
262df49d1bbSPeter Collingbourne 
263cdc71612SEugene Zelenko } // end namespace llvm
2649b656527SPeter Collingbourne 
265df49d1bbSPeter Collingbourne namespace {
266df49d1bbSPeter Collingbourne 
267df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is
268df49d1bbSPeter Collingbourne // the indirect virtual call.
269df49d1bbSPeter Collingbourne struct VirtualCallSite {
270df49d1bbSPeter Collingbourne   Value *VTable;
271df49d1bbSPeter Collingbourne   CallSite CS;
272df49d1bbSPeter Collingbourne 
2730312f614SPeter Collingbourne   // If non-null, this field points to the associated unsafe use count stored in
2740312f614SPeter Collingbourne   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
2750312f614SPeter Collingbourne   // of that field for details.
2760312f614SPeter Collingbourne   unsigned *NumUnsafeUses;
2770312f614SPeter Collingbourne 
278f3403fd2SIvan Krasin   void emitRemark(const Twine &OptName, const Twine &TargetName) {
2795474645dSIvan Krasin     Function *F = CS.getCaller();
280f3403fd2SIvan Krasin     emitOptimizationRemark(
281f3403fd2SIvan Krasin         F->getContext(), DEBUG_TYPE, *F,
2825474645dSIvan Krasin         CS.getInstruction()->getDebugLoc(),
283f3403fd2SIvan Krasin         OptName + ": devirtualized a call to " + TargetName);
2845474645dSIvan Krasin   }
2855474645dSIvan Krasin 
286f3403fd2SIvan Krasin   void replaceAndErase(const Twine &OptName, const Twine &TargetName,
287f3403fd2SIvan Krasin                        bool RemarksEnabled, Value *New) {
288f3403fd2SIvan Krasin     if (RemarksEnabled)
289f3403fd2SIvan Krasin       emitRemark(OptName, TargetName);
290df49d1bbSPeter Collingbourne     CS->replaceAllUsesWith(New);
291df49d1bbSPeter Collingbourne     if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
292df49d1bbSPeter Collingbourne       BranchInst::Create(II->getNormalDest(), CS.getInstruction());
293df49d1bbSPeter Collingbourne       II->getUnwindDest()->removePredecessor(II->getParent());
294df49d1bbSPeter Collingbourne     }
295df49d1bbSPeter Collingbourne     CS->eraseFromParent();
2960312f614SPeter Collingbourne     // This use is no longer unsafe.
2970312f614SPeter Collingbourne     if (NumUnsafeUses)
2980312f614SPeter Collingbourne       --*NumUnsafeUses;
299df49d1bbSPeter Collingbourne   }
300df49d1bbSPeter Collingbourne };
301df49d1bbSPeter Collingbourne 
30250cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list
30350cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the
30450cbd7ccSPeter Collingbourne // VTableSlotInfo class.
30550cbd7ccSPeter Collingbourne struct CallSiteInfo {
306b406baaeSPeter Collingbourne   /// The set of call sites for this slot. Used during regular LTO and the
307b406baaeSPeter Collingbourne   /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
308b406baaeSPeter Collingbourne   /// call sites that appear in the merged module itself); in each of these
309b406baaeSPeter Collingbourne   /// cases we are directly operating on the call sites at the IR level.
31050cbd7ccSPeter Collingbourne   std::vector<VirtualCallSite> CallSites;
311b406baaeSPeter Collingbourne 
312b406baaeSPeter Collingbourne   // These fields are used during the export phase of ThinLTO and reflect
313b406baaeSPeter Collingbourne   // information collected from function summaries.
314b406baaeSPeter Collingbourne 
3152325bb34SPeter Collingbourne   /// Whether any function summary contains an llvm.assume(llvm.type.test) for
3162325bb34SPeter Collingbourne   /// this slot.
3172325bb34SPeter Collingbourne   bool SummaryHasTypeTestAssumeUsers;
3182325bb34SPeter Collingbourne 
319b406baaeSPeter Collingbourne   /// CFI-specific: a vector containing the list of function summaries that use
320b406baaeSPeter Collingbourne   /// the llvm.type.checked.load intrinsic and therefore will require
321b406baaeSPeter Collingbourne   /// resolutions for llvm.type.test in order to implement CFI checks if
322b406baaeSPeter Collingbourne   /// devirtualization was unsuccessful. If devirtualization was successful, the
32359675ba0SPeter Collingbourne   /// pass will clear this vector by calling markDevirt(). If at the end of the
32459675ba0SPeter Collingbourne   /// pass the vector is non-empty, we will need to add a use of llvm.type.test
32559675ba0SPeter Collingbourne   /// to each of the function summaries in the vector.
326b406baaeSPeter Collingbourne   std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
3272325bb34SPeter Collingbourne 
3282325bb34SPeter Collingbourne   bool isExported() const {
3292325bb34SPeter Collingbourne     return SummaryHasTypeTestAssumeUsers ||
3302325bb34SPeter Collingbourne            !SummaryTypeCheckedLoadUsers.empty();
3312325bb34SPeter Collingbourne   }
33259675ba0SPeter Collingbourne 
33359675ba0SPeter Collingbourne   /// As explained in the comment for SummaryTypeCheckedLoadUsers.
33459675ba0SPeter Collingbourne   void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); }
33550cbd7ccSPeter Collingbourne };
33650cbd7ccSPeter Collingbourne 
33750cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot.
33850cbd7ccSPeter Collingbourne struct VTableSlotInfo {
33950cbd7ccSPeter Collingbourne   // The set of call sites which do not have all constant integer arguments
34050cbd7ccSPeter Collingbourne   // (excluding "this").
34150cbd7ccSPeter Collingbourne   CallSiteInfo CSInfo;
34250cbd7ccSPeter Collingbourne 
34350cbd7ccSPeter Collingbourne   // The set of call sites with all constant integer arguments (excluding
34450cbd7ccSPeter Collingbourne   // "this"), grouped by argument list.
34550cbd7ccSPeter Collingbourne   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
34650cbd7ccSPeter Collingbourne 
34750cbd7ccSPeter Collingbourne   void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
34850cbd7ccSPeter Collingbourne 
34950cbd7ccSPeter Collingbourne private:
35050cbd7ccSPeter Collingbourne   CallSiteInfo &findCallSiteInfo(CallSite CS);
35150cbd7ccSPeter Collingbourne };
35250cbd7ccSPeter Collingbourne 
35350cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
35450cbd7ccSPeter Collingbourne   std::vector<uint64_t> Args;
35550cbd7ccSPeter Collingbourne   auto *CI = dyn_cast<IntegerType>(CS.getType());
35650cbd7ccSPeter Collingbourne   if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
35750cbd7ccSPeter Collingbourne     return CSInfo;
35850cbd7ccSPeter Collingbourne   for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
35950cbd7ccSPeter Collingbourne     auto *CI = dyn_cast<ConstantInt>(Arg);
36050cbd7ccSPeter Collingbourne     if (!CI || CI->getBitWidth() > 64)
36150cbd7ccSPeter Collingbourne       return CSInfo;
36250cbd7ccSPeter Collingbourne     Args.push_back(CI->getZExtValue());
36350cbd7ccSPeter Collingbourne   }
36450cbd7ccSPeter Collingbourne   return ConstCSInfo[Args];
36550cbd7ccSPeter Collingbourne }
36650cbd7ccSPeter Collingbourne 
36750cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
36850cbd7ccSPeter Collingbourne                                  unsigned *NumUnsafeUses) {
36950cbd7ccSPeter Collingbourne   findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
37050cbd7ccSPeter Collingbourne }
37150cbd7ccSPeter Collingbourne 
372df49d1bbSPeter Collingbourne struct DevirtModule {
373df49d1bbSPeter Collingbourne   Module &M;
37437317f12SPeter Collingbourne   function_ref<AAResults &(Function &)> AARGetter;
3752b33f653SPeter Collingbourne 
376f7691d8bSPeter Collingbourne   ModuleSummaryIndex *ExportSummary;
377f7691d8bSPeter Collingbourne   const ModuleSummaryIndex *ImportSummary;
3782b33f653SPeter Collingbourne 
379df49d1bbSPeter Collingbourne   IntegerType *Int8Ty;
380df49d1bbSPeter Collingbourne   PointerType *Int8PtrTy;
381df49d1bbSPeter Collingbourne   IntegerType *Int32Ty;
38250cbd7ccSPeter Collingbourne   IntegerType *Int64Ty;
38314dcf02fSPeter Collingbourne   IntegerType *IntPtrTy;
384df49d1bbSPeter Collingbourne 
385f3403fd2SIvan Krasin   bool RemarksEnabled;
386f3403fd2SIvan Krasin 
38750cbd7ccSPeter Collingbourne   MapVector<VTableSlot, VTableSlotInfo> CallSlots;
388df49d1bbSPeter Collingbourne 
3890312f614SPeter Collingbourne   // This map keeps track of the number of "unsafe" uses of a loaded function
3900312f614SPeter Collingbourne   // pointer. The key is the associated llvm.type.test intrinsic call generated
3910312f614SPeter Collingbourne   // by this pass. An unsafe use is one that calls the loaded function pointer
3920312f614SPeter Collingbourne   // directly. Every time we eliminate an unsafe use (for example, by
3930312f614SPeter Collingbourne   // devirtualizing it or by applying virtual constant propagation), we
3940312f614SPeter Collingbourne   // decrement the value stored in this map. If a value reaches zero, we can
3950312f614SPeter Collingbourne   // eliminate the type check by RAUWing the associated llvm.type.test call with
3960312f614SPeter Collingbourne   // true.
3970312f614SPeter Collingbourne   std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
3980312f614SPeter Collingbourne 
39937317f12SPeter Collingbourne   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
400f7691d8bSPeter Collingbourne                ModuleSummaryIndex *ExportSummary,
401f7691d8bSPeter Collingbourne                const ModuleSummaryIndex *ImportSummary)
402f7691d8bSPeter Collingbourne       : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary),
403f7691d8bSPeter Collingbourne         ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),
404df49d1bbSPeter Collingbourne         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
405f3403fd2SIvan Krasin         Int32Ty(Type::getInt32Ty(M.getContext())),
40650cbd7ccSPeter Collingbourne         Int64Ty(Type::getInt64Ty(M.getContext())),
40714dcf02fSPeter Collingbourne         IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
408f7691d8bSPeter Collingbourne         RemarksEnabled(areRemarksEnabled()) {
409f7691d8bSPeter Collingbourne     assert(!(ExportSummary && ImportSummary));
410f7691d8bSPeter Collingbourne   }
411f3403fd2SIvan Krasin 
412f3403fd2SIvan Krasin   bool areRemarksEnabled();
413df49d1bbSPeter Collingbourne 
4140312f614SPeter Collingbourne   void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
4150312f614SPeter Collingbourne   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
4160312f614SPeter Collingbourne 
4177efd7506SPeter Collingbourne   void buildTypeIdentifierMap(
4187efd7506SPeter Collingbourne       std::vector<VTableBits> &Bits,
4197efd7506SPeter Collingbourne       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
4208786754cSPeter Collingbourne   Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
4217efd7506SPeter Collingbourne   bool
4227efd7506SPeter Collingbourne   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
4237efd7506SPeter Collingbourne                             const std::set<TypeMemberInfo> &TypeMemberInfos,
424df49d1bbSPeter Collingbourne                             uint64_t ByteOffset);
42550cbd7ccSPeter Collingbourne 
4262325bb34SPeter Collingbourne   void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
4272325bb34SPeter Collingbourne                              bool &IsExported);
428f3403fd2SIvan Krasin   bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
4292325bb34SPeter Collingbourne                            VTableSlotInfo &SlotInfo,
4302325bb34SPeter Collingbourne                            WholeProgramDevirtResolution *Res);
43150cbd7ccSPeter Collingbourne 
432df49d1bbSPeter Collingbourne   bool tryEvaluateFunctionsWithArgs(
433df49d1bbSPeter Collingbourne       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
43450cbd7ccSPeter Collingbourne       ArrayRef<uint64_t> Args);
43550cbd7ccSPeter Collingbourne 
43650cbd7ccSPeter Collingbourne   void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
43750cbd7ccSPeter Collingbourne                              uint64_t TheRetVal);
43850cbd7ccSPeter Collingbourne   bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
43977a8d563SPeter Collingbourne                            CallSiteInfo &CSInfo,
44077a8d563SPeter Collingbourne                            WholeProgramDevirtResolution::ByArg *Res);
44150cbd7ccSPeter Collingbourne 
44259675ba0SPeter Collingbourne   // Returns the global symbol name that is used to export information about the
44359675ba0SPeter Collingbourne   // given vtable slot and list of arguments.
44459675ba0SPeter Collingbourne   std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
44559675ba0SPeter Collingbourne                             StringRef Name);
44659675ba0SPeter Collingbourne 
44759675ba0SPeter Collingbourne   // This function is called during the export phase to create a symbol
44859675ba0SPeter Collingbourne   // definition containing information about the given vtable slot and list of
44959675ba0SPeter Collingbourne   // arguments.
45059675ba0SPeter Collingbourne   void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
45159675ba0SPeter Collingbourne                     Constant *C);
45259675ba0SPeter Collingbourne 
45359675ba0SPeter Collingbourne   // This function is called during the import phase to create a reference to
45459675ba0SPeter Collingbourne   // the symbol definition created during the export phase.
45559675ba0SPeter Collingbourne   Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
45614dcf02fSPeter Collingbourne                          StringRef Name, unsigned AbsWidth = 0);
45759675ba0SPeter Collingbourne 
45850cbd7ccSPeter Collingbourne   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
45950cbd7ccSPeter Collingbourne                             Constant *UniqueMemberAddr);
460df49d1bbSPeter Collingbourne   bool tryUniqueRetValOpt(unsigned BitWidth,
461f3403fd2SIvan Krasin                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
46259675ba0SPeter Collingbourne                           CallSiteInfo &CSInfo,
46359675ba0SPeter Collingbourne                           WholeProgramDevirtResolution::ByArg *Res,
46459675ba0SPeter Collingbourne                           VTableSlot Slot, ArrayRef<uint64_t> Args);
46550cbd7ccSPeter Collingbourne 
46650cbd7ccSPeter Collingbourne   void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
46750cbd7ccSPeter Collingbourne                              Constant *Byte, Constant *Bit);
468df49d1bbSPeter Collingbourne   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
46977a8d563SPeter Collingbourne                            VTableSlotInfo &SlotInfo,
47059675ba0SPeter Collingbourne                            WholeProgramDevirtResolution *Res, VTableSlot Slot);
471df49d1bbSPeter Collingbourne 
472df49d1bbSPeter Collingbourne   void rebuildGlobal(VTableBits &B);
473df49d1bbSPeter Collingbourne 
4746d284fabSPeter Collingbourne   // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
4756d284fabSPeter Collingbourne   void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
4766d284fabSPeter Collingbourne 
4776d284fabSPeter Collingbourne   // If we were able to eliminate all unsafe uses for a type checked load,
4786d284fabSPeter Collingbourne   // eliminate the associated type tests by replacing them with true.
4796d284fabSPeter Collingbourne   void removeRedundantTypeTests();
4806d284fabSPeter Collingbourne 
481df49d1bbSPeter Collingbourne   bool run();
4822b33f653SPeter Collingbourne 
4832b33f653SPeter Collingbourne   // Lower the module using the action and summary passed as command line
4842b33f653SPeter Collingbourne   // arguments. For testing purposes only.
48537317f12SPeter Collingbourne   static bool runForTesting(Module &M,
48637317f12SPeter Collingbourne                             function_ref<AAResults &(Function &)> AARGetter);
487df49d1bbSPeter Collingbourne };
488df49d1bbSPeter Collingbourne 
489df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass {
490df49d1bbSPeter Collingbourne   static char ID;
491cdc71612SEugene Zelenko 
4922b33f653SPeter Collingbourne   bool UseCommandLine = false;
4932b33f653SPeter Collingbourne 
494f7691d8bSPeter Collingbourne   ModuleSummaryIndex *ExportSummary;
495f7691d8bSPeter Collingbourne   const ModuleSummaryIndex *ImportSummary;
4962b33f653SPeter Collingbourne 
4972b33f653SPeter Collingbourne   WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
4982b33f653SPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
4992b33f653SPeter Collingbourne   }
5002b33f653SPeter Collingbourne 
501f7691d8bSPeter Collingbourne   WholeProgramDevirt(ModuleSummaryIndex *ExportSummary,
502f7691d8bSPeter Collingbourne                      const ModuleSummaryIndex *ImportSummary)
503f7691d8bSPeter Collingbourne       : ModulePass(ID), ExportSummary(ExportSummary),
504f7691d8bSPeter Collingbourne         ImportSummary(ImportSummary) {
505df49d1bbSPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
506df49d1bbSPeter Collingbourne   }
507cdc71612SEugene Zelenko 
508cdc71612SEugene Zelenko   bool runOnModule(Module &M) override {
509aa641a51SAndrew Kaylor     if (skipModule(M))
510aa641a51SAndrew Kaylor       return false;
5112b33f653SPeter Collingbourne     if (UseCommandLine)
51237317f12SPeter Collingbourne       return DevirtModule::runForTesting(M, LegacyAARGetter(*this));
513f7691d8bSPeter Collingbourne     return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary)
514f7691d8bSPeter Collingbourne         .run();
51537317f12SPeter Collingbourne   }
51637317f12SPeter Collingbourne 
51737317f12SPeter Collingbourne   void getAnalysisUsage(AnalysisUsage &AU) const override {
51837317f12SPeter Collingbourne     AU.addRequired<AssumptionCacheTracker>();
51937317f12SPeter Collingbourne     AU.addRequired<TargetLibraryInfoWrapperPass>();
520aa641a51SAndrew Kaylor   }
521df49d1bbSPeter Collingbourne };
522df49d1bbSPeter Collingbourne 
523cdc71612SEugene Zelenko } // end anonymous namespace
524df49d1bbSPeter Collingbourne 
52537317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
52637317f12SPeter Collingbourne                       "Whole program devirtualization", false, false)
52737317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
52837317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
52937317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
530df49d1bbSPeter Collingbourne                     "Whole program devirtualization", false, false)
531df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0;
532df49d1bbSPeter Collingbourne 
533f7691d8bSPeter Collingbourne ModulePass *
534f7691d8bSPeter Collingbourne llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary,
535f7691d8bSPeter Collingbourne                                    const ModuleSummaryIndex *ImportSummary) {
536f7691d8bSPeter Collingbourne   return new WholeProgramDevirt(ExportSummary, ImportSummary);
537df49d1bbSPeter Collingbourne }
538df49d1bbSPeter Collingbourne 
539164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
54037317f12SPeter Collingbourne                                               ModuleAnalysisManager &AM) {
54137317f12SPeter Collingbourne   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
54237317f12SPeter Collingbourne   auto AARGetter = [&](Function &F) -> AAResults & {
54337317f12SPeter Collingbourne     return FAM.getResult<AAManager>(F);
54437317f12SPeter Collingbourne   };
545f7691d8bSPeter Collingbourne   if (!DevirtModule(M, AARGetter, nullptr, nullptr).run())
546d737dd2eSDavide Italiano     return PreservedAnalyses::all();
547d737dd2eSDavide Italiano   return PreservedAnalyses::none();
548d737dd2eSDavide Italiano }
549d737dd2eSDavide Italiano 
55037317f12SPeter Collingbourne bool DevirtModule::runForTesting(
55137317f12SPeter Collingbourne     Module &M, function_ref<AAResults &(Function &)> AARGetter) {
5522b33f653SPeter Collingbourne   ModuleSummaryIndex Summary;
5532b33f653SPeter Collingbourne 
5542b33f653SPeter Collingbourne   // Handle the command-line summary arguments. This code is for testing
5552b33f653SPeter Collingbourne   // purposes only, so we handle errors directly.
5562b33f653SPeter Collingbourne   if (!ClReadSummary.empty()) {
5572b33f653SPeter Collingbourne     ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
5582b33f653SPeter Collingbourne                           ": ");
5592b33f653SPeter Collingbourne     auto ReadSummaryFile =
5602b33f653SPeter Collingbourne         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
5612b33f653SPeter Collingbourne 
5622b33f653SPeter Collingbourne     yaml::Input In(ReadSummaryFile->getBuffer());
5632b33f653SPeter Collingbourne     In >> Summary;
5642b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(In.error()));
5652b33f653SPeter Collingbourne   }
5662b33f653SPeter Collingbourne 
567f7691d8bSPeter Collingbourne   bool Changed =
568f7691d8bSPeter Collingbourne       DevirtModule(
569f7691d8bSPeter Collingbourne           M, AARGetter,
570f7691d8bSPeter Collingbourne           ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
571f7691d8bSPeter Collingbourne           ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
572f7691d8bSPeter Collingbourne           .run();
5732b33f653SPeter Collingbourne 
5742b33f653SPeter Collingbourne   if (!ClWriteSummary.empty()) {
5752b33f653SPeter Collingbourne     ExitOnError ExitOnErr(
5762b33f653SPeter Collingbourne         "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
5772b33f653SPeter Collingbourne     std::error_code EC;
5782b33f653SPeter Collingbourne     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
5792b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(EC));
5802b33f653SPeter Collingbourne 
5812b33f653SPeter Collingbourne     yaml::Output Out(OS);
5822b33f653SPeter Collingbourne     Out << Summary;
5832b33f653SPeter Collingbourne   }
5842b33f653SPeter Collingbourne 
5852b33f653SPeter Collingbourne   return Changed;
5862b33f653SPeter Collingbourne }
5872b33f653SPeter Collingbourne 
5887efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap(
589df49d1bbSPeter Collingbourne     std::vector<VTableBits> &Bits,
5907efd7506SPeter Collingbourne     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
591df49d1bbSPeter Collingbourne   DenseMap<GlobalVariable *, VTableBits *> GVToBits;
5927efd7506SPeter Collingbourne   Bits.reserve(M.getGlobalList().size());
5937efd7506SPeter Collingbourne   SmallVector<MDNode *, 2> Types;
5947efd7506SPeter Collingbourne   for (GlobalVariable &GV : M.globals()) {
5957efd7506SPeter Collingbourne     Types.clear();
5967efd7506SPeter Collingbourne     GV.getMetadata(LLVMContext::MD_type, Types);
5977efd7506SPeter Collingbourne     if (Types.empty())
598df49d1bbSPeter Collingbourne       continue;
599df49d1bbSPeter Collingbourne 
6007efd7506SPeter Collingbourne     VTableBits *&BitsPtr = GVToBits[&GV];
6017efd7506SPeter Collingbourne     if (!BitsPtr) {
6027efd7506SPeter Collingbourne       Bits.emplace_back();
6037efd7506SPeter Collingbourne       Bits.back().GV = &GV;
6047efd7506SPeter Collingbourne       Bits.back().ObjectSize =
6057efd7506SPeter Collingbourne           M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
6067efd7506SPeter Collingbourne       BitsPtr = &Bits.back();
6077efd7506SPeter Collingbourne     }
6087efd7506SPeter Collingbourne 
6097efd7506SPeter Collingbourne     for (MDNode *Type : Types) {
6107efd7506SPeter Collingbourne       auto TypeID = Type->getOperand(1).get();
611df49d1bbSPeter Collingbourne 
612df49d1bbSPeter Collingbourne       uint64_t Offset =
613df49d1bbSPeter Collingbourne           cast<ConstantInt>(
6147efd7506SPeter Collingbourne               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
615df49d1bbSPeter Collingbourne               ->getZExtValue();
616df49d1bbSPeter Collingbourne 
6177efd7506SPeter Collingbourne       TypeIdMap[TypeID].insert({BitsPtr, Offset});
618df49d1bbSPeter Collingbourne     }
619df49d1bbSPeter Collingbourne   }
620df49d1bbSPeter Collingbourne }
621df49d1bbSPeter Collingbourne 
6228786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
6238786754cSPeter Collingbourne   if (I->getType()->isPointerTy()) {
6248786754cSPeter Collingbourne     if (Offset == 0)
6258786754cSPeter Collingbourne       return I;
6268786754cSPeter Collingbourne     return nullptr;
6278786754cSPeter Collingbourne   }
6288786754cSPeter Collingbourne 
6297a1e5bbeSPeter Collingbourne   const DataLayout &DL = M.getDataLayout();
6307a1e5bbeSPeter Collingbourne 
6317a1e5bbeSPeter Collingbourne   if (auto *C = dyn_cast<ConstantStruct>(I)) {
6327a1e5bbeSPeter Collingbourne     const StructLayout *SL = DL.getStructLayout(C->getType());
6337a1e5bbeSPeter Collingbourne     if (Offset >= SL->getSizeInBytes())
6347a1e5bbeSPeter Collingbourne       return nullptr;
6357a1e5bbeSPeter Collingbourne 
6368786754cSPeter Collingbourne     unsigned Op = SL->getElementContainingOffset(Offset);
6378786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
6388786754cSPeter Collingbourne                               Offset - SL->getElementOffset(Op));
6398786754cSPeter Collingbourne   }
6408786754cSPeter Collingbourne   if (auto *C = dyn_cast<ConstantArray>(I)) {
6417a1e5bbeSPeter Collingbourne     ArrayType *VTableTy = C->getType();
6427a1e5bbeSPeter Collingbourne     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
6437a1e5bbeSPeter Collingbourne 
6448786754cSPeter Collingbourne     unsigned Op = Offset / ElemSize;
6457a1e5bbeSPeter Collingbourne     if (Op >= C->getNumOperands())
6467a1e5bbeSPeter Collingbourne       return nullptr;
6477a1e5bbeSPeter Collingbourne 
6488786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
6498786754cSPeter Collingbourne                               Offset % ElemSize);
6508786754cSPeter Collingbourne   }
6518786754cSPeter Collingbourne   return nullptr;
6527a1e5bbeSPeter Collingbourne }
6537a1e5bbeSPeter Collingbourne 
654df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets(
655df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> &TargetsForSlot,
6567efd7506SPeter Collingbourne     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
6577efd7506SPeter Collingbourne   for (const TypeMemberInfo &TM : TypeMemberInfos) {
6587efd7506SPeter Collingbourne     if (!TM.Bits->GV->isConstant())
659df49d1bbSPeter Collingbourne       return false;
660df49d1bbSPeter Collingbourne 
6618786754cSPeter Collingbourne     Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
6628786754cSPeter Collingbourne                                        TM.Offset + ByteOffset);
6638786754cSPeter Collingbourne     if (!Ptr)
664df49d1bbSPeter Collingbourne       return false;
665df49d1bbSPeter Collingbourne 
6668786754cSPeter Collingbourne     auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
667df49d1bbSPeter Collingbourne     if (!Fn)
668df49d1bbSPeter Collingbourne       return false;
669df49d1bbSPeter Collingbourne 
670df49d1bbSPeter Collingbourne     // We can disregard __cxa_pure_virtual as a possible call target, as
671df49d1bbSPeter Collingbourne     // calls to pure virtuals are UB.
672df49d1bbSPeter Collingbourne     if (Fn->getName() == "__cxa_pure_virtual")
673df49d1bbSPeter Collingbourne       continue;
674df49d1bbSPeter Collingbourne 
6757efd7506SPeter Collingbourne     TargetsForSlot.push_back({Fn, &TM});
676df49d1bbSPeter Collingbourne   }
677df49d1bbSPeter Collingbourne 
678df49d1bbSPeter Collingbourne   // Give up if we couldn't find any targets.
679df49d1bbSPeter Collingbourne   return !TargetsForSlot.empty();
680df49d1bbSPeter Collingbourne }
681df49d1bbSPeter Collingbourne 
68250cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
6832325bb34SPeter Collingbourne                                          Constant *TheFn, bool &IsExported) {
68450cbd7ccSPeter Collingbourne   auto Apply = [&](CallSiteInfo &CSInfo) {
68550cbd7ccSPeter Collingbourne     for (auto &&VCallSite : CSInfo.CallSites) {
686f3403fd2SIvan Krasin       if (RemarksEnabled)
687f3403fd2SIvan Krasin         VCallSite.emitRemark("single-impl", TheFn->getName());
688df49d1bbSPeter Collingbourne       VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
689df49d1bbSPeter Collingbourne           TheFn, VCallSite.CS.getCalledValue()->getType()));
6900312f614SPeter Collingbourne       // This use is no longer unsafe.
6910312f614SPeter Collingbourne       if (VCallSite.NumUnsafeUses)
6920312f614SPeter Collingbourne         --*VCallSite.NumUnsafeUses;
693df49d1bbSPeter Collingbourne     }
6942325bb34SPeter Collingbourne     if (CSInfo.isExported()) {
6952325bb34SPeter Collingbourne       IsExported = true;
69659675ba0SPeter Collingbourne       CSInfo.markDevirt();
6972325bb34SPeter Collingbourne     }
69850cbd7ccSPeter Collingbourne   };
69950cbd7ccSPeter Collingbourne   Apply(SlotInfo.CSInfo);
70050cbd7ccSPeter Collingbourne   for (auto &P : SlotInfo.ConstCSInfo)
70150cbd7ccSPeter Collingbourne     Apply(P.second);
70250cbd7ccSPeter Collingbourne }
70350cbd7ccSPeter Collingbourne 
70450cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt(
70550cbd7ccSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
7062325bb34SPeter Collingbourne     VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {
70750cbd7ccSPeter Collingbourne   // See if the program contains a single implementation of this virtual
70850cbd7ccSPeter Collingbourne   // function.
70950cbd7ccSPeter Collingbourne   Function *TheFn = TargetsForSlot[0].Fn;
71050cbd7ccSPeter Collingbourne   for (auto &&Target : TargetsForSlot)
71150cbd7ccSPeter Collingbourne     if (TheFn != Target.Fn)
71250cbd7ccSPeter Collingbourne       return false;
71350cbd7ccSPeter Collingbourne 
71450cbd7ccSPeter Collingbourne   // If so, update each call site to call that implementation directly.
71550cbd7ccSPeter Collingbourne   if (RemarksEnabled)
71650cbd7ccSPeter Collingbourne     TargetsForSlot[0].WasDevirt = true;
7172325bb34SPeter Collingbourne 
7182325bb34SPeter Collingbourne   bool IsExported = false;
7192325bb34SPeter Collingbourne   applySingleImplDevirt(SlotInfo, TheFn, IsExported);
7202325bb34SPeter Collingbourne   if (!IsExported)
7212325bb34SPeter Collingbourne     return false;
7222325bb34SPeter Collingbourne 
7232325bb34SPeter Collingbourne   // If the only implementation has local linkage, we must promote to external
7242325bb34SPeter Collingbourne   // to make it visible to thin LTO objects. We can only get here during the
7252325bb34SPeter Collingbourne   // ThinLTO export phase.
7262325bb34SPeter Collingbourne   if (TheFn->hasLocalLinkage()) {
7272325bb34SPeter Collingbourne     TheFn->setLinkage(GlobalValue::ExternalLinkage);
7282325bb34SPeter Collingbourne     TheFn->setVisibility(GlobalValue::HiddenVisibility);
7292325bb34SPeter Collingbourne     TheFn->setName(TheFn->getName() + "$merged");
7302325bb34SPeter Collingbourne   }
7312325bb34SPeter Collingbourne 
7322325bb34SPeter Collingbourne   Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
7332325bb34SPeter Collingbourne   Res->SingleImplName = TheFn->getName();
7342325bb34SPeter Collingbourne 
735df49d1bbSPeter Collingbourne   return true;
736df49d1bbSPeter Collingbourne }
737df49d1bbSPeter Collingbourne 
738df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs(
739df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
74050cbd7ccSPeter Collingbourne     ArrayRef<uint64_t> Args) {
741df49d1bbSPeter Collingbourne   // Evaluate each function and store the result in each target's RetVal
742df49d1bbSPeter Collingbourne   // field.
743df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
744df49d1bbSPeter Collingbourne     if (Target.Fn->arg_size() != Args.size() + 1)
745df49d1bbSPeter Collingbourne       return false;
746df49d1bbSPeter Collingbourne 
747df49d1bbSPeter Collingbourne     Evaluator Eval(M.getDataLayout(), nullptr);
748df49d1bbSPeter Collingbourne     SmallVector<Constant *, 2> EvalArgs;
749df49d1bbSPeter Collingbourne     EvalArgs.push_back(
750df49d1bbSPeter Collingbourne         Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
75150cbd7ccSPeter Collingbourne     for (unsigned I = 0; I != Args.size(); ++I) {
75250cbd7ccSPeter Collingbourne       auto *ArgTy = dyn_cast<IntegerType>(
75350cbd7ccSPeter Collingbourne           Target.Fn->getFunctionType()->getParamType(I + 1));
75450cbd7ccSPeter Collingbourne       if (!ArgTy)
75550cbd7ccSPeter Collingbourne         return false;
75650cbd7ccSPeter Collingbourne       EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
75750cbd7ccSPeter Collingbourne     }
75850cbd7ccSPeter Collingbourne 
759df49d1bbSPeter Collingbourne     Constant *RetVal;
760df49d1bbSPeter Collingbourne     if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
761df49d1bbSPeter Collingbourne         !isa<ConstantInt>(RetVal))
762df49d1bbSPeter Collingbourne       return false;
763df49d1bbSPeter Collingbourne     Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
764df49d1bbSPeter Collingbourne   }
765df49d1bbSPeter Collingbourne   return true;
766df49d1bbSPeter Collingbourne }
767df49d1bbSPeter Collingbourne 
76850cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
76950cbd7ccSPeter Collingbourne                                          uint64_t TheRetVal) {
77050cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites)
77150cbd7ccSPeter Collingbourne     Call.replaceAndErase(
77250cbd7ccSPeter Collingbourne         "uniform-ret-val", FnName, RemarksEnabled,
77350cbd7ccSPeter Collingbourne         ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
77459675ba0SPeter Collingbourne   CSInfo.markDevirt();
77550cbd7ccSPeter Collingbourne }
77650cbd7ccSPeter Collingbourne 
777df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt(
77877a8d563SPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
77977a8d563SPeter Collingbourne     WholeProgramDevirtResolution::ByArg *Res) {
780df49d1bbSPeter Collingbourne   // Uniform return value optimization. If all functions return the same
781df49d1bbSPeter Collingbourne   // constant, replace all calls with that constant.
782df49d1bbSPeter Collingbourne   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
783df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : TargetsForSlot)
784df49d1bbSPeter Collingbourne     if (Target.RetVal != TheRetVal)
785df49d1bbSPeter Collingbourne       return false;
786df49d1bbSPeter Collingbourne 
78777a8d563SPeter Collingbourne   if (CSInfo.isExported()) {
78877a8d563SPeter Collingbourne     Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
78977a8d563SPeter Collingbourne     Res->Info = TheRetVal;
79077a8d563SPeter Collingbourne   }
79177a8d563SPeter Collingbourne 
79250cbd7ccSPeter Collingbourne   applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
793f3403fd2SIvan Krasin   if (RemarksEnabled)
794f3403fd2SIvan Krasin     for (auto &&Target : TargetsForSlot)
795f3403fd2SIvan Krasin       Target.WasDevirt = true;
796df49d1bbSPeter Collingbourne   return true;
797df49d1bbSPeter Collingbourne }
798df49d1bbSPeter Collingbourne 
79959675ba0SPeter Collingbourne std::string DevirtModule::getGlobalName(VTableSlot Slot,
80059675ba0SPeter Collingbourne                                         ArrayRef<uint64_t> Args,
80159675ba0SPeter Collingbourne                                         StringRef Name) {
80259675ba0SPeter Collingbourne   std::string FullName = "__typeid_";
80359675ba0SPeter Collingbourne   raw_string_ostream OS(FullName);
80459675ba0SPeter Collingbourne   OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
80559675ba0SPeter Collingbourne   for (uint64_t Arg : Args)
80659675ba0SPeter Collingbourne     OS << '_' << Arg;
80759675ba0SPeter Collingbourne   OS << '_' << Name;
80859675ba0SPeter Collingbourne   return OS.str();
80959675ba0SPeter Collingbourne }
81059675ba0SPeter Collingbourne 
81159675ba0SPeter Collingbourne void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
81259675ba0SPeter Collingbourne                                 StringRef Name, Constant *C) {
81359675ba0SPeter Collingbourne   GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
81459675ba0SPeter Collingbourne                                         getGlobalName(Slot, Args, Name), C, &M);
81559675ba0SPeter Collingbourne   GA->setVisibility(GlobalValue::HiddenVisibility);
81659675ba0SPeter Collingbourne }
81759675ba0SPeter Collingbourne 
81859675ba0SPeter Collingbourne Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
81914dcf02fSPeter Collingbourne                                      StringRef Name, unsigned AbsWidth) {
82059675ba0SPeter Collingbourne   Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty);
82159675ba0SPeter Collingbourne   auto *GV = dyn_cast<GlobalVariable>(C);
82214dcf02fSPeter Collingbourne   // We only need to set metadata if the global is newly created, in which
82314dcf02fSPeter Collingbourne   // case it would not have hidden visibility.
82414dcf02fSPeter Collingbourne   if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility)
82559675ba0SPeter Collingbourne     return C;
82614dcf02fSPeter Collingbourne 
82759675ba0SPeter Collingbourne   GV->setVisibility(GlobalValue::HiddenVisibility);
82814dcf02fSPeter Collingbourne   auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
82914dcf02fSPeter Collingbourne     auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
83014dcf02fSPeter Collingbourne     auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
83114dcf02fSPeter Collingbourne     GV->setMetadata(LLVMContext::MD_absolute_symbol,
83214dcf02fSPeter Collingbourne                     MDNode::get(M.getContext(), {MinC, MaxC}));
83314dcf02fSPeter Collingbourne   };
83414dcf02fSPeter Collingbourne   if (AbsWidth == IntPtrTy->getBitWidth())
83514dcf02fSPeter Collingbourne     SetAbsRange(~0ull, ~0ull); // Full set.
83614dcf02fSPeter Collingbourne   else if (AbsWidth)
83714dcf02fSPeter Collingbourne     SetAbsRange(0, 1ull << AbsWidth);
83859675ba0SPeter Collingbourne   return GV;
83959675ba0SPeter Collingbourne }
84059675ba0SPeter Collingbourne 
84150cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
84250cbd7ccSPeter Collingbourne                                         bool IsOne,
84350cbd7ccSPeter Collingbourne                                         Constant *UniqueMemberAddr) {
84450cbd7ccSPeter Collingbourne   for (auto &&Call : CSInfo.CallSites) {
84550cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
84650cbd7ccSPeter Collingbourne     Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
84750cbd7ccSPeter Collingbourne                               Call.VTable, UniqueMemberAddr);
84850cbd7ccSPeter Collingbourne     Cmp = B.CreateZExt(Cmp, Call.CS->getType());
84950cbd7ccSPeter Collingbourne     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp);
85050cbd7ccSPeter Collingbourne   }
85159675ba0SPeter Collingbourne   CSInfo.markDevirt();
85250cbd7ccSPeter Collingbourne }
85350cbd7ccSPeter Collingbourne 
854df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt(
855f3403fd2SIvan Krasin     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
85659675ba0SPeter Collingbourne     CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
85759675ba0SPeter Collingbourne     VTableSlot Slot, ArrayRef<uint64_t> Args) {
858df49d1bbSPeter Collingbourne   // IsOne controls whether we look for a 0 or a 1.
859df49d1bbSPeter Collingbourne   auto tryUniqueRetValOptFor = [&](bool IsOne) {
860cdc71612SEugene Zelenko     const TypeMemberInfo *UniqueMember = nullptr;
861df49d1bbSPeter Collingbourne     for (const VirtualCallTarget &Target : TargetsForSlot) {
8623866cc5fSPeter Collingbourne       if (Target.RetVal == (IsOne ? 1 : 0)) {
8637efd7506SPeter Collingbourne         if (UniqueMember)
864df49d1bbSPeter Collingbourne           return false;
8657efd7506SPeter Collingbourne         UniqueMember = Target.TM;
866df49d1bbSPeter Collingbourne       }
867df49d1bbSPeter Collingbourne     }
868df49d1bbSPeter Collingbourne 
8697efd7506SPeter Collingbourne     // We should have found a unique member or bailed out by now. We already
870df49d1bbSPeter Collingbourne     // checked for a uniform return value in tryUniformRetValOpt.
8717efd7506SPeter Collingbourne     assert(UniqueMember);
872df49d1bbSPeter Collingbourne 
87350cbd7ccSPeter Collingbourne     Constant *UniqueMemberAddr =
87450cbd7ccSPeter Collingbourne         ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
87550cbd7ccSPeter Collingbourne     UniqueMemberAddr = ConstantExpr::getGetElementPtr(
87650cbd7ccSPeter Collingbourne         Int8Ty, UniqueMemberAddr,
87750cbd7ccSPeter Collingbourne         ConstantInt::get(Int64Ty, UniqueMember->Offset));
87850cbd7ccSPeter Collingbourne 
87959675ba0SPeter Collingbourne     if (CSInfo.isExported()) {
88059675ba0SPeter Collingbourne       Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
88159675ba0SPeter Collingbourne       Res->Info = IsOne;
88259675ba0SPeter Collingbourne 
88359675ba0SPeter Collingbourne       exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
88459675ba0SPeter Collingbourne     }
88559675ba0SPeter Collingbourne 
88659675ba0SPeter Collingbourne     // Replace each call with the comparison.
88750cbd7ccSPeter Collingbourne     applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
88850cbd7ccSPeter Collingbourne                          UniqueMemberAddr);
88950cbd7ccSPeter Collingbourne 
890f3403fd2SIvan Krasin     // Update devirtualization statistics for targets.
891f3403fd2SIvan Krasin     if (RemarksEnabled)
892f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
893f3403fd2SIvan Krasin         Target.WasDevirt = true;
894f3403fd2SIvan Krasin 
895df49d1bbSPeter Collingbourne     return true;
896df49d1bbSPeter Collingbourne   };
897df49d1bbSPeter Collingbourne 
898df49d1bbSPeter Collingbourne   if (BitWidth == 1) {
899df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(true))
900df49d1bbSPeter Collingbourne       return true;
901df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(false))
902df49d1bbSPeter Collingbourne       return true;
903df49d1bbSPeter Collingbourne   }
904df49d1bbSPeter Collingbourne   return false;
905df49d1bbSPeter Collingbourne }
906df49d1bbSPeter Collingbourne 
90750cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
90850cbd7ccSPeter Collingbourne                                          Constant *Byte, Constant *Bit) {
90950cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites) {
91050cbd7ccSPeter Collingbourne     auto *RetType = cast<IntegerType>(Call.CS.getType());
91150cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
91250cbd7ccSPeter Collingbourne     Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
91350cbd7ccSPeter Collingbourne     if (RetType->getBitWidth() == 1) {
91450cbd7ccSPeter Collingbourne       Value *Bits = B.CreateLoad(Addr);
91550cbd7ccSPeter Collingbourne       Value *BitsAndBit = B.CreateAnd(Bits, Bit);
91650cbd7ccSPeter Collingbourne       auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
91750cbd7ccSPeter Collingbourne       Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
91850cbd7ccSPeter Collingbourne                            IsBitSet);
91950cbd7ccSPeter Collingbourne     } else {
92050cbd7ccSPeter Collingbourne       Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
92150cbd7ccSPeter Collingbourne       Value *Val = B.CreateLoad(RetType, ValAddr);
92250cbd7ccSPeter Collingbourne       Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val);
92350cbd7ccSPeter Collingbourne     }
92450cbd7ccSPeter Collingbourne   }
92514dcf02fSPeter Collingbourne   CSInfo.markDevirt();
92650cbd7ccSPeter Collingbourne }
92750cbd7ccSPeter Collingbourne 
928df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp(
92959675ba0SPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
93059675ba0SPeter Collingbourne     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
931df49d1bbSPeter Collingbourne   // This only works if the function returns an integer.
932df49d1bbSPeter Collingbourne   auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
933df49d1bbSPeter Collingbourne   if (!RetType)
934df49d1bbSPeter Collingbourne     return false;
935df49d1bbSPeter Collingbourne   unsigned BitWidth = RetType->getBitWidth();
936df49d1bbSPeter Collingbourne   if (BitWidth > 64)
937df49d1bbSPeter Collingbourne     return false;
938df49d1bbSPeter Collingbourne 
93917febdbbSPeter Collingbourne   // Make sure that each function is defined, does not access memory, takes at
94017febdbbSPeter Collingbourne   // least one argument, does not use its first argument (which we assume is
94117febdbbSPeter Collingbourne   // 'this'), and has the same return type.
94237317f12SPeter Collingbourne   //
94337317f12SPeter Collingbourne   // Note that we test whether this copy of the function is readnone, rather
94437317f12SPeter Collingbourne   // than testing function attributes, which must hold for any copy of the
94537317f12SPeter Collingbourne   // function, even a less optimized version substituted at link time. This is
94637317f12SPeter Collingbourne   // sound because the virtual constant propagation optimizations effectively
94737317f12SPeter Collingbourne   // inline all implementations of the virtual function into each call site,
94837317f12SPeter Collingbourne   // rather than using function attributes to perform local optimization.
949df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
95037317f12SPeter Collingbourne     if (Target.Fn->isDeclaration() ||
95137317f12SPeter Collingbourne         computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
95237317f12SPeter Collingbourne             MAK_ReadNone ||
95317febdbbSPeter Collingbourne         Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
954df49d1bbSPeter Collingbourne         Target.Fn->getReturnType() != RetType)
955df49d1bbSPeter Collingbourne       return false;
956df49d1bbSPeter Collingbourne   }
957df49d1bbSPeter Collingbourne 
95850cbd7ccSPeter Collingbourne   for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
959df49d1bbSPeter Collingbourne     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
960df49d1bbSPeter Collingbourne       continue;
961df49d1bbSPeter Collingbourne 
96277a8d563SPeter Collingbourne     WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
96377a8d563SPeter Collingbourne     if (Res)
96477a8d563SPeter Collingbourne       ResByArg = &Res->ResByArg[CSByConstantArg.first];
96577a8d563SPeter Collingbourne 
96677a8d563SPeter Collingbourne     if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
967df49d1bbSPeter Collingbourne       continue;
968df49d1bbSPeter Collingbourne 
96959675ba0SPeter Collingbourne     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
97059675ba0SPeter Collingbourne                            ResByArg, Slot, CSByConstantArg.first))
971df49d1bbSPeter Collingbourne       continue;
972df49d1bbSPeter Collingbourne 
9737efd7506SPeter Collingbourne     // Find an allocation offset in bits in all vtables associated with the
9747efd7506SPeter Collingbourne     // type.
975df49d1bbSPeter Collingbourne     uint64_t AllocBefore =
976df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
977df49d1bbSPeter Collingbourne     uint64_t AllocAfter =
978df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
979df49d1bbSPeter Collingbourne 
980df49d1bbSPeter Collingbourne     // Calculate the total amount of padding needed to store a value at both
981df49d1bbSPeter Collingbourne     // ends of the object.
982df49d1bbSPeter Collingbourne     uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
983df49d1bbSPeter Collingbourne     for (auto &&Target : TargetsForSlot) {
984df49d1bbSPeter Collingbourne       TotalPaddingBefore += std::max<int64_t>(
985df49d1bbSPeter Collingbourne           (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
986df49d1bbSPeter Collingbourne       TotalPaddingAfter += std::max<int64_t>(
987df49d1bbSPeter Collingbourne           (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
988df49d1bbSPeter Collingbourne     }
989df49d1bbSPeter Collingbourne 
990df49d1bbSPeter Collingbourne     // If the amount of padding is too large, give up.
991df49d1bbSPeter Collingbourne     // FIXME: do something smarter here.
992df49d1bbSPeter Collingbourne     if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
993df49d1bbSPeter Collingbourne       continue;
994df49d1bbSPeter Collingbourne 
995df49d1bbSPeter Collingbourne     // Calculate the offset to the value as a (possibly negative) byte offset
996df49d1bbSPeter Collingbourne     // and (if applicable) a bit offset, and store the values in the targets.
997df49d1bbSPeter Collingbourne     int64_t OffsetByte;
998df49d1bbSPeter Collingbourne     uint64_t OffsetBit;
999df49d1bbSPeter Collingbourne     if (TotalPaddingBefore <= TotalPaddingAfter)
1000df49d1bbSPeter Collingbourne       setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
1001df49d1bbSPeter Collingbourne                             OffsetBit);
1002df49d1bbSPeter Collingbourne     else
1003df49d1bbSPeter Collingbourne       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
1004df49d1bbSPeter Collingbourne                            OffsetBit);
1005df49d1bbSPeter Collingbourne 
1006f3403fd2SIvan Krasin     if (RemarksEnabled)
1007f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
1008f3403fd2SIvan Krasin         Target.WasDevirt = true;
1009f3403fd2SIvan Krasin 
1010184773d8SPeter Collingbourne     Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
101150cbd7ccSPeter Collingbourne     Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
101214dcf02fSPeter Collingbourne 
101314dcf02fSPeter Collingbourne     if (CSByConstantArg.second.isExported()) {
101414dcf02fSPeter Collingbourne       ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
101514dcf02fSPeter Collingbourne       exportGlobal(Slot, CSByConstantArg.first, "byte",
101614dcf02fSPeter Collingbourne                    ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy));
101714dcf02fSPeter Collingbourne       exportGlobal(Slot, CSByConstantArg.first, "bit",
101814dcf02fSPeter Collingbourne                    ConstantExpr::getIntToPtr(BitConst, Int8PtrTy));
101914dcf02fSPeter Collingbourne     }
102014dcf02fSPeter Collingbourne 
102114dcf02fSPeter Collingbourne     // Rewrite each call to a load from OffsetByte/OffsetBit.
102250cbd7ccSPeter Collingbourne     applyVirtualConstProp(CSByConstantArg.second,
102350cbd7ccSPeter Collingbourne                           TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
1024df49d1bbSPeter Collingbourne   }
1025df49d1bbSPeter Collingbourne   return true;
1026df49d1bbSPeter Collingbourne }
1027df49d1bbSPeter Collingbourne 
1028df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) {
1029df49d1bbSPeter Collingbourne   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
1030df49d1bbSPeter Collingbourne     return;
1031df49d1bbSPeter Collingbourne 
1032df49d1bbSPeter Collingbourne   // Align each byte array to pointer width.
1033df49d1bbSPeter Collingbourne   unsigned PointerSize = M.getDataLayout().getPointerSize();
1034df49d1bbSPeter Collingbourne   B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
1035df49d1bbSPeter Collingbourne   B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
1036df49d1bbSPeter Collingbourne 
1037df49d1bbSPeter Collingbourne   // Before was stored in reverse order; flip it now.
1038df49d1bbSPeter Collingbourne   for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
1039df49d1bbSPeter Collingbourne     std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
1040df49d1bbSPeter Collingbourne 
1041df49d1bbSPeter Collingbourne   // Build an anonymous global containing the before bytes, followed by the
1042df49d1bbSPeter Collingbourne   // original initializer, followed by the after bytes.
1043df49d1bbSPeter Collingbourne   auto NewInit = ConstantStruct::getAnon(
1044df49d1bbSPeter Collingbourne       {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
1045df49d1bbSPeter Collingbourne        B.GV->getInitializer(),
1046df49d1bbSPeter Collingbourne        ConstantDataArray::get(M.getContext(), B.After.Bytes)});
1047df49d1bbSPeter Collingbourne   auto NewGV =
1048df49d1bbSPeter Collingbourne       new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
1049df49d1bbSPeter Collingbourne                          GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
1050df49d1bbSPeter Collingbourne   NewGV->setSection(B.GV->getSection());
1051df49d1bbSPeter Collingbourne   NewGV->setComdat(B.GV->getComdat());
1052df49d1bbSPeter Collingbourne 
10530312f614SPeter Collingbourne   // Copy the original vtable's metadata to the anonymous global, adjusting
10540312f614SPeter Collingbourne   // offsets as required.
10550312f614SPeter Collingbourne   NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
10560312f614SPeter Collingbourne 
1057df49d1bbSPeter Collingbourne   // Build an alias named after the original global, pointing at the second
1058df49d1bbSPeter Collingbourne   // element (the original initializer).
1059df49d1bbSPeter Collingbourne   auto Alias = GlobalAlias::create(
1060df49d1bbSPeter Collingbourne       B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
1061df49d1bbSPeter Collingbourne       ConstantExpr::getGetElementPtr(
1062df49d1bbSPeter Collingbourne           NewInit->getType(), NewGV,
1063df49d1bbSPeter Collingbourne           ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
1064df49d1bbSPeter Collingbourne                                ConstantInt::get(Int32Ty, 1)}),
1065df49d1bbSPeter Collingbourne       &M);
1066df49d1bbSPeter Collingbourne   Alias->setVisibility(B.GV->getVisibility());
1067df49d1bbSPeter Collingbourne   Alias->takeName(B.GV);
1068df49d1bbSPeter Collingbourne 
1069df49d1bbSPeter Collingbourne   B.GV->replaceAllUsesWith(Alias);
1070df49d1bbSPeter Collingbourne   B.GV->eraseFromParent();
1071df49d1bbSPeter Collingbourne }
1072df49d1bbSPeter Collingbourne 
1073f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() {
1074f3403fd2SIvan Krasin   const auto &FL = M.getFunctionList();
1075f3403fd2SIvan Krasin   if (FL.empty())
1076f3403fd2SIvan Krasin     return false;
1077f3403fd2SIvan Krasin   const Function &Fn = FL.front();
1078de53bfb9SAdam Nemet 
1079de53bfb9SAdam Nemet   const auto &BBL = Fn.getBasicBlockList();
1080de53bfb9SAdam Nemet   if (BBL.empty())
1081de53bfb9SAdam Nemet     return false;
1082de53bfb9SAdam Nemet   auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());
1083f3403fd2SIvan Krasin   return DI.isEnabled();
1084f3403fd2SIvan Krasin }
1085f3403fd2SIvan Krasin 
10860312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
10870312f614SPeter Collingbourne                                      Function *AssumeFunc) {
1088df49d1bbSPeter Collingbourne   // Find all virtual calls via a virtual table pointer %p under an assumption
10897efd7506SPeter Collingbourne   // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
10907efd7506SPeter Collingbourne   // points to a member of the type identifier %md. Group calls by (type ID,
10917efd7506SPeter Collingbourne   // offset) pair (effectively the identity of the virtual function) and store
10927efd7506SPeter Collingbourne   // to CallSlots.
1093df49d1bbSPeter Collingbourne   DenseSet<Value *> SeenPtrs;
10947efd7506SPeter Collingbourne   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
1095df49d1bbSPeter Collingbourne        I != E;) {
1096df49d1bbSPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
1097df49d1bbSPeter Collingbourne     ++I;
1098df49d1bbSPeter Collingbourne     if (!CI)
1099df49d1bbSPeter Collingbourne       continue;
1100df49d1bbSPeter Collingbourne 
1101ccdc225cSPeter Collingbourne     // Search for virtual calls based on %p and add them to DevirtCalls.
1102ccdc225cSPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
1103df49d1bbSPeter Collingbourne     SmallVector<CallInst *, 1> Assumes;
11040312f614SPeter Collingbourne     findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
1105df49d1bbSPeter Collingbourne 
1106ccdc225cSPeter Collingbourne     // If we found any, add them to CallSlots. Only do this if we haven't seen
1107ccdc225cSPeter Collingbourne     // the vtable pointer before, as it may have been CSE'd with pointers from
1108ccdc225cSPeter Collingbourne     // other call sites, and we don't want to process call sites multiple times.
1109df49d1bbSPeter Collingbourne     if (!Assumes.empty()) {
11107efd7506SPeter Collingbourne       Metadata *TypeId =
1111df49d1bbSPeter Collingbourne           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
1112df49d1bbSPeter Collingbourne       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
1113ccdc225cSPeter Collingbourne       if (SeenPtrs.insert(Ptr).second) {
1114ccdc225cSPeter Collingbourne         for (DevirtCallSite Call : DevirtCalls) {
111550cbd7ccSPeter Collingbourne           CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
111650cbd7ccSPeter Collingbourne                                                        Call.CS, nullptr);
1117ccdc225cSPeter Collingbourne         }
1118ccdc225cSPeter Collingbourne       }
1119df49d1bbSPeter Collingbourne     }
1120df49d1bbSPeter Collingbourne 
11217efd7506SPeter Collingbourne     // We no longer need the assumes or the type test.
1122df49d1bbSPeter Collingbourne     for (auto Assume : Assumes)
1123df49d1bbSPeter Collingbourne       Assume->eraseFromParent();
1124df49d1bbSPeter Collingbourne     // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
1125df49d1bbSPeter Collingbourne     // may use the vtable argument later.
1126df49d1bbSPeter Collingbourne     if (CI->use_empty())
1127df49d1bbSPeter Collingbourne       CI->eraseFromParent();
1128df49d1bbSPeter Collingbourne   }
11290312f614SPeter Collingbourne }
11300312f614SPeter Collingbourne 
11310312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
11320312f614SPeter Collingbourne   Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
11330312f614SPeter Collingbourne 
11340312f614SPeter Collingbourne   for (auto I = TypeCheckedLoadFunc->use_begin(),
11350312f614SPeter Collingbourne             E = TypeCheckedLoadFunc->use_end();
11360312f614SPeter Collingbourne        I != E;) {
11370312f614SPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
11380312f614SPeter Collingbourne     ++I;
11390312f614SPeter Collingbourne     if (!CI)
11400312f614SPeter Collingbourne       continue;
11410312f614SPeter Collingbourne 
11420312f614SPeter Collingbourne     Value *Ptr = CI->getArgOperand(0);
11430312f614SPeter Collingbourne     Value *Offset = CI->getArgOperand(1);
11440312f614SPeter Collingbourne     Value *TypeIdValue = CI->getArgOperand(2);
11450312f614SPeter Collingbourne     Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
11460312f614SPeter Collingbourne 
11470312f614SPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
11480312f614SPeter Collingbourne     SmallVector<Instruction *, 1> LoadedPtrs;
11490312f614SPeter Collingbourne     SmallVector<Instruction *, 1> Preds;
11500312f614SPeter Collingbourne     bool HasNonCallUses = false;
11510312f614SPeter Collingbourne     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
11520312f614SPeter Collingbourne                                                HasNonCallUses, CI);
11530312f614SPeter Collingbourne 
11540312f614SPeter Collingbourne     // Start by generating "pessimistic" code that explicitly loads the function
11550312f614SPeter Collingbourne     // pointer from the vtable and performs the type check. If possible, we will
11560312f614SPeter Collingbourne     // eliminate the load and the type check later.
11570312f614SPeter Collingbourne 
11580312f614SPeter Collingbourne     // If possible, only generate the load at the point where it is used.
11590312f614SPeter Collingbourne     // This helps avoid unnecessary spills.
11600312f614SPeter Collingbourne     IRBuilder<> LoadB(
11610312f614SPeter Collingbourne         (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
11620312f614SPeter Collingbourne     Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
11630312f614SPeter Collingbourne     Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
11640312f614SPeter Collingbourne     Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
11650312f614SPeter Collingbourne 
11660312f614SPeter Collingbourne     for (Instruction *LoadedPtr : LoadedPtrs) {
11670312f614SPeter Collingbourne       LoadedPtr->replaceAllUsesWith(LoadedValue);
11680312f614SPeter Collingbourne       LoadedPtr->eraseFromParent();
11690312f614SPeter Collingbourne     }
11700312f614SPeter Collingbourne 
11710312f614SPeter Collingbourne     // Likewise for the type test.
11720312f614SPeter Collingbourne     IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
11730312f614SPeter Collingbourne     CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
11740312f614SPeter Collingbourne 
11750312f614SPeter Collingbourne     for (Instruction *Pred : Preds) {
11760312f614SPeter Collingbourne       Pred->replaceAllUsesWith(TypeTestCall);
11770312f614SPeter Collingbourne       Pred->eraseFromParent();
11780312f614SPeter Collingbourne     }
11790312f614SPeter Collingbourne 
11800312f614SPeter Collingbourne     // We have already erased any extractvalue instructions that refer to the
11810312f614SPeter Collingbourne     // intrinsic call, but the intrinsic may have other non-extractvalue uses
11820312f614SPeter Collingbourne     // (although this is unlikely). In that case, explicitly build a pair and
11830312f614SPeter Collingbourne     // RAUW it.
11840312f614SPeter Collingbourne     if (!CI->use_empty()) {
11850312f614SPeter Collingbourne       Value *Pair = UndefValue::get(CI->getType());
11860312f614SPeter Collingbourne       IRBuilder<> B(CI);
11870312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
11880312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
11890312f614SPeter Collingbourne       CI->replaceAllUsesWith(Pair);
11900312f614SPeter Collingbourne     }
11910312f614SPeter Collingbourne 
11920312f614SPeter Collingbourne     // The number of unsafe uses is initially the number of uses.
11930312f614SPeter Collingbourne     auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
11940312f614SPeter Collingbourne     NumUnsafeUses = DevirtCalls.size();
11950312f614SPeter Collingbourne 
11960312f614SPeter Collingbourne     // If the function pointer has a non-call user, we cannot eliminate the type
11970312f614SPeter Collingbourne     // check, as one of those users may eventually call the pointer. Increment
11980312f614SPeter Collingbourne     // the unsafe use count to make sure it cannot reach zero.
11990312f614SPeter Collingbourne     if (HasNonCallUses)
12000312f614SPeter Collingbourne       ++NumUnsafeUses;
12010312f614SPeter Collingbourne     for (DevirtCallSite Call : DevirtCalls) {
120250cbd7ccSPeter Collingbourne       CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
120350cbd7ccSPeter Collingbourne                                                    &NumUnsafeUses);
12040312f614SPeter Collingbourne     }
12050312f614SPeter Collingbourne 
12060312f614SPeter Collingbourne     CI->eraseFromParent();
12070312f614SPeter Collingbourne   }
12080312f614SPeter Collingbourne }
12090312f614SPeter Collingbourne 
12106d284fabSPeter Collingbourne void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
12119a3f9797SPeter Collingbourne   const TypeIdSummary *TidSummary =
1212f7691d8bSPeter Collingbourne       ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString());
12139a3f9797SPeter Collingbourne   if (!TidSummary)
12149a3f9797SPeter Collingbourne     return;
12159a3f9797SPeter Collingbourne   auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
12169a3f9797SPeter Collingbourne   if (ResI == TidSummary->WPDRes.end())
12179a3f9797SPeter Collingbourne     return;
12189a3f9797SPeter Collingbourne   const WholeProgramDevirtResolution &Res = ResI->second;
12196d284fabSPeter Collingbourne 
12206d284fabSPeter Collingbourne   if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
12216d284fabSPeter Collingbourne     // The type of the function in the declaration is irrelevant because every
12226d284fabSPeter Collingbourne     // call site will cast it to the correct type.
1223*db11fdfdSMehdi Amini     auto *SingleImpl = M.getOrInsertFunction(
1224*db11fdfdSMehdi Amini         Res.SingleImplName, Type::getVoidTy(M.getContext()), nullptr);
12256d284fabSPeter Collingbourne 
12266d284fabSPeter Collingbourne     // This is the import phase so we should not be exporting anything.
12276d284fabSPeter Collingbourne     bool IsExported = false;
12286d284fabSPeter Collingbourne     applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
12296d284fabSPeter Collingbourne     assert(!IsExported);
12306d284fabSPeter Collingbourne   }
12310152c815SPeter Collingbourne 
12320152c815SPeter Collingbourne   for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
12330152c815SPeter Collingbourne     auto I = Res.ResByArg.find(CSByConstantArg.first);
12340152c815SPeter Collingbourne     if (I == Res.ResByArg.end())
12350152c815SPeter Collingbourne       continue;
12360152c815SPeter Collingbourne     auto &ResByArg = I->second;
12370152c815SPeter Collingbourne     // FIXME: We should figure out what to do about the "function name" argument
12380152c815SPeter Collingbourne     // to the apply* functions, as the function names are unavailable during the
12390152c815SPeter Collingbourne     // importing phase. For now we just pass the empty string. This does not
12400152c815SPeter Collingbourne     // impact correctness because the function names are just used for remarks.
12410152c815SPeter Collingbourne     switch (ResByArg.TheKind) {
12420152c815SPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::UniformRetVal:
12430152c815SPeter Collingbourne       applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
12440152c815SPeter Collingbourne       break;
124559675ba0SPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
124659675ba0SPeter Collingbourne       Constant *UniqueMemberAddr =
124759675ba0SPeter Collingbourne           importGlobal(Slot, CSByConstantArg.first, "unique_member");
124859675ba0SPeter Collingbourne       applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
124959675ba0SPeter Collingbourne                            UniqueMemberAddr);
125059675ba0SPeter Collingbourne       break;
125159675ba0SPeter Collingbourne     }
125214dcf02fSPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
125314dcf02fSPeter Collingbourne       Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32);
125414dcf02fSPeter Collingbourne       Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty);
125514dcf02fSPeter Collingbourne       Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8);
125614dcf02fSPeter Collingbourne       Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty);
125714dcf02fSPeter Collingbourne       applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
125814dcf02fSPeter Collingbourne     }
12590152c815SPeter Collingbourne     default:
12600152c815SPeter Collingbourne       break;
12610152c815SPeter Collingbourne     }
12620152c815SPeter Collingbourne   }
12636d284fabSPeter Collingbourne }
12646d284fabSPeter Collingbourne 
12656d284fabSPeter Collingbourne void DevirtModule::removeRedundantTypeTests() {
12666d284fabSPeter Collingbourne   auto True = ConstantInt::getTrue(M.getContext());
12676d284fabSPeter Collingbourne   for (auto &&U : NumUnsafeUsesForTypeTest) {
12686d284fabSPeter Collingbourne     if (U.second == 0) {
12696d284fabSPeter Collingbourne       U.first->replaceAllUsesWith(True);
12706d284fabSPeter Collingbourne       U.first->eraseFromParent();
12716d284fabSPeter Collingbourne     }
12726d284fabSPeter Collingbourne   }
12736d284fabSPeter Collingbourne }
12746d284fabSPeter Collingbourne 
12750312f614SPeter Collingbourne bool DevirtModule::run() {
12760312f614SPeter Collingbourne   Function *TypeTestFunc =
12770312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
12780312f614SPeter Collingbourne   Function *TypeCheckedLoadFunc =
12790312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
12800312f614SPeter Collingbourne   Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
12810312f614SPeter Collingbourne 
1282b406baaeSPeter Collingbourne   // Normally if there are no users of the devirtualization intrinsics in the
1283b406baaeSPeter Collingbourne   // module, this pass has nothing to do. But if we are exporting, we also need
1284b406baaeSPeter Collingbourne   // to handle any users that appear only in the function summaries.
1285f7691d8bSPeter Collingbourne   if (!ExportSummary &&
1286b406baaeSPeter Collingbourne       (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
12870312f614SPeter Collingbourne        AssumeFunc->use_empty()) &&
12880312f614SPeter Collingbourne       (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
12890312f614SPeter Collingbourne     return false;
12900312f614SPeter Collingbourne 
12910312f614SPeter Collingbourne   if (TypeTestFunc && AssumeFunc)
12920312f614SPeter Collingbourne     scanTypeTestUsers(TypeTestFunc, AssumeFunc);
12930312f614SPeter Collingbourne 
12940312f614SPeter Collingbourne   if (TypeCheckedLoadFunc)
12950312f614SPeter Collingbourne     scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
1296df49d1bbSPeter Collingbourne 
1297f7691d8bSPeter Collingbourne   if (ImportSummary) {
12986d284fabSPeter Collingbourne     for (auto &S : CallSlots)
12996d284fabSPeter Collingbourne       importResolution(S.first, S.second);
13006d284fabSPeter Collingbourne 
13016d284fabSPeter Collingbourne     removeRedundantTypeTests();
13026d284fabSPeter Collingbourne 
13036d284fabSPeter Collingbourne     // The rest of the code is only necessary when exporting or during regular
13046d284fabSPeter Collingbourne     // LTO, so we are done.
13056d284fabSPeter Collingbourne     return true;
13066d284fabSPeter Collingbourne   }
13076d284fabSPeter Collingbourne 
13087efd7506SPeter Collingbourne   // Rebuild type metadata into a map for easy lookup.
1309df49d1bbSPeter Collingbourne   std::vector<VTableBits> Bits;
13107efd7506SPeter Collingbourne   DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
13117efd7506SPeter Collingbourne   buildTypeIdentifierMap(Bits, TypeIdMap);
13127efd7506SPeter Collingbourne   if (TypeIdMap.empty())
1313df49d1bbSPeter Collingbourne     return true;
1314df49d1bbSPeter Collingbourne 
1315b406baaeSPeter Collingbourne   // Collect information from summary about which calls to try to devirtualize.
1316f7691d8bSPeter Collingbourne   if (ExportSummary) {
1317b406baaeSPeter Collingbourne     DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
1318b406baaeSPeter Collingbourne     for (auto &P : TypeIdMap) {
1319b406baaeSPeter Collingbourne       if (auto *TypeId = dyn_cast<MDString>(P.first))
1320b406baaeSPeter Collingbourne         MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
1321b406baaeSPeter Collingbourne             TypeId);
1322b406baaeSPeter Collingbourne     }
1323b406baaeSPeter Collingbourne 
1324f7691d8bSPeter Collingbourne     for (auto &P : *ExportSummary) {
1325b406baaeSPeter Collingbourne       for (auto &S : P.second) {
1326b406baaeSPeter Collingbourne         auto *FS = dyn_cast<FunctionSummary>(S.get());
1327b406baaeSPeter Collingbourne         if (!FS)
1328b406baaeSPeter Collingbourne           continue;
1329b406baaeSPeter Collingbourne         // FIXME: Only add live functions.
13305d8aea10SGeorge Rimar         for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
13315d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
13322325bb34SPeter Collingbourne             CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers =
13332325bb34SPeter Collingbourne                 true;
13345d8aea10SGeorge Rimar           }
13355d8aea10SGeorge Rimar         }
13365d8aea10SGeorge Rimar         for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
13375d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
1338b406baaeSPeter Collingbourne             CallSlots[{MD, VF.Offset}]
1339b406baaeSPeter Collingbourne                 .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS);
13405d8aea10SGeorge Rimar           }
13415d8aea10SGeorge Rimar         }
1342b406baaeSPeter Collingbourne         for (const FunctionSummary::ConstVCall &VC :
13435d8aea10SGeorge Rimar              FS->type_test_assume_const_vcalls()) {
13445d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
13452325bb34SPeter Collingbourne             CallSlots[{MD, VC.VFunc.Offset}]
13465d8aea10SGeorge Rimar                 .ConstCSInfo[VC.Args]
13475d8aea10SGeorge Rimar                 .SummaryHasTypeTestAssumeUsers = true;
13485d8aea10SGeorge Rimar           }
13495d8aea10SGeorge Rimar         }
13502325bb34SPeter Collingbourne         for (const FunctionSummary::ConstVCall &VC :
13515d8aea10SGeorge Rimar              FS->type_checked_load_const_vcalls()) {
13525d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
1353b406baaeSPeter Collingbourne             CallSlots[{MD, VC.VFunc.Offset}]
1354b406baaeSPeter Collingbourne                 .ConstCSInfo[VC.Args]
1355b406baaeSPeter Collingbourne                 .SummaryTypeCheckedLoadUsers.push_back(FS);
1356b406baaeSPeter Collingbourne           }
1357b406baaeSPeter Collingbourne         }
1358b406baaeSPeter Collingbourne       }
13595d8aea10SGeorge Rimar     }
13605d8aea10SGeorge Rimar   }
1361b406baaeSPeter Collingbourne 
13627efd7506SPeter Collingbourne   // For each (type, offset) pair:
1363df49d1bbSPeter Collingbourne   bool DidVirtualConstProp = false;
1364f3403fd2SIvan Krasin   std::map<std::string, Function*> DevirtTargets;
1365df49d1bbSPeter Collingbourne   for (auto &S : CallSlots) {
13667efd7506SPeter Collingbourne     // Search each of the members of the type identifier for the virtual
13677efd7506SPeter Collingbourne     // function implementation at offset S.first.ByteOffset, and add to
13687efd7506SPeter Collingbourne     // TargetsForSlot.
1369df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> TargetsForSlot;
1370b406baaeSPeter Collingbourne     if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
1371b406baaeSPeter Collingbourne                                   S.first.ByteOffset)) {
13722325bb34SPeter Collingbourne       WholeProgramDevirtResolution *Res = nullptr;
1373f7691d8bSPeter Collingbourne       if (ExportSummary && isa<MDString>(S.first.TypeID))
1374f7691d8bSPeter Collingbourne         Res = &ExportSummary
13759a3f9797SPeter Collingbourne                    ->getOrInsertTypeIdSummary(
13769a3f9797SPeter Collingbourne                        cast<MDString>(S.first.TypeID)->getString())
13772325bb34SPeter Collingbourne                    .WPDRes[S.first.ByteOffset];
13782325bb34SPeter Collingbourne 
13792325bb34SPeter Collingbourne       if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) &&
138059675ba0SPeter Collingbourne           tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))
1381f3403fd2SIvan Krasin         DidVirtualConstProp = true;
1382f3403fd2SIvan Krasin 
1383f3403fd2SIvan Krasin       // Collect functions devirtualized at least for one call site for stats.
1384f3403fd2SIvan Krasin       if (RemarksEnabled)
1385f3403fd2SIvan Krasin         for (const auto &T : TargetsForSlot)
1386f3403fd2SIvan Krasin           if (T.WasDevirt)
1387f3403fd2SIvan Krasin             DevirtTargets[T.Fn->getName()] = T.Fn;
1388b05e06e4SIvan Krasin     }
1389df49d1bbSPeter Collingbourne 
1390b406baaeSPeter Collingbourne     // CFI-specific: if we are exporting and any llvm.type.checked.load
1391b406baaeSPeter Collingbourne     // intrinsics were *not* devirtualized, we need to add the resulting
1392b406baaeSPeter Collingbourne     // llvm.type.test intrinsics to the function summaries so that the
1393b406baaeSPeter Collingbourne     // LowerTypeTests pass will export them.
1394f7691d8bSPeter Collingbourne     if (ExportSummary && isa<MDString>(S.first.TypeID)) {
1395b406baaeSPeter Collingbourne       auto GUID =
1396b406baaeSPeter Collingbourne           GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
1397b406baaeSPeter Collingbourne       for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
1398b406baaeSPeter Collingbourne         FS->addTypeTest(GUID);
1399b406baaeSPeter Collingbourne       for (auto &CCS : S.second.ConstCSInfo)
1400b406baaeSPeter Collingbourne         for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers)
1401b406baaeSPeter Collingbourne           FS->addTypeTest(GUID);
1402b406baaeSPeter Collingbourne     }
1403b406baaeSPeter Collingbourne   }
1404b406baaeSPeter Collingbourne 
1405f3403fd2SIvan Krasin   if (RemarksEnabled) {
1406f3403fd2SIvan Krasin     // Generate remarks for each devirtualized function.
1407f3403fd2SIvan Krasin     for (const auto &DT : DevirtTargets) {
1408f3403fd2SIvan Krasin       Function *F = DT.second;
1409f3403fd2SIvan Krasin       DISubprogram *SP = F->getSubprogram();
14107bc978b5SJustin Bogner       emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP,
1411f3403fd2SIvan Krasin                              Twine("devirtualized ") + F->getName());
1412b05e06e4SIvan Krasin     }
1413df49d1bbSPeter Collingbourne   }
1414df49d1bbSPeter Collingbourne 
14156d284fabSPeter Collingbourne   removeRedundantTypeTests();
14160312f614SPeter Collingbourne 
1417df49d1bbSPeter Collingbourne   // Rebuild each global we touched as part of virtual constant propagation to
1418df49d1bbSPeter Collingbourne   // include the before and after bytes.
1419df49d1bbSPeter Collingbourne   if (DidVirtualConstProp)
1420df49d1bbSPeter Collingbourne     for (VTableBits &B : Bits)
1421df49d1bbSPeter Collingbourne       rebuildGlobal(B);
1422df49d1bbSPeter Collingbourne 
1423df49d1bbSPeter Collingbourne   return true;
1424df49d1bbSPeter Collingbourne }
1425