1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2df49d1bbSPeter Collingbourne //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6df49d1bbSPeter Collingbourne //
7df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
8df49d1bbSPeter Collingbourne //
9df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases
107efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This
11df49d1bbSPeter Collingbourne // includes the following:
12df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single
13df49d1bbSPeter Collingbourne //   possible callee, replace all calls with a direct call to that callee.
14df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an
15df49d1bbSPeter Collingbourne //   integer <=64 bits and all possible callees are readnone, for each class and
16df49d1bbSPeter Collingbourne //   each list of constant arguments: evaluate the function, store the return
17df49d1bbSPeter Collingbourne //   value alongside the virtual table, and rewrite each virtual call as a load
18df49d1bbSPeter Collingbourne //   from the virtual table.
19df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant
20df49d1bbSPeter Collingbourne //   propagation hold and each function returns the same constant value, replace
21df49d1bbSPeter Collingbourne //   each virtual call with that constant.
22df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions
23df49d1bbSPeter Collingbourne //   for virtual constant propagation hold and a single vtable's function
24df49d1bbSPeter Collingbourne //   returns 0, or a single vtable's function returns 1, replace each virtual
25df49d1bbSPeter Collingbourne //   call with a comparison of the vptr against that vtable's address.
26df49d1bbSPeter Collingbourne //
27d2df54e6STeresa Johnson // This pass is intended to be used during the regular and thin LTO pipelines:
28d2df54e6STeresa Johnson //
29b406baaeSPeter Collingbourne // During regular LTO, the pass determines the best optimization for each
30b406baaeSPeter Collingbourne // virtual call and applies the resolutions directly to virtual calls that are
31b406baaeSPeter Collingbourne // eligible for virtual call optimization (i.e. calls that use either of the
32d2df54e6STeresa Johnson // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics).
33d2df54e6STeresa Johnson //
34d2df54e6STeresa Johnson // During hybrid Regular/ThinLTO, the pass operates in two phases:
35b406baaeSPeter Collingbourne // - Export phase: this is run during the thin link over a single merged module
36b406baaeSPeter Collingbourne //   that contains all vtables with !type metadata that participate in the link.
37b406baaeSPeter Collingbourne //   The pass computes a resolution for each virtual call and stores it in the
38b406baaeSPeter Collingbourne //   type identifier summary.
39b406baaeSPeter Collingbourne // - Import phase: this is run during the thin backends over the individual
40b406baaeSPeter Collingbourne //   modules. The pass applies the resolutions previously computed during the
41b406baaeSPeter Collingbourne //   import phase to each eligible virtual call.
42b406baaeSPeter Collingbourne //
43d2df54e6STeresa Johnson // During ThinLTO, the pass operates in two phases:
44d2df54e6STeresa Johnson // - Export phase: this is run during the thin link over the index which
45d2df54e6STeresa Johnson //   contains a summary of all vtables with !type metadata that participate in
46d2df54e6STeresa Johnson //   the link. It computes a resolution for each virtual call and stores it in
47d2df54e6STeresa Johnson //   the type identifier summary. Only single implementation devirtualization
48d2df54e6STeresa Johnson //   is supported.
49d2df54e6STeresa Johnson // - Import phase: (same as with hybrid case above).
50d2df54e6STeresa Johnson //
51df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
52df49d1bbSPeter Collingbourne 
53df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
54b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h"
55cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h"
56cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h"
57df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h"
58df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h"
59cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h"
606bda14b3SChandler Carruth #include "llvm/ADT/iterator_range.h"
6137317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h"
6237317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h"
630965da20SAdam Nemet #include "llvm/Analysis/OptimizationRemarkEmitter.h"
647efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h"
65df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h"
66df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h"
67df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h"
68cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h"
69cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h"
70f24136f1STeresa Johnson #include "llvm/IR/Dominators.h"
71cdc71612SEugene Zelenko #include "llvm/IR/Function.h"
72cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h"
73cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h"
74df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h"
75cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h"
76cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h"
77df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h"
78df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h"
79cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h"
80cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h"
81df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h"
822b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h"
83df49d1bbSPeter Collingbourne #include "llvm/Pass.h"
84cdc71612SEugene Zelenko #include "llvm/PassRegistry.h"
85cdc71612SEugene Zelenko #include "llvm/PassSupport.h"
86cdc71612SEugene Zelenko #include "llvm/Support/Casting.h"
872b33f653SPeter Collingbourne #include "llvm/Support/Error.h"
882b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h"
89cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h"
90b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h"
9137317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h"
92df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h"
93cdc71612SEugene Zelenko #include <algorithm>
94cdc71612SEugene Zelenko #include <cstddef>
95cdc71612SEugene Zelenko #include <map>
96df49d1bbSPeter Collingbourne #include <set>
97cdc71612SEugene Zelenko #include <string>
98df49d1bbSPeter Collingbourne 
99df49d1bbSPeter Collingbourne using namespace llvm;
100df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt;
101df49d1bbSPeter Collingbourne 
102df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt"
103df49d1bbSPeter Collingbourne 
1042b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction(
1052b33f653SPeter Collingbourne     "wholeprogramdevirt-summary-action",
1062b33f653SPeter Collingbourne     cl::desc("What to do with the summary when running this pass"),
1072b33f653SPeter Collingbourne     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
1082b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Import, "import",
1092b33f653SPeter Collingbourne                           "Import typeid resolutions from summary and globals"),
1102b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Export, "export",
1112b33f653SPeter Collingbourne                           "Export typeid resolutions to summary and globals")),
1122b33f653SPeter Collingbourne     cl::Hidden);
1132b33f653SPeter Collingbourne 
1142b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary(
1152b33f653SPeter Collingbourne     "wholeprogramdevirt-read-summary",
1162b33f653SPeter Collingbourne     cl::desc("Read summary from given YAML file before running pass"),
1172b33f653SPeter Collingbourne     cl::Hidden);
1182b33f653SPeter Collingbourne 
1192b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary(
1202b33f653SPeter Collingbourne     "wholeprogramdevirt-write-summary",
1212b33f653SPeter Collingbourne     cl::desc("Write summary to given YAML file after running pass"),
1222b33f653SPeter Collingbourne     cl::Hidden);
1232b33f653SPeter Collingbourne 
1249cb59b92SVitaly Buka static cl::opt<unsigned>
1259cb59b92SVitaly Buka     ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden,
1269cb59b92SVitaly Buka                 cl::init(10), cl::ZeroOrMore,
12766f53d71SVitaly Buka                 cl::desc("Maximum number of call targets per "
12866f53d71SVitaly Buka                          "call site to enable branch funnels"));
12966f53d71SVitaly Buka 
130d2df54e6STeresa Johnson static cl::opt<bool>
131d2df54e6STeresa Johnson     PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden,
132d2df54e6STeresa Johnson                        cl::init(false), cl::ZeroOrMore,
133d2df54e6STeresa Johnson                        cl::desc("Print index-based devirtualization messages"));
134d2df54e6STeresa Johnson 
135df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If
136df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an
137df49d1bbSPeter Collingbourne // offset after the object.
138df49d1bbSPeter Collingbourne uint64_t
139df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
140df49d1bbSPeter Collingbourne                                      bool IsAfter, uint64_t Size) {
141df49d1bbSPeter Collingbourne   // Find a minimum offset taking into account only vtable sizes.
142df49d1bbSPeter Collingbourne   uint64_t MinByte = 0;
143df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
144df49d1bbSPeter Collingbourne     if (IsAfter)
145df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minAfterBytes());
146df49d1bbSPeter Collingbourne     else
147df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minBeforeBytes());
148df49d1bbSPeter Collingbourne   }
149df49d1bbSPeter Collingbourne 
150df49d1bbSPeter Collingbourne   // Build a vector of arrays of bytes covering, for each target, a slice of the
151df49d1bbSPeter Collingbourne   // used region (see AccumBitVector::BytesUsed in
152df49d1bbSPeter Collingbourne   // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
153df49d1bbSPeter Collingbourne   // this aligns the used regions to start at MinByte.
154df49d1bbSPeter Collingbourne   //
155df49d1bbSPeter Collingbourne   // In this example, A, B and C are vtables, # is a byte already allocated for
156df49d1bbSPeter Collingbourne   // a virtual function pointer, AAAA... (etc.) are the used regions for the
157df49d1bbSPeter Collingbourne   // vtables and Offset(X) is the value computed for the Offset variable below
158df49d1bbSPeter Collingbourne   // for X.
159df49d1bbSPeter Collingbourne   //
160df49d1bbSPeter Collingbourne   //                    Offset(A)
161df49d1bbSPeter Collingbourne   //                    |       |
162df49d1bbSPeter Collingbourne   //                            |MinByte
163df49d1bbSPeter Collingbourne   // A: ################AAAAAAAA|AAAAAAAA
164df49d1bbSPeter Collingbourne   // B: ########BBBBBBBBBBBBBBBB|BBBB
165df49d1bbSPeter Collingbourne   // C: ########################|CCCCCCCCCCCCCCCC
166df49d1bbSPeter Collingbourne   //            |   Offset(B)   |
167df49d1bbSPeter Collingbourne   //
168df49d1bbSPeter Collingbourne   // This code produces the slices of A, B and C that appear after the divider
169df49d1bbSPeter Collingbourne   // at MinByte.
170df49d1bbSPeter Collingbourne   std::vector<ArrayRef<uint8_t>> Used;
171df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
1727efd7506SPeter Collingbourne     ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
1737efd7506SPeter Collingbourne                                        : Target.TM->Bits->Before.BytesUsed;
174df49d1bbSPeter Collingbourne     uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
175df49d1bbSPeter Collingbourne                               : MinByte - Target.minBeforeBytes();
176df49d1bbSPeter Collingbourne 
177df49d1bbSPeter Collingbourne     // Disregard used regions that are smaller than Offset. These are
178df49d1bbSPeter Collingbourne     // effectively all-free regions that do not need to be checked.
179df49d1bbSPeter Collingbourne     if (VTUsed.size() > Offset)
180df49d1bbSPeter Collingbourne       Used.push_back(VTUsed.slice(Offset));
181df49d1bbSPeter Collingbourne   }
182df49d1bbSPeter Collingbourne 
183df49d1bbSPeter Collingbourne   if (Size == 1) {
184df49d1bbSPeter Collingbourne     // Find a free bit in each member of Used.
185df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
186df49d1bbSPeter Collingbourne       uint8_t BitsUsed = 0;
187df49d1bbSPeter Collingbourne       for (auto &&B : Used)
188df49d1bbSPeter Collingbourne         if (I < B.size())
189df49d1bbSPeter Collingbourne           BitsUsed |= B[I];
190df49d1bbSPeter Collingbourne       if (BitsUsed != 0xff)
191df49d1bbSPeter Collingbourne         return (MinByte + I) * 8 +
192df49d1bbSPeter Collingbourne                countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
193df49d1bbSPeter Collingbourne     }
194df49d1bbSPeter Collingbourne   } else {
195df49d1bbSPeter Collingbourne     // Find a free (Size/8) byte region in each member of Used.
196df49d1bbSPeter Collingbourne     // FIXME: see if alignment helps.
197df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
198df49d1bbSPeter Collingbourne       for (auto &&B : Used) {
199df49d1bbSPeter Collingbourne         unsigned Byte = 0;
200df49d1bbSPeter Collingbourne         while ((I + Byte) < B.size() && Byte < (Size / 8)) {
201df49d1bbSPeter Collingbourne           if (B[I + Byte])
202df49d1bbSPeter Collingbourne             goto NextI;
203df49d1bbSPeter Collingbourne           ++Byte;
204df49d1bbSPeter Collingbourne         }
205df49d1bbSPeter Collingbourne       }
206df49d1bbSPeter Collingbourne       return (MinByte + I) * 8;
207df49d1bbSPeter Collingbourne     NextI:;
208df49d1bbSPeter Collingbourne     }
209df49d1bbSPeter Collingbourne   }
210df49d1bbSPeter Collingbourne }
211df49d1bbSPeter Collingbourne 
212df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues(
213df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
214df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
215df49d1bbSPeter Collingbourne   if (BitWidth == 1)
216df49d1bbSPeter Collingbourne     OffsetByte = -(AllocBefore / 8 + 1);
217df49d1bbSPeter Collingbourne   else
218df49d1bbSPeter Collingbourne     OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
219df49d1bbSPeter Collingbourne   OffsetBit = AllocBefore % 8;
220df49d1bbSPeter Collingbourne 
221df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
222df49d1bbSPeter Collingbourne     if (BitWidth == 1)
223df49d1bbSPeter Collingbourne       Target.setBeforeBit(AllocBefore);
224df49d1bbSPeter Collingbourne     else
225df49d1bbSPeter Collingbourne       Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
226df49d1bbSPeter Collingbourne   }
227df49d1bbSPeter Collingbourne }
228df49d1bbSPeter Collingbourne 
229df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues(
230df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
231df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
232df49d1bbSPeter Collingbourne   if (BitWidth == 1)
233df49d1bbSPeter Collingbourne     OffsetByte = AllocAfter / 8;
234df49d1bbSPeter Collingbourne   else
235df49d1bbSPeter Collingbourne     OffsetByte = (AllocAfter + 7) / 8;
236df49d1bbSPeter Collingbourne   OffsetBit = AllocAfter % 8;
237df49d1bbSPeter Collingbourne 
238df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
239df49d1bbSPeter Collingbourne     if (BitWidth == 1)
240df49d1bbSPeter Collingbourne       Target.setAfterBit(AllocAfter);
241df49d1bbSPeter Collingbourne     else
242df49d1bbSPeter Collingbourne       Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
243df49d1bbSPeter Collingbourne   }
244df49d1bbSPeter Collingbourne }
245df49d1bbSPeter Collingbourne 
2467efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
2477efd7506SPeter Collingbourne     : Fn(Fn), TM(TM),
24889439a79SIvan Krasin       IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
249df49d1bbSPeter Collingbourne 
250df49d1bbSPeter Collingbourne namespace {
251df49d1bbSPeter Collingbourne 
2527efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual
253df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to
254df49d1bbSPeter Collingbourne // the virtual function pointer.
255df49d1bbSPeter Collingbourne struct VTableSlot {
2567efd7506SPeter Collingbourne   Metadata *TypeID;
257df49d1bbSPeter Collingbourne   uint64_t ByteOffset;
258df49d1bbSPeter Collingbourne };
259df49d1bbSPeter Collingbourne 
260cdc71612SEugene Zelenko } // end anonymous namespace
261df49d1bbSPeter Collingbourne 
2629b656527SPeter Collingbourne namespace llvm {
2639b656527SPeter Collingbourne 
264df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> {
265df49d1bbSPeter Collingbourne   static VTableSlot getEmptyKey() {
266df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getEmptyKey(),
267df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getEmptyKey()};
268df49d1bbSPeter Collingbourne   }
269df49d1bbSPeter Collingbourne   static VTableSlot getTombstoneKey() {
270df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getTombstoneKey(),
271df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getTombstoneKey()};
272df49d1bbSPeter Collingbourne   }
273df49d1bbSPeter Collingbourne   static unsigned getHashValue(const VTableSlot &I) {
2747efd7506SPeter Collingbourne     return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
275df49d1bbSPeter Collingbourne            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
276df49d1bbSPeter Collingbourne   }
277df49d1bbSPeter Collingbourne   static bool isEqual(const VTableSlot &LHS,
278df49d1bbSPeter Collingbourne                       const VTableSlot &RHS) {
2797efd7506SPeter Collingbourne     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
280df49d1bbSPeter Collingbourne   }
281df49d1bbSPeter Collingbourne };
282df49d1bbSPeter Collingbourne 
283d2df54e6STeresa Johnson template <> struct DenseMapInfo<VTableSlotSummary> {
284d2df54e6STeresa Johnson   static VTableSlotSummary getEmptyKey() {
285d2df54e6STeresa Johnson     return {DenseMapInfo<StringRef>::getEmptyKey(),
286d2df54e6STeresa Johnson             DenseMapInfo<uint64_t>::getEmptyKey()};
287d2df54e6STeresa Johnson   }
288d2df54e6STeresa Johnson   static VTableSlotSummary getTombstoneKey() {
289d2df54e6STeresa Johnson     return {DenseMapInfo<StringRef>::getTombstoneKey(),
290d2df54e6STeresa Johnson             DenseMapInfo<uint64_t>::getTombstoneKey()};
291d2df54e6STeresa Johnson   }
292d2df54e6STeresa Johnson   static unsigned getHashValue(const VTableSlotSummary &I) {
293d2df54e6STeresa Johnson     return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^
294d2df54e6STeresa Johnson            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
295d2df54e6STeresa Johnson   }
296d2df54e6STeresa Johnson   static bool isEqual(const VTableSlotSummary &LHS,
297d2df54e6STeresa Johnson                       const VTableSlotSummary &RHS) {
298d2df54e6STeresa Johnson     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
299d2df54e6STeresa Johnson   }
300d2df54e6STeresa Johnson };
301d2df54e6STeresa Johnson 
302cdc71612SEugene Zelenko } // end namespace llvm
3039b656527SPeter Collingbourne 
304df49d1bbSPeter Collingbourne namespace {
305df49d1bbSPeter Collingbourne 
306df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is
307df49d1bbSPeter Collingbourne // the indirect virtual call.
308df49d1bbSPeter Collingbourne struct VirtualCallSite {
309df49d1bbSPeter Collingbourne   Value *VTable;
310df49d1bbSPeter Collingbourne   CallSite CS;
311df49d1bbSPeter Collingbourne 
3120312f614SPeter Collingbourne   // If non-null, this field points to the associated unsafe use count stored in
3130312f614SPeter Collingbourne   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
3140312f614SPeter Collingbourne   // of that field for details.
3150312f614SPeter Collingbourne   unsigned *NumUnsafeUses;
3160312f614SPeter Collingbourne 
317e963c89dSSam Elliott   void
318e963c89dSSam Elliott   emitRemark(const StringRef OptName, const StringRef TargetName,
319e963c89dSSam Elliott              function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
3205474645dSIvan Krasin     Function *F = CS.getCaller();
321e963c89dSSam Elliott     DebugLoc DLoc = CS->getDebugLoc();
322e963c89dSSam Elliott     BasicBlock *Block = CS.getParent();
323e963c89dSSam Elliott 
324e963c89dSSam Elliott     using namespace ore;
3259110cb45SPeter Collingbourne     OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
3269110cb45SPeter Collingbourne                       << NV("Optimization", OptName)
3279110cb45SPeter Collingbourne                       << ": devirtualized a call to "
328e963c89dSSam Elliott                       << NV("FunctionName", TargetName));
329e963c89dSSam Elliott   }
330e963c89dSSam Elliott 
331e963c89dSSam Elliott   void replaceAndErase(
332e963c89dSSam Elliott       const StringRef OptName, const StringRef TargetName, bool RemarksEnabled,
333e963c89dSSam Elliott       function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
334e963c89dSSam Elliott       Value *New) {
335f3403fd2SIvan Krasin     if (RemarksEnabled)
336e963c89dSSam Elliott       emitRemark(OptName, TargetName, OREGetter);
337df49d1bbSPeter Collingbourne     CS->replaceAllUsesWith(New);
338df49d1bbSPeter Collingbourne     if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
339df49d1bbSPeter Collingbourne       BranchInst::Create(II->getNormalDest(), CS.getInstruction());
340df49d1bbSPeter Collingbourne       II->getUnwindDest()->removePredecessor(II->getParent());
341df49d1bbSPeter Collingbourne     }
342df49d1bbSPeter Collingbourne     CS->eraseFromParent();
3430312f614SPeter Collingbourne     // This use is no longer unsafe.
3440312f614SPeter Collingbourne     if (NumUnsafeUses)
3450312f614SPeter Collingbourne       --*NumUnsafeUses;
346df49d1bbSPeter Collingbourne   }
347df49d1bbSPeter Collingbourne };
348df49d1bbSPeter Collingbourne 
34950cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list
35050cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the
35150cbd7ccSPeter Collingbourne // VTableSlotInfo class.
35250cbd7ccSPeter Collingbourne struct CallSiteInfo {
353b406baaeSPeter Collingbourne   /// The set of call sites for this slot. Used during regular LTO and the
354b406baaeSPeter Collingbourne   /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
355b406baaeSPeter Collingbourne   /// call sites that appear in the merged module itself); in each of these
356b406baaeSPeter Collingbourne   /// cases we are directly operating on the call sites at the IR level.
35750cbd7ccSPeter Collingbourne   std::vector<VirtualCallSite> CallSites;
358b406baaeSPeter Collingbourne 
3592974856aSPeter Collingbourne   /// Whether all call sites represented by this CallSiteInfo, including those
3602974856aSPeter Collingbourne   /// in summaries, have been devirtualized. This starts off as true because a
3612974856aSPeter Collingbourne   /// default constructed CallSiteInfo represents no call sites.
3622974856aSPeter Collingbourne   bool AllCallSitesDevirted = true;
3632974856aSPeter Collingbourne 
364b406baaeSPeter Collingbourne   // These fields are used during the export phase of ThinLTO and reflect
365b406baaeSPeter Collingbourne   // information collected from function summaries.
366b406baaeSPeter Collingbourne 
3672325bb34SPeter Collingbourne   /// Whether any function summary contains an llvm.assume(llvm.type.test) for
3682325bb34SPeter Collingbourne   /// this slot.
3692974856aSPeter Collingbourne   bool SummaryHasTypeTestAssumeUsers = false;
3702325bb34SPeter Collingbourne 
371b406baaeSPeter Collingbourne   /// CFI-specific: a vector containing the list of function summaries that use
372b406baaeSPeter Collingbourne   /// the llvm.type.checked.load intrinsic and therefore will require
373b406baaeSPeter Collingbourne   /// resolutions for llvm.type.test in order to implement CFI checks if
374b406baaeSPeter Collingbourne   /// devirtualization was unsuccessful. If devirtualization was successful, the
37559675ba0SPeter Collingbourne   /// pass will clear this vector by calling markDevirt(). If at the end of the
37659675ba0SPeter Collingbourne   /// pass the vector is non-empty, we will need to add a use of llvm.type.test
37759675ba0SPeter Collingbourne   /// to each of the function summaries in the vector.
378b406baaeSPeter Collingbourne   std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
379d2df54e6STeresa Johnson   std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers;
3802325bb34SPeter Collingbourne 
3812325bb34SPeter Collingbourne   bool isExported() const {
3822325bb34SPeter Collingbourne     return SummaryHasTypeTestAssumeUsers ||
3832325bb34SPeter Collingbourne            !SummaryTypeCheckedLoadUsers.empty();
3842325bb34SPeter Collingbourne   }
38559675ba0SPeter Collingbourne 
3862974856aSPeter Collingbourne   void markSummaryHasTypeTestAssumeUsers() {
3872974856aSPeter Collingbourne     SummaryHasTypeTestAssumeUsers = true;
3882974856aSPeter Collingbourne     AllCallSitesDevirted = false;
3892974856aSPeter Collingbourne   }
3902974856aSPeter Collingbourne 
3912974856aSPeter Collingbourne   void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
3922974856aSPeter Collingbourne     SummaryTypeCheckedLoadUsers.push_back(FS);
3932974856aSPeter Collingbourne     AllCallSitesDevirted = false;
3942974856aSPeter Collingbourne   }
3952974856aSPeter Collingbourne 
396d2df54e6STeresa Johnson   void addSummaryTypeTestAssumeUser(FunctionSummary *FS) {
397d2df54e6STeresa Johnson     SummaryTypeTestAssumeUsers.push_back(FS);
398d2df54e6STeresa Johnson     markSummaryHasTypeTestAssumeUsers();
399d2df54e6STeresa Johnson   }
400d2df54e6STeresa Johnson 
4012974856aSPeter Collingbourne   void markDevirt() {
4022974856aSPeter Collingbourne     AllCallSitesDevirted = true;
4032974856aSPeter Collingbourne 
4042974856aSPeter Collingbourne     // As explained in the comment for SummaryTypeCheckedLoadUsers.
4052974856aSPeter Collingbourne     SummaryTypeCheckedLoadUsers.clear();
4062974856aSPeter Collingbourne   }
40750cbd7ccSPeter Collingbourne };
40850cbd7ccSPeter Collingbourne 
40950cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot.
41050cbd7ccSPeter Collingbourne struct VTableSlotInfo {
41150cbd7ccSPeter Collingbourne   // The set of call sites which do not have all constant integer arguments
41250cbd7ccSPeter Collingbourne   // (excluding "this").
41350cbd7ccSPeter Collingbourne   CallSiteInfo CSInfo;
41450cbd7ccSPeter Collingbourne 
41550cbd7ccSPeter Collingbourne   // The set of call sites with all constant integer arguments (excluding
41650cbd7ccSPeter Collingbourne   // "this"), grouped by argument list.
41750cbd7ccSPeter Collingbourne   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
41850cbd7ccSPeter Collingbourne 
41950cbd7ccSPeter Collingbourne   void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
42050cbd7ccSPeter Collingbourne 
42150cbd7ccSPeter Collingbourne private:
42250cbd7ccSPeter Collingbourne   CallSiteInfo &findCallSiteInfo(CallSite CS);
42350cbd7ccSPeter Collingbourne };
42450cbd7ccSPeter Collingbourne 
42550cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
42650cbd7ccSPeter Collingbourne   std::vector<uint64_t> Args;
42750cbd7ccSPeter Collingbourne   auto *CI = dyn_cast<IntegerType>(CS.getType());
42850cbd7ccSPeter Collingbourne   if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
42950cbd7ccSPeter Collingbourne     return CSInfo;
43050cbd7ccSPeter Collingbourne   for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
43150cbd7ccSPeter Collingbourne     auto *CI = dyn_cast<ConstantInt>(Arg);
43250cbd7ccSPeter Collingbourne     if (!CI || CI->getBitWidth() > 64)
43350cbd7ccSPeter Collingbourne       return CSInfo;
43450cbd7ccSPeter Collingbourne     Args.push_back(CI->getZExtValue());
43550cbd7ccSPeter Collingbourne   }
43650cbd7ccSPeter Collingbourne   return ConstCSInfo[Args];
43750cbd7ccSPeter Collingbourne }
43850cbd7ccSPeter Collingbourne 
43950cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
44050cbd7ccSPeter Collingbourne                                  unsigned *NumUnsafeUses) {
4412974856aSPeter Collingbourne   auto &CSI = findCallSiteInfo(CS);
4422974856aSPeter Collingbourne   CSI.AllCallSitesDevirted = false;
4432974856aSPeter Collingbourne   CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
44450cbd7ccSPeter Collingbourne }
44550cbd7ccSPeter Collingbourne 
446df49d1bbSPeter Collingbourne struct DevirtModule {
447df49d1bbSPeter Collingbourne   Module &M;
44837317f12SPeter Collingbourne   function_ref<AAResults &(Function &)> AARGetter;
449f24136f1STeresa Johnson   function_ref<DominatorTree &(Function &)> LookupDomTree;
4502b33f653SPeter Collingbourne 
451f7691d8bSPeter Collingbourne   ModuleSummaryIndex *ExportSummary;
452f7691d8bSPeter Collingbourne   const ModuleSummaryIndex *ImportSummary;
4532b33f653SPeter Collingbourne 
454df49d1bbSPeter Collingbourne   IntegerType *Int8Ty;
455df49d1bbSPeter Collingbourne   PointerType *Int8PtrTy;
456df49d1bbSPeter Collingbourne   IntegerType *Int32Ty;
45750cbd7ccSPeter Collingbourne   IntegerType *Int64Ty;
45814dcf02fSPeter Collingbourne   IntegerType *IntPtrTy;
459df49d1bbSPeter Collingbourne 
460f3403fd2SIvan Krasin   bool RemarksEnabled;
461e963c89dSSam Elliott   function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter;
462f3403fd2SIvan Krasin 
46350cbd7ccSPeter Collingbourne   MapVector<VTableSlot, VTableSlotInfo> CallSlots;
464df49d1bbSPeter Collingbourne 
4650312f614SPeter Collingbourne   // This map keeps track of the number of "unsafe" uses of a loaded function
4660312f614SPeter Collingbourne   // pointer. The key is the associated llvm.type.test intrinsic call generated
4670312f614SPeter Collingbourne   // by this pass. An unsafe use is one that calls the loaded function pointer
4680312f614SPeter Collingbourne   // directly. Every time we eliminate an unsafe use (for example, by
4690312f614SPeter Collingbourne   // devirtualizing it or by applying virtual constant propagation), we
4700312f614SPeter Collingbourne   // decrement the value stored in this map. If a value reaches zero, we can
4710312f614SPeter Collingbourne   // eliminate the type check by RAUWing the associated llvm.type.test call with
4720312f614SPeter Collingbourne   // true.
4730312f614SPeter Collingbourne   std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
4740312f614SPeter Collingbourne 
47537317f12SPeter Collingbourne   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
476e963c89dSSam Elliott                function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
477f24136f1STeresa Johnson                function_ref<DominatorTree &(Function &)> LookupDomTree,
478f7691d8bSPeter Collingbourne                ModuleSummaryIndex *ExportSummary,
479f7691d8bSPeter Collingbourne                const ModuleSummaryIndex *ImportSummary)
480f24136f1STeresa Johnson       : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
481f24136f1STeresa Johnson         ExportSummary(ExportSummary), ImportSummary(ImportSummary),
482f24136f1STeresa Johnson         Int8Ty(Type::getInt8Ty(M.getContext())),
483df49d1bbSPeter Collingbourne         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
484f3403fd2SIvan Krasin         Int32Ty(Type::getInt32Ty(M.getContext())),
48550cbd7ccSPeter Collingbourne         Int64Ty(Type::getInt64Ty(M.getContext())),
48614dcf02fSPeter Collingbourne         IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
487e963c89dSSam Elliott         RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) {
488f7691d8bSPeter Collingbourne     assert(!(ExportSummary && ImportSummary));
489f7691d8bSPeter Collingbourne   }
490f3403fd2SIvan Krasin 
491f3403fd2SIvan Krasin   bool areRemarksEnabled();
492df49d1bbSPeter Collingbourne 
4930312f614SPeter Collingbourne   void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
4940312f614SPeter Collingbourne   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
4950312f614SPeter Collingbourne 
4967efd7506SPeter Collingbourne   void buildTypeIdentifierMap(
4977efd7506SPeter Collingbourne       std::vector<VTableBits> &Bits,
4987efd7506SPeter Collingbourne       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
4998786754cSPeter Collingbourne   Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
5007efd7506SPeter Collingbourne   bool
5017efd7506SPeter Collingbourne   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
5027efd7506SPeter Collingbourne                             const std::set<TypeMemberInfo> &TypeMemberInfos,
503df49d1bbSPeter Collingbourne                             uint64_t ByteOffset);
50450cbd7ccSPeter Collingbourne 
5052325bb34SPeter Collingbourne   void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
5062325bb34SPeter Collingbourne                              bool &IsExported);
507f3403fd2SIvan Krasin   bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
5082325bb34SPeter Collingbourne                            VTableSlotInfo &SlotInfo,
5092325bb34SPeter Collingbourne                            WholeProgramDevirtResolution *Res);
51050cbd7ccSPeter Collingbourne 
5112974856aSPeter Collingbourne   void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
5122974856aSPeter Collingbourne                               bool &IsExported);
5132974856aSPeter Collingbourne   void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
5142974856aSPeter Collingbourne                             VTableSlotInfo &SlotInfo,
5152974856aSPeter Collingbourne                             WholeProgramDevirtResolution *Res, VTableSlot Slot);
5162974856aSPeter Collingbourne 
517df49d1bbSPeter Collingbourne   bool tryEvaluateFunctionsWithArgs(
518df49d1bbSPeter Collingbourne       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
51950cbd7ccSPeter Collingbourne       ArrayRef<uint64_t> Args);
52050cbd7ccSPeter Collingbourne 
52150cbd7ccSPeter Collingbourne   void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
52250cbd7ccSPeter Collingbourne                              uint64_t TheRetVal);
52350cbd7ccSPeter Collingbourne   bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
52477a8d563SPeter Collingbourne                            CallSiteInfo &CSInfo,
52577a8d563SPeter Collingbourne                            WholeProgramDevirtResolution::ByArg *Res);
52650cbd7ccSPeter Collingbourne 
52759675ba0SPeter Collingbourne   // Returns the global symbol name that is used to export information about the
52859675ba0SPeter Collingbourne   // given vtable slot and list of arguments.
52959675ba0SPeter Collingbourne   std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
53059675ba0SPeter Collingbourne                             StringRef Name);
53159675ba0SPeter Collingbourne 
532b15a35e6SPeter Collingbourne   bool shouldExportConstantsAsAbsoluteSymbols();
533b15a35e6SPeter Collingbourne 
53459675ba0SPeter Collingbourne   // This function is called during the export phase to create a symbol
53559675ba0SPeter Collingbourne   // definition containing information about the given vtable slot and list of
53659675ba0SPeter Collingbourne   // arguments.
53759675ba0SPeter Collingbourne   void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
53859675ba0SPeter Collingbourne                     Constant *C);
539b15a35e6SPeter Collingbourne   void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
540b15a35e6SPeter Collingbourne                       uint32_t Const, uint32_t &Storage);
54159675ba0SPeter Collingbourne 
54259675ba0SPeter Collingbourne   // This function is called during the import phase to create a reference to
54359675ba0SPeter Collingbourne   // the symbol definition created during the export phase.
54459675ba0SPeter Collingbourne   Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
545b15a35e6SPeter Collingbourne                          StringRef Name);
546b15a35e6SPeter Collingbourne   Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
547b15a35e6SPeter Collingbourne                            StringRef Name, IntegerType *IntTy,
548b15a35e6SPeter Collingbourne                            uint32_t Storage);
54959675ba0SPeter Collingbourne 
5502974856aSPeter Collingbourne   Constant *getMemberAddr(const TypeMemberInfo *M);
5512974856aSPeter Collingbourne 
55250cbd7ccSPeter Collingbourne   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
55350cbd7ccSPeter Collingbourne                             Constant *UniqueMemberAddr);
554df49d1bbSPeter Collingbourne   bool tryUniqueRetValOpt(unsigned BitWidth,
555f3403fd2SIvan Krasin                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
55659675ba0SPeter Collingbourne                           CallSiteInfo &CSInfo,
55759675ba0SPeter Collingbourne                           WholeProgramDevirtResolution::ByArg *Res,
55859675ba0SPeter Collingbourne                           VTableSlot Slot, ArrayRef<uint64_t> Args);
55950cbd7ccSPeter Collingbourne 
56050cbd7ccSPeter Collingbourne   void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
56150cbd7ccSPeter Collingbourne                              Constant *Byte, Constant *Bit);
562df49d1bbSPeter Collingbourne   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
56377a8d563SPeter Collingbourne                            VTableSlotInfo &SlotInfo,
56459675ba0SPeter Collingbourne                            WholeProgramDevirtResolution *Res, VTableSlot Slot);
565df49d1bbSPeter Collingbourne 
566df49d1bbSPeter Collingbourne   void rebuildGlobal(VTableBits &B);
567df49d1bbSPeter Collingbourne 
5686d284fabSPeter Collingbourne   // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
5696d284fabSPeter Collingbourne   void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
5706d284fabSPeter Collingbourne 
5716d284fabSPeter Collingbourne   // If we were able to eliminate all unsafe uses for a type checked load,
5726d284fabSPeter Collingbourne   // eliminate the associated type tests by replacing them with true.
5736d284fabSPeter Collingbourne   void removeRedundantTypeTests();
5746d284fabSPeter Collingbourne 
575df49d1bbSPeter Collingbourne   bool run();
5762b33f653SPeter Collingbourne 
5772b33f653SPeter Collingbourne   // Lower the module using the action and summary passed as command line
5782b33f653SPeter Collingbourne   // arguments. For testing purposes only.
579f24136f1STeresa Johnson   static bool
580f24136f1STeresa Johnson   runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter,
581f24136f1STeresa Johnson                 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
582f24136f1STeresa Johnson                 function_ref<DominatorTree &(Function &)> LookupDomTree);
583df49d1bbSPeter Collingbourne };
584df49d1bbSPeter Collingbourne 
585d2df54e6STeresa Johnson struct DevirtIndex {
586d2df54e6STeresa Johnson   ModuleSummaryIndex &ExportSummary;
587d2df54e6STeresa Johnson   // The set in which to record GUIDs exported from their module by
588d2df54e6STeresa Johnson   // devirtualization, used by client to ensure they are not internalized.
589d2df54e6STeresa Johnson   std::set<GlobalValue::GUID> &ExportedGUIDs;
590d2df54e6STeresa Johnson   // A map in which to record the information necessary to locate the WPD
591d2df54e6STeresa Johnson   // resolution for local targets in case they are exported by cross module
592d2df54e6STeresa Johnson   // importing.
593d2df54e6STeresa Johnson   std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap;
594d2df54e6STeresa Johnson 
595d2df54e6STeresa Johnson   MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots;
596d2df54e6STeresa Johnson 
597d2df54e6STeresa Johnson   DevirtIndex(
598d2df54e6STeresa Johnson       ModuleSummaryIndex &ExportSummary,
599d2df54e6STeresa Johnson       std::set<GlobalValue::GUID> &ExportedGUIDs,
600d2df54e6STeresa Johnson       std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap)
601d2df54e6STeresa Johnson       : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs),
602d2df54e6STeresa Johnson         LocalWPDTargetsMap(LocalWPDTargetsMap) {}
603d2df54e6STeresa Johnson 
604d2df54e6STeresa Johnson   bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot,
605d2df54e6STeresa Johnson                                  const TypeIdCompatibleVtableInfo TIdInfo,
606d2df54e6STeresa Johnson                                  uint64_t ByteOffset);
607d2df54e6STeresa Johnson 
608d2df54e6STeresa Johnson   bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
609d2df54e6STeresa Johnson                            VTableSlotSummary &SlotSummary,
610d2df54e6STeresa Johnson                            VTableSlotInfo &SlotInfo,
611d2df54e6STeresa Johnson                            WholeProgramDevirtResolution *Res,
612d2df54e6STeresa Johnson                            std::set<ValueInfo> &DevirtTargets);
613d2df54e6STeresa Johnson 
614d2df54e6STeresa Johnson   void run();
615d2df54e6STeresa Johnson };
616d2df54e6STeresa Johnson 
617df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass {
618df49d1bbSPeter Collingbourne   static char ID;
619cdc71612SEugene Zelenko 
6202b33f653SPeter Collingbourne   bool UseCommandLine = false;
6212b33f653SPeter Collingbourne 
622f7691d8bSPeter Collingbourne   ModuleSummaryIndex *ExportSummary;
623f7691d8bSPeter Collingbourne   const ModuleSummaryIndex *ImportSummary;
6242b33f653SPeter Collingbourne 
6252b33f653SPeter Collingbourne   WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
6262b33f653SPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
6272b33f653SPeter Collingbourne   }
6282b33f653SPeter Collingbourne 
629f7691d8bSPeter Collingbourne   WholeProgramDevirt(ModuleSummaryIndex *ExportSummary,
630f7691d8bSPeter Collingbourne                      const ModuleSummaryIndex *ImportSummary)
631f7691d8bSPeter Collingbourne       : ModulePass(ID), ExportSummary(ExportSummary),
632f7691d8bSPeter Collingbourne         ImportSummary(ImportSummary) {
633df49d1bbSPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
634df49d1bbSPeter Collingbourne   }
635cdc71612SEugene Zelenko 
636cdc71612SEugene Zelenko   bool runOnModule(Module &M) override {
637aa641a51SAndrew Kaylor     if (skipModule(M))
638aa641a51SAndrew Kaylor       return false;
639e963c89dSSam Elliott 
6409110cb45SPeter Collingbourne     // In the new pass manager, we can request the optimization
6419110cb45SPeter Collingbourne     // remark emitter pass on a per-function-basis, which the
6429110cb45SPeter Collingbourne     // OREGetter will do for us.
6439110cb45SPeter Collingbourne     // In the old pass manager, this is harder, so we just build
6449110cb45SPeter Collingbourne     // an optimization remark emitter on the fly, when we need it.
6459110cb45SPeter Collingbourne     std::unique_ptr<OptimizationRemarkEmitter> ORE;
6469110cb45SPeter Collingbourne     auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
647*0eaee545SJonas Devlieghere       ORE = std::make_unique<OptimizationRemarkEmitter>(F);
6489110cb45SPeter Collingbourne       return *ORE;
6499110cb45SPeter Collingbourne     };
650e963c89dSSam Elliott 
651f24136f1STeresa Johnson     auto LookupDomTree = [this](Function &F) -> DominatorTree & {
652f24136f1STeresa Johnson       return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
653f24136f1STeresa Johnson     };
654e963c89dSSam Elliott 
655f24136f1STeresa Johnson     if (UseCommandLine)
656f24136f1STeresa Johnson       return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter,
657f24136f1STeresa Johnson                                          LookupDomTree);
658f24136f1STeresa Johnson 
659f24136f1STeresa Johnson     return DevirtModule(M, LegacyAARGetter(*this), OREGetter, LookupDomTree,
660f24136f1STeresa Johnson                         ExportSummary, ImportSummary)
661f7691d8bSPeter Collingbourne         .run();
66237317f12SPeter Collingbourne   }
66337317f12SPeter Collingbourne 
66437317f12SPeter Collingbourne   void getAnalysisUsage(AnalysisUsage &AU) const override {
66537317f12SPeter Collingbourne     AU.addRequired<AssumptionCacheTracker>();
66637317f12SPeter Collingbourne     AU.addRequired<TargetLibraryInfoWrapperPass>();
667f24136f1STeresa Johnson     AU.addRequired<DominatorTreeWrapperPass>();
668aa641a51SAndrew Kaylor   }
669df49d1bbSPeter Collingbourne };
670df49d1bbSPeter Collingbourne 
671cdc71612SEugene Zelenko } // end anonymous namespace
672df49d1bbSPeter Collingbourne 
67337317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
67437317f12SPeter Collingbourne                       "Whole program devirtualization", false, false)
67537317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
67637317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
677f24136f1STeresa Johnson INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
67837317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
679df49d1bbSPeter Collingbourne                     "Whole program devirtualization", false, false)
680df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0;
681df49d1bbSPeter Collingbourne 
682f7691d8bSPeter Collingbourne ModulePass *
683f7691d8bSPeter Collingbourne llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary,
684f7691d8bSPeter Collingbourne                                    const ModuleSummaryIndex *ImportSummary) {
685f7691d8bSPeter Collingbourne   return new WholeProgramDevirt(ExportSummary, ImportSummary);
686df49d1bbSPeter Collingbourne }
687df49d1bbSPeter Collingbourne 
688164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
68937317f12SPeter Collingbourne                                               ModuleAnalysisManager &AM) {
69037317f12SPeter Collingbourne   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
69137317f12SPeter Collingbourne   auto AARGetter = [&](Function &F) -> AAResults & {
69237317f12SPeter Collingbourne     return FAM.getResult<AAManager>(F);
69337317f12SPeter Collingbourne   };
694e963c89dSSam Elliott   auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
695e963c89dSSam Elliott     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
696e963c89dSSam Elliott   };
697f24136f1STeresa Johnson   auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
698f24136f1STeresa Johnson     return FAM.getResult<DominatorTreeAnalysis>(F);
699f24136f1STeresa Johnson   };
700f24136f1STeresa Johnson   if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary,
701f24136f1STeresa Johnson                     ImportSummary)
70228023dbeSTeresa Johnson            .run())
703d737dd2eSDavide Italiano     return PreservedAnalyses::all();
704d737dd2eSDavide Italiano   return PreservedAnalyses::none();
705d737dd2eSDavide Italiano }
706d737dd2eSDavide Italiano 
707d2df54e6STeresa Johnson namespace llvm {
708d2df54e6STeresa Johnson void runWholeProgramDevirtOnIndex(
709d2df54e6STeresa Johnson     ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
710d2df54e6STeresa Johnson     std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
711d2df54e6STeresa Johnson   DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run();
712d2df54e6STeresa Johnson }
713d2df54e6STeresa Johnson 
714d2df54e6STeresa Johnson void updateIndexWPDForExports(
715d2df54e6STeresa Johnson     ModuleSummaryIndex &Summary,
716d2df54e6STeresa Johnson     StringMap<FunctionImporter::ExportSetTy> &ExportLists,
717d2df54e6STeresa Johnson     std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
718d2df54e6STeresa Johnson   for (auto &T : LocalWPDTargetsMap) {
719d2df54e6STeresa Johnson     auto &VI = T.first;
720d2df54e6STeresa Johnson     // This was enforced earlier during trySingleImplDevirt.
721d2df54e6STeresa Johnson     assert(VI.getSummaryList().size() == 1 &&
722d2df54e6STeresa Johnson            "Devirt of local target has more than one copy");
723d2df54e6STeresa Johnson     auto &S = VI.getSummaryList()[0];
724d2df54e6STeresa Johnson     const auto &ExportList = ExportLists.find(S->modulePath());
725d2df54e6STeresa Johnson     if (ExportList == ExportLists.end() ||
726d2df54e6STeresa Johnson         !ExportList->second.count(VI.getGUID()))
727d2df54e6STeresa Johnson       continue;
728d2df54e6STeresa Johnson 
729d2df54e6STeresa Johnson     // It's been exported by a cross module import.
730d2df54e6STeresa Johnson     for (auto &SlotSummary : T.second) {
731d2df54e6STeresa Johnson       auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID);
732d2df54e6STeresa Johnson       assert(TIdSum);
733d2df54e6STeresa Johnson       auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset);
734d2df54e6STeresa Johnson       assert(WPDRes != TIdSum->WPDRes.end());
735d2df54e6STeresa Johnson       WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
736d2df54e6STeresa Johnson           WPDRes->second.SingleImplName,
737d2df54e6STeresa Johnson           Summary.getModuleHash(S->modulePath()));
738d2df54e6STeresa Johnson     }
739d2df54e6STeresa Johnson   }
740d2df54e6STeresa Johnson }
741d2df54e6STeresa Johnson 
742d2df54e6STeresa Johnson } // end namespace llvm
743d2df54e6STeresa Johnson 
74437317f12SPeter Collingbourne bool DevirtModule::runForTesting(
745e963c89dSSam Elliott     Module &M, function_ref<AAResults &(Function &)> AARGetter,
746f24136f1STeresa Johnson     function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
747f24136f1STeresa Johnson     function_ref<DominatorTree &(Function &)> LookupDomTree) {
7484ffc3e78STeresa Johnson   ModuleSummaryIndex Summary(/*HaveGVs=*/false);
7492b33f653SPeter Collingbourne 
7502b33f653SPeter Collingbourne   // Handle the command-line summary arguments. This code is for testing
7512b33f653SPeter Collingbourne   // purposes only, so we handle errors directly.
7522b33f653SPeter Collingbourne   if (!ClReadSummary.empty()) {
7532b33f653SPeter Collingbourne     ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
7542b33f653SPeter Collingbourne                           ": ");
7552b33f653SPeter Collingbourne     auto ReadSummaryFile =
7562b33f653SPeter Collingbourne         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
7572b33f653SPeter Collingbourne 
7582b33f653SPeter Collingbourne     yaml::Input In(ReadSummaryFile->getBuffer());
7592b33f653SPeter Collingbourne     In >> Summary;
7602b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(In.error()));
7612b33f653SPeter Collingbourne   }
7622b33f653SPeter Collingbourne 
763f7691d8bSPeter Collingbourne   bool Changed =
764f7691d8bSPeter Collingbourne       DevirtModule(
765f24136f1STeresa Johnson           M, AARGetter, OREGetter, LookupDomTree,
766f7691d8bSPeter Collingbourne           ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
767f7691d8bSPeter Collingbourne           ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
768f7691d8bSPeter Collingbourne           .run();
7692b33f653SPeter Collingbourne 
7702b33f653SPeter Collingbourne   if (!ClWriteSummary.empty()) {
7712b33f653SPeter Collingbourne     ExitOnError ExitOnErr(
7722b33f653SPeter Collingbourne         "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
7732b33f653SPeter Collingbourne     std::error_code EC;
774d9b948b6SFangrui Song     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text);
7752b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(EC));
7762b33f653SPeter Collingbourne 
7772b33f653SPeter Collingbourne     yaml::Output Out(OS);
7782b33f653SPeter Collingbourne     Out << Summary;
7792b33f653SPeter Collingbourne   }
7802b33f653SPeter Collingbourne 
7812b33f653SPeter Collingbourne   return Changed;
7822b33f653SPeter Collingbourne }
7832b33f653SPeter Collingbourne 
7847efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap(
785df49d1bbSPeter Collingbourne     std::vector<VTableBits> &Bits,
7867efd7506SPeter Collingbourne     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
787df49d1bbSPeter Collingbourne   DenseMap<GlobalVariable *, VTableBits *> GVToBits;
7887efd7506SPeter Collingbourne   Bits.reserve(M.getGlobalList().size());
7897efd7506SPeter Collingbourne   SmallVector<MDNode *, 2> Types;
7907efd7506SPeter Collingbourne   for (GlobalVariable &GV : M.globals()) {
7917efd7506SPeter Collingbourne     Types.clear();
7927efd7506SPeter Collingbourne     GV.getMetadata(LLVMContext::MD_type, Types);
7932b70d616SEugene Leviant     if (GV.isDeclaration() || Types.empty())
794df49d1bbSPeter Collingbourne       continue;
795df49d1bbSPeter Collingbourne 
7967efd7506SPeter Collingbourne     VTableBits *&BitsPtr = GVToBits[&GV];
7977efd7506SPeter Collingbourne     if (!BitsPtr) {
7987efd7506SPeter Collingbourne       Bits.emplace_back();
7997efd7506SPeter Collingbourne       Bits.back().GV = &GV;
8007efd7506SPeter Collingbourne       Bits.back().ObjectSize =
8017efd7506SPeter Collingbourne           M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
8027efd7506SPeter Collingbourne       BitsPtr = &Bits.back();
8037efd7506SPeter Collingbourne     }
8047efd7506SPeter Collingbourne 
8057efd7506SPeter Collingbourne     for (MDNode *Type : Types) {
8067efd7506SPeter Collingbourne       auto TypeID = Type->getOperand(1).get();
807df49d1bbSPeter Collingbourne 
808df49d1bbSPeter Collingbourne       uint64_t Offset =
809df49d1bbSPeter Collingbourne           cast<ConstantInt>(
8107efd7506SPeter Collingbourne               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
811df49d1bbSPeter Collingbourne               ->getZExtValue();
812df49d1bbSPeter Collingbourne 
8137efd7506SPeter Collingbourne       TypeIdMap[TypeID].insert({BitsPtr, Offset});
814df49d1bbSPeter Collingbourne     }
815df49d1bbSPeter Collingbourne   }
816df49d1bbSPeter Collingbourne }
817df49d1bbSPeter Collingbourne 
8188786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
8198786754cSPeter Collingbourne   if (I->getType()->isPointerTy()) {
8208786754cSPeter Collingbourne     if (Offset == 0)
8218786754cSPeter Collingbourne       return I;
8228786754cSPeter Collingbourne     return nullptr;
8238786754cSPeter Collingbourne   }
8248786754cSPeter Collingbourne 
8257a1e5bbeSPeter Collingbourne   const DataLayout &DL = M.getDataLayout();
8267a1e5bbeSPeter Collingbourne 
8277a1e5bbeSPeter Collingbourne   if (auto *C = dyn_cast<ConstantStruct>(I)) {
8287a1e5bbeSPeter Collingbourne     const StructLayout *SL = DL.getStructLayout(C->getType());
8297a1e5bbeSPeter Collingbourne     if (Offset >= SL->getSizeInBytes())
8307a1e5bbeSPeter Collingbourne       return nullptr;
8317a1e5bbeSPeter Collingbourne 
8328786754cSPeter Collingbourne     unsigned Op = SL->getElementContainingOffset(Offset);
8338786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
8348786754cSPeter Collingbourne                               Offset - SL->getElementOffset(Op));
8358786754cSPeter Collingbourne   }
8368786754cSPeter Collingbourne   if (auto *C = dyn_cast<ConstantArray>(I)) {
8377a1e5bbeSPeter Collingbourne     ArrayType *VTableTy = C->getType();
8387a1e5bbeSPeter Collingbourne     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
8397a1e5bbeSPeter Collingbourne 
8408786754cSPeter Collingbourne     unsigned Op = Offset / ElemSize;
8417a1e5bbeSPeter Collingbourne     if (Op >= C->getNumOperands())
8427a1e5bbeSPeter Collingbourne       return nullptr;
8437a1e5bbeSPeter Collingbourne 
8448786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
8458786754cSPeter Collingbourne                               Offset % ElemSize);
8468786754cSPeter Collingbourne   }
8478786754cSPeter Collingbourne   return nullptr;
8487a1e5bbeSPeter Collingbourne }
8497a1e5bbeSPeter Collingbourne 
850df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets(
851df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> &TargetsForSlot,
8527efd7506SPeter Collingbourne     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
8537efd7506SPeter Collingbourne   for (const TypeMemberInfo &TM : TypeMemberInfos) {
8547efd7506SPeter Collingbourne     if (!TM.Bits->GV->isConstant())
855df49d1bbSPeter Collingbourne       return false;
856df49d1bbSPeter Collingbourne 
8578786754cSPeter Collingbourne     Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
8588786754cSPeter Collingbourne                                        TM.Offset + ByteOffset);
8598786754cSPeter Collingbourne     if (!Ptr)
860df49d1bbSPeter Collingbourne       return false;
861df49d1bbSPeter Collingbourne 
8628786754cSPeter Collingbourne     auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
863df49d1bbSPeter Collingbourne     if (!Fn)
864df49d1bbSPeter Collingbourne       return false;
865df49d1bbSPeter Collingbourne 
866df49d1bbSPeter Collingbourne     // We can disregard __cxa_pure_virtual as a possible call target, as
867df49d1bbSPeter Collingbourne     // calls to pure virtuals are UB.
868df49d1bbSPeter Collingbourne     if (Fn->getName() == "__cxa_pure_virtual")
869df49d1bbSPeter Collingbourne       continue;
870df49d1bbSPeter Collingbourne 
8717efd7506SPeter Collingbourne     TargetsForSlot.push_back({Fn, &TM});
872df49d1bbSPeter Collingbourne   }
873df49d1bbSPeter Collingbourne 
874df49d1bbSPeter Collingbourne   // Give up if we couldn't find any targets.
875df49d1bbSPeter Collingbourne   return !TargetsForSlot.empty();
876df49d1bbSPeter Collingbourne }
877df49d1bbSPeter Collingbourne 
878d2df54e6STeresa Johnson bool DevirtIndex::tryFindVirtualCallTargets(
879d2df54e6STeresa Johnson     std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo,
880d2df54e6STeresa Johnson     uint64_t ByteOffset) {
881d2df54e6STeresa Johnson   for (const TypeIdOffsetVtableInfo P : TIdInfo) {
882d2df54e6STeresa Johnson     // VTable initializer should have only one summary, or all copies must be
883d2df54e6STeresa Johnson     // linkonce/weak ODR.
884d2df54e6STeresa Johnson     assert(P.VTableVI.getSummaryList().size() == 1 ||
885d2df54e6STeresa Johnson            llvm::all_of(
886d2df54e6STeresa Johnson                P.VTableVI.getSummaryList(),
887d2df54e6STeresa Johnson                [&](const std::unique_ptr<GlobalValueSummary> &Summary) {
888d2df54e6STeresa Johnson                  return GlobalValue::isLinkOnceODRLinkage(Summary->linkage()) ||
889d2df54e6STeresa Johnson                         GlobalValue::isWeakODRLinkage(Summary->linkage());
890d2df54e6STeresa Johnson                }));
891d2df54e6STeresa Johnson     const auto *VS = cast<GlobalVarSummary>(P.VTableVI.getSummaryList()[0].get());
892d2df54e6STeresa Johnson     if (!P.VTableVI.getSummaryList()[0]->isLive())
893d2df54e6STeresa Johnson       continue;
894d2df54e6STeresa Johnson     for (auto VTP : VS->vTableFuncs()) {
895d2df54e6STeresa Johnson       if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset)
896d2df54e6STeresa Johnson         continue;
897d2df54e6STeresa Johnson 
898d2df54e6STeresa Johnson       TargetsForSlot.push_back(VTP.FuncVI);
899d2df54e6STeresa Johnson     }
900d2df54e6STeresa Johnson   }
901d2df54e6STeresa Johnson 
902d2df54e6STeresa Johnson   // Give up if we couldn't find any targets.
903d2df54e6STeresa Johnson   return !TargetsForSlot.empty();
904d2df54e6STeresa Johnson }
905d2df54e6STeresa Johnson 
90650cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
9072325bb34SPeter Collingbourne                                          Constant *TheFn, bool &IsExported) {
90850cbd7ccSPeter Collingbourne   auto Apply = [&](CallSiteInfo &CSInfo) {
90950cbd7ccSPeter Collingbourne     for (auto &&VCallSite : CSInfo.CallSites) {
910f3403fd2SIvan Krasin       if (RemarksEnabled)
911b0a1d3bdSTeresa Johnson         VCallSite.emitRemark("single-impl",
912b0a1d3bdSTeresa Johnson                              TheFn->stripPointerCasts()->getName(), OREGetter);
913df49d1bbSPeter Collingbourne       VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
914df49d1bbSPeter Collingbourne           TheFn, VCallSite.CS.getCalledValue()->getType()));
9150312f614SPeter Collingbourne       // This use is no longer unsafe.
9160312f614SPeter Collingbourne       if (VCallSite.NumUnsafeUses)
9170312f614SPeter Collingbourne         --*VCallSite.NumUnsafeUses;
918df49d1bbSPeter Collingbourne     }
9192974856aSPeter Collingbourne     if (CSInfo.isExported())
9202325bb34SPeter Collingbourne       IsExported = true;
92159675ba0SPeter Collingbourne     CSInfo.markDevirt();
92250cbd7ccSPeter Collingbourne   };
92350cbd7ccSPeter Collingbourne   Apply(SlotInfo.CSInfo);
92450cbd7ccSPeter Collingbourne   for (auto &P : SlotInfo.ConstCSInfo)
92550cbd7ccSPeter Collingbourne     Apply(P.second);
92650cbd7ccSPeter Collingbourne }
92750cbd7ccSPeter Collingbourne 
92850cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt(
92950cbd7ccSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
9302325bb34SPeter Collingbourne     VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {
93150cbd7ccSPeter Collingbourne   // See if the program contains a single implementation of this virtual
93250cbd7ccSPeter Collingbourne   // function.
93350cbd7ccSPeter Collingbourne   Function *TheFn = TargetsForSlot[0].Fn;
93450cbd7ccSPeter Collingbourne   for (auto &&Target : TargetsForSlot)
93550cbd7ccSPeter Collingbourne     if (TheFn != Target.Fn)
93650cbd7ccSPeter Collingbourne       return false;
93750cbd7ccSPeter Collingbourne 
93850cbd7ccSPeter Collingbourne   // If so, update each call site to call that implementation directly.
93950cbd7ccSPeter Collingbourne   if (RemarksEnabled)
94050cbd7ccSPeter Collingbourne     TargetsForSlot[0].WasDevirt = true;
9412325bb34SPeter Collingbourne 
9422325bb34SPeter Collingbourne   bool IsExported = false;
9432325bb34SPeter Collingbourne   applySingleImplDevirt(SlotInfo, TheFn, IsExported);
9442325bb34SPeter Collingbourne   if (!IsExported)
9452325bb34SPeter Collingbourne     return false;
9462325bb34SPeter Collingbourne 
9472325bb34SPeter Collingbourne   // If the only implementation has local linkage, we must promote to external
9482325bb34SPeter Collingbourne   // to make it visible to thin LTO objects. We can only get here during the
9492325bb34SPeter Collingbourne   // ThinLTO export phase.
9502325bb34SPeter Collingbourne   if (TheFn->hasLocalLinkage()) {
95188a58cf9SPeter Collingbourne     std::string NewName = (TheFn->getName() + "$merged").str();
95288a58cf9SPeter Collingbourne 
95388a58cf9SPeter Collingbourne     // Since we are renaming the function, any comdats with the same name must
95488a58cf9SPeter Collingbourne     // also be renamed. This is required when targeting COFF, as the comdat name
95588a58cf9SPeter Collingbourne     // must match one of the names of the symbols in the comdat.
95688a58cf9SPeter Collingbourne     if (Comdat *C = TheFn->getComdat()) {
95788a58cf9SPeter Collingbourne       if (C->getName() == TheFn->getName()) {
95888a58cf9SPeter Collingbourne         Comdat *NewC = M.getOrInsertComdat(NewName);
95988a58cf9SPeter Collingbourne         NewC->setSelectionKind(C->getSelectionKind());
96088a58cf9SPeter Collingbourne         for (GlobalObject &GO : M.global_objects())
96188a58cf9SPeter Collingbourne           if (GO.getComdat() == C)
96288a58cf9SPeter Collingbourne             GO.setComdat(NewC);
96388a58cf9SPeter Collingbourne       }
96488a58cf9SPeter Collingbourne     }
96588a58cf9SPeter Collingbourne 
9662325bb34SPeter Collingbourne     TheFn->setLinkage(GlobalValue::ExternalLinkage);
9672325bb34SPeter Collingbourne     TheFn->setVisibility(GlobalValue::HiddenVisibility);
96888a58cf9SPeter Collingbourne     TheFn->setName(NewName);
9692325bb34SPeter Collingbourne   }
9702325bb34SPeter Collingbourne 
9712325bb34SPeter Collingbourne   Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
9722325bb34SPeter Collingbourne   Res->SingleImplName = TheFn->getName();
9732325bb34SPeter Collingbourne 
974df49d1bbSPeter Collingbourne   return true;
975df49d1bbSPeter Collingbourne }
976df49d1bbSPeter Collingbourne 
977d2df54e6STeresa Johnson bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
978d2df54e6STeresa Johnson                                       VTableSlotSummary &SlotSummary,
979d2df54e6STeresa Johnson                                       VTableSlotInfo &SlotInfo,
980d2df54e6STeresa Johnson                                       WholeProgramDevirtResolution *Res,
981d2df54e6STeresa Johnson                                       std::set<ValueInfo> &DevirtTargets) {
982d2df54e6STeresa Johnson   // See if the program contains a single implementation of this virtual
983d2df54e6STeresa Johnson   // function.
984d2df54e6STeresa Johnson   auto TheFn = TargetsForSlot[0];
985d2df54e6STeresa Johnson   for (auto &&Target : TargetsForSlot)
986d2df54e6STeresa Johnson     if (TheFn != Target)
987d2df54e6STeresa Johnson       return false;
988d2df54e6STeresa Johnson 
989d2df54e6STeresa Johnson   // Don't devirtualize if we don't have target definition.
990d2df54e6STeresa Johnson   auto Size = TheFn.getSummaryList().size();
991d2df54e6STeresa Johnson   if (!Size)
992d2df54e6STeresa Johnson     return false;
993d2df54e6STeresa Johnson 
994d2df54e6STeresa Johnson   // If the summary list contains multiple summaries where at least one is
995d2df54e6STeresa Johnson   // a local, give up, as we won't know which (possibly promoted) name to use.
996d2df54e6STeresa Johnson   for (auto &S : TheFn.getSummaryList())
997d2df54e6STeresa Johnson     if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1)
998d2df54e6STeresa Johnson       return false;
999d2df54e6STeresa Johnson 
1000d2df54e6STeresa Johnson   // Collect functions devirtualized at least for one call site for stats.
1001d2df54e6STeresa Johnson   if (PrintSummaryDevirt)
1002d2df54e6STeresa Johnson     DevirtTargets.insert(TheFn);
1003d2df54e6STeresa Johnson 
1004d2df54e6STeresa Johnson   auto &S = TheFn.getSummaryList()[0];
1005d2df54e6STeresa Johnson   bool IsExported = false;
1006d2df54e6STeresa Johnson 
1007d2df54e6STeresa Johnson   // Insert calls into the summary index so that the devirtualized targets
1008d2df54e6STeresa Johnson   // are eligible for import.
1009d2df54e6STeresa Johnson   // FIXME: Annotate type tests with hotness. For now, mark these as hot
1010d2df54e6STeresa Johnson   // to better ensure we have the opportunity to inline them.
1011d2df54e6STeresa Johnson   CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0);
1012d2df54e6STeresa Johnson   auto AddCalls = [&](CallSiteInfo &CSInfo) {
1013d2df54e6STeresa Johnson     for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
1014d2df54e6STeresa Johnson       FS->addCall({TheFn, CI});
1015d2df54e6STeresa Johnson       IsExported |= S->modulePath() != FS->modulePath();
1016d2df54e6STeresa Johnson     }
1017d2df54e6STeresa Johnson     for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) {
1018d2df54e6STeresa Johnson       FS->addCall({TheFn, CI});
1019d2df54e6STeresa Johnson       IsExported |= S->modulePath() != FS->modulePath();
1020d2df54e6STeresa Johnson     }
1021d2df54e6STeresa Johnson   };
1022d2df54e6STeresa Johnson   AddCalls(SlotInfo.CSInfo);
1023d2df54e6STeresa Johnson   for (auto &P : SlotInfo.ConstCSInfo)
1024d2df54e6STeresa Johnson     AddCalls(P.second);
1025d2df54e6STeresa Johnson 
1026d2df54e6STeresa Johnson   if (IsExported)
1027d2df54e6STeresa Johnson     ExportedGUIDs.insert(TheFn.getGUID());
1028d2df54e6STeresa Johnson 
1029d2df54e6STeresa Johnson   // Record in summary for use in devirtualization during the ThinLTO import
1030d2df54e6STeresa Johnson   // step.
1031d2df54e6STeresa Johnson   Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1032d2df54e6STeresa Johnson   if (GlobalValue::isLocalLinkage(S->linkage())) {
1033d2df54e6STeresa Johnson     if (IsExported)
1034d2df54e6STeresa Johnson       // If target is a local function and we are exporting it by
1035d2df54e6STeresa Johnson       // devirtualizing a call in another module, we need to record the
1036d2df54e6STeresa Johnson       // promoted name.
1037d2df54e6STeresa Johnson       Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
1038d2df54e6STeresa Johnson           TheFn.name(), ExportSummary.getModuleHash(S->modulePath()));
1039d2df54e6STeresa Johnson     else {
1040d2df54e6STeresa Johnson       LocalWPDTargetsMap[TheFn].push_back(SlotSummary);
1041d2df54e6STeresa Johnson       Res->SingleImplName = TheFn.name();
1042d2df54e6STeresa Johnson     }
1043d2df54e6STeresa Johnson   } else
1044d2df54e6STeresa Johnson     Res->SingleImplName = TheFn.name();
1045d2df54e6STeresa Johnson 
1046d2df54e6STeresa Johnson   // Name will be empty if this thin link driven off of serialized combined
1047d2df54e6STeresa Johnson   // index (e.g. llvm-lto). However, WPD is not supported/invoked for the
1048d2df54e6STeresa Johnson   // legacy LTO API anyway.
1049d2df54e6STeresa Johnson   assert(!Res->SingleImplName.empty());
1050d2df54e6STeresa Johnson 
1051d2df54e6STeresa Johnson   return true;
1052d2df54e6STeresa Johnson }
1053d2df54e6STeresa Johnson 
10542974856aSPeter Collingbourne void DevirtModule::tryICallBranchFunnel(
10552974856aSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
10562974856aSPeter Collingbourne     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
10572974856aSPeter Collingbourne   Triple T(M.getTargetTriple());
10582974856aSPeter Collingbourne   if (T.getArch() != Triple::x86_64)
10592974856aSPeter Collingbourne     return;
10602974856aSPeter Collingbourne 
106166f53d71SVitaly Buka   if (TargetsForSlot.size() > ClThreshold)
10622974856aSPeter Collingbourne     return;
10632974856aSPeter Collingbourne 
10642974856aSPeter Collingbourne   bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
10652974856aSPeter Collingbourne   if (!HasNonDevirt)
10662974856aSPeter Collingbourne     for (auto &P : SlotInfo.ConstCSInfo)
10672974856aSPeter Collingbourne       if (!P.second.AllCallSitesDevirted) {
10682974856aSPeter Collingbourne         HasNonDevirt = true;
10692974856aSPeter Collingbourne         break;
10702974856aSPeter Collingbourne       }
10712974856aSPeter Collingbourne 
10722974856aSPeter Collingbourne   if (!HasNonDevirt)
10732974856aSPeter Collingbourne     return;
10742974856aSPeter Collingbourne 
10752974856aSPeter Collingbourne   FunctionType *FT =
10762974856aSPeter Collingbourne       FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
10772974856aSPeter Collingbourne   Function *JT;
10782974856aSPeter Collingbourne   if (isa<MDString>(Slot.TypeID)) {
10792974856aSPeter Collingbourne     JT = Function::Create(FT, Function::ExternalLinkage,
1080f920da00SDylan McKay                           M.getDataLayout().getProgramAddressSpace(),
10812974856aSPeter Collingbourne                           getGlobalName(Slot, {}, "branch_funnel"), &M);
10822974856aSPeter Collingbourne     JT->setVisibility(GlobalValue::HiddenVisibility);
10832974856aSPeter Collingbourne   } else {
1084f920da00SDylan McKay     JT = Function::Create(FT, Function::InternalLinkage,
1085f920da00SDylan McKay                           M.getDataLayout().getProgramAddressSpace(),
1086f920da00SDylan McKay                           "branch_funnel", &M);
10872974856aSPeter Collingbourne   }
10882974856aSPeter Collingbourne   JT->addAttribute(1, Attribute::Nest);
10892974856aSPeter Collingbourne 
10902974856aSPeter Collingbourne   std::vector<Value *> JTArgs;
10912974856aSPeter Collingbourne   JTArgs.push_back(JT->arg_begin());
10922974856aSPeter Collingbourne   for (auto &T : TargetsForSlot) {
10932974856aSPeter Collingbourne     JTArgs.push_back(getMemberAddr(T.TM));
10942974856aSPeter Collingbourne     JTArgs.push_back(T.Fn);
10952974856aSPeter Collingbourne   }
10962974856aSPeter Collingbourne 
10972974856aSPeter Collingbourne   BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
10987976eb58SJames Y Knight   Function *Intr =
10992974856aSPeter Collingbourne       Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {});
11002974856aSPeter Collingbourne 
11012974856aSPeter Collingbourne   auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
11022974856aSPeter Collingbourne   CI->setTailCallKind(CallInst::TCK_MustTail);
11032974856aSPeter Collingbourne   ReturnInst::Create(M.getContext(), nullptr, BB);
11042974856aSPeter Collingbourne 
11052974856aSPeter Collingbourne   bool IsExported = false;
11062974856aSPeter Collingbourne   applyICallBranchFunnel(SlotInfo, JT, IsExported);
11072974856aSPeter Collingbourne   if (IsExported)
11082974856aSPeter Collingbourne     Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
11092974856aSPeter Collingbourne }
11102974856aSPeter Collingbourne 
11112974856aSPeter Collingbourne void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
11122974856aSPeter Collingbourne                                           Constant *JT, bool &IsExported) {
11132974856aSPeter Collingbourne   auto Apply = [&](CallSiteInfo &CSInfo) {
11142974856aSPeter Collingbourne     if (CSInfo.isExported())
11152974856aSPeter Collingbourne       IsExported = true;
11162974856aSPeter Collingbourne     if (CSInfo.AllCallSitesDevirted)
11172974856aSPeter Collingbourne       return;
11182974856aSPeter Collingbourne     for (auto &&VCallSite : CSInfo.CallSites) {
11192974856aSPeter Collingbourne       CallSite CS = VCallSite.CS;
11202974856aSPeter Collingbourne 
11212974856aSPeter Collingbourne       // Jump tables are only profitable if the retpoline mitigation is enabled.
11222974856aSPeter Collingbourne       Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
11232974856aSPeter Collingbourne       if (FSAttr.hasAttribute(Attribute::None) ||
11242974856aSPeter Collingbourne           !FSAttr.getValueAsString().contains("+retpoline"))
11252974856aSPeter Collingbourne         continue;
11262974856aSPeter Collingbourne 
11272974856aSPeter Collingbourne       if (RemarksEnabled)
1128b0a1d3bdSTeresa Johnson         VCallSite.emitRemark("branch-funnel",
1129b0a1d3bdSTeresa Johnson                              JT->stripPointerCasts()->getName(), OREGetter);
11302974856aSPeter Collingbourne 
11312974856aSPeter Collingbourne       // Pass the address of the vtable in the nest register, which is r10 on
11322974856aSPeter Collingbourne       // x86_64.
11332974856aSPeter Collingbourne       std::vector<Type *> NewArgs;
11342974856aSPeter Collingbourne       NewArgs.push_back(Int8PtrTy);
11352974856aSPeter Collingbourne       for (Type *T : CS.getFunctionType()->params())
11362974856aSPeter Collingbourne         NewArgs.push_back(T);
11377976eb58SJames Y Knight       FunctionType *NewFT =
11382974856aSPeter Collingbourne           FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
11397976eb58SJames Y Knight                             CS.getFunctionType()->isVarArg());
11407976eb58SJames Y Knight       PointerType *NewFTPtr = PointerType::getUnqual(NewFT);
11412974856aSPeter Collingbourne 
11422974856aSPeter Collingbourne       IRBuilder<> IRB(CS.getInstruction());
11432974856aSPeter Collingbourne       std::vector<Value *> Args;
11442974856aSPeter Collingbourne       Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
11452974856aSPeter Collingbourne       for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
11462974856aSPeter Collingbourne         Args.push_back(CS.getArgOperand(I));
11472974856aSPeter Collingbourne 
11482974856aSPeter Collingbourne       CallSite NewCS;
11492974856aSPeter Collingbourne       if (CS.isCall())
11507976eb58SJames Y Knight         NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args);
11512974856aSPeter Collingbourne       else
11522974856aSPeter Collingbourne         NewCS = IRB.CreateInvoke(
1153d9e85a08SJames Y Knight             NewFT, IRB.CreateBitCast(JT, NewFTPtr),
11542974856aSPeter Collingbourne             cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
11552974856aSPeter Collingbourne             cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
11562974856aSPeter Collingbourne       NewCS.setCallingConv(CS.getCallingConv());
11572974856aSPeter Collingbourne 
11582974856aSPeter Collingbourne       AttributeList Attrs = CS.getAttributes();
11592974856aSPeter Collingbourne       std::vector<AttributeSet> NewArgAttrs;
11602974856aSPeter Collingbourne       NewArgAttrs.push_back(AttributeSet::get(
11612974856aSPeter Collingbourne           M.getContext(), ArrayRef<Attribute>{Attribute::get(
11622974856aSPeter Collingbourne                               M.getContext(), Attribute::Nest)}));
11632974856aSPeter Collingbourne       for (unsigned I = 0; I + 2 <  Attrs.getNumAttrSets(); ++I)
11642974856aSPeter Collingbourne         NewArgAttrs.push_back(Attrs.getParamAttributes(I));
11652974856aSPeter Collingbourne       NewCS.setAttributes(
11662974856aSPeter Collingbourne           AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
11672974856aSPeter Collingbourne                              Attrs.getRetAttributes(), NewArgAttrs));
11682974856aSPeter Collingbourne 
11692974856aSPeter Collingbourne       CS->replaceAllUsesWith(NewCS.getInstruction());
11702974856aSPeter Collingbourne       CS->eraseFromParent();
11712974856aSPeter Collingbourne 
11722974856aSPeter Collingbourne       // This use is no longer unsafe.
11732974856aSPeter Collingbourne       if (VCallSite.NumUnsafeUses)
11742974856aSPeter Collingbourne         --*VCallSite.NumUnsafeUses;
11752974856aSPeter Collingbourne     }
11762974856aSPeter Collingbourne     // Don't mark as devirtualized because there may be callers compiled without
11772974856aSPeter Collingbourne     // retpoline mitigation, which would mean that they are lowered to
11782974856aSPeter Collingbourne     // llvm.type.test and therefore require an llvm.type.test resolution for the
11792974856aSPeter Collingbourne     // type identifier.
11802974856aSPeter Collingbourne   };
11812974856aSPeter Collingbourne   Apply(SlotInfo.CSInfo);
11822974856aSPeter Collingbourne   for (auto &P : SlotInfo.ConstCSInfo)
11832974856aSPeter Collingbourne     Apply(P.second);
11842974856aSPeter Collingbourne }
11852974856aSPeter Collingbourne 
1186df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs(
1187df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
118850cbd7ccSPeter Collingbourne     ArrayRef<uint64_t> Args) {
1189df49d1bbSPeter Collingbourne   // Evaluate each function and store the result in each target's RetVal
1190df49d1bbSPeter Collingbourne   // field.
1191df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
1192df49d1bbSPeter Collingbourne     if (Target.Fn->arg_size() != Args.size() + 1)
1193df49d1bbSPeter Collingbourne       return false;
1194df49d1bbSPeter Collingbourne 
1195df49d1bbSPeter Collingbourne     Evaluator Eval(M.getDataLayout(), nullptr);
1196df49d1bbSPeter Collingbourne     SmallVector<Constant *, 2> EvalArgs;
1197df49d1bbSPeter Collingbourne     EvalArgs.push_back(
1198df49d1bbSPeter Collingbourne         Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
119950cbd7ccSPeter Collingbourne     for (unsigned I = 0; I != Args.size(); ++I) {
120050cbd7ccSPeter Collingbourne       auto *ArgTy = dyn_cast<IntegerType>(
120150cbd7ccSPeter Collingbourne           Target.Fn->getFunctionType()->getParamType(I + 1));
120250cbd7ccSPeter Collingbourne       if (!ArgTy)
120350cbd7ccSPeter Collingbourne         return false;
120450cbd7ccSPeter Collingbourne       EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
120550cbd7ccSPeter Collingbourne     }
120650cbd7ccSPeter Collingbourne 
1207df49d1bbSPeter Collingbourne     Constant *RetVal;
1208df49d1bbSPeter Collingbourne     if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
1209df49d1bbSPeter Collingbourne         !isa<ConstantInt>(RetVal))
1210df49d1bbSPeter Collingbourne       return false;
1211df49d1bbSPeter Collingbourne     Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
1212df49d1bbSPeter Collingbourne   }
1213df49d1bbSPeter Collingbourne   return true;
1214df49d1bbSPeter Collingbourne }
1215df49d1bbSPeter Collingbourne 
121650cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
121750cbd7ccSPeter Collingbourne                                          uint64_t TheRetVal) {
121850cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites)
121950cbd7ccSPeter Collingbourne     Call.replaceAndErase(
1220e963c89dSSam Elliott         "uniform-ret-val", FnName, RemarksEnabled, OREGetter,
122150cbd7ccSPeter Collingbourne         ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
122259675ba0SPeter Collingbourne   CSInfo.markDevirt();
122350cbd7ccSPeter Collingbourne }
122450cbd7ccSPeter Collingbourne 
1225df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt(
122677a8d563SPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
122777a8d563SPeter Collingbourne     WholeProgramDevirtResolution::ByArg *Res) {
1228df49d1bbSPeter Collingbourne   // Uniform return value optimization. If all functions return the same
1229df49d1bbSPeter Collingbourne   // constant, replace all calls with that constant.
1230df49d1bbSPeter Collingbourne   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
1231df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : TargetsForSlot)
1232df49d1bbSPeter Collingbourne     if (Target.RetVal != TheRetVal)
1233df49d1bbSPeter Collingbourne       return false;
1234df49d1bbSPeter Collingbourne 
123577a8d563SPeter Collingbourne   if (CSInfo.isExported()) {
123677a8d563SPeter Collingbourne     Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
123777a8d563SPeter Collingbourne     Res->Info = TheRetVal;
123877a8d563SPeter Collingbourne   }
123977a8d563SPeter Collingbourne 
124050cbd7ccSPeter Collingbourne   applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
1241f3403fd2SIvan Krasin   if (RemarksEnabled)
1242f3403fd2SIvan Krasin     for (auto &&Target : TargetsForSlot)
1243f3403fd2SIvan Krasin       Target.WasDevirt = true;
1244df49d1bbSPeter Collingbourne   return true;
1245df49d1bbSPeter Collingbourne }
1246df49d1bbSPeter Collingbourne 
124759675ba0SPeter Collingbourne std::string DevirtModule::getGlobalName(VTableSlot Slot,
124859675ba0SPeter Collingbourne                                         ArrayRef<uint64_t> Args,
124959675ba0SPeter Collingbourne                                         StringRef Name) {
125059675ba0SPeter Collingbourne   std::string FullName = "__typeid_";
125159675ba0SPeter Collingbourne   raw_string_ostream OS(FullName);
125259675ba0SPeter Collingbourne   OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
125359675ba0SPeter Collingbourne   for (uint64_t Arg : Args)
125459675ba0SPeter Collingbourne     OS << '_' << Arg;
125559675ba0SPeter Collingbourne   OS << '_' << Name;
125659675ba0SPeter Collingbourne   return OS.str();
125759675ba0SPeter Collingbourne }
125859675ba0SPeter Collingbourne 
1259b15a35e6SPeter Collingbourne bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() {
1260b15a35e6SPeter Collingbourne   Triple T(M.getTargetTriple());
1261b15a35e6SPeter Collingbourne   return (T.getArch() == Triple::x86 || T.getArch() == Triple::x86_64) &&
1262b15a35e6SPeter Collingbourne          T.getObjectFormat() == Triple::ELF;
1263b15a35e6SPeter Collingbourne }
1264b15a35e6SPeter Collingbourne 
126559675ba0SPeter Collingbourne void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
126659675ba0SPeter Collingbourne                                 StringRef Name, Constant *C) {
126759675ba0SPeter Collingbourne   GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
126859675ba0SPeter Collingbourne                                         getGlobalName(Slot, Args, Name), C, &M);
126959675ba0SPeter Collingbourne   GA->setVisibility(GlobalValue::HiddenVisibility);
127059675ba0SPeter Collingbourne }
127159675ba0SPeter Collingbourne 
1272b15a35e6SPeter Collingbourne void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1273b15a35e6SPeter Collingbourne                                   StringRef Name, uint32_t Const,
1274b15a35e6SPeter Collingbourne                                   uint32_t &Storage) {
1275b15a35e6SPeter Collingbourne   if (shouldExportConstantsAsAbsoluteSymbols()) {
1276b15a35e6SPeter Collingbourne     exportGlobal(
1277b15a35e6SPeter Collingbourne         Slot, Args, Name,
1278b15a35e6SPeter Collingbourne         ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy));
1279b15a35e6SPeter Collingbourne     return;
1280b15a35e6SPeter Collingbourne   }
1281b15a35e6SPeter Collingbourne 
1282b15a35e6SPeter Collingbourne   Storage = Const;
1283b15a35e6SPeter Collingbourne }
1284b15a35e6SPeter Collingbourne 
128559675ba0SPeter Collingbourne Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1286b15a35e6SPeter Collingbourne                                      StringRef Name) {
128759675ba0SPeter Collingbourne   Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty);
128859675ba0SPeter Collingbourne   auto *GV = dyn_cast<GlobalVariable>(C);
1289b15a35e6SPeter Collingbourne   if (GV)
1290b15a35e6SPeter Collingbourne     GV->setVisibility(GlobalValue::HiddenVisibility);
1291b15a35e6SPeter Collingbourne   return C;
1292b15a35e6SPeter Collingbourne }
1293b15a35e6SPeter Collingbourne 
1294b15a35e6SPeter Collingbourne Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1295b15a35e6SPeter Collingbourne                                        StringRef Name, IntegerType *IntTy,
1296b15a35e6SPeter Collingbourne                                        uint32_t Storage) {
1297b15a35e6SPeter Collingbourne   if (!shouldExportConstantsAsAbsoluteSymbols())
1298b15a35e6SPeter Collingbourne     return ConstantInt::get(IntTy, Storage);
1299b15a35e6SPeter Collingbourne 
1300b15a35e6SPeter Collingbourne   Constant *C = importGlobal(Slot, Args, Name);
1301b15a35e6SPeter Collingbourne   auto *GV = cast<GlobalVariable>(C->stripPointerCasts());
1302b15a35e6SPeter Collingbourne   C = ConstantExpr::getPtrToInt(C, IntTy);
1303b15a35e6SPeter Collingbourne 
130414dcf02fSPeter Collingbourne   // We only need to set metadata if the global is newly created, in which
130514dcf02fSPeter Collingbourne   // case it would not have hidden visibility.
13060deb9a9aSBenjamin Kramer   if (GV->hasMetadata(LLVMContext::MD_absolute_symbol))
130759675ba0SPeter Collingbourne     return C;
130814dcf02fSPeter Collingbourne 
130914dcf02fSPeter Collingbourne   auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
131014dcf02fSPeter Collingbourne     auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
131114dcf02fSPeter Collingbourne     auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
131214dcf02fSPeter Collingbourne     GV->setMetadata(LLVMContext::MD_absolute_symbol,
131314dcf02fSPeter Collingbourne                     MDNode::get(M.getContext(), {MinC, MaxC}));
131414dcf02fSPeter Collingbourne   };
1315b15a35e6SPeter Collingbourne   unsigned AbsWidth = IntTy->getBitWidth();
131614dcf02fSPeter Collingbourne   if (AbsWidth == IntPtrTy->getBitWidth())
131714dcf02fSPeter Collingbourne     SetAbsRange(~0ull, ~0ull); // Full set.
1318b15a35e6SPeter Collingbourne   else
131914dcf02fSPeter Collingbourne     SetAbsRange(0, 1ull << AbsWidth);
1320b15a35e6SPeter Collingbourne   return C;
132159675ba0SPeter Collingbourne }
132259675ba0SPeter Collingbourne 
132350cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
132450cbd7ccSPeter Collingbourne                                         bool IsOne,
132550cbd7ccSPeter Collingbourne                                         Constant *UniqueMemberAddr) {
132650cbd7ccSPeter Collingbourne   for (auto &&Call : CSInfo.CallSites) {
132750cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
1328001052a0SPeter Collingbourne     Value *Cmp =
1329001052a0SPeter Collingbourne         B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
1330001052a0SPeter Collingbourne                      B.CreateBitCast(Call.VTable, Int8PtrTy), UniqueMemberAddr);
133150cbd7ccSPeter Collingbourne     Cmp = B.CreateZExt(Cmp, Call.CS->getType());
1332e963c89dSSam Elliott     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
1333e963c89dSSam Elliott                          Cmp);
133450cbd7ccSPeter Collingbourne   }
133559675ba0SPeter Collingbourne   CSInfo.markDevirt();
133650cbd7ccSPeter Collingbourne }
133750cbd7ccSPeter Collingbourne 
13382974856aSPeter Collingbourne Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
13392974856aSPeter Collingbourne   Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
13402974856aSPeter Collingbourne   return ConstantExpr::getGetElementPtr(Int8Ty, C,
13412974856aSPeter Collingbourne                                         ConstantInt::get(Int64Ty, M->Offset));
13422974856aSPeter Collingbourne }
13432974856aSPeter Collingbourne 
1344df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt(
1345f3403fd2SIvan Krasin     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
134659675ba0SPeter Collingbourne     CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
134759675ba0SPeter Collingbourne     VTableSlot Slot, ArrayRef<uint64_t> Args) {
1348df49d1bbSPeter Collingbourne   // IsOne controls whether we look for a 0 or a 1.
1349df49d1bbSPeter Collingbourne   auto tryUniqueRetValOptFor = [&](bool IsOne) {
1350cdc71612SEugene Zelenko     const TypeMemberInfo *UniqueMember = nullptr;
1351df49d1bbSPeter Collingbourne     for (const VirtualCallTarget &Target : TargetsForSlot) {
13523866cc5fSPeter Collingbourne       if (Target.RetVal == (IsOne ? 1 : 0)) {
13537efd7506SPeter Collingbourne         if (UniqueMember)
1354df49d1bbSPeter Collingbourne           return false;
13557efd7506SPeter Collingbourne         UniqueMember = Target.TM;
1356df49d1bbSPeter Collingbourne       }
1357df49d1bbSPeter Collingbourne     }
1358df49d1bbSPeter Collingbourne 
13597efd7506SPeter Collingbourne     // We should have found a unique member or bailed out by now. We already
1360df49d1bbSPeter Collingbourne     // checked for a uniform return value in tryUniformRetValOpt.
13617efd7506SPeter Collingbourne     assert(UniqueMember);
1362df49d1bbSPeter Collingbourne 
13632974856aSPeter Collingbourne     Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
136459675ba0SPeter Collingbourne     if (CSInfo.isExported()) {
136559675ba0SPeter Collingbourne       Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
136659675ba0SPeter Collingbourne       Res->Info = IsOne;
136759675ba0SPeter Collingbourne 
136859675ba0SPeter Collingbourne       exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
136959675ba0SPeter Collingbourne     }
137059675ba0SPeter Collingbourne 
137159675ba0SPeter Collingbourne     // Replace each call with the comparison.
137250cbd7ccSPeter Collingbourne     applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
137350cbd7ccSPeter Collingbourne                          UniqueMemberAddr);
137450cbd7ccSPeter Collingbourne 
1375f3403fd2SIvan Krasin     // Update devirtualization statistics for targets.
1376f3403fd2SIvan Krasin     if (RemarksEnabled)
1377f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
1378f3403fd2SIvan Krasin         Target.WasDevirt = true;
1379f3403fd2SIvan Krasin 
1380df49d1bbSPeter Collingbourne     return true;
1381df49d1bbSPeter Collingbourne   };
1382df49d1bbSPeter Collingbourne 
1383df49d1bbSPeter Collingbourne   if (BitWidth == 1) {
1384df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(true))
1385df49d1bbSPeter Collingbourne       return true;
1386df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(false))
1387df49d1bbSPeter Collingbourne       return true;
1388df49d1bbSPeter Collingbourne   }
1389df49d1bbSPeter Collingbourne   return false;
1390df49d1bbSPeter Collingbourne }
1391df49d1bbSPeter Collingbourne 
139250cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
139350cbd7ccSPeter Collingbourne                                          Constant *Byte, Constant *Bit) {
139450cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites) {
139550cbd7ccSPeter Collingbourne     auto *RetType = cast<IntegerType>(Call.CS.getType());
139650cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
1397001052a0SPeter Collingbourne     Value *Addr =
1398001052a0SPeter Collingbourne         B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
139950cbd7ccSPeter Collingbourne     if (RetType->getBitWidth() == 1) {
140014359ef1SJames Y Knight       Value *Bits = B.CreateLoad(Int8Ty, Addr);
140150cbd7ccSPeter Collingbourne       Value *BitsAndBit = B.CreateAnd(Bits, Bit);
140250cbd7ccSPeter Collingbourne       auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
140350cbd7ccSPeter Collingbourne       Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
1404e963c89dSSam Elliott                            OREGetter, IsBitSet);
140550cbd7ccSPeter Collingbourne     } else {
140650cbd7ccSPeter Collingbourne       Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
140750cbd7ccSPeter Collingbourne       Value *Val = B.CreateLoad(RetType, ValAddr);
1408e963c89dSSam Elliott       Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled,
1409e963c89dSSam Elliott                            OREGetter, Val);
141050cbd7ccSPeter Collingbourne     }
141150cbd7ccSPeter Collingbourne   }
141214dcf02fSPeter Collingbourne   CSInfo.markDevirt();
141350cbd7ccSPeter Collingbourne }
141450cbd7ccSPeter Collingbourne 
1415df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp(
141659675ba0SPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
141759675ba0SPeter Collingbourne     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1418df49d1bbSPeter Collingbourne   // This only works if the function returns an integer.
1419df49d1bbSPeter Collingbourne   auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
1420df49d1bbSPeter Collingbourne   if (!RetType)
1421df49d1bbSPeter Collingbourne     return false;
1422df49d1bbSPeter Collingbourne   unsigned BitWidth = RetType->getBitWidth();
1423df49d1bbSPeter Collingbourne   if (BitWidth > 64)
1424df49d1bbSPeter Collingbourne     return false;
1425df49d1bbSPeter Collingbourne 
142617febdbbSPeter Collingbourne   // Make sure that each function is defined, does not access memory, takes at
142717febdbbSPeter Collingbourne   // least one argument, does not use its first argument (which we assume is
142817febdbbSPeter Collingbourne   // 'this'), and has the same return type.
142937317f12SPeter Collingbourne   //
143037317f12SPeter Collingbourne   // Note that we test whether this copy of the function is readnone, rather
143137317f12SPeter Collingbourne   // than testing function attributes, which must hold for any copy of the
143237317f12SPeter Collingbourne   // function, even a less optimized version substituted at link time. This is
143337317f12SPeter Collingbourne   // sound because the virtual constant propagation optimizations effectively
143437317f12SPeter Collingbourne   // inline all implementations of the virtual function into each call site,
143537317f12SPeter Collingbourne   // rather than using function attributes to perform local optimization.
1436df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
143737317f12SPeter Collingbourne     if (Target.Fn->isDeclaration() ||
143837317f12SPeter Collingbourne         computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
143937317f12SPeter Collingbourne             MAK_ReadNone ||
144017febdbbSPeter Collingbourne         Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
1441df49d1bbSPeter Collingbourne         Target.Fn->getReturnType() != RetType)
1442df49d1bbSPeter Collingbourne       return false;
1443df49d1bbSPeter Collingbourne   }
1444df49d1bbSPeter Collingbourne 
144550cbd7ccSPeter Collingbourne   for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
1446df49d1bbSPeter Collingbourne     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
1447df49d1bbSPeter Collingbourne       continue;
1448df49d1bbSPeter Collingbourne 
144977a8d563SPeter Collingbourne     WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
145077a8d563SPeter Collingbourne     if (Res)
145177a8d563SPeter Collingbourne       ResByArg = &Res->ResByArg[CSByConstantArg.first];
145277a8d563SPeter Collingbourne 
145377a8d563SPeter Collingbourne     if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
1454df49d1bbSPeter Collingbourne       continue;
1455df49d1bbSPeter Collingbourne 
145659675ba0SPeter Collingbourne     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
145759675ba0SPeter Collingbourne                            ResByArg, Slot, CSByConstantArg.first))
1458df49d1bbSPeter Collingbourne       continue;
1459df49d1bbSPeter Collingbourne 
14607efd7506SPeter Collingbourne     // Find an allocation offset in bits in all vtables associated with the
14617efd7506SPeter Collingbourne     // type.
1462df49d1bbSPeter Collingbourne     uint64_t AllocBefore =
1463df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
1464df49d1bbSPeter Collingbourne     uint64_t AllocAfter =
1465df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
1466df49d1bbSPeter Collingbourne 
1467df49d1bbSPeter Collingbourne     // Calculate the total amount of padding needed to store a value at both
1468df49d1bbSPeter Collingbourne     // ends of the object.
1469df49d1bbSPeter Collingbourne     uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
1470df49d1bbSPeter Collingbourne     for (auto &&Target : TargetsForSlot) {
1471df49d1bbSPeter Collingbourne       TotalPaddingBefore += std::max<int64_t>(
1472df49d1bbSPeter Collingbourne           (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
1473df49d1bbSPeter Collingbourne       TotalPaddingAfter += std::max<int64_t>(
1474df49d1bbSPeter Collingbourne           (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
1475df49d1bbSPeter Collingbourne     }
1476df49d1bbSPeter Collingbourne 
1477df49d1bbSPeter Collingbourne     // If the amount of padding is too large, give up.
1478df49d1bbSPeter Collingbourne     // FIXME: do something smarter here.
1479df49d1bbSPeter Collingbourne     if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
1480df49d1bbSPeter Collingbourne       continue;
1481df49d1bbSPeter Collingbourne 
1482df49d1bbSPeter Collingbourne     // Calculate the offset to the value as a (possibly negative) byte offset
1483df49d1bbSPeter Collingbourne     // and (if applicable) a bit offset, and store the values in the targets.
1484df49d1bbSPeter Collingbourne     int64_t OffsetByte;
1485df49d1bbSPeter Collingbourne     uint64_t OffsetBit;
1486df49d1bbSPeter Collingbourne     if (TotalPaddingBefore <= TotalPaddingAfter)
1487df49d1bbSPeter Collingbourne       setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
1488df49d1bbSPeter Collingbourne                             OffsetBit);
1489df49d1bbSPeter Collingbourne     else
1490df49d1bbSPeter Collingbourne       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
1491df49d1bbSPeter Collingbourne                            OffsetBit);
1492df49d1bbSPeter Collingbourne 
1493f3403fd2SIvan Krasin     if (RemarksEnabled)
1494f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
1495f3403fd2SIvan Krasin         Target.WasDevirt = true;
1496f3403fd2SIvan Krasin 
149714dcf02fSPeter Collingbourne 
149814dcf02fSPeter Collingbourne     if (CSByConstantArg.second.isExported()) {
149914dcf02fSPeter Collingbourne       ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
1500b15a35e6SPeter Collingbourne       exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte,
1501b15a35e6SPeter Collingbourne                      ResByArg->Byte);
1502b15a35e6SPeter Collingbourne       exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit,
1503b15a35e6SPeter Collingbourne                      ResByArg->Bit);
150414dcf02fSPeter Collingbourne     }
150514dcf02fSPeter Collingbourne 
150614dcf02fSPeter Collingbourne     // Rewrite each call to a load from OffsetByte/OffsetBit.
1507b15a35e6SPeter Collingbourne     Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
1508b15a35e6SPeter Collingbourne     Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
150950cbd7ccSPeter Collingbourne     applyVirtualConstProp(CSByConstantArg.second,
151050cbd7ccSPeter Collingbourne                           TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
1511df49d1bbSPeter Collingbourne   }
1512df49d1bbSPeter Collingbourne   return true;
1513df49d1bbSPeter Collingbourne }
1514df49d1bbSPeter Collingbourne 
1515df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) {
1516df49d1bbSPeter Collingbourne   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
1517df49d1bbSPeter Collingbourne     return;
1518df49d1bbSPeter Collingbourne 
1519ef5cfc2dSPeter Collingbourne   // Align the before byte array to the global's minimum alignment so that we
1520ef5cfc2dSPeter Collingbourne   // don't break any alignment requirements on the global.
1521ef5cfc2dSPeter Collingbourne   unsigned Align = B.GV->getAlignment();
1522ef5cfc2dSPeter Collingbourne   if (Align == 0)
1523ef5cfc2dSPeter Collingbourne     Align = M.getDataLayout().getABITypeAlignment(B.GV->getValueType());
1524ef5cfc2dSPeter Collingbourne   B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Align));
1525df49d1bbSPeter Collingbourne 
1526df49d1bbSPeter Collingbourne   // Before was stored in reverse order; flip it now.
1527df49d1bbSPeter Collingbourne   for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
1528df49d1bbSPeter Collingbourne     std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
1529df49d1bbSPeter Collingbourne 
1530df49d1bbSPeter Collingbourne   // Build an anonymous global containing the before bytes, followed by the
1531df49d1bbSPeter Collingbourne   // original initializer, followed by the after bytes.
1532df49d1bbSPeter Collingbourne   auto NewInit = ConstantStruct::getAnon(
1533df49d1bbSPeter Collingbourne       {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
1534df49d1bbSPeter Collingbourne        B.GV->getInitializer(),
1535df49d1bbSPeter Collingbourne        ConstantDataArray::get(M.getContext(), B.After.Bytes)});
1536df49d1bbSPeter Collingbourne   auto NewGV =
1537df49d1bbSPeter Collingbourne       new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
1538df49d1bbSPeter Collingbourne                          GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
1539df49d1bbSPeter Collingbourne   NewGV->setSection(B.GV->getSection());
1540df49d1bbSPeter Collingbourne   NewGV->setComdat(B.GV->getComdat());
1541ef5cfc2dSPeter Collingbourne   NewGV->setAlignment(B.GV->getAlignment());
1542df49d1bbSPeter Collingbourne 
15430312f614SPeter Collingbourne   // Copy the original vtable's metadata to the anonymous global, adjusting
15440312f614SPeter Collingbourne   // offsets as required.
15450312f614SPeter Collingbourne   NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
15460312f614SPeter Collingbourne 
1547df49d1bbSPeter Collingbourne   // Build an alias named after the original global, pointing at the second
1548df49d1bbSPeter Collingbourne   // element (the original initializer).
1549df49d1bbSPeter Collingbourne   auto Alias = GlobalAlias::create(
1550df49d1bbSPeter Collingbourne       B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
1551df49d1bbSPeter Collingbourne       ConstantExpr::getGetElementPtr(
1552df49d1bbSPeter Collingbourne           NewInit->getType(), NewGV,
1553df49d1bbSPeter Collingbourne           ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
1554df49d1bbSPeter Collingbourne                                ConstantInt::get(Int32Ty, 1)}),
1555df49d1bbSPeter Collingbourne       &M);
1556df49d1bbSPeter Collingbourne   Alias->setVisibility(B.GV->getVisibility());
1557df49d1bbSPeter Collingbourne   Alias->takeName(B.GV);
1558df49d1bbSPeter Collingbourne 
1559df49d1bbSPeter Collingbourne   B.GV->replaceAllUsesWith(Alias);
1560df49d1bbSPeter Collingbourne   B.GV->eraseFromParent();
1561df49d1bbSPeter Collingbourne }
1562df49d1bbSPeter Collingbourne 
1563f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() {
1564f3403fd2SIvan Krasin   const auto &FL = M.getFunctionList();
15655e1c0e76STeresa Johnson   for (const Function &Fn : FL) {
1566de53bfb9SAdam Nemet     const auto &BBL = Fn.getBasicBlockList();
1567de53bfb9SAdam Nemet     if (BBL.empty())
15685e1c0e76STeresa Johnson       continue;
1569de53bfb9SAdam Nemet     auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());
1570f3403fd2SIvan Krasin     return DI.isEnabled();
1571f3403fd2SIvan Krasin   }
15725e1c0e76STeresa Johnson   return false;
15735e1c0e76STeresa Johnson }
1574f3403fd2SIvan Krasin 
15750312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
15760312f614SPeter Collingbourne                                      Function *AssumeFunc) {
1577df49d1bbSPeter Collingbourne   // Find all virtual calls via a virtual table pointer %p under an assumption
15787efd7506SPeter Collingbourne   // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
15797efd7506SPeter Collingbourne   // points to a member of the type identifier %md. Group calls by (type ID,
15807efd7506SPeter Collingbourne   // offset) pair (effectively the identity of the virtual function) and store
15817efd7506SPeter Collingbourne   // to CallSlots.
1582f24136f1STeresa Johnson   DenseSet<CallSite> SeenCallSites;
15837efd7506SPeter Collingbourne   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
1584df49d1bbSPeter Collingbourne        I != E;) {
1585df49d1bbSPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
1586df49d1bbSPeter Collingbourne     ++I;
1587df49d1bbSPeter Collingbourne     if (!CI)
1588df49d1bbSPeter Collingbourne       continue;
1589df49d1bbSPeter Collingbourne 
1590ccdc225cSPeter Collingbourne     // Search for virtual calls based on %p and add them to DevirtCalls.
1591ccdc225cSPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
1592df49d1bbSPeter Collingbourne     SmallVector<CallInst *, 1> Assumes;
1593f24136f1STeresa Johnson     auto &DT = LookupDomTree(*CI->getFunction());
1594f24136f1STeresa Johnson     findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
1595df49d1bbSPeter Collingbourne 
1596f24136f1STeresa Johnson     // If we found any, add them to CallSlots.
1597df49d1bbSPeter Collingbourne     if (!Assumes.empty()) {
15987efd7506SPeter Collingbourne       Metadata *TypeId =
1599df49d1bbSPeter Collingbourne           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
1600df49d1bbSPeter Collingbourne       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
1601ccdc225cSPeter Collingbourne       for (DevirtCallSite Call : DevirtCalls) {
1602f24136f1STeresa Johnson         // Only add this CallSite if we haven't seen it before. The vtable
1603f24136f1STeresa Johnson         // pointer may have been CSE'd with pointers from other call sites,
1604f24136f1STeresa Johnson         // and we don't want to process call sites multiple times. We can't
1605f24136f1STeresa Johnson         // just skip the vtable Ptr if it has been seen before, however, since
1606f24136f1STeresa Johnson         // it may be shared by type tests that dominate different calls.
1607f24136f1STeresa Johnson         if (SeenCallSites.insert(Call.CS).second)
1608001052a0SPeter Collingbourne           CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
1609ccdc225cSPeter Collingbourne       }
1610ccdc225cSPeter Collingbourne     }
1611df49d1bbSPeter Collingbourne 
16127efd7506SPeter Collingbourne     // We no longer need the assumes or the type test.
1613df49d1bbSPeter Collingbourne     for (auto Assume : Assumes)
1614df49d1bbSPeter Collingbourne       Assume->eraseFromParent();
1615df49d1bbSPeter Collingbourne     // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
1616df49d1bbSPeter Collingbourne     // may use the vtable argument later.
1617df49d1bbSPeter Collingbourne     if (CI->use_empty())
1618df49d1bbSPeter Collingbourne       CI->eraseFromParent();
1619df49d1bbSPeter Collingbourne   }
16200312f614SPeter Collingbourne }
16210312f614SPeter Collingbourne 
16220312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
16230312f614SPeter Collingbourne   Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
16240312f614SPeter Collingbourne 
16250312f614SPeter Collingbourne   for (auto I = TypeCheckedLoadFunc->use_begin(),
16260312f614SPeter Collingbourne             E = TypeCheckedLoadFunc->use_end();
16270312f614SPeter Collingbourne        I != E;) {
16280312f614SPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
16290312f614SPeter Collingbourne     ++I;
16300312f614SPeter Collingbourne     if (!CI)
16310312f614SPeter Collingbourne       continue;
16320312f614SPeter Collingbourne 
16330312f614SPeter Collingbourne     Value *Ptr = CI->getArgOperand(0);
16340312f614SPeter Collingbourne     Value *Offset = CI->getArgOperand(1);
16350312f614SPeter Collingbourne     Value *TypeIdValue = CI->getArgOperand(2);
16360312f614SPeter Collingbourne     Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
16370312f614SPeter Collingbourne 
16380312f614SPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
16390312f614SPeter Collingbourne     SmallVector<Instruction *, 1> LoadedPtrs;
16400312f614SPeter Collingbourne     SmallVector<Instruction *, 1> Preds;
16410312f614SPeter Collingbourne     bool HasNonCallUses = false;
1642f24136f1STeresa Johnson     auto &DT = LookupDomTree(*CI->getFunction());
16430312f614SPeter Collingbourne     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
1644f24136f1STeresa Johnson                                                HasNonCallUses, CI, DT);
16450312f614SPeter Collingbourne 
16460312f614SPeter Collingbourne     // Start by generating "pessimistic" code that explicitly loads the function
16470312f614SPeter Collingbourne     // pointer from the vtable and performs the type check. If possible, we will
16480312f614SPeter Collingbourne     // eliminate the load and the type check later.
16490312f614SPeter Collingbourne 
16500312f614SPeter Collingbourne     // If possible, only generate the load at the point where it is used.
16510312f614SPeter Collingbourne     // This helps avoid unnecessary spills.
16520312f614SPeter Collingbourne     IRBuilder<> LoadB(
16530312f614SPeter Collingbourne         (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
16540312f614SPeter Collingbourne     Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
16550312f614SPeter Collingbourne     Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
16560312f614SPeter Collingbourne     Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
16570312f614SPeter Collingbourne 
16580312f614SPeter Collingbourne     for (Instruction *LoadedPtr : LoadedPtrs) {
16590312f614SPeter Collingbourne       LoadedPtr->replaceAllUsesWith(LoadedValue);
16600312f614SPeter Collingbourne       LoadedPtr->eraseFromParent();
16610312f614SPeter Collingbourne     }
16620312f614SPeter Collingbourne 
16630312f614SPeter Collingbourne     // Likewise for the type test.
16640312f614SPeter Collingbourne     IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
16650312f614SPeter Collingbourne     CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
16660312f614SPeter Collingbourne 
16670312f614SPeter Collingbourne     for (Instruction *Pred : Preds) {
16680312f614SPeter Collingbourne       Pred->replaceAllUsesWith(TypeTestCall);
16690312f614SPeter Collingbourne       Pred->eraseFromParent();
16700312f614SPeter Collingbourne     }
16710312f614SPeter Collingbourne 
16720312f614SPeter Collingbourne     // We have already erased any extractvalue instructions that refer to the
16730312f614SPeter Collingbourne     // intrinsic call, but the intrinsic may have other non-extractvalue uses
16740312f614SPeter Collingbourne     // (although this is unlikely). In that case, explicitly build a pair and
16750312f614SPeter Collingbourne     // RAUW it.
16760312f614SPeter Collingbourne     if (!CI->use_empty()) {
16770312f614SPeter Collingbourne       Value *Pair = UndefValue::get(CI->getType());
16780312f614SPeter Collingbourne       IRBuilder<> B(CI);
16790312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
16800312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
16810312f614SPeter Collingbourne       CI->replaceAllUsesWith(Pair);
16820312f614SPeter Collingbourne     }
16830312f614SPeter Collingbourne 
16840312f614SPeter Collingbourne     // The number of unsafe uses is initially the number of uses.
16850312f614SPeter Collingbourne     auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
16860312f614SPeter Collingbourne     NumUnsafeUses = DevirtCalls.size();
16870312f614SPeter Collingbourne 
16880312f614SPeter Collingbourne     // If the function pointer has a non-call user, we cannot eliminate the type
16890312f614SPeter Collingbourne     // check, as one of those users may eventually call the pointer. Increment
16900312f614SPeter Collingbourne     // the unsafe use count to make sure it cannot reach zero.
16910312f614SPeter Collingbourne     if (HasNonCallUses)
16920312f614SPeter Collingbourne       ++NumUnsafeUses;
16930312f614SPeter Collingbourne     for (DevirtCallSite Call : DevirtCalls) {
169450cbd7ccSPeter Collingbourne       CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
169550cbd7ccSPeter Collingbourne                                                    &NumUnsafeUses);
16960312f614SPeter Collingbourne     }
16970312f614SPeter Collingbourne 
16980312f614SPeter Collingbourne     CI->eraseFromParent();
16990312f614SPeter Collingbourne   }
17000312f614SPeter Collingbourne }
17010312f614SPeter Collingbourne 
17026d284fabSPeter Collingbourne void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
1703d2df54e6STeresa Johnson   auto *TypeId = dyn_cast<MDString>(Slot.TypeID);
1704d2df54e6STeresa Johnson   if (!TypeId)
1705d2df54e6STeresa Johnson     return;
17069a3f9797SPeter Collingbourne   const TypeIdSummary *TidSummary =
1707d2df54e6STeresa Johnson       ImportSummary->getTypeIdSummary(TypeId->getString());
17089a3f9797SPeter Collingbourne   if (!TidSummary)
17099a3f9797SPeter Collingbourne     return;
17109a3f9797SPeter Collingbourne   auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
17119a3f9797SPeter Collingbourne   if (ResI == TidSummary->WPDRes.end())
17129a3f9797SPeter Collingbourne     return;
17139a3f9797SPeter Collingbourne   const WholeProgramDevirtResolution &Res = ResI->second;
17146d284fabSPeter Collingbourne 
17156d284fabSPeter Collingbourne   if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
1716d2df54e6STeresa Johnson     assert(!Res.SingleImplName.empty());
17176d284fabSPeter Collingbourne     // The type of the function in the declaration is irrelevant because every
17186d284fabSPeter Collingbourne     // call site will cast it to the correct type.
171913680223SJames Y Knight     Constant *SingleImpl =
172013680223SJames Y Knight         cast<Constant>(M.getOrInsertFunction(Res.SingleImplName,
172113680223SJames Y Knight                                              Type::getVoidTy(M.getContext()))
172213680223SJames Y Knight                            .getCallee());
17236d284fabSPeter Collingbourne 
17246d284fabSPeter Collingbourne     // This is the import phase so we should not be exporting anything.
17256d284fabSPeter Collingbourne     bool IsExported = false;
17266d284fabSPeter Collingbourne     applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
17276d284fabSPeter Collingbourne     assert(!IsExported);
17286d284fabSPeter Collingbourne   }
17290152c815SPeter Collingbourne 
17300152c815SPeter Collingbourne   for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
17310152c815SPeter Collingbourne     auto I = Res.ResByArg.find(CSByConstantArg.first);
17320152c815SPeter Collingbourne     if (I == Res.ResByArg.end())
17330152c815SPeter Collingbourne       continue;
17340152c815SPeter Collingbourne     auto &ResByArg = I->second;
17350152c815SPeter Collingbourne     // FIXME: We should figure out what to do about the "function name" argument
17360152c815SPeter Collingbourne     // to the apply* functions, as the function names are unavailable during the
17370152c815SPeter Collingbourne     // importing phase. For now we just pass the empty string. This does not
17380152c815SPeter Collingbourne     // impact correctness because the function names are just used for remarks.
17390152c815SPeter Collingbourne     switch (ResByArg.TheKind) {
17400152c815SPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::UniformRetVal:
17410152c815SPeter Collingbourne       applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
17420152c815SPeter Collingbourne       break;
174359675ba0SPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
174459675ba0SPeter Collingbourne       Constant *UniqueMemberAddr =
174559675ba0SPeter Collingbourne           importGlobal(Slot, CSByConstantArg.first, "unique_member");
174659675ba0SPeter Collingbourne       applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
174759675ba0SPeter Collingbourne                            UniqueMemberAddr);
174859675ba0SPeter Collingbourne       break;
174959675ba0SPeter Collingbourne     }
175014dcf02fSPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
1751b15a35e6SPeter Collingbourne       Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte",
1752b15a35e6SPeter Collingbourne                                       Int32Ty, ResByArg.Byte);
1753b15a35e6SPeter Collingbourne       Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty,
1754b15a35e6SPeter Collingbourne                                      ResByArg.Bit);
175514dcf02fSPeter Collingbourne       applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
17560e6694d1SAdrian Prantl       break;
175714dcf02fSPeter Collingbourne     }
17580152c815SPeter Collingbourne     default:
17590152c815SPeter Collingbourne       break;
17600152c815SPeter Collingbourne     }
17610152c815SPeter Collingbourne   }
17622974856aSPeter Collingbourne 
17632974856aSPeter Collingbourne   if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
176413680223SJames Y Knight     // The type of the function is irrelevant, because it's bitcast at calls
176513680223SJames Y Knight     // anyhow.
176613680223SJames Y Knight     Constant *JT = cast<Constant>(
176713680223SJames Y Knight         M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
176813680223SJames Y Knight                               Type::getVoidTy(M.getContext()))
176913680223SJames Y Knight             .getCallee());
17702974856aSPeter Collingbourne     bool IsExported = false;
17712974856aSPeter Collingbourne     applyICallBranchFunnel(SlotInfo, JT, IsExported);
17722974856aSPeter Collingbourne     assert(!IsExported);
17732974856aSPeter Collingbourne   }
17746d284fabSPeter Collingbourne }
17756d284fabSPeter Collingbourne 
17766d284fabSPeter Collingbourne void DevirtModule::removeRedundantTypeTests() {
17776d284fabSPeter Collingbourne   auto True = ConstantInt::getTrue(M.getContext());
17786d284fabSPeter Collingbourne   for (auto &&U : NumUnsafeUsesForTypeTest) {
17796d284fabSPeter Collingbourne     if (U.second == 0) {
17806d284fabSPeter Collingbourne       U.first->replaceAllUsesWith(True);
17816d284fabSPeter Collingbourne       U.first->eraseFromParent();
17826d284fabSPeter Collingbourne     }
17836d284fabSPeter Collingbourne   }
17846d284fabSPeter Collingbourne }
17856d284fabSPeter Collingbourne 
17860312f614SPeter Collingbourne bool DevirtModule::run() {
1787d0b1f30bSTeresa Johnson   // If only some of the modules were split, we cannot correctly perform
1788d0b1f30bSTeresa Johnson   // this transformation. We already checked for the presense of type tests
1789d0b1f30bSTeresa Johnson   // with partially split modules during the thin link, and would have emitted
1790d0b1f30bSTeresa Johnson   // an error if any were found, so here we can simply return.
1791d0b1f30bSTeresa Johnson   if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) ||
1792d0b1f30bSTeresa Johnson       (ImportSummary && ImportSummary->partiallySplitLTOUnits()))
1793d0b1f30bSTeresa Johnson     return false;
1794d0b1f30bSTeresa Johnson 
17950312f614SPeter Collingbourne   Function *TypeTestFunc =
17960312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
17970312f614SPeter Collingbourne   Function *TypeCheckedLoadFunc =
17980312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
17990312f614SPeter Collingbourne   Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
18000312f614SPeter Collingbourne 
1801b406baaeSPeter Collingbourne   // Normally if there are no users of the devirtualization intrinsics in the
1802b406baaeSPeter Collingbourne   // module, this pass has nothing to do. But if we are exporting, we also need
1803b406baaeSPeter Collingbourne   // to handle any users that appear only in the function summaries.
1804f7691d8bSPeter Collingbourne   if (!ExportSummary &&
1805b406baaeSPeter Collingbourne       (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
18060312f614SPeter Collingbourne        AssumeFunc->use_empty()) &&
18070312f614SPeter Collingbourne       (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
18080312f614SPeter Collingbourne     return false;
18090312f614SPeter Collingbourne 
18100312f614SPeter Collingbourne   if (TypeTestFunc && AssumeFunc)
18110312f614SPeter Collingbourne     scanTypeTestUsers(TypeTestFunc, AssumeFunc);
18120312f614SPeter Collingbourne 
18130312f614SPeter Collingbourne   if (TypeCheckedLoadFunc)
18140312f614SPeter Collingbourne     scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
1815df49d1bbSPeter Collingbourne 
1816f7691d8bSPeter Collingbourne   if (ImportSummary) {
18176d284fabSPeter Collingbourne     for (auto &S : CallSlots)
18186d284fabSPeter Collingbourne       importResolution(S.first, S.second);
18196d284fabSPeter Collingbourne 
18206d284fabSPeter Collingbourne     removeRedundantTypeTests();
18216d284fabSPeter Collingbourne 
18226d284fabSPeter Collingbourne     // The rest of the code is only necessary when exporting or during regular
18236d284fabSPeter Collingbourne     // LTO, so we are done.
18246d284fabSPeter Collingbourne     return true;
18256d284fabSPeter Collingbourne   }
18266d284fabSPeter Collingbourne 
18277efd7506SPeter Collingbourne   // Rebuild type metadata into a map for easy lookup.
1828df49d1bbSPeter Collingbourne   std::vector<VTableBits> Bits;
18297efd7506SPeter Collingbourne   DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
18307efd7506SPeter Collingbourne   buildTypeIdentifierMap(Bits, TypeIdMap);
18317efd7506SPeter Collingbourne   if (TypeIdMap.empty())
1832df49d1bbSPeter Collingbourne     return true;
1833df49d1bbSPeter Collingbourne 
1834b406baaeSPeter Collingbourne   // Collect information from summary about which calls to try to devirtualize.
1835f7691d8bSPeter Collingbourne   if (ExportSummary) {
1836b406baaeSPeter Collingbourne     DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
1837b406baaeSPeter Collingbourne     for (auto &P : TypeIdMap) {
1838b406baaeSPeter Collingbourne       if (auto *TypeId = dyn_cast<MDString>(P.first))
1839b406baaeSPeter Collingbourne         MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
1840b406baaeSPeter Collingbourne             TypeId);
1841b406baaeSPeter Collingbourne     }
1842b406baaeSPeter Collingbourne 
1843f7691d8bSPeter Collingbourne     for (auto &P : *ExportSummary) {
18449667b91bSPeter Collingbourne       for (auto &S : P.second.SummaryList) {
1845b406baaeSPeter Collingbourne         auto *FS = dyn_cast<FunctionSummary>(S.get());
1846b406baaeSPeter Collingbourne         if (!FS)
1847b406baaeSPeter Collingbourne           continue;
1848b406baaeSPeter Collingbourne         // FIXME: Only add live functions.
18495d8aea10SGeorge Rimar         for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
18505d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
18512974856aSPeter Collingbourne             CallSlots[{MD, VF.Offset}]
18522974856aSPeter Collingbourne                 .CSInfo.markSummaryHasTypeTestAssumeUsers();
18535d8aea10SGeorge Rimar           }
18545d8aea10SGeorge Rimar         }
18555d8aea10SGeorge Rimar         for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
18565d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
18572974856aSPeter Collingbourne             CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
18585d8aea10SGeorge Rimar           }
18595d8aea10SGeorge Rimar         }
1860b406baaeSPeter Collingbourne         for (const FunctionSummary::ConstVCall &VC :
18615d8aea10SGeorge Rimar              FS->type_test_assume_const_vcalls()) {
18625d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
18632325bb34SPeter Collingbourne             CallSlots[{MD, VC.VFunc.Offset}]
18645d8aea10SGeorge Rimar                 .ConstCSInfo[VC.Args]
18652974856aSPeter Collingbourne                 .markSummaryHasTypeTestAssumeUsers();
18665d8aea10SGeorge Rimar           }
18675d8aea10SGeorge Rimar         }
18682325bb34SPeter Collingbourne         for (const FunctionSummary::ConstVCall &VC :
18695d8aea10SGeorge Rimar              FS->type_checked_load_const_vcalls()) {
18705d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
1871b406baaeSPeter Collingbourne             CallSlots[{MD, VC.VFunc.Offset}]
1872b406baaeSPeter Collingbourne                 .ConstCSInfo[VC.Args]
18732974856aSPeter Collingbourne                 .addSummaryTypeCheckedLoadUser(FS);
1874b406baaeSPeter Collingbourne           }
1875b406baaeSPeter Collingbourne         }
1876b406baaeSPeter Collingbourne       }
18775d8aea10SGeorge Rimar     }
18785d8aea10SGeorge Rimar   }
1879b406baaeSPeter Collingbourne 
18807efd7506SPeter Collingbourne   // For each (type, offset) pair:
1881df49d1bbSPeter Collingbourne   bool DidVirtualConstProp = false;
1882f3403fd2SIvan Krasin   std::map<std::string, Function*> DevirtTargets;
1883df49d1bbSPeter Collingbourne   for (auto &S : CallSlots) {
18847efd7506SPeter Collingbourne     // Search each of the members of the type identifier for the virtual
18857efd7506SPeter Collingbourne     // function implementation at offset S.first.ByteOffset, and add to
18867efd7506SPeter Collingbourne     // TargetsForSlot.
1887df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> TargetsForSlot;
1888b406baaeSPeter Collingbourne     if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
1889b406baaeSPeter Collingbourne                                   S.first.ByteOffset)) {
18902325bb34SPeter Collingbourne       WholeProgramDevirtResolution *Res = nullptr;
1891f7691d8bSPeter Collingbourne       if (ExportSummary && isa<MDString>(S.first.TypeID))
1892f7691d8bSPeter Collingbourne         Res = &ExportSummary
18939a3f9797SPeter Collingbourne                    ->getOrInsertTypeIdSummary(
18949a3f9797SPeter Collingbourne                        cast<MDString>(S.first.TypeID)->getString())
18952325bb34SPeter Collingbourne                    .WPDRes[S.first.ByteOffset];
18962325bb34SPeter Collingbourne 
18972974856aSPeter Collingbourne       if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) {
18982974856aSPeter Collingbourne         DidVirtualConstProp |=
18992974856aSPeter Collingbourne             tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
19002974856aSPeter Collingbourne 
19012974856aSPeter Collingbourne         tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
19022974856aSPeter Collingbourne       }
1903f3403fd2SIvan Krasin 
1904f3403fd2SIvan Krasin       // Collect functions devirtualized at least for one call site for stats.
1905f3403fd2SIvan Krasin       if (RemarksEnabled)
1906f3403fd2SIvan Krasin         for (const auto &T : TargetsForSlot)
1907f3403fd2SIvan Krasin           if (T.WasDevirt)
1908f3403fd2SIvan Krasin             DevirtTargets[T.Fn->getName()] = T.Fn;
1909b05e06e4SIvan Krasin     }
1910df49d1bbSPeter Collingbourne 
1911b406baaeSPeter Collingbourne     // CFI-specific: if we are exporting and any llvm.type.checked.load
1912b406baaeSPeter Collingbourne     // intrinsics were *not* devirtualized, we need to add the resulting
1913b406baaeSPeter Collingbourne     // llvm.type.test intrinsics to the function summaries so that the
1914b406baaeSPeter Collingbourne     // LowerTypeTests pass will export them.
1915f7691d8bSPeter Collingbourne     if (ExportSummary && isa<MDString>(S.first.TypeID)) {
1916b406baaeSPeter Collingbourne       auto GUID =
1917b406baaeSPeter Collingbourne           GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
1918b406baaeSPeter Collingbourne       for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
1919b406baaeSPeter Collingbourne         FS->addTypeTest(GUID);
1920b406baaeSPeter Collingbourne       for (auto &CCS : S.second.ConstCSInfo)
1921b406baaeSPeter Collingbourne         for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers)
1922b406baaeSPeter Collingbourne           FS->addTypeTest(GUID);
1923b406baaeSPeter Collingbourne     }
1924b406baaeSPeter Collingbourne   }
1925b406baaeSPeter Collingbourne 
1926f3403fd2SIvan Krasin   if (RemarksEnabled) {
1927f3403fd2SIvan Krasin     // Generate remarks for each devirtualized function.
1928f3403fd2SIvan Krasin     for (const auto &DT : DevirtTargets) {
1929f3403fd2SIvan Krasin       Function *F = DT.second;
1930e963c89dSSam Elliott 
1931e963c89dSSam Elliott       using namespace ore;
19329110cb45SPeter Collingbourne       OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
19339110cb45SPeter Collingbourne                         << "devirtualized "
1934d2df54e6STeresa Johnson                         << NV("FunctionName", DT.first));
1935b05e06e4SIvan Krasin     }
1936df49d1bbSPeter Collingbourne   }
1937df49d1bbSPeter Collingbourne 
19386d284fabSPeter Collingbourne   removeRedundantTypeTests();
19390312f614SPeter Collingbourne 
1940df49d1bbSPeter Collingbourne   // Rebuild each global we touched as part of virtual constant propagation to
1941df49d1bbSPeter Collingbourne   // include the before and after bytes.
1942df49d1bbSPeter Collingbourne   if (DidVirtualConstProp)
1943df49d1bbSPeter Collingbourne     for (VTableBits &B : Bits)
1944df49d1bbSPeter Collingbourne       rebuildGlobal(B);
1945df49d1bbSPeter Collingbourne 
1946df49d1bbSPeter Collingbourne   return true;
1947df49d1bbSPeter Collingbourne }
1948d2df54e6STeresa Johnson 
1949d2df54e6STeresa Johnson void DevirtIndex::run() {
1950d2df54e6STeresa Johnson   if (ExportSummary.typeIdCompatibleVtableMap().empty())
1951d2df54e6STeresa Johnson     return;
1952d2df54e6STeresa Johnson 
1953d2df54e6STeresa Johnson   DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
1954d2df54e6STeresa Johnson   for (auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
1955d2df54e6STeresa Johnson     NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first);
1956d2df54e6STeresa Johnson   }
1957d2df54e6STeresa Johnson 
1958d2df54e6STeresa Johnson   // Collect information from summary about which calls to try to devirtualize.
1959d2df54e6STeresa Johnson   for (auto &P : ExportSummary) {
1960d2df54e6STeresa Johnson     for (auto &S : P.second.SummaryList) {
1961d2df54e6STeresa Johnson       auto *FS = dyn_cast<FunctionSummary>(S.get());
1962d2df54e6STeresa Johnson       if (!FS)
1963d2df54e6STeresa Johnson         continue;
1964d2df54e6STeresa Johnson       // FIXME: Only add live functions.
1965d2df54e6STeresa Johnson       for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
1966d2df54e6STeresa Johnson         for (StringRef Name : NameByGUID[VF.GUID]) {
1967d2df54e6STeresa Johnson           CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
1968d2df54e6STeresa Johnson         }
1969d2df54e6STeresa Johnson       }
1970d2df54e6STeresa Johnson       for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
1971d2df54e6STeresa Johnson         for (StringRef Name : NameByGUID[VF.GUID]) {
1972d2df54e6STeresa Johnson           CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
1973d2df54e6STeresa Johnson         }
1974d2df54e6STeresa Johnson       }
1975d2df54e6STeresa Johnson       for (const FunctionSummary::ConstVCall &VC :
1976d2df54e6STeresa Johnson            FS->type_test_assume_const_vcalls()) {
1977d2df54e6STeresa Johnson         for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
1978d2df54e6STeresa Johnson           CallSlots[{Name, VC.VFunc.Offset}]
1979d2df54e6STeresa Johnson               .ConstCSInfo[VC.Args]
1980d2df54e6STeresa Johnson               .addSummaryTypeTestAssumeUser(FS);
1981d2df54e6STeresa Johnson         }
1982d2df54e6STeresa Johnson       }
1983d2df54e6STeresa Johnson       for (const FunctionSummary::ConstVCall &VC :
1984d2df54e6STeresa Johnson            FS->type_checked_load_const_vcalls()) {
1985d2df54e6STeresa Johnson         for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
1986d2df54e6STeresa Johnson           CallSlots[{Name, VC.VFunc.Offset}]
1987d2df54e6STeresa Johnson               .ConstCSInfo[VC.Args]
1988d2df54e6STeresa Johnson               .addSummaryTypeCheckedLoadUser(FS);
1989d2df54e6STeresa Johnson         }
1990d2df54e6STeresa Johnson       }
1991d2df54e6STeresa Johnson     }
1992d2df54e6STeresa Johnson   }
1993d2df54e6STeresa Johnson 
1994d2df54e6STeresa Johnson   std::set<ValueInfo> DevirtTargets;
1995d2df54e6STeresa Johnson   // For each (type, offset) pair:
1996d2df54e6STeresa Johnson   for (auto &S : CallSlots) {
1997d2df54e6STeresa Johnson     // Search each of the members of the type identifier for the virtual
1998d2df54e6STeresa Johnson     // function implementation at offset S.first.ByteOffset, and add to
1999d2df54e6STeresa Johnson     // TargetsForSlot.
2000d2df54e6STeresa Johnson     std::vector<ValueInfo> TargetsForSlot;
2001d2df54e6STeresa Johnson     auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID);
2002d2df54e6STeresa Johnson     assert(TidSummary);
2003d2df54e6STeresa Johnson     if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary,
2004d2df54e6STeresa Johnson                                   S.first.ByteOffset)) {
2005d2df54e6STeresa Johnson       WholeProgramDevirtResolution *Res =
2006d2df54e6STeresa Johnson           &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID)
2007d2df54e6STeresa Johnson                .WPDRes[S.first.ByteOffset];
2008d2df54e6STeresa Johnson 
2009d2df54e6STeresa Johnson       if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res,
2010d2df54e6STeresa Johnson                                DevirtTargets))
2011d2df54e6STeresa Johnson         continue;
2012d2df54e6STeresa Johnson     }
2013d2df54e6STeresa Johnson   }
2014d2df54e6STeresa Johnson 
2015d2df54e6STeresa Johnson   // Optionally have the thin link print message for each devirtualized
2016d2df54e6STeresa Johnson   // function.
2017d2df54e6STeresa Johnson   if (PrintSummaryDevirt)
2018d2df54e6STeresa Johnson     for (const auto &DT : DevirtTargets)
2019d2df54e6STeresa Johnson       errs() << "Devirtualized call to " << DT << "\n";
2020d2df54e6STeresa Johnson 
2021d2df54e6STeresa Johnson   return;
2022d2df54e6STeresa Johnson }
2023