1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2df49d1bbSPeter Collingbourne //
3df49d1bbSPeter Collingbourne //                     The LLVM Compiler Infrastructure
4df49d1bbSPeter Collingbourne //
5df49d1bbSPeter Collingbourne // This file is distributed under the University of Illinois Open Source
6df49d1bbSPeter Collingbourne // License. See LICENSE.TXT for details.
7df49d1bbSPeter Collingbourne //
8df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
9df49d1bbSPeter Collingbourne //
10df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases
117efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This
12df49d1bbSPeter Collingbourne // includes the following:
13df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single
14df49d1bbSPeter Collingbourne //   possible callee, replace all calls with a direct call to that callee.
15df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an
16df49d1bbSPeter Collingbourne //   integer <=64 bits and all possible callees are readnone, for each class and
17df49d1bbSPeter Collingbourne //   each list of constant arguments: evaluate the function, store the return
18df49d1bbSPeter Collingbourne //   value alongside the virtual table, and rewrite each virtual call as a load
19df49d1bbSPeter Collingbourne //   from the virtual table.
20df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant
21df49d1bbSPeter Collingbourne //   propagation hold and each function returns the same constant value, replace
22df49d1bbSPeter Collingbourne //   each virtual call with that constant.
23df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions
24df49d1bbSPeter Collingbourne //   for virtual constant propagation hold and a single vtable's function
25df49d1bbSPeter Collingbourne //   returns 0, or a single vtable's function returns 1, replace each virtual
26df49d1bbSPeter Collingbourne //   call with a comparison of the vptr against that vtable's address.
27df49d1bbSPeter Collingbourne //
28b406baaeSPeter Collingbourne // This pass is intended to be used during the regular and thin LTO pipelines.
29b406baaeSPeter Collingbourne // During regular LTO, the pass determines the best optimization for each
30b406baaeSPeter Collingbourne // virtual call and applies the resolutions directly to virtual calls that are
31b406baaeSPeter Collingbourne // eligible for virtual call optimization (i.e. calls that use either of the
32b406baaeSPeter Collingbourne // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During
33b406baaeSPeter Collingbourne // ThinLTO, the pass operates in two phases:
34b406baaeSPeter Collingbourne // - Export phase: this is run during the thin link over a single merged module
35b406baaeSPeter Collingbourne //   that contains all vtables with !type metadata that participate in the link.
36b406baaeSPeter Collingbourne //   The pass computes a resolution for each virtual call and stores it in the
37b406baaeSPeter Collingbourne //   type identifier summary.
38b406baaeSPeter Collingbourne // - Import phase: this is run during the thin backends over the individual
39b406baaeSPeter Collingbourne //   modules. The pass applies the resolutions previously computed during the
40b406baaeSPeter Collingbourne //   import phase to each eligible virtual call.
41b406baaeSPeter Collingbourne //
42df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===//
43df49d1bbSPeter Collingbourne 
44df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
45b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h"
46cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h"
47cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h"
48df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h"
49df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h"
50cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h"
516bda14b3SChandler Carruth #include "llvm/ADT/iterator_range.h"
5237317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h"
5337317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h"
540965da20SAdam Nemet #include "llvm/Analysis/OptimizationRemarkEmitter.h"
557efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h"
56df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h"
57df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h"
58df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h"
59cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h"
60cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h"
61cdc71612SEugene Zelenko #include "llvm/IR/Function.h"
62cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h"
63cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h"
64df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h"
65cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h"
66cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h"
67df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h"
68df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h"
69cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h"
70cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h"
71df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h"
722b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h"
73df49d1bbSPeter Collingbourne #include "llvm/Pass.h"
74cdc71612SEugene Zelenko #include "llvm/PassRegistry.h"
75cdc71612SEugene Zelenko #include "llvm/PassSupport.h"
76cdc71612SEugene Zelenko #include "llvm/Support/Casting.h"
772b33f653SPeter Collingbourne #include "llvm/Support/Error.h"
782b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h"
79cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h"
80b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h"
8137317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h"
82df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h"
83cdc71612SEugene Zelenko #include <algorithm>
84cdc71612SEugene Zelenko #include <cstddef>
85cdc71612SEugene Zelenko #include <map>
86df49d1bbSPeter Collingbourne #include <set>
87cdc71612SEugene Zelenko #include <string>
88df49d1bbSPeter Collingbourne 
89df49d1bbSPeter Collingbourne using namespace llvm;
90df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt;
91df49d1bbSPeter Collingbourne 
92df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt"
93df49d1bbSPeter Collingbourne 
942b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction(
952b33f653SPeter Collingbourne     "wholeprogramdevirt-summary-action",
962b33f653SPeter Collingbourne     cl::desc("What to do with the summary when running this pass"),
972b33f653SPeter Collingbourne     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
982b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Import, "import",
992b33f653SPeter Collingbourne                           "Import typeid resolutions from summary and globals"),
1002b33f653SPeter Collingbourne                clEnumValN(PassSummaryAction::Export, "export",
1012b33f653SPeter Collingbourne                           "Export typeid resolutions to summary and globals")),
1022b33f653SPeter Collingbourne     cl::Hidden);
1032b33f653SPeter Collingbourne 
1042b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary(
1052b33f653SPeter Collingbourne     "wholeprogramdevirt-read-summary",
1062b33f653SPeter Collingbourne     cl::desc("Read summary from given YAML file before running pass"),
1072b33f653SPeter Collingbourne     cl::Hidden);
1082b33f653SPeter Collingbourne 
1092b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary(
1102b33f653SPeter Collingbourne     "wholeprogramdevirt-write-summary",
1112b33f653SPeter Collingbourne     cl::desc("Write summary to given YAML file after running pass"),
1122b33f653SPeter Collingbourne     cl::Hidden);
1132b33f653SPeter Collingbourne 
1149cb59b92SVitaly Buka static cl::opt<unsigned>
1159cb59b92SVitaly Buka     ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden,
1169cb59b92SVitaly Buka                 cl::init(10), cl::ZeroOrMore,
11766f53d71SVitaly Buka                 cl::desc("Maximum number of call targets per "
11866f53d71SVitaly Buka                          "call site to enable branch funnels"));
11966f53d71SVitaly Buka 
120df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If
121df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an
122df49d1bbSPeter Collingbourne // offset after the object.
123df49d1bbSPeter Collingbourne uint64_t
124df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
125df49d1bbSPeter Collingbourne                                      bool IsAfter, uint64_t Size) {
126df49d1bbSPeter Collingbourne   // Find a minimum offset taking into account only vtable sizes.
127df49d1bbSPeter Collingbourne   uint64_t MinByte = 0;
128df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
129df49d1bbSPeter Collingbourne     if (IsAfter)
130df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minAfterBytes());
131df49d1bbSPeter Collingbourne     else
132df49d1bbSPeter Collingbourne       MinByte = std::max(MinByte, Target.minBeforeBytes());
133df49d1bbSPeter Collingbourne   }
134df49d1bbSPeter Collingbourne 
135df49d1bbSPeter Collingbourne   // Build a vector of arrays of bytes covering, for each target, a slice of the
136df49d1bbSPeter Collingbourne   // used region (see AccumBitVector::BytesUsed in
137df49d1bbSPeter Collingbourne   // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
138df49d1bbSPeter Collingbourne   // this aligns the used regions to start at MinByte.
139df49d1bbSPeter Collingbourne   //
140df49d1bbSPeter Collingbourne   // In this example, A, B and C are vtables, # is a byte already allocated for
141df49d1bbSPeter Collingbourne   // a virtual function pointer, AAAA... (etc.) are the used regions for the
142df49d1bbSPeter Collingbourne   // vtables and Offset(X) is the value computed for the Offset variable below
143df49d1bbSPeter Collingbourne   // for X.
144df49d1bbSPeter Collingbourne   //
145df49d1bbSPeter Collingbourne   //                    Offset(A)
146df49d1bbSPeter Collingbourne   //                    |       |
147df49d1bbSPeter Collingbourne   //                            |MinByte
148df49d1bbSPeter Collingbourne   // A: ################AAAAAAAA|AAAAAAAA
149df49d1bbSPeter Collingbourne   // B: ########BBBBBBBBBBBBBBBB|BBBB
150df49d1bbSPeter Collingbourne   // C: ########################|CCCCCCCCCCCCCCCC
151df49d1bbSPeter Collingbourne   //            |   Offset(B)   |
152df49d1bbSPeter Collingbourne   //
153df49d1bbSPeter Collingbourne   // This code produces the slices of A, B and C that appear after the divider
154df49d1bbSPeter Collingbourne   // at MinByte.
155df49d1bbSPeter Collingbourne   std::vector<ArrayRef<uint8_t>> Used;
156df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : Targets) {
1577efd7506SPeter Collingbourne     ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
1587efd7506SPeter Collingbourne                                        : Target.TM->Bits->Before.BytesUsed;
159df49d1bbSPeter Collingbourne     uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
160df49d1bbSPeter Collingbourne                               : MinByte - Target.minBeforeBytes();
161df49d1bbSPeter Collingbourne 
162df49d1bbSPeter Collingbourne     // Disregard used regions that are smaller than Offset. These are
163df49d1bbSPeter Collingbourne     // effectively all-free regions that do not need to be checked.
164df49d1bbSPeter Collingbourne     if (VTUsed.size() > Offset)
165df49d1bbSPeter Collingbourne       Used.push_back(VTUsed.slice(Offset));
166df49d1bbSPeter Collingbourne   }
167df49d1bbSPeter Collingbourne 
168df49d1bbSPeter Collingbourne   if (Size == 1) {
169df49d1bbSPeter Collingbourne     // Find a free bit in each member of Used.
170df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
171df49d1bbSPeter Collingbourne       uint8_t BitsUsed = 0;
172df49d1bbSPeter Collingbourne       for (auto &&B : Used)
173df49d1bbSPeter Collingbourne         if (I < B.size())
174df49d1bbSPeter Collingbourne           BitsUsed |= B[I];
175df49d1bbSPeter Collingbourne       if (BitsUsed != 0xff)
176df49d1bbSPeter Collingbourne         return (MinByte + I) * 8 +
177df49d1bbSPeter Collingbourne                countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
178df49d1bbSPeter Collingbourne     }
179df49d1bbSPeter Collingbourne   } else {
180df49d1bbSPeter Collingbourne     // Find a free (Size/8) byte region in each member of Used.
181df49d1bbSPeter Collingbourne     // FIXME: see if alignment helps.
182df49d1bbSPeter Collingbourne     for (unsigned I = 0;; ++I) {
183df49d1bbSPeter Collingbourne       for (auto &&B : Used) {
184df49d1bbSPeter Collingbourne         unsigned Byte = 0;
185df49d1bbSPeter Collingbourne         while ((I + Byte) < B.size() && Byte < (Size / 8)) {
186df49d1bbSPeter Collingbourne           if (B[I + Byte])
187df49d1bbSPeter Collingbourne             goto NextI;
188df49d1bbSPeter Collingbourne           ++Byte;
189df49d1bbSPeter Collingbourne         }
190df49d1bbSPeter Collingbourne       }
191df49d1bbSPeter Collingbourne       return (MinByte + I) * 8;
192df49d1bbSPeter Collingbourne     NextI:;
193df49d1bbSPeter Collingbourne     }
194df49d1bbSPeter Collingbourne   }
195df49d1bbSPeter Collingbourne }
196df49d1bbSPeter Collingbourne 
197df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues(
198df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
199df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
200df49d1bbSPeter Collingbourne   if (BitWidth == 1)
201df49d1bbSPeter Collingbourne     OffsetByte = -(AllocBefore / 8 + 1);
202df49d1bbSPeter Collingbourne   else
203df49d1bbSPeter Collingbourne     OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
204df49d1bbSPeter Collingbourne   OffsetBit = AllocBefore % 8;
205df49d1bbSPeter Collingbourne 
206df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
207df49d1bbSPeter Collingbourne     if (BitWidth == 1)
208df49d1bbSPeter Collingbourne       Target.setBeforeBit(AllocBefore);
209df49d1bbSPeter Collingbourne     else
210df49d1bbSPeter Collingbourne       Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
211df49d1bbSPeter Collingbourne   }
212df49d1bbSPeter Collingbourne }
213df49d1bbSPeter Collingbourne 
214df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues(
215df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
216df49d1bbSPeter Collingbourne     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
217df49d1bbSPeter Collingbourne   if (BitWidth == 1)
218df49d1bbSPeter Collingbourne     OffsetByte = AllocAfter / 8;
219df49d1bbSPeter Collingbourne   else
220df49d1bbSPeter Collingbourne     OffsetByte = (AllocAfter + 7) / 8;
221df49d1bbSPeter Collingbourne   OffsetBit = AllocAfter % 8;
222df49d1bbSPeter Collingbourne 
223df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : Targets) {
224df49d1bbSPeter Collingbourne     if (BitWidth == 1)
225df49d1bbSPeter Collingbourne       Target.setAfterBit(AllocAfter);
226df49d1bbSPeter Collingbourne     else
227df49d1bbSPeter Collingbourne       Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
228df49d1bbSPeter Collingbourne   }
229df49d1bbSPeter Collingbourne }
230df49d1bbSPeter Collingbourne 
2317efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
2327efd7506SPeter Collingbourne     : Fn(Fn), TM(TM),
23389439a79SIvan Krasin       IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
234df49d1bbSPeter Collingbourne 
235df49d1bbSPeter Collingbourne namespace {
236df49d1bbSPeter Collingbourne 
2377efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual
238df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to
239df49d1bbSPeter Collingbourne // the virtual function pointer.
240df49d1bbSPeter Collingbourne struct VTableSlot {
2417efd7506SPeter Collingbourne   Metadata *TypeID;
242df49d1bbSPeter Collingbourne   uint64_t ByteOffset;
243df49d1bbSPeter Collingbourne };
244df49d1bbSPeter Collingbourne 
245cdc71612SEugene Zelenko } // end anonymous namespace
246df49d1bbSPeter Collingbourne 
2479b656527SPeter Collingbourne namespace llvm {
2489b656527SPeter Collingbourne 
249df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> {
250df49d1bbSPeter Collingbourne   static VTableSlot getEmptyKey() {
251df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getEmptyKey(),
252df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getEmptyKey()};
253df49d1bbSPeter Collingbourne   }
254df49d1bbSPeter Collingbourne   static VTableSlot getTombstoneKey() {
255df49d1bbSPeter Collingbourne     return {DenseMapInfo<Metadata *>::getTombstoneKey(),
256df49d1bbSPeter Collingbourne             DenseMapInfo<uint64_t>::getTombstoneKey()};
257df49d1bbSPeter Collingbourne   }
258df49d1bbSPeter Collingbourne   static unsigned getHashValue(const VTableSlot &I) {
2597efd7506SPeter Collingbourne     return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
260df49d1bbSPeter Collingbourne            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
261df49d1bbSPeter Collingbourne   }
262df49d1bbSPeter Collingbourne   static bool isEqual(const VTableSlot &LHS,
263df49d1bbSPeter Collingbourne                       const VTableSlot &RHS) {
2647efd7506SPeter Collingbourne     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
265df49d1bbSPeter Collingbourne   }
266df49d1bbSPeter Collingbourne };
267df49d1bbSPeter Collingbourne 
268cdc71612SEugene Zelenko } // end namespace llvm
2699b656527SPeter Collingbourne 
270df49d1bbSPeter Collingbourne namespace {
271df49d1bbSPeter Collingbourne 
272df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is
273df49d1bbSPeter Collingbourne // the indirect virtual call.
274df49d1bbSPeter Collingbourne struct VirtualCallSite {
275df49d1bbSPeter Collingbourne   Value *VTable;
276df49d1bbSPeter Collingbourne   CallSite CS;
277df49d1bbSPeter Collingbourne 
2780312f614SPeter Collingbourne   // If non-null, this field points to the associated unsafe use count stored in
2790312f614SPeter Collingbourne   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
2800312f614SPeter Collingbourne   // of that field for details.
2810312f614SPeter Collingbourne   unsigned *NumUnsafeUses;
2820312f614SPeter Collingbourne 
283e963c89dSSam Elliott   void
284e963c89dSSam Elliott   emitRemark(const StringRef OptName, const StringRef TargetName,
285e963c89dSSam Elliott              function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
2865474645dSIvan Krasin     Function *F = CS.getCaller();
287e963c89dSSam Elliott     DebugLoc DLoc = CS->getDebugLoc();
288e963c89dSSam Elliott     BasicBlock *Block = CS.getParent();
289e963c89dSSam Elliott 
290e963c89dSSam Elliott     using namespace ore;
2919110cb45SPeter Collingbourne     OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
2929110cb45SPeter Collingbourne                       << NV("Optimization", OptName)
2939110cb45SPeter Collingbourne                       << ": devirtualized a call to "
294e963c89dSSam Elliott                       << NV("FunctionName", TargetName));
295e963c89dSSam Elliott   }
296e963c89dSSam Elliott 
297e963c89dSSam Elliott   void replaceAndErase(
298e963c89dSSam Elliott       const StringRef OptName, const StringRef TargetName, bool RemarksEnabled,
299e963c89dSSam Elliott       function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
300e963c89dSSam Elliott       Value *New) {
301f3403fd2SIvan Krasin     if (RemarksEnabled)
302e963c89dSSam Elliott       emitRemark(OptName, TargetName, OREGetter);
303df49d1bbSPeter Collingbourne     CS->replaceAllUsesWith(New);
304df49d1bbSPeter Collingbourne     if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
305df49d1bbSPeter Collingbourne       BranchInst::Create(II->getNormalDest(), CS.getInstruction());
306df49d1bbSPeter Collingbourne       II->getUnwindDest()->removePredecessor(II->getParent());
307df49d1bbSPeter Collingbourne     }
308df49d1bbSPeter Collingbourne     CS->eraseFromParent();
3090312f614SPeter Collingbourne     // This use is no longer unsafe.
3100312f614SPeter Collingbourne     if (NumUnsafeUses)
3110312f614SPeter Collingbourne       --*NumUnsafeUses;
312df49d1bbSPeter Collingbourne   }
313df49d1bbSPeter Collingbourne };
314df49d1bbSPeter Collingbourne 
31550cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list
31650cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the
31750cbd7ccSPeter Collingbourne // VTableSlotInfo class.
31850cbd7ccSPeter Collingbourne struct CallSiteInfo {
319b406baaeSPeter Collingbourne   /// The set of call sites for this slot. Used during regular LTO and the
320b406baaeSPeter Collingbourne   /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
321b406baaeSPeter Collingbourne   /// call sites that appear in the merged module itself); in each of these
322b406baaeSPeter Collingbourne   /// cases we are directly operating on the call sites at the IR level.
32350cbd7ccSPeter Collingbourne   std::vector<VirtualCallSite> CallSites;
324b406baaeSPeter Collingbourne 
3252974856aSPeter Collingbourne   /// Whether all call sites represented by this CallSiteInfo, including those
3262974856aSPeter Collingbourne   /// in summaries, have been devirtualized. This starts off as true because a
3272974856aSPeter Collingbourne   /// default constructed CallSiteInfo represents no call sites.
3282974856aSPeter Collingbourne   bool AllCallSitesDevirted = true;
3292974856aSPeter Collingbourne 
330b406baaeSPeter Collingbourne   // These fields are used during the export phase of ThinLTO and reflect
331b406baaeSPeter Collingbourne   // information collected from function summaries.
332b406baaeSPeter Collingbourne 
3332325bb34SPeter Collingbourne   /// Whether any function summary contains an llvm.assume(llvm.type.test) for
3342325bb34SPeter Collingbourne   /// this slot.
3352974856aSPeter Collingbourne   bool SummaryHasTypeTestAssumeUsers = false;
3362325bb34SPeter Collingbourne 
337b406baaeSPeter Collingbourne   /// CFI-specific: a vector containing the list of function summaries that use
338b406baaeSPeter Collingbourne   /// the llvm.type.checked.load intrinsic and therefore will require
339b406baaeSPeter Collingbourne   /// resolutions for llvm.type.test in order to implement CFI checks if
340b406baaeSPeter Collingbourne   /// devirtualization was unsuccessful. If devirtualization was successful, the
34159675ba0SPeter Collingbourne   /// pass will clear this vector by calling markDevirt(). If at the end of the
34259675ba0SPeter Collingbourne   /// pass the vector is non-empty, we will need to add a use of llvm.type.test
34359675ba0SPeter Collingbourne   /// to each of the function summaries in the vector.
344b406baaeSPeter Collingbourne   std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
3452325bb34SPeter Collingbourne 
3462325bb34SPeter Collingbourne   bool isExported() const {
3472325bb34SPeter Collingbourne     return SummaryHasTypeTestAssumeUsers ||
3482325bb34SPeter Collingbourne            !SummaryTypeCheckedLoadUsers.empty();
3492325bb34SPeter Collingbourne   }
35059675ba0SPeter Collingbourne 
3512974856aSPeter Collingbourne   void markSummaryHasTypeTestAssumeUsers() {
3522974856aSPeter Collingbourne     SummaryHasTypeTestAssumeUsers = true;
3532974856aSPeter Collingbourne     AllCallSitesDevirted = false;
3542974856aSPeter Collingbourne   }
3552974856aSPeter Collingbourne 
3562974856aSPeter Collingbourne   void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
3572974856aSPeter Collingbourne     SummaryTypeCheckedLoadUsers.push_back(FS);
3582974856aSPeter Collingbourne     AllCallSitesDevirted = false;
3592974856aSPeter Collingbourne   }
3602974856aSPeter Collingbourne 
3612974856aSPeter Collingbourne   void markDevirt() {
3622974856aSPeter Collingbourne     AllCallSitesDevirted = true;
3632974856aSPeter Collingbourne 
3642974856aSPeter Collingbourne     // As explained in the comment for SummaryTypeCheckedLoadUsers.
3652974856aSPeter Collingbourne     SummaryTypeCheckedLoadUsers.clear();
3662974856aSPeter Collingbourne   }
36750cbd7ccSPeter Collingbourne };
36850cbd7ccSPeter Collingbourne 
36950cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot.
37050cbd7ccSPeter Collingbourne struct VTableSlotInfo {
37150cbd7ccSPeter Collingbourne   // The set of call sites which do not have all constant integer arguments
37250cbd7ccSPeter Collingbourne   // (excluding "this").
37350cbd7ccSPeter Collingbourne   CallSiteInfo CSInfo;
37450cbd7ccSPeter Collingbourne 
37550cbd7ccSPeter Collingbourne   // The set of call sites with all constant integer arguments (excluding
37650cbd7ccSPeter Collingbourne   // "this"), grouped by argument list.
37750cbd7ccSPeter Collingbourne   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
37850cbd7ccSPeter Collingbourne 
37950cbd7ccSPeter Collingbourne   void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
38050cbd7ccSPeter Collingbourne 
38150cbd7ccSPeter Collingbourne private:
38250cbd7ccSPeter Collingbourne   CallSiteInfo &findCallSiteInfo(CallSite CS);
38350cbd7ccSPeter Collingbourne };
38450cbd7ccSPeter Collingbourne 
38550cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
38650cbd7ccSPeter Collingbourne   std::vector<uint64_t> Args;
38750cbd7ccSPeter Collingbourne   auto *CI = dyn_cast<IntegerType>(CS.getType());
38850cbd7ccSPeter Collingbourne   if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
38950cbd7ccSPeter Collingbourne     return CSInfo;
39050cbd7ccSPeter Collingbourne   for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
39150cbd7ccSPeter Collingbourne     auto *CI = dyn_cast<ConstantInt>(Arg);
39250cbd7ccSPeter Collingbourne     if (!CI || CI->getBitWidth() > 64)
39350cbd7ccSPeter Collingbourne       return CSInfo;
39450cbd7ccSPeter Collingbourne     Args.push_back(CI->getZExtValue());
39550cbd7ccSPeter Collingbourne   }
39650cbd7ccSPeter Collingbourne   return ConstCSInfo[Args];
39750cbd7ccSPeter Collingbourne }
39850cbd7ccSPeter Collingbourne 
39950cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
40050cbd7ccSPeter Collingbourne                                  unsigned *NumUnsafeUses) {
4012974856aSPeter Collingbourne   auto &CSI = findCallSiteInfo(CS);
4022974856aSPeter Collingbourne   CSI.AllCallSitesDevirted = false;
4032974856aSPeter Collingbourne   CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
40450cbd7ccSPeter Collingbourne }
40550cbd7ccSPeter Collingbourne 
406df49d1bbSPeter Collingbourne struct DevirtModule {
407df49d1bbSPeter Collingbourne   Module &M;
40837317f12SPeter Collingbourne   function_ref<AAResults &(Function &)> AARGetter;
4092b33f653SPeter Collingbourne 
410f7691d8bSPeter Collingbourne   ModuleSummaryIndex *ExportSummary;
411f7691d8bSPeter Collingbourne   const ModuleSummaryIndex *ImportSummary;
4122b33f653SPeter Collingbourne 
413df49d1bbSPeter Collingbourne   IntegerType *Int8Ty;
414df49d1bbSPeter Collingbourne   PointerType *Int8PtrTy;
415df49d1bbSPeter Collingbourne   IntegerType *Int32Ty;
41650cbd7ccSPeter Collingbourne   IntegerType *Int64Ty;
41714dcf02fSPeter Collingbourne   IntegerType *IntPtrTy;
418df49d1bbSPeter Collingbourne 
419f3403fd2SIvan Krasin   bool RemarksEnabled;
420e963c89dSSam Elliott   function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter;
421f3403fd2SIvan Krasin 
42250cbd7ccSPeter Collingbourne   MapVector<VTableSlot, VTableSlotInfo> CallSlots;
423df49d1bbSPeter Collingbourne 
4240312f614SPeter Collingbourne   // This map keeps track of the number of "unsafe" uses of a loaded function
4250312f614SPeter Collingbourne   // pointer. The key is the associated llvm.type.test intrinsic call generated
4260312f614SPeter Collingbourne   // by this pass. An unsafe use is one that calls the loaded function pointer
4270312f614SPeter Collingbourne   // directly. Every time we eliminate an unsafe use (for example, by
4280312f614SPeter Collingbourne   // devirtualizing it or by applying virtual constant propagation), we
4290312f614SPeter Collingbourne   // decrement the value stored in this map. If a value reaches zero, we can
4300312f614SPeter Collingbourne   // eliminate the type check by RAUWing the associated llvm.type.test call with
4310312f614SPeter Collingbourne   // true.
4320312f614SPeter Collingbourne   std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
4330312f614SPeter Collingbourne 
43437317f12SPeter Collingbourne   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
435e963c89dSSam Elliott                function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
436f7691d8bSPeter Collingbourne                ModuleSummaryIndex *ExportSummary,
437f7691d8bSPeter Collingbourne                const ModuleSummaryIndex *ImportSummary)
438f7691d8bSPeter Collingbourne       : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary),
439f7691d8bSPeter Collingbourne         ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),
440df49d1bbSPeter Collingbourne         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
441f3403fd2SIvan Krasin         Int32Ty(Type::getInt32Ty(M.getContext())),
44250cbd7ccSPeter Collingbourne         Int64Ty(Type::getInt64Ty(M.getContext())),
44314dcf02fSPeter Collingbourne         IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
444e963c89dSSam Elliott         RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) {
445f7691d8bSPeter Collingbourne     assert(!(ExportSummary && ImportSummary));
446f7691d8bSPeter Collingbourne   }
447f3403fd2SIvan Krasin 
448f3403fd2SIvan Krasin   bool areRemarksEnabled();
449df49d1bbSPeter Collingbourne 
4500312f614SPeter Collingbourne   void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
4510312f614SPeter Collingbourne   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
4520312f614SPeter Collingbourne 
4537efd7506SPeter Collingbourne   void buildTypeIdentifierMap(
4547efd7506SPeter Collingbourne       std::vector<VTableBits> &Bits,
4557efd7506SPeter Collingbourne       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
4568786754cSPeter Collingbourne   Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
4577efd7506SPeter Collingbourne   bool
4587efd7506SPeter Collingbourne   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
4597efd7506SPeter Collingbourne                             const std::set<TypeMemberInfo> &TypeMemberInfos,
460df49d1bbSPeter Collingbourne                             uint64_t ByteOffset);
46150cbd7ccSPeter Collingbourne 
4622325bb34SPeter Collingbourne   void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
4632325bb34SPeter Collingbourne                              bool &IsExported);
464f3403fd2SIvan Krasin   bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
4652325bb34SPeter Collingbourne                            VTableSlotInfo &SlotInfo,
4662325bb34SPeter Collingbourne                            WholeProgramDevirtResolution *Res);
46750cbd7ccSPeter Collingbourne 
4682974856aSPeter Collingbourne   void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
4692974856aSPeter Collingbourne                               bool &IsExported);
4702974856aSPeter Collingbourne   void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
4712974856aSPeter Collingbourne                             VTableSlotInfo &SlotInfo,
4722974856aSPeter Collingbourne                             WholeProgramDevirtResolution *Res, VTableSlot Slot);
4732974856aSPeter Collingbourne 
474df49d1bbSPeter Collingbourne   bool tryEvaluateFunctionsWithArgs(
475df49d1bbSPeter Collingbourne       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
47650cbd7ccSPeter Collingbourne       ArrayRef<uint64_t> Args);
47750cbd7ccSPeter Collingbourne 
47850cbd7ccSPeter Collingbourne   void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
47950cbd7ccSPeter Collingbourne                              uint64_t TheRetVal);
48050cbd7ccSPeter Collingbourne   bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
48177a8d563SPeter Collingbourne                            CallSiteInfo &CSInfo,
48277a8d563SPeter Collingbourne                            WholeProgramDevirtResolution::ByArg *Res);
48350cbd7ccSPeter Collingbourne 
48459675ba0SPeter Collingbourne   // Returns the global symbol name that is used to export information about the
48559675ba0SPeter Collingbourne   // given vtable slot and list of arguments.
48659675ba0SPeter Collingbourne   std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
48759675ba0SPeter Collingbourne                             StringRef Name);
48859675ba0SPeter Collingbourne 
489b15a35e6SPeter Collingbourne   bool shouldExportConstantsAsAbsoluteSymbols();
490b15a35e6SPeter Collingbourne 
49159675ba0SPeter Collingbourne   // This function is called during the export phase to create a symbol
49259675ba0SPeter Collingbourne   // definition containing information about the given vtable slot and list of
49359675ba0SPeter Collingbourne   // arguments.
49459675ba0SPeter Collingbourne   void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
49559675ba0SPeter Collingbourne                     Constant *C);
496b15a35e6SPeter Collingbourne   void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
497b15a35e6SPeter Collingbourne                       uint32_t Const, uint32_t &Storage);
49859675ba0SPeter Collingbourne 
49959675ba0SPeter Collingbourne   // This function is called during the import phase to create a reference to
50059675ba0SPeter Collingbourne   // the symbol definition created during the export phase.
50159675ba0SPeter Collingbourne   Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
502b15a35e6SPeter Collingbourne                          StringRef Name);
503b15a35e6SPeter Collingbourne   Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
504b15a35e6SPeter Collingbourne                            StringRef Name, IntegerType *IntTy,
505b15a35e6SPeter Collingbourne                            uint32_t Storage);
50659675ba0SPeter Collingbourne 
5072974856aSPeter Collingbourne   Constant *getMemberAddr(const TypeMemberInfo *M);
5082974856aSPeter Collingbourne 
50950cbd7ccSPeter Collingbourne   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
51050cbd7ccSPeter Collingbourne                             Constant *UniqueMemberAddr);
511df49d1bbSPeter Collingbourne   bool tryUniqueRetValOpt(unsigned BitWidth,
512f3403fd2SIvan Krasin                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
51359675ba0SPeter Collingbourne                           CallSiteInfo &CSInfo,
51459675ba0SPeter Collingbourne                           WholeProgramDevirtResolution::ByArg *Res,
51559675ba0SPeter Collingbourne                           VTableSlot Slot, ArrayRef<uint64_t> Args);
51650cbd7ccSPeter Collingbourne 
51750cbd7ccSPeter Collingbourne   void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
51850cbd7ccSPeter Collingbourne                              Constant *Byte, Constant *Bit);
519df49d1bbSPeter Collingbourne   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
52077a8d563SPeter Collingbourne                            VTableSlotInfo &SlotInfo,
52159675ba0SPeter Collingbourne                            WholeProgramDevirtResolution *Res, VTableSlot Slot);
522df49d1bbSPeter Collingbourne 
523df49d1bbSPeter Collingbourne   void rebuildGlobal(VTableBits &B);
524df49d1bbSPeter Collingbourne 
5256d284fabSPeter Collingbourne   // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
5266d284fabSPeter Collingbourne   void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
5276d284fabSPeter Collingbourne 
5286d284fabSPeter Collingbourne   // If we were able to eliminate all unsafe uses for a type checked load,
5296d284fabSPeter Collingbourne   // eliminate the associated type tests by replacing them with true.
5306d284fabSPeter Collingbourne   void removeRedundantTypeTests();
5316d284fabSPeter Collingbourne 
532df49d1bbSPeter Collingbourne   bool run();
5332b33f653SPeter Collingbourne 
5342b33f653SPeter Collingbourne   // Lower the module using the action and summary passed as command line
5352b33f653SPeter Collingbourne   // arguments. For testing purposes only.
536e963c89dSSam Elliott   static bool runForTesting(
537e963c89dSSam Elliott       Module &M, function_ref<AAResults &(Function &)> AARGetter,
538e963c89dSSam Elliott       function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter);
539df49d1bbSPeter Collingbourne };
540df49d1bbSPeter Collingbourne 
541df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass {
542df49d1bbSPeter Collingbourne   static char ID;
543cdc71612SEugene Zelenko 
5442b33f653SPeter Collingbourne   bool UseCommandLine = false;
5452b33f653SPeter Collingbourne 
546f7691d8bSPeter Collingbourne   ModuleSummaryIndex *ExportSummary;
547f7691d8bSPeter Collingbourne   const ModuleSummaryIndex *ImportSummary;
5482b33f653SPeter Collingbourne 
5492b33f653SPeter Collingbourne   WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
5502b33f653SPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
5512b33f653SPeter Collingbourne   }
5522b33f653SPeter Collingbourne 
553f7691d8bSPeter Collingbourne   WholeProgramDevirt(ModuleSummaryIndex *ExportSummary,
554f7691d8bSPeter Collingbourne                      const ModuleSummaryIndex *ImportSummary)
555f7691d8bSPeter Collingbourne       : ModulePass(ID), ExportSummary(ExportSummary),
556f7691d8bSPeter Collingbourne         ImportSummary(ImportSummary) {
557df49d1bbSPeter Collingbourne     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
558df49d1bbSPeter Collingbourne   }
559cdc71612SEugene Zelenko 
560cdc71612SEugene Zelenko   bool runOnModule(Module &M) override {
561aa641a51SAndrew Kaylor     if (skipModule(M))
562aa641a51SAndrew Kaylor       return false;
563e963c89dSSam Elliott 
5649110cb45SPeter Collingbourne     // In the new pass manager, we can request the optimization
5659110cb45SPeter Collingbourne     // remark emitter pass on a per-function-basis, which the
5669110cb45SPeter Collingbourne     // OREGetter will do for us.
5679110cb45SPeter Collingbourne     // In the old pass manager, this is harder, so we just build
5689110cb45SPeter Collingbourne     // an optimization remark emitter on the fly, when we need it.
5699110cb45SPeter Collingbourne     std::unique_ptr<OptimizationRemarkEmitter> ORE;
5709110cb45SPeter Collingbourne     auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
5719110cb45SPeter Collingbourne       ORE = make_unique<OptimizationRemarkEmitter>(F);
5729110cb45SPeter Collingbourne       return *ORE;
5739110cb45SPeter Collingbourne     };
574e963c89dSSam Elliott 
5752b33f653SPeter Collingbourne     if (UseCommandLine)
576e963c89dSSam Elliott       return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter);
577e963c89dSSam Elliott 
578e963c89dSSam Elliott     return DevirtModule(M, LegacyAARGetter(*this), OREGetter, ExportSummary,
579e963c89dSSam Elliott                         ImportSummary)
580f7691d8bSPeter Collingbourne         .run();
58137317f12SPeter Collingbourne   }
58237317f12SPeter Collingbourne 
58337317f12SPeter Collingbourne   void getAnalysisUsage(AnalysisUsage &AU) const override {
58437317f12SPeter Collingbourne     AU.addRequired<AssumptionCacheTracker>();
58537317f12SPeter Collingbourne     AU.addRequired<TargetLibraryInfoWrapperPass>();
586aa641a51SAndrew Kaylor   }
587df49d1bbSPeter Collingbourne };
588df49d1bbSPeter Collingbourne 
589cdc71612SEugene Zelenko } // end anonymous namespace
590df49d1bbSPeter Collingbourne 
59137317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
59237317f12SPeter Collingbourne                       "Whole program devirtualization", false, false)
59337317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
59437317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
59537317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
596df49d1bbSPeter Collingbourne                     "Whole program devirtualization", false, false)
597df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0;
598df49d1bbSPeter Collingbourne 
599f7691d8bSPeter Collingbourne ModulePass *
600f7691d8bSPeter Collingbourne llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary,
601f7691d8bSPeter Collingbourne                                    const ModuleSummaryIndex *ImportSummary) {
602f7691d8bSPeter Collingbourne   return new WholeProgramDevirt(ExportSummary, ImportSummary);
603df49d1bbSPeter Collingbourne }
604df49d1bbSPeter Collingbourne 
605164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
60637317f12SPeter Collingbourne                                               ModuleAnalysisManager &AM) {
60737317f12SPeter Collingbourne   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
60837317f12SPeter Collingbourne   auto AARGetter = [&](Function &F) -> AAResults & {
60937317f12SPeter Collingbourne     return FAM.getResult<AAManager>(F);
61037317f12SPeter Collingbourne   };
611e963c89dSSam Elliott   auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
612e963c89dSSam Elliott     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
613e963c89dSSam Elliott   };
614e963c89dSSam Elliott   if (!DevirtModule(M, AARGetter, OREGetter, nullptr, nullptr).run())
615d737dd2eSDavide Italiano     return PreservedAnalyses::all();
616d737dd2eSDavide Italiano   return PreservedAnalyses::none();
617d737dd2eSDavide Italiano }
618d737dd2eSDavide Italiano 
61937317f12SPeter Collingbourne bool DevirtModule::runForTesting(
620e963c89dSSam Elliott     Module &M, function_ref<AAResults &(Function &)> AARGetter,
621e963c89dSSam Elliott     function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
622*4ffc3e78STeresa Johnson   ModuleSummaryIndex Summary(/*HaveGVs=*/false);
6232b33f653SPeter Collingbourne 
6242b33f653SPeter Collingbourne   // Handle the command-line summary arguments. This code is for testing
6252b33f653SPeter Collingbourne   // purposes only, so we handle errors directly.
6262b33f653SPeter Collingbourne   if (!ClReadSummary.empty()) {
6272b33f653SPeter Collingbourne     ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
6282b33f653SPeter Collingbourne                           ": ");
6292b33f653SPeter Collingbourne     auto ReadSummaryFile =
6302b33f653SPeter Collingbourne         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
6312b33f653SPeter Collingbourne 
6322b33f653SPeter Collingbourne     yaml::Input In(ReadSummaryFile->getBuffer());
6332b33f653SPeter Collingbourne     In >> Summary;
6342b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(In.error()));
6352b33f653SPeter Collingbourne   }
6362b33f653SPeter Collingbourne 
637f7691d8bSPeter Collingbourne   bool Changed =
638f7691d8bSPeter Collingbourne       DevirtModule(
639e963c89dSSam Elliott           M, AARGetter, OREGetter,
640f7691d8bSPeter Collingbourne           ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
641f7691d8bSPeter Collingbourne           ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
642f7691d8bSPeter Collingbourne           .run();
6432b33f653SPeter Collingbourne 
6442b33f653SPeter Collingbourne   if (!ClWriteSummary.empty()) {
6452b33f653SPeter Collingbourne     ExitOnError ExitOnErr(
6462b33f653SPeter Collingbourne         "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
6472b33f653SPeter Collingbourne     std::error_code EC;
6482b33f653SPeter Collingbourne     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
6492b33f653SPeter Collingbourne     ExitOnErr(errorCodeToError(EC));
6502b33f653SPeter Collingbourne 
6512b33f653SPeter Collingbourne     yaml::Output Out(OS);
6522b33f653SPeter Collingbourne     Out << Summary;
6532b33f653SPeter Collingbourne   }
6542b33f653SPeter Collingbourne 
6552b33f653SPeter Collingbourne   return Changed;
6562b33f653SPeter Collingbourne }
6572b33f653SPeter Collingbourne 
6587efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap(
659df49d1bbSPeter Collingbourne     std::vector<VTableBits> &Bits,
6607efd7506SPeter Collingbourne     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
661df49d1bbSPeter Collingbourne   DenseMap<GlobalVariable *, VTableBits *> GVToBits;
6627efd7506SPeter Collingbourne   Bits.reserve(M.getGlobalList().size());
6637efd7506SPeter Collingbourne   SmallVector<MDNode *, 2> Types;
6647efd7506SPeter Collingbourne   for (GlobalVariable &GV : M.globals()) {
6657efd7506SPeter Collingbourne     Types.clear();
6667efd7506SPeter Collingbourne     GV.getMetadata(LLVMContext::MD_type, Types);
6677efd7506SPeter Collingbourne     if (Types.empty())
668df49d1bbSPeter Collingbourne       continue;
669df49d1bbSPeter Collingbourne 
6707efd7506SPeter Collingbourne     VTableBits *&BitsPtr = GVToBits[&GV];
6717efd7506SPeter Collingbourne     if (!BitsPtr) {
6727efd7506SPeter Collingbourne       Bits.emplace_back();
6737efd7506SPeter Collingbourne       Bits.back().GV = &GV;
6747efd7506SPeter Collingbourne       Bits.back().ObjectSize =
6757efd7506SPeter Collingbourne           M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
6767efd7506SPeter Collingbourne       BitsPtr = &Bits.back();
6777efd7506SPeter Collingbourne     }
6787efd7506SPeter Collingbourne 
6797efd7506SPeter Collingbourne     for (MDNode *Type : Types) {
6807efd7506SPeter Collingbourne       auto TypeID = Type->getOperand(1).get();
681df49d1bbSPeter Collingbourne 
682df49d1bbSPeter Collingbourne       uint64_t Offset =
683df49d1bbSPeter Collingbourne           cast<ConstantInt>(
6847efd7506SPeter Collingbourne               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
685df49d1bbSPeter Collingbourne               ->getZExtValue();
686df49d1bbSPeter Collingbourne 
6877efd7506SPeter Collingbourne       TypeIdMap[TypeID].insert({BitsPtr, Offset});
688df49d1bbSPeter Collingbourne     }
689df49d1bbSPeter Collingbourne   }
690df49d1bbSPeter Collingbourne }
691df49d1bbSPeter Collingbourne 
6928786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
6938786754cSPeter Collingbourne   if (I->getType()->isPointerTy()) {
6948786754cSPeter Collingbourne     if (Offset == 0)
6958786754cSPeter Collingbourne       return I;
6968786754cSPeter Collingbourne     return nullptr;
6978786754cSPeter Collingbourne   }
6988786754cSPeter Collingbourne 
6997a1e5bbeSPeter Collingbourne   const DataLayout &DL = M.getDataLayout();
7007a1e5bbeSPeter Collingbourne 
7017a1e5bbeSPeter Collingbourne   if (auto *C = dyn_cast<ConstantStruct>(I)) {
7027a1e5bbeSPeter Collingbourne     const StructLayout *SL = DL.getStructLayout(C->getType());
7037a1e5bbeSPeter Collingbourne     if (Offset >= SL->getSizeInBytes())
7047a1e5bbeSPeter Collingbourne       return nullptr;
7057a1e5bbeSPeter Collingbourne 
7068786754cSPeter Collingbourne     unsigned Op = SL->getElementContainingOffset(Offset);
7078786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
7088786754cSPeter Collingbourne                               Offset - SL->getElementOffset(Op));
7098786754cSPeter Collingbourne   }
7108786754cSPeter Collingbourne   if (auto *C = dyn_cast<ConstantArray>(I)) {
7117a1e5bbeSPeter Collingbourne     ArrayType *VTableTy = C->getType();
7127a1e5bbeSPeter Collingbourne     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
7137a1e5bbeSPeter Collingbourne 
7148786754cSPeter Collingbourne     unsigned Op = Offset / ElemSize;
7157a1e5bbeSPeter Collingbourne     if (Op >= C->getNumOperands())
7167a1e5bbeSPeter Collingbourne       return nullptr;
7177a1e5bbeSPeter Collingbourne 
7188786754cSPeter Collingbourne     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
7198786754cSPeter Collingbourne                               Offset % ElemSize);
7208786754cSPeter Collingbourne   }
7218786754cSPeter Collingbourne   return nullptr;
7227a1e5bbeSPeter Collingbourne }
7237a1e5bbeSPeter Collingbourne 
724df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets(
725df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> &TargetsForSlot,
7267efd7506SPeter Collingbourne     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
7277efd7506SPeter Collingbourne   for (const TypeMemberInfo &TM : TypeMemberInfos) {
7287efd7506SPeter Collingbourne     if (!TM.Bits->GV->isConstant())
729df49d1bbSPeter Collingbourne       return false;
730df49d1bbSPeter Collingbourne 
7318786754cSPeter Collingbourne     Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
7328786754cSPeter Collingbourne                                        TM.Offset + ByteOffset);
7338786754cSPeter Collingbourne     if (!Ptr)
734df49d1bbSPeter Collingbourne       return false;
735df49d1bbSPeter Collingbourne 
7368786754cSPeter Collingbourne     auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
737df49d1bbSPeter Collingbourne     if (!Fn)
738df49d1bbSPeter Collingbourne       return false;
739df49d1bbSPeter Collingbourne 
740df49d1bbSPeter Collingbourne     // We can disregard __cxa_pure_virtual as a possible call target, as
741df49d1bbSPeter Collingbourne     // calls to pure virtuals are UB.
742df49d1bbSPeter Collingbourne     if (Fn->getName() == "__cxa_pure_virtual")
743df49d1bbSPeter Collingbourne       continue;
744df49d1bbSPeter Collingbourne 
7457efd7506SPeter Collingbourne     TargetsForSlot.push_back({Fn, &TM});
746df49d1bbSPeter Collingbourne   }
747df49d1bbSPeter Collingbourne 
748df49d1bbSPeter Collingbourne   // Give up if we couldn't find any targets.
749df49d1bbSPeter Collingbourne   return !TargetsForSlot.empty();
750df49d1bbSPeter Collingbourne }
751df49d1bbSPeter Collingbourne 
75250cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
7532325bb34SPeter Collingbourne                                          Constant *TheFn, bool &IsExported) {
75450cbd7ccSPeter Collingbourne   auto Apply = [&](CallSiteInfo &CSInfo) {
75550cbd7ccSPeter Collingbourne     for (auto &&VCallSite : CSInfo.CallSites) {
756f3403fd2SIvan Krasin       if (RemarksEnabled)
757e963c89dSSam Elliott         VCallSite.emitRemark("single-impl", TheFn->getName(), OREGetter);
758df49d1bbSPeter Collingbourne       VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
759df49d1bbSPeter Collingbourne           TheFn, VCallSite.CS.getCalledValue()->getType()));
7600312f614SPeter Collingbourne       // This use is no longer unsafe.
7610312f614SPeter Collingbourne       if (VCallSite.NumUnsafeUses)
7620312f614SPeter Collingbourne         --*VCallSite.NumUnsafeUses;
763df49d1bbSPeter Collingbourne     }
7642974856aSPeter Collingbourne     if (CSInfo.isExported())
7652325bb34SPeter Collingbourne       IsExported = true;
76659675ba0SPeter Collingbourne     CSInfo.markDevirt();
76750cbd7ccSPeter Collingbourne   };
76850cbd7ccSPeter Collingbourne   Apply(SlotInfo.CSInfo);
76950cbd7ccSPeter Collingbourne   for (auto &P : SlotInfo.ConstCSInfo)
77050cbd7ccSPeter Collingbourne     Apply(P.second);
77150cbd7ccSPeter Collingbourne }
77250cbd7ccSPeter Collingbourne 
77350cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt(
77450cbd7ccSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
7752325bb34SPeter Collingbourne     VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {
77650cbd7ccSPeter Collingbourne   // See if the program contains a single implementation of this virtual
77750cbd7ccSPeter Collingbourne   // function.
77850cbd7ccSPeter Collingbourne   Function *TheFn = TargetsForSlot[0].Fn;
77950cbd7ccSPeter Collingbourne   for (auto &&Target : TargetsForSlot)
78050cbd7ccSPeter Collingbourne     if (TheFn != Target.Fn)
78150cbd7ccSPeter Collingbourne       return false;
78250cbd7ccSPeter Collingbourne 
78350cbd7ccSPeter Collingbourne   // If so, update each call site to call that implementation directly.
78450cbd7ccSPeter Collingbourne   if (RemarksEnabled)
78550cbd7ccSPeter Collingbourne     TargetsForSlot[0].WasDevirt = true;
7862325bb34SPeter Collingbourne 
7872325bb34SPeter Collingbourne   bool IsExported = false;
7882325bb34SPeter Collingbourne   applySingleImplDevirt(SlotInfo, TheFn, IsExported);
7892325bb34SPeter Collingbourne   if (!IsExported)
7902325bb34SPeter Collingbourne     return false;
7912325bb34SPeter Collingbourne 
7922325bb34SPeter Collingbourne   // If the only implementation has local linkage, we must promote to external
7932325bb34SPeter Collingbourne   // to make it visible to thin LTO objects. We can only get here during the
7942325bb34SPeter Collingbourne   // ThinLTO export phase.
7952325bb34SPeter Collingbourne   if (TheFn->hasLocalLinkage()) {
79688a58cf9SPeter Collingbourne     std::string NewName = (TheFn->getName() + "$merged").str();
79788a58cf9SPeter Collingbourne 
79888a58cf9SPeter Collingbourne     // Since we are renaming the function, any comdats with the same name must
79988a58cf9SPeter Collingbourne     // also be renamed. This is required when targeting COFF, as the comdat name
80088a58cf9SPeter Collingbourne     // must match one of the names of the symbols in the comdat.
80188a58cf9SPeter Collingbourne     if (Comdat *C = TheFn->getComdat()) {
80288a58cf9SPeter Collingbourne       if (C->getName() == TheFn->getName()) {
80388a58cf9SPeter Collingbourne         Comdat *NewC = M.getOrInsertComdat(NewName);
80488a58cf9SPeter Collingbourne         NewC->setSelectionKind(C->getSelectionKind());
80588a58cf9SPeter Collingbourne         for (GlobalObject &GO : M.global_objects())
80688a58cf9SPeter Collingbourne           if (GO.getComdat() == C)
80788a58cf9SPeter Collingbourne             GO.setComdat(NewC);
80888a58cf9SPeter Collingbourne       }
80988a58cf9SPeter Collingbourne     }
81088a58cf9SPeter Collingbourne 
8112325bb34SPeter Collingbourne     TheFn->setLinkage(GlobalValue::ExternalLinkage);
8122325bb34SPeter Collingbourne     TheFn->setVisibility(GlobalValue::HiddenVisibility);
81388a58cf9SPeter Collingbourne     TheFn->setName(NewName);
8142325bb34SPeter Collingbourne   }
8152325bb34SPeter Collingbourne 
8162325bb34SPeter Collingbourne   Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
8172325bb34SPeter Collingbourne   Res->SingleImplName = TheFn->getName();
8182325bb34SPeter Collingbourne 
819df49d1bbSPeter Collingbourne   return true;
820df49d1bbSPeter Collingbourne }
821df49d1bbSPeter Collingbourne 
8222974856aSPeter Collingbourne void DevirtModule::tryICallBranchFunnel(
8232974856aSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
8242974856aSPeter Collingbourne     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
8252974856aSPeter Collingbourne   Triple T(M.getTargetTriple());
8262974856aSPeter Collingbourne   if (T.getArch() != Triple::x86_64)
8272974856aSPeter Collingbourne     return;
8282974856aSPeter Collingbourne 
82966f53d71SVitaly Buka   if (TargetsForSlot.size() > ClThreshold)
8302974856aSPeter Collingbourne     return;
8312974856aSPeter Collingbourne 
8322974856aSPeter Collingbourne   bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
8332974856aSPeter Collingbourne   if (!HasNonDevirt)
8342974856aSPeter Collingbourne     for (auto &P : SlotInfo.ConstCSInfo)
8352974856aSPeter Collingbourne       if (!P.second.AllCallSitesDevirted) {
8362974856aSPeter Collingbourne         HasNonDevirt = true;
8372974856aSPeter Collingbourne         break;
8382974856aSPeter Collingbourne       }
8392974856aSPeter Collingbourne 
8402974856aSPeter Collingbourne   if (!HasNonDevirt)
8412974856aSPeter Collingbourne     return;
8422974856aSPeter Collingbourne 
8432974856aSPeter Collingbourne   FunctionType *FT =
8442974856aSPeter Collingbourne       FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
8452974856aSPeter Collingbourne   Function *JT;
8462974856aSPeter Collingbourne   if (isa<MDString>(Slot.TypeID)) {
8472974856aSPeter Collingbourne     JT = Function::Create(FT, Function::ExternalLinkage,
8482974856aSPeter Collingbourne                           getGlobalName(Slot, {}, "branch_funnel"), &M);
8492974856aSPeter Collingbourne     JT->setVisibility(GlobalValue::HiddenVisibility);
8502974856aSPeter Collingbourne   } else {
8512974856aSPeter Collingbourne     JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M);
8522974856aSPeter Collingbourne   }
8532974856aSPeter Collingbourne   JT->addAttribute(1, Attribute::Nest);
8542974856aSPeter Collingbourne 
8552974856aSPeter Collingbourne   std::vector<Value *> JTArgs;
8562974856aSPeter Collingbourne   JTArgs.push_back(JT->arg_begin());
8572974856aSPeter Collingbourne   for (auto &T : TargetsForSlot) {
8582974856aSPeter Collingbourne     JTArgs.push_back(getMemberAddr(T.TM));
8592974856aSPeter Collingbourne     JTArgs.push_back(T.Fn);
8602974856aSPeter Collingbourne   }
8612974856aSPeter Collingbourne 
8622974856aSPeter Collingbourne   BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
8632974856aSPeter Collingbourne   Constant *Intr =
8642974856aSPeter Collingbourne       Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {});
8652974856aSPeter Collingbourne 
8662974856aSPeter Collingbourne   auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
8672974856aSPeter Collingbourne   CI->setTailCallKind(CallInst::TCK_MustTail);
8682974856aSPeter Collingbourne   ReturnInst::Create(M.getContext(), nullptr, BB);
8692974856aSPeter Collingbourne 
8702974856aSPeter Collingbourne   bool IsExported = false;
8712974856aSPeter Collingbourne   applyICallBranchFunnel(SlotInfo, JT, IsExported);
8722974856aSPeter Collingbourne   if (IsExported)
8732974856aSPeter Collingbourne     Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
8742974856aSPeter Collingbourne }
8752974856aSPeter Collingbourne 
8762974856aSPeter Collingbourne void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
8772974856aSPeter Collingbourne                                           Constant *JT, bool &IsExported) {
8782974856aSPeter Collingbourne   auto Apply = [&](CallSiteInfo &CSInfo) {
8792974856aSPeter Collingbourne     if (CSInfo.isExported())
8802974856aSPeter Collingbourne       IsExported = true;
8812974856aSPeter Collingbourne     if (CSInfo.AllCallSitesDevirted)
8822974856aSPeter Collingbourne       return;
8832974856aSPeter Collingbourne     for (auto &&VCallSite : CSInfo.CallSites) {
8842974856aSPeter Collingbourne       CallSite CS = VCallSite.CS;
8852974856aSPeter Collingbourne 
8862974856aSPeter Collingbourne       // Jump tables are only profitable if the retpoline mitigation is enabled.
8872974856aSPeter Collingbourne       Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
8882974856aSPeter Collingbourne       if (FSAttr.hasAttribute(Attribute::None) ||
8892974856aSPeter Collingbourne           !FSAttr.getValueAsString().contains("+retpoline"))
8902974856aSPeter Collingbourne         continue;
8912974856aSPeter Collingbourne 
8922974856aSPeter Collingbourne       if (RemarksEnabled)
8932974856aSPeter Collingbourne         VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter);
8942974856aSPeter Collingbourne 
8952974856aSPeter Collingbourne       // Pass the address of the vtable in the nest register, which is r10 on
8962974856aSPeter Collingbourne       // x86_64.
8972974856aSPeter Collingbourne       std::vector<Type *> NewArgs;
8982974856aSPeter Collingbourne       NewArgs.push_back(Int8PtrTy);
8992974856aSPeter Collingbourne       for (Type *T : CS.getFunctionType()->params())
9002974856aSPeter Collingbourne         NewArgs.push_back(T);
9012974856aSPeter Collingbourne       PointerType *NewFT = PointerType::getUnqual(
9022974856aSPeter Collingbourne           FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
9032974856aSPeter Collingbourne                             CS.getFunctionType()->isVarArg()));
9042974856aSPeter Collingbourne 
9052974856aSPeter Collingbourne       IRBuilder<> IRB(CS.getInstruction());
9062974856aSPeter Collingbourne       std::vector<Value *> Args;
9072974856aSPeter Collingbourne       Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
9082974856aSPeter Collingbourne       for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
9092974856aSPeter Collingbourne         Args.push_back(CS.getArgOperand(I));
9102974856aSPeter Collingbourne 
9112974856aSPeter Collingbourne       CallSite NewCS;
9122974856aSPeter Collingbourne       if (CS.isCall())
9132974856aSPeter Collingbourne         NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args);
9142974856aSPeter Collingbourne       else
9152974856aSPeter Collingbourne         NewCS = IRB.CreateInvoke(
9162974856aSPeter Collingbourne             IRB.CreateBitCast(JT, NewFT),
9172974856aSPeter Collingbourne             cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
9182974856aSPeter Collingbourne             cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
9192974856aSPeter Collingbourne       NewCS.setCallingConv(CS.getCallingConv());
9202974856aSPeter Collingbourne 
9212974856aSPeter Collingbourne       AttributeList Attrs = CS.getAttributes();
9222974856aSPeter Collingbourne       std::vector<AttributeSet> NewArgAttrs;
9232974856aSPeter Collingbourne       NewArgAttrs.push_back(AttributeSet::get(
9242974856aSPeter Collingbourne           M.getContext(), ArrayRef<Attribute>{Attribute::get(
9252974856aSPeter Collingbourne                               M.getContext(), Attribute::Nest)}));
9262974856aSPeter Collingbourne       for (unsigned I = 0; I + 2 <  Attrs.getNumAttrSets(); ++I)
9272974856aSPeter Collingbourne         NewArgAttrs.push_back(Attrs.getParamAttributes(I));
9282974856aSPeter Collingbourne       NewCS.setAttributes(
9292974856aSPeter Collingbourne           AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
9302974856aSPeter Collingbourne                              Attrs.getRetAttributes(), NewArgAttrs));
9312974856aSPeter Collingbourne 
9322974856aSPeter Collingbourne       CS->replaceAllUsesWith(NewCS.getInstruction());
9332974856aSPeter Collingbourne       CS->eraseFromParent();
9342974856aSPeter Collingbourne 
9352974856aSPeter Collingbourne       // This use is no longer unsafe.
9362974856aSPeter Collingbourne       if (VCallSite.NumUnsafeUses)
9372974856aSPeter Collingbourne         --*VCallSite.NumUnsafeUses;
9382974856aSPeter Collingbourne     }
9392974856aSPeter Collingbourne     // Don't mark as devirtualized because there may be callers compiled without
9402974856aSPeter Collingbourne     // retpoline mitigation, which would mean that they are lowered to
9412974856aSPeter Collingbourne     // llvm.type.test and therefore require an llvm.type.test resolution for the
9422974856aSPeter Collingbourne     // type identifier.
9432974856aSPeter Collingbourne   };
9442974856aSPeter Collingbourne   Apply(SlotInfo.CSInfo);
9452974856aSPeter Collingbourne   for (auto &P : SlotInfo.ConstCSInfo)
9462974856aSPeter Collingbourne     Apply(P.second);
9472974856aSPeter Collingbourne }
9482974856aSPeter Collingbourne 
949df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs(
950df49d1bbSPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
95150cbd7ccSPeter Collingbourne     ArrayRef<uint64_t> Args) {
952df49d1bbSPeter Collingbourne   // Evaluate each function and store the result in each target's RetVal
953df49d1bbSPeter Collingbourne   // field.
954df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
955df49d1bbSPeter Collingbourne     if (Target.Fn->arg_size() != Args.size() + 1)
956df49d1bbSPeter Collingbourne       return false;
957df49d1bbSPeter Collingbourne 
958df49d1bbSPeter Collingbourne     Evaluator Eval(M.getDataLayout(), nullptr);
959df49d1bbSPeter Collingbourne     SmallVector<Constant *, 2> EvalArgs;
960df49d1bbSPeter Collingbourne     EvalArgs.push_back(
961df49d1bbSPeter Collingbourne         Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
96250cbd7ccSPeter Collingbourne     for (unsigned I = 0; I != Args.size(); ++I) {
96350cbd7ccSPeter Collingbourne       auto *ArgTy = dyn_cast<IntegerType>(
96450cbd7ccSPeter Collingbourne           Target.Fn->getFunctionType()->getParamType(I + 1));
96550cbd7ccSPeter Collingbourne       if (!ArgTy)
96650cbd7ccSPeter Collingbourne         return false;
96750cbd7ccSPeter Collingbourne       EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
96850cbd7ccSPeter Collingbourne     }
96950cbd7ccSPeter Collingbourne 
970df49d1bbSPeter Collingbourne     Constant *RetVal;
971df49d1bbSPeter Collingbourne     if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
972df49d1bbSPeter Collingbourne         !isa<ConstantInt>(RetVal))
973df49d1bbSPeter Collingbourne       return false;
974df49d1bbSPeter Collingbourne     Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
975df49d1bbSPeter Collingbourne   }
976df49d1bbSPeter Collingbourne   return true;
977df49d1bbSPeter Collingbourne }
978df49d1bbSPeter Collingbourne 
97950cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
98050cbd7ccSPeter Collingbourne                                          uint64_t TheRetVal) {
98150cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites)
98250cbd7ccSPeter Collingbourne     Call.replaceAndErase(
983e963c89dSSam Elliott         "uniform-ret-val", FnName, RemarksEnabled, OREGetter,
98450cbd7ccSPeter Collingbourne         ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
98559675ba0SPeter Collingbourne   CSInfo.markDevirt();
98650cbd7ccSPeter Collingbourne }
98750cbd7ccSPeter Collingbourne 
988df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt(
98977a8d563SPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
99077a8d563SPeter Collingbourne     WholeProgramDevirtResolution::ByArg *Res) {
991df49d1bbSPeter Collingbourne   // Uniform return value optimization. If all functions return the same
992df49d1bbSPeter Collingbourne   // constant, replace all calls with that constant.
993df49d1bbSPeter Collingbourne   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
994df49d1bbSPeter Collingbourne   for (const VirtualCallTarget &Target : TargetsForSlot)
995df49d1bbSPeter Collingbourne     if (Target.RetVal != TheRetVal)
996df49d1bbSPeter Collingbourne       return false;
997df49d1bbSPeter Collingbourne 
99877a8d563SPeter Collingbourne   if (CSInfo.isExported()) {
99977a8d563SPeter Collingbourne     Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
100077a8d563SPeter Collingbourne     Res->Info = TheRetVal;
100177a8d563SPeter Collingbourne   }
100277a8d563SPeter Collingbourne 
100350cbd7ccSPeter Collingbourne   applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
1004f3403fd2SIvan Krasin   if (RemarksEnabled)
1005f3403fd2SIvan Krasin     for (auto &&Target : TargetsForSlot)
1006f3403fd2SIvan Krasin       Target.WasDevirt = true;
1007df49d1bbSPeter Collingbourne   return true;
1008df49d1bbSPeter Collingbourne }
1009df49d1bbSPeter Collingbourne 
101059675ba0SPeter Collingbourne std::string DevirtModule::getGlobalName(VTableSlot Slot,
101159675ba0SPeter Collingbourne                                         ArrayRef<uint64_t> Args,
101259675ba0SPeter Collingbourne                                         StringRef Name) {
101359675ba0SPeter Collingbourne   std::string FullName = "__typeid_";
101459675ba0SPeter Collingbourne   raw_string_ostream OS(FullName);
101559675ba0SPeter Collingbourne   OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
101659675ba0SPeter Collingbourne   for (uint64_t Arg : Args)
101759675ba0SPeter Collingbourne     OS << '_' << Arg;
101859675ba0SPeter Collingbourne   OS << '_' << Name;
101959675ba0SPeter Collingbourne   return OS.str();
102059675ba0SPeter Collingbourne }
102159675ba0SPeter Collingbourne 
1022b15a35e6SPeter Collingbourne bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() {
1023b15a35e6SPeter Collingbourne   Triple T(M.getTargetTriple());
1024b15a35e6SPeter Collingbourne   return (T.getArch() == Triple::x86 || T.getArch() == Triple::x86_64) &&
1025b15a35e6SPeter Collingbourne          T.getObjectFormat() == Triple::ELF;
1026b15a35e6SPeter Collingbourne }
1027b15a35e6SPeter Collingbourne 
102859675ba0SPeter Collingbourne void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
102959675ba0SPeter Collingbourne                                 StringRef Name, Constant *C) {
103059675ba0SPeter Collingbourne   GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
103159675ba0SPeter Collingbourne                                         getGlobalName(Slot, Args, Name), C, &M);
103259675ba0SPeter Collingbourne   GA->setVisibility(GlobalValue::HiddenVisibility);
103359675ba0SPeter Collingbourne }
103459675ba0SPeter Collingbourne 
1035b15a35e6SPeter Collingbourne void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1036b15a35e6SPeter Collingbourne                                   StringRef Name, uint32_t Const,
1037b15a35e6SPeter Collingbourne                                   uint32_t &Storage) {
1038b15a35e6SPeter Collingbourne   if (shouldExportConstantsAsAbsoluteSymbols()) {
1039b15a35e6SPeter Collingbourne     exportGlobal(
1040b15a35e6SPeter Collingbourne         Slot, Args, Name,
1041b15a35e6SPeter Collingbourne         ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy));
1042b15a35e6SPeter Collingbourne     return;
1043b15a35e6SPeter Collingbourne   }
1044b15a35e6SPeter Collingbourne 
1045b15a35e6SPeter Collingbourne   Storage = Const;
1046b15a35e6SPeter Collingbourne }
1047b15a35e6SPeter Collingbourne 
104859675ba0SPeter Collingbourne Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1049b15a35e6SPeter Collingbourne                                      StringRef Name) {
105059675ba0SPeter Collingbourne   Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty);
105159675ba0SPeter Collingbourne   auto *GV = dyn_cast<GlobalVariable>(C);
1052b15a35e6SPeter Collingbourne   if (GV)
1053b15a35e6SPeter Collingbourne     GV->setVisibility(GlobalValue::HiddenVisibility);
1054b15a35e6SPeter Collingbourne   return C;
1055b15a35e6SPeter Collingbourne }
1056b15a35e6SPeter Collingbourne 
1057b15a35e6SPeter Collingbourne Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1058b15a35e6SPeter Collingbourne                                        StringRef Name, IntegerType *IntTy,
1059b15a35e6SPeter Collingbourne                                        uint32_t Storage) {
1060b15a35e6SPeter Collingbourne   if (!shouldExportConstantsAsAbsoluteSymbols())
1061b15a35e6SPeter Collingbourne     return ConstantInt::get(IntTy, Storage);
1062b15a35e6SPeter Collingbourne 
1063b15a35e6SPeter Collingbourne   Constant *C = importGlobal(Slot, Args, Name);
1064b15a35e6SPeter Collingbourne   auto *GV = cast<GlobalVariable>(C->stripPointerCasts());
1065b15a35e6SPeter Collingbourne   C = ConstantExpr::getPtrToInt(C, IntTy);
1066b15a35e6SPeter Collingbourne 
106714dcf02fSPeter Collingbourne   // We only need to set metadata if the global is newly created, in which
106814dcf02fSPeter Collingbourne   // case it would not have hidden visibility.
10690deb9a9aSBenjamin Kramer   if (GV->hasMetadata(LLVMContext::MD_absolute_symbol))
107059675ba0SPeter Collingbourne     return C;
107114dcf02fSPeter Collingbourne 
107214dcf02fSPeter Collingbourne   auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
107314dcf02fSPeter Collingbourne     auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
107414dcf02fSPeter Collingbourne     auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
107514dcf02fSPeter Collingbourne     GV->setMetadata(LLVMContext::MD_absolute_symbol,
107614dcf02fSPeter Collingbourne                     MDNode::get(M.getContext(), {MinC, MaxC}));
107714dcf02fSPeter Collingbourne   };
1078b15a35e6SPeter Collingbourne   unsigned AbsWidth = IntTy->getBitWidth();
107914dcf02fSPeter Collingbourne   if (AbsWidth == IntPtrTy->getBitWidth())
108014dcf02fSPeter Collingbourne     SetAbsRange(~0ull, ~0ull); // Full set.
1081b15a35e6SPeter Collingbourne   else
108214dcf02fSPeter Collingbourne     SetAbsRange(0, 1ull << AbsWidth);
1083b15a35e6SPeter Collingbourne   return C;
108459675ba0SPeter Collingbourne }
108559675ba0SPeter Collingbourne 
108650cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
108750cbd7ccSPeter Collingbourne                                         bool IsOne,
108850cbd7ccSPeter Collingbourne                                         Constant *UniqueMemberAddr) {
108950cbd7ccSPeter Collingbourne   for (auto &&Call : CSInfo.CallSites) {
109050cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
1091001052a0SPeter Collingbourne     Value *Cmp =
1092001052a0SPeter Collingbourne         B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
1093001052a0SPeter Collingbourne                      B.CreateBitCast(Call.VTable, Int8PtrTy), UniqueMemberAddr);
109450cbd7ccSPeter Collingbourne     Cmp = B.CreateZExt(Cmp, Call.CS->getType());
1095e963c89dSSam Elliott     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
1096e963c89dSSam Elliott                          Cmp);
109750cbd7ccSPeter Collingbourne   }
109859675ba0SPeter Collingbourne   CSInfo.markDevirt();
109950cbd7ccSPeter Collingbourne }
110050cbd7ccSPeter Collingbourne 
11012974856aSPeter Collingbourne Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
11022974856aSPeter Collingbourne   Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
11032974856aSPeter Collingbourne   return ConstantExpr::getGetElementPtr(Int8Ty, C,
11042974856aSPeter Collingbourne                                         ConstantInt::get(Int64Ty, M->Offset));
11052974856aSPeter Collingbourne }
11062974856aSPeter Collingbourne 
1107df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt(
1108f3403fd2SIvan Krasin     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
110959675ba0SPeter Collingbourne     CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
111059675ba0SPeter Collingbourne     VTableSlot Slot, ArrayRef<uint64_t> Args) {
1111df49d1bbSPeter Collingbourne   // IsOne controls whether we look for a 0 or a 1.
1112df49d1bbSPeter Collingbourne   auto tryUniqueRetValOptFor = [&](bool IsOne) {
1113cdc71612SEugene Zelenko     const TypeMemberInfo *UniqueMember = nullptr;
1114df49d1bbSPeter Collingbourne     for (const VirtualCallTarget &Target : TargetsForSlot) {
11153866cc5fSPeter Collingbourne       if (Target.RetVal == (IsOne ? 1 : 0)) {
11167efd7506SPeter Collingbourne         if (UniqueMember)
1117df49d1bbSPeter Collingbourne           return false;
11187efd7506SPeter Collingbourne         UniqueMember = Target.TM;
1119df49d1bbSPeter Collingbourne       }
1120df49d1bbSPeter Collingbourne     }
1121df49d1bbSPeter Collingbourne 
11227efd7506SPeter Collingbourne     // We should have found a unique member or bailed out by now. We already
1123df49d1bbSPeter Collingbourne     // checked for a uniform return value in tryUniformRetValOpt.
11247efd7506SPeter Collingbourne     assert(UniqueMember);
1125df49d1bbSPeter Collingbourne 
11262974856aSPeter Collingbourne     Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
112759675ba0SPeter Collingbourne     if (CSInfo.isExported()) {
112859675ba0SPeter Collingbourne       Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
112959675ba0SPeter Collingbourne       Res->Info = IsOne;
113059675ba0SPeter Collingbourne 
113159675ba0SPeter Collingbourne       exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
113259675ba0SPeter Collingbourne     }
113359675ba0SPeter Collingbourne 
113459675ba0SPeter Collingbourne     // Replace each call with the comparison.
113550cbd7ccSPeter Collingbourne     applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
113650cbd7ccSPeter Collingbourne                          UniqueMemberAddr);
113750cbd7ccSPeter Collingbourne 
1138f3403fd2SIvan Krasin     // Update devirtualization statistics for targets.
1139f3403fd2SIvan Krasin     if (RemarksEnabled)
1140f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
1141f3403fd2SIvan Krasin         Target.WasDevirt = true;
1142f3403fd2SIvan Krasin 
1143df49d1bbSPeter Collingbourne     return true;
1144df49d1bbSPeter Collingbourne   };
1145df49d1bbSPeter Collingbourne 
1146df49d1bbSPeter Collingbourne   if (BitWidth == 1) {
1147df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(true))
1148df49d1bbSPeter Collingbourne       return true;
1149df49d1bbSPeter Collingbourne     if (tryUniqueRetValOptFor(false))
1150df49d1bbSPeter Collingbourne       return true;
1151df49d1bbSPeter Collingbourne   }
1152df49d1bbSPeter Collingbourne   return false;
1153df49d1bbSPeter Collingbourne }
1154df49d1bbSPeter Collingbourne 
115550cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
115650cbd7ccSPeter Collingbourne                                          Constant *Byte, Constant *Bit) {
115750cbd7ccSPeter Collingbourne   for (auto Call : CSInfo.CallSites) {
115850cbd7ccSPeter Collingbourne     auto *RetType = cast<IntegerType>(Call.CS.getType());
115950cbd7ccSPeter Collingbourne     IRBuilder<> B(Call.CS.getInstruction());
1160001052a0SPeter Collingbourne     Value *Addr =
1161001052a0SPeter Collingbourne         B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
116250cbd7ccSPeter Collingbourne     if (RetType->getBitWidth() == 1) {
116350cbd7ccSPeter Collingbourne       Value *Bits = B.CreateLoad(Addr);
116450cbd7ccSPeter Collingbourne       Value *BitsAndBit = B.CreateAnd(Bits, Bit);
116550cbd7ccSPeter Collingbourne       auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
116650cbd7ccSPeter Collingbourne       Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
1167e963c89dSSam Elliott                            OREGetter, IsBitSet);
116850cbd7ccSPeter Collingbourne     } else {
116950cbd7ccSPeter Collingbourne       Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
117050cbd7ccSPeter Collingbourne       Value *Val = B.CreateLoad(RetType, ValAddr);
1171e963c89dSSam Elliott       Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled,
1172e963c89dSSam Elliott                            OREGetter, Val);
117350cbd7ccSPeter Collingbourne     }
117450cbd7ccSPeter Collingbourne   }
117514dcf02fSPeter Collingbourne   CSInfo.markDevirt();
117650cbd7ccSPeter Collingbourne }
117750cbd7ccSPeter Collingbourne 
1178df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp(
117959675ba0SPeter Collingbourne     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
118059675ba0SPeter Collingbourne     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1181df49d1bbSPeter Collingbourne   // This only works if the function returns an integer.
1182df49d1bbSPeter Collingbourne   auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
1183df49d1bbSPeter Collingbourne   if (!RetType)
1184df49d1bbSPeter Collingbourne     return false;
1185df49d1bbSPeter Collingbourne   unsigned BitWidth = RetType->getBitWidth();
1186df49d1bbSPeter Collingbourne   if (BitWidth > 64)
1187df49d1bbSPeter Collingbourne     return false;
1188df49d1bbSPeter Collingbourne 
118917febdbbSPeter Collingbourne   // Make sure that each function is defined, does not access memory, takes at
119017febdbbSPeter Collingbourne   // least one argument, does not use its first argument (which we assume is
119117febdbbSPeter Collingbourne   // 'this'), and has the same return type.
119237317f12SPeter Collingbourne   //
119337317f12SPeter Collingbourne   // Note that we test whether this copy of the function is readnone, rather
119437317f12SPeter Collingbourne   // than testing function attributes, which must hold for any copy of the
119537317f12SPeter Collingbourne   // function, even a less optimized version substituted at link time. This is
119637317f12SPeter Collingbourne   // sound because the virtual constant propagation optimizations effectively
119737317f12SPeter Collingbourne   // inline all implementations of the virtual function into each call site,
119837317f12SPeter Collingbourne   // rather than using function attributes to perform local optimization.
1199df49d1bbSPeter Collingbourne   for (VirtualCallTarget &Target : TargetsForSlot) {
120037317f12SPeter Collingbourne     if (Target.Fn->isDeclaration() ||
120137317f12SPeter Collingbourne         computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
120237317f12SPeter Collingbourne             MAK_ReadNone ||
120317febdbbSPeter Collingbourne         Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
1204df49d1bbSPeter Collingbourne         Target.Fn->getReturnType() != RetType)
1205df49d1bbSPeter Collingbourne       return false;
1206df49d1bbSPeter Collingbourne   }
1207df49d1bbSPeter Collingbourne 
120850cbd7ccSPeter Collingbourne   for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
1209df49d1bbSPeter Collingbourne     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
1210df49d1bbSPeter Collingbourne       continue;
1211df49d1bbSPeter Collingbourne 
121277a8d563SPeter Collingbourne     WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
121377a8d563SPeter Collingbourne     if (Res)
121477a8d563SPeter Collingbourne       ResByArg = &Res->ResByArg[CSByConstantArg.first];
121577a8d563SPeter Collingbourne 
121677a8d563SPeter Collingbourne     if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
1217df49d1bbSPeter Collingbourne       continue;
1218df49d1bbSPeter Collingbourne 
121959675ba0SPeter Collingbourne     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
122059675ba0SPeter Collingbourne                            ResByArg, Slot, CSByConstantArg.first))
1221df49d1bbSPeter Collingbourne       continue;
1222df49d1bbSPeter Collingbourne 
12237efd7506SPeter Collingbourne     // Find an allocation offset in bits in all vtables associated with the
12247efd7506SPeter Collingbourne     // type.
1225df49d1bbSPeter Collingbourne     uint64_t AllocBefore =
1226df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
1227df49d1bbSPeter Collingbourne     uint64_t AllocAfter =
1228df49d1bbSPeter Collingbourne         findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
1229df49d1bbSPeter Collingbourne 
1230df49d1bbSPeter Collingbourne     // Calculate the total amount of padding needed to store a value at both
1231df49d1bbSPeter Collingbourne     // ends of the object.
1232df49d1bbSPeter Collingbourne     uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
1233df49d1bbSPeter Collingbourne     for (auto &&Target : TargetsForSlot) {
1234df49d1bbSPeter Collingbourne       TotalPaddingBefore += std::max<int64_t>(
1235df49d1bbSPeter Collingbourne           (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
1236df49d1bbSPeter Collingbourne       TotalPaddingAfter += std::max<int64_t>(
1237df49d1bbSPeter Collingbourne           (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
1238df49d1bbSPeter Collingbourne     }
1239df49d1bbSPeter Collingbourne 
1240df49d1bbSPeter Collingbourne     // If the amount of padding is too large, give up.
1241df49d1bbSPeter Collingbourne     // FIXME: do something smarter here.
1242df49d1bbSPeter Collingbourne     if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
1243df49d1bbSPeter Collingbourne       continue;
1244df49d1bbSPeter Collingbourne 
1245df49d1bbSPeter Collingbourne     // Calculate the offset to the value as a (possibly negative) byte offset
1246df49d1bbSPeter Collingbourne     // and (if applicable) a bit offset, and store the values in the targets.
1247df49d1bbSPeter Collingbourne     int64_t OffsetByte;
1248df49d1bbSPeter Collingbourne     uint64_t OffsetBit;
1249df49d1bbSPeter Collingbourne     if (TotalPaddingBefore <= TotalPaddingAfter)
1250df49d1bbSPeter Collingbourne       setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
1251df49d1bbSPeter Collingbourne                             OffsetBit);
1252df49d1bbSPeter Collingbourne     else
1253df49d1bbSPeter Collingbourne       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
1254df49d1bbSPeter Collingbourne                            OffsetBit);
1255df49d1bbSPeter Collingbourne 
1256f3403fd2SIvan Krasin     if (RemarksEnabled)
1257f3403fd2SIvan Krasin       for (auto &&Target : TargetsForSlot)
1258f3403fd2SIvan Krasin         Target.WasDevirt = true;
1259f3403fd2SIvan Krasin 
126014dcf02fSPeter Collingbourne 
126114dcf02fSPeter Collingbourne     if (CSByConstantArg.second.isExported()) {
126214dcf02fSPeter Collingbourne       ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
1263b15a35e6SPeter Collingbourne       exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte,
1264b15a35e6SPeter Collingbourne                      ResByArg->Byte);
1265b15a35e6SPeter Collingbourne       exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit,
1266b15a35e6SPeter Collingbourne                      ResByArg->Bit);
126714dcf02fSPeter Collingbourne     }
126814dcf02fSPeter Collingbourne 
126914dcf02fSPeter Collingbourne     // Rewrite each call to a load from OffsetByte/OffsetBit.
1270b15a35e6SPeter Collingbourne     Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
1271b15a35e6SPeter Collingbourne     Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
127250cbd7ccSPeter Collingbourne     applyVirtualConstProp(CSByConstantArg.second,
127350cbd7ccSPeter Collingbourne                           TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
1274df49d1bbSPeter Collingbourne   }
1275df49d1bbSPeter Collingbourne   return true;
1276df49d1bbSPeter Collingbourne }
1277df49d1bbSPeter Collingbourne 
1278df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) {
1279df49d1bbSPeter Collingbourne   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
1280df49d1bbSPeter Collingbourne     return;
1281df49d1bbSPeter Collingbourne 
1282df49d1bbSPeter Collingbourne   // Align each byte array to pointer width.
1283df49d1bbSPeter Collingbourne   unsigned PointerSize = M.getDataLayout().getPointerSize();
1284df49d1bbSPeter Collingbourne   B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
1285df49d1bbSPeter Collingbourne   B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
1286df49d1bbSPeter Collingbourne 
1287df49d1bbSPeter Collingbourne   // Before was stored in reverse order; flip it now.
1288df49d1bbSPeter Collingbourne   for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
1289df49d1bbSPeter Collingbourne     std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
1290df49d1bbSPeter Collingbourne 
1291df49d1bbSPeter Collingbourne   // Build an anonymous global containing the before bytes, followed by the
1292df49d1bbSPeter Collingbourne   // original initializer, followed by the after bytes.
1293df49d1bbSPeter Collingbourne   auto NewInit = ConstantStruct::getAnon(
1294df49d1bbSPeter Collingbourne       {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
1295df49d1bbSPeter Collingbourne        B.GV->getInitializer(),
1296df49d1bbSPeter Collingbourne        ConstantDataArray::get(M.getContext(), B.After.Bytes)});
1297df49d1bbSPeter Collingbourne   auto NewGV =
1298df49d1bbSPeter Collingbourne       new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
1299df49d1bbSPeter Collingbourne                          GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
1300df49d1bbSPeter Collingbourne   NewGV->setSection(B.GV->getSection());
1301df49d1bbSPeter Collingbourne   NewGV->setComdat(B.GV->getComdat());
1302df49d1bbSPeter Collingbourne 
13030312f614SPeter Collingbourne   // Copy the original vtable's metadata to the anonymous global, adjusting
13040312f614SPeter Collingbourne   // offsets as required.
13050312f614SPeter Collingbourne   NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
13060312f614SPeter Collingbourne 
1307df49d1bbSPeter Collingbourne   // Build an alias named after the original global, pointing at the second
1308df49d1bbSPeter Collingbourne   // element (the original initializer).
1309df49d1bbSPeter Collingbourne   auto Alias = GlobalAlias::create(
1310df49d1bbSPeter Collingbourne       B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
1311df49d1bbSPeter Collingbourne       ConstantExpr::getGetElementPtr(
1312df49d1bbSPeter Collingbourne           NewInit->getType(), NewGV,
1313df49d1bbSPeter Collingbourne           ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
1314df49d1bbSPeter Collingbourne                                ConstantInt::get(Int32Ty, 1)}),
1315df49d1bbSPeter Collingbourne       &M);
1316df49d1bbSPeter Collingbourne   Alias->setVisibility(B.GV->getVisibility());
1317df49d1bbSPeter Collingbourne   Alias->takeName(B.GV);
1318df49d1bbSPeter Collingbourne 
1319df49d1bbSPeter Collingbourne   B.GV->replaceAllUsesWith(Alias);
1320df49d1bbSPeter Collingbourne   B.GV->eraseFromParent();
1321df49d1bbSPeter Collingbourne }
1322df49d1bbSPeter Collingbourne 
1323f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() {
1324f3403fd2SIvan Krasin   const auto &FL = M.getFunctionList();
1325f3403fd2SIvan Krasin   if (FL.empty())
1326f3403fd2SIvan Krasin     return false;
1327f3403fd2SIvan Krasin   const Function &Fn = FL.front();
1328de53bfb9SAdam Nemet 
1329de53bfb9SAdam Nemet   const auto &BBL = Fn.getBasicBlockList();
1330de53bfb9SAdam Nemet   if (BBL.empty())
1331de53bfb9SAdam Nemet     return false;
1332de53bfb9SAdam Nemet   auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());
1333f3403fd2SIvan Krasin   return DI.isEnabled();
1334f3403fd2SIvan Krasin }
1335f3403fd2SIvan Krasin 
13360312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
13370312f614SPeter Collingbourne                                      Function *AssumeFunc) {
1338df49d1bbSPeter Collingbourne   // Find all virtual calls via a virtual table pointer %p under an assumption
13397efd7506SPeter Collingbourne   // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
13407efd7506SPeter Collingbourne   // points to a member of the type identifier %md. Group calls by (type ID,
13417efd7506SPeter Collingbourne   // offset) pair (effectively the identity of the virtual function) and store
13427efd7506SPeter Collingbourne   // to CallSlots.
1343df49d1bbSPeter Collingbourne   DenseSet<Value *> SeenPtrs;
13447efd7506SPeter Collingbourne   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
1345df49d1bbSPeter Collingbourne        I != E;) {
1346df49d1bbSPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
1347df49d1bbSPeter Collingbourne     ++I;
1348df49d1bbSPeter Collingbourne     if (!CI)
1349df49d1bbSPeter Collingbourne       continue;
1350df49d1bbSPeter Collingbourne 
1351ccdc225cSPeter Collingbourne     // Search for virtual calls based on %p and add them to DevirtCalls.
1352ccdc225cSPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
1353df49d1bbSPeter Collingbourne     SmallVector<CallInst *, 1> Assumes;
13540312f614SPeter Collingbourne     findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
1355df49d1bbSPeter Collingbourne 
1356ccdc225cSPeter Collingbourne     // If we found any, add them to CallSlots. Only do this if we haven't seen
1357ccdc225cSPeter Collingbourne     // the vtable pointer before, as it may have been CSE'd with pointers from
1358ccdc225cSPeter Collingbourne     // other call sites, and we don't want to process call sites multiple times.
1359df49d1bbSPeter Collingbourne     if (!Assumes.empty()) {
13607efd7506SPeter Collingbourne       Metadata *TypeId =
1361df49d1bbSPeter Collingbourne           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
1362df49d1bbSPeter Collingbourne       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
1363ccdc225cSPeter Collingbourne       if (SeenPtrs.insert(Ptr).second) {
1364ccdc225cSPeter Collingbourne         for (DevirtCallSite Call : DevirtCalls) {
1365001052a0SPeter Collingbourne           CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
1366ccdc225cSPeter Collingbourne         }
1367ccdc225cSPeter Collingbourne       }
1368df49d1bbSPeter Collingbourne     }
1369df49d1bbSPeter Collingbourne 
13707efd7506SPeter Collingbourne     // We no longer need the assumes or the type test.
1371df49d1bbSPeter Collingbourne     for (auto Assume : Assumes)
1372df49d1bbSPeter Collingbourne       Assume->eraseFromParent();
1373df49d1bbSPeter Collingbourne     // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
1374df49d1bbSPeter Collingbourne     // may use the vtable argument later.
1375df49d1bbSPeter Collingbourne     if (CI->use_empty())
1376df49d1bbSPeter Collingbourne       CI->eraseFromParent();
1377df49d1bbSPeter Collingbourne   }
13780312f614SPeter Collingbourne }
13790312f614SPeter Collingbourne 
13800312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
13810312f614SPeter Collingbourne   Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
13820312f614SPeter Collingbourne 
13830312f614SPeter Collingbourne   for (auto I = TypeCheckedLoadFunc->use_begin(),
13840312f614SPeter Collingbourne             E = TypeCheckedLoadFunc->use_end();
13850312f614SPeter Collingbourne        I != E;) {
13860312f614SPeter Collingbourne     auto CI = dyn_cast<CallInst>(I->getUser());
13870312f614SPeter Collingbourne     ++I;
13880312f614SPeter Collingbourne     if (!CI)
13890312f614SPeter Collingbourne       continue;
13900312f614SPeter Collingbourne 
13910312f614SPeter Collingbourne     Value *Ptr = CI->getArgOperand(0);
13920312f614SPeter Collingbourne     Value *Offset = CI->getArgOperand(1);
13930312f614SPeter Collingbourne     Value *TypeIdValue = CI->getArgOperand(2);
13940312f614SPeter Collingbourne     Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
13950312f614SPeter Collingbourne 
13960312f614SPeter Collingbourne     SmallVector<DevirtCallSite, 1> DevirtCalls;
13970312f614SPeter Collingbourne     SmallVector<Instruction *, 1> LoadedPtrs;
13980312f614SPeter Collingbourne     SmallVector<Instruction *, 1> Preds;
13990312f614SPeter Collingbourne     bool HasNonCallUses = false;
14000312f614SPeter Collingbourne     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
14010312f614SPeter Collingbourne                                                HasNonCallUses, CI);
14020312f614SPeter Collingbourne 
14030312f614SPeter Collingbourne     // Start by generating "pessimistic" code that explicitly loads the function
14040312f614SPeter Collingbourne     // pointer from the vtable and performs the type check. If possible, we will
14050312f614SPeter Collingbourne     // eliminate the load and the type check later.
14060312f614SPeter Collingbourne 
14070312f614SPeter Collingbourne     // If possible, only generate the load at the point where it is used.
14080312f614SPeter Collingbourne     // This helps avoid unnecessary spills.
14090312f614SPeter Collingbourne     IRBuilder<> LoadB(
14100312f614SPeter Collingbourne         (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
14110312f614SPeter Collingbourne     Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
14120312f614SPeter Collingbourne     Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
14130312f614SPeter Collingbourne     Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
14140312f614SPeter Collingbourne 
14150312f614SPeter Collingbourne     for (Instruction *LoadedPtr : LoadedPtrs) {
14160312f614SPeter Collingbourne       LoadedPtr->replaceAllUsesWith(LoadedValue);
14170312f614SPeter Collingbourne       LoadedPtr->eraseFromParent();
14180312f614SPeter Collingbourne     }
14190312f614SPeter Collingbourne 
14200312f614SPeter Collingbourne     // Likewise for the type test.
14210312f614SPeter Collingbourne     IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
14220312f614SPeter Collingbourne     CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
14230312f614SPeter Collingbourne 
14240312f614SPeter Collingbourne     for (Instruction *Pred : Preds) {
14250312f614SPeter Collingbourne       Pred->replaceAllUsesWith(TypeTestCall);
14260312f614SPeter Collingbourne       Pred->eraseFromParent();
14270312f614SPeter Collingbourne     }
14280312f614SPeter Collingbourne 
14290312f614SPeter Collingbourne     // We have already erased any extractvalue instructions that refer to the
14300312f614SPeter Collingbourne     // intrinsic call, but the intrinsic may have other non-extractvalue uses
14310312f614SPeter Collingbourne     // (although this is unlikely). In that case, explicitly build a pair and
14320312f614SPeter Collingbourne     // RAUW it.
14330312f614SPeter Collingbourne     if (!CI->use_empty()) {
14340312f614SPeter Collingbourne       Value *Pair = UndefValue::get(CI->getType());
14350312f614SPeter Collingbourne       IRBuilder<> B(CI);
14360312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
14370312f614SPeter Collingbourne       Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
14380312f614SPeter Collingbourne       CI->replaceAllUsesWith(Pair);
14390312f614SPeter Collingbourne     }
14400312f614SPeter Collingbourne 
14410312f614SPeter Collingbourne     // The number of unsafe uses is initially the number of uses.
14420312f614SPeter Collingbourne     auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
14430312f614SPeter Collingbourne     NumUnsafeUses = DevirtCalls.size();
14440312f614SPeter Collingbourne 
14450312f614SPeter Collingbourne     // If the function pointer has a non-call user, we cannot eliminate the type
14460312f614SPeter Collingbourne     // check, as one of those users may eventually call the pointer. Increment
14470312f614SPeter Collingbourne     // the unsafe use count to make sure it cannot reach zero.
14480312f614SPeter Collingbourne     if (HasNonCallUses)
14490312f614SPeter Collingbourne       ++NumUnsafeUses;
14500312f614SPeter Collingbourne     for (DevirtCallSite Call : DevirtCalls) {
145150cbd7ccSPeter Collingbourne       CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
145250cbd7ccSPeter Collingbourne                                                    &NumUnsafeUses);
14530312f614SPeter Collingbourne     }
14540312f614SPeter Collingbourne 
14550312f614SPeter Collingbourne     CI->eraseFromParent();
14560312f614SPeter Collingbourne   }
14570312f614SPeter Collingbourne }
14580312f614SPeter Collingbourne 
14596d284fabSPeter Collingbourne void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
14609a3f9797SPeter Collingbourne   const TypeIdSummary *TidSummary =
1461f7691d8bSPeter Collingbourne       ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString());
14629a3f9797SPeter Collingbourne   if (!TidSummary)
14639a3f9797SPeter Collingbourne     return;
14649a3f9797SPeter Collingbourne   auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
14659a3f9797SPeter Collingbourne   if (ResI == TidSummary->WPDRes.end())
14669a3f9797SPeter Collingbourne     return;
14679a3f9797SPeter Collingbourne   const WholeProgramDevirtResolution &Res = ResI->second;
14686d284fabSPeter Collingbourne 
14696d284fabSPeter Collingbourne   if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
14706d284fabSPeter Collingbourne     // The type of the function in the declaration is irrelevant because every
14716d284fabSPeter Collingbourne     // call site will cast it to the correct type.
1472db11fdfdSMehdi Amini     auto *SingleImpl = M.getOrInsertFunction(
147359a2d7b9SSerge Guelton         Res.SingleImplName, Type::getVoidTy(M.getContext()));
14746d284fabSPeter Collingbourne 
14756d284fabSPeter Collingbourne     // This is the import phase so we should not be exporting anything.
14766d284fabSPeter Collingbourne     bool IsExported = false;
14776d284fabSPeter Collingbourne     applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
14786d284fabSPeter Collingbourne     assert(!IsExported);
14796d284fabSPeter Collingbourne   }
14800152c815SPeter Collingbourne 
14810152c815SPeter Collingbourne   for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
14820152c815SPeter Collingbourne     auto I = Res.ResByArg.find(CSByConstantArg.first);
14830152c815SPeter Collingbourne     if (I == Res.ResByArg.end())
14840152c815SPeter Collingbourne       continue;
14850152c815SPeter Collingbourne     auto &ResByArg = I->second;
14860152c815SPeter Collingbourne     // FIXME: We should figure out what to do about the "function name" argument
14870152c815SPeter Collingbourne     // to the apply* functions, as the function names are unavailable during the
14880152c815SPeter Collingbourne     // importing phase. For now we just pass the empty string. This does not
14890152c815SPeter Collingbourne     // impact correctness because the function names are just used for remarks.
14900152c815SPeter Collingbourne     switch (ResByArg.TheKind) {
14910152c815SPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::UniformRetVal:
14920152c815SPeter Collingbourne       applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
14930152c815SPeter Collingbourne       break;
149459675ba0SPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
149559675ba0SPeter Collingbourne       Constant *UniqueMemberAddr =
149659675ba0SPeter Collingbourne           importGlobal(Slot, CSByConstantArg.first, "unique_member");
149759675ba0SPeter Collingbourne       applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
149859675ba0SPeter Collingbourne                            UniqueMemberAddr);
149959675ba0SPeter Collingbourne       break;
150059675ba0SPeter Collingbourne     }
150114dcf02fSPeter Collingbourne     case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
1502b15a35e6SPeter Collingbourne       Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte",
1503b15a35e6SPeter Collingbourne                                       Int32Ty, ResByArg.Byte);
1504b15a35e6SPeter Collingbourne       Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty,
1505b15a35e6SPeter Collingbourne                                      ResByArg.Bit);
150614dcf02fSPeter Collingbourne       applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
15070e6694d1SAdrian Prantl       break;
150814dcf02fSPeter Collingbourne     }
15090152c815SPeter Collingbourne     default:
15100152c815SPeter Collingbourne       break;
15110152c815SPeter Collingbourne     }
15120152c815SPeter Collingbourne   }
15132974856aSPeter Collingbourne 
15142974856aSPeter Collingbourne   if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
15152974856aSPeter Collingbourne     auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
15162974856aSPeter Collingbourne                                      Type::getVoidTy(M.getContext()));
15172974856aSPeter Collingbourne     bool IsExported = false;
15182974856aSPeter Collingbourne     applyICallBranchFunnel(SlotInfo, JT, IsExported);
15192974856aSPeter Collingbourne     assert(!IsExported);
15202974856aSPeter Collingbourne   }
15216d284fabSPeter Collingbourne }
15226d284fabSPeter Collingbourne 
15236d284fabSPeter Collingbourne void DevirtModule::removeRedundantTypeTests() {
15246d284fabSPeter Collingbourne   auto True = ConstantInt::getTrue(M.getContext());
15256d284fabSPeter Collingbourne   for (auto &&U : NumUnsafeUsesForTypeTest) {
15266d284fabSPeter Collingbourne     if (U.second == 0) {
15276d284fabSPeter Collingbourne       U.first->replaceAllUsesWith(True);
15286d284fabSPeter Collingbourne       U.first->eraseFromParent();
15296d284fabSPeter Collingbourne     }
15306d284fabSPeter Collingbourne   }
15316d284fabSPeter Collingbourne }
15326d284fabSPeter Collingbourne 
15330312f614SPeter Collingbourne bool DevirtModule::run() {
15340312f614SPeter Collingbourne   Function *TypeTestFunc =
15350312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
15360312f614SPeter Collingbourne   Function *TypeCheckedLoadFunc =
15370312f614SPeter Collingbourne       M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
15380312f614SPeter Collingbourne   Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
15390312f614SPeter Collingbourne 
1540b406baaeSPeter Collingbourne   // Normally if there are no users of the devirtualization intrinsics in the
1541b406baaeSPeter Collingbourne   // module, this pass has nothing to do. But if we are exporting, we also need
1542b406baaeSPeter Collingbourne   // to handle any users that appear only in the function summaries.
1543f7691d8bSPeter Collingbourne   if (!ExportSummary &&
1544b406baaeSPeter Collingbourne       (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
15450312f614SPeter Collingbourne        AssumeFunc->use_empty()) &&
15460312f614SPeter Collingbourne       (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
15470312f614SPeter Collingbourne     return false;
15480312f614SPeter Collingbourne 
15490312f614SPeter Collingbourne   if (TypeTestFunc && AssumeFunc)
15500312f614SPeter Collingbourne     scanTypeTestUsers(TypeTestFunc, AssumeFunc);
15510312f614SPeter Collingbourne 
15520312f614SPeter Collingbourne   if (TypeCheckedLoadFunc)
15530312f614SPeter Collingbourne     scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
1554df49d1bbSPeter Collingbourne 
1555f7691d8bSPeter Collingbourne   if (ImportSummary) {
15566d284fabSPeter Collingbourne     for (auto &S : CallSlots)
15576d284fabSPeter Collingbourne       importResolution(S.first, S.second);
15586d284fabSPeter Collingbourne 
15596d284fabSPeter Collingbourne     removeRedundantTypeTests();
15606d284fabSPeter Collingbourne 
15616d284fabSPeter Collingbourne     // The rest of the code is only necessary when exporting or during regular
15626d284fabSPeter Collingbourne     // LTO, so we are done.
15636d284fabSPeter Collingbourne     return true;
15646d284fabSPeter Collingbourne   }
15656d284fabSPeter Collingbourne 
15667efd7506SPeter Collingbourne   // Rebuild type metadata into a map for easy lookup.
1567df49d1bbSPeter Collingbourne   std::vector<VTableBits> Bits;
15687efd7506SPeter Collingbourne   DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
15697efd7506SPeter Collingbourne   buildTypeIdentifierMap(Bits, TypeIdMap);
15707efd7506SPeter Collingbourne   if (TypeIdMap.empty())
1571df49d1bbSPeter Collingbourne     return true;
1572df49d1bbSPeter Collingbourne 
1573b406baaeSPeter Collingbourne   // Collect information from summary about which calls to try to devirtualize.
1574f7691d8bSPeter Collingbourne   if (ExportSummary) {
1575b406baaeSPeter Collingbourne     DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
1576b406baaeSPeter Collingbourne     for (auto &P : TypeIdMap) {
1577b406baaeSPeter Collingbourne       if (auto *TypeId = dyn_cast<MDString>(P.first))
1578b406baaeSPeter Collingbourne         MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
1579b406baaeSPeter Collingbourne             TypeId);
1580b406baaeSPeter Collingbourne     }
1581b406baaeSPeter Collingbourne 
1582f7691d8bSPeter Collingbourne     for (auto &P : *ExportSummary) {
15839667b91bSPeter Collingbourne       for (auto &S : P.second.SummaryList) {
1584b406baaeSPeter Collingbourne         auto *FS = dyn_cast<FunctionSummary>(S.get());
1585b406baaeSPeter Collingbourne         if (!FS)
1586b406baaeSPeter Collingbourne           continue;
1587b406baaeSPeter Collingbourne         // FIXME: Only add live functions.
15885d8aea10SGeorge Rimar         for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
15895d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
15902974856aSPeter Collingbourne             CallSlots[{MD, VF.Offset}]
15912974856aSPeter Collingbourne                 .CSInfo.markSummaryHasTypeTestAssumeUsers();
15925d8aea10SGeorge Rimar           }
15935d8aea10SGeorge Rimar         }
15945d8aea10SGeorge Rimar         for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
15955d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
15962974856aSPeter Collingbourne             CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
15975d8aea10SGeorge Rimar           }
15985d8aea10SGeorge Rimar         }
1599b406baaeSPeter Collingbourne         for (const FunctionSummary::ConstVCall &VC :
16005d8aea10SGeorge Rimar              FS->type_test_assume_const_vcalls()) {
16015d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
16022325bb34SPeter Collingbourne             CallSlots[{MD, VC.VFunc.Offset}]
16035d8aea10SGeorge Rimar                 .ConstCSInfo[VC.Args]
16042974856aSPeter Collingbourne                 .markSummaryHasTypeTestAssumeUsers();
16055d8aea10SGeorge Rimar           }
16065d8aea10SGeorge Rimar         }
16072325bb34SPeter Collingbourne         for (const FunctionSummary::ConstVCall &VC :
16085d8aea10SGeorge Rimar              FS->type_checked_load_const_vcalls()) {
16095d8aea10SGeorge Rimar           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
1610b406baaeSPeter Collingbourne             CallSlots[{MD, VC.VFunc.Offset}]
1611b406baaeSPeter Collingbourne                 .ConstCSInfo[VC.Args]
16122974856aSPeter Collingbourne                 .addSummaryTypeCheckedLoadUser(FS);
1613b406baaeSPeter Collingbourne           }
1614b406baaeSPeter Collingbourne         }
1615b406baaeSPeter Collingbourne       }
16165d8aea10SGeorge Rimar     }
16175d8aea10SGeorge Rimar   }
1618b406baaeSPeter Collingbourne 
16197efd7506SPeter Collingbourne   // For each (type, offset) pair:
1620df49d1bbSPeter Collingbourne   bool DidVirtualConstProp = false;
1621f3403fd2SIvan Krasin   std::map<std::string, Function*> DevirtTargets;
1622df49d1bbSPeter Collingbourne   for (auto &S : CallSlots) {
16237efd7506SPeter Collingbourne     // Search each of the members of the type identifier for the virtual
16247efd7506SPeter Collingbourne     // function implementation at offset S.first.ByteOffset, and add to
16257efd7506SPeter Collingbourne     // TargetsForSlot.
1626df49d1bbSPeter Collingbourne     std::vector<VirtualCallTarget> TargetsForSlot;
1627b406baaeSPeter Collingbourne     if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
1628b406baaeSPeter Collingbourne                                   S.first.ByteOffset)) {
16292325bb34SPeter Collingbourne       WholeProgramDevirtResolution *Res = nullptr;
1630f7691d8bSPeter Collingbourne       if (ExportSummary && isa<MDString>(S.first.TypeID))
1631f7691d8bSPeter Collingbourne         Res = &ExportSummary
16329a3f9797SPeter Collingbourne                    ->getOrInsertTypeIdSummary(
16339a3f9797SPeter Collingbourne                        cast<MDString>(S.first.TypeID)->getString())
16342325bb34SPeter Collingbourne                    .WPDRes[S.first.ByteOffset];
16352325bb34SPeter Collingbourne 
16362974856aSPeter Collingbourne       if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) {
16372974856aSPeter Collingbourne         DidVirtualConstProp |=
16382974856aSPeter Collingbourne             tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
16392974856aSPeter Collingbourne 
16402974856aSPeter Collingbourne         tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
16412974856aSPeter Collingbourne       }
1642f3403fd2SIvan Krasin 
1643f3403fd2SIvan Krasin       // Collect functions devirtualized at least for one call site for stats.
1644f3403fd2SIvan Krasin       if (RemarksEnabled)
1645f3403fd2SIvan Krasin         for (const auto &T : TargetsForSlot)
1646f3403fd2SIvan Krasin           if (T.WasDevirt)
1647f3403fd2SIvan Krasin             DevirtTargets[T.Fn->getName()] = T.Fn;
1648b05e06e4SIvan Krasin     }
1649df49d1bbSPeter Collingbourne 
1650b406baaeSPeter Collingbourne     // CFI-specific: if we are exporting and any llvm.type.checked.load
1651b406baaeSPeter Collingbourne     // intrinsics were *not* devirtualized, we need to add the resulting
1652b406baaeSPeter Collingbourne     // llvm.type.test intrinsics to the function summaries so that the
1653b406baaeSPeter Collingbourne     // LowerTypeTests pass will export them.
1654f7691d8bSPeter Collingbourne     if (ExportSummary && isa<MDString>(S.first.TypeID)) {
1655b406baaeSPeter Collingbourne       auto GUID =
1656b406baaeSPeter Collingbourne           GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
1657b406baaeSPeter Collingbourne       for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
1658b406baaeSPeter Collingbourne         FS->addTypeTest(GUID);
1659b406baaeSPeter Collingbourne       for (auto &CCS : S.second.ConstCSInfo)
1660b406baaeSPeter Collingbourne         for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers)
1661b406baaeSPeter Collingbourne           FS->addTypeTest(GUID);
1662b406baaeSPeter Collingbourne     }
1663b406baaeSPeter Collingbourne   }
1664b406baaeSPeter Collingbourne 
1665f3403fd2SIvan Krasin   if (RemarksEnabled) {
1666f3403fd2SIvan Krasin     // Generate remarks for each devirtualized function.
1667f3403fd2SIvan Krasin     for (const auto &DT : DevirtTargets) {
1668f3403fd2SIvan Krasin       Function *F = DT.second;
1669e963c89dSSam Elliott 
1670e963c89dSSam Elliott       using namespace ore;
16719110cb45SPeter Collingbourne       OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
16729110cb45SPeter Collingbourne                         << "devirtualized "
16739110cb45SPeter Collingbourne                         << NV("FunctionName", F->getName()));
1674b05e06e4SIvan Krasin     }
1675df49d1bbSPeter Collingbourne   }
1676df49d1bbSPeter Collingbourne 
16776d284fabSPeter Collingbourne   removeRedundantTypeTests();
16780312f614SPeter Collingbourne 
1679df49d1bbSPeter Collingbourne   // Rebuild each global we touched as part of virtual constant propagation to
1680df49d1bbSPeter Collingbourne   // include the before and after bytes.
1681df49d1bbSPeter Collingbourne   if (DidVirtualConstProp)
1682df49d1bbSPeter Collingbourne     for (VTableBits &B : Bits)
1683df49d1bbSPeter Collingbourne       rebuildGlobal(B);
1684df49d1bbSPeter Collingbourne 
1685df49d1bbSPeter Collingbourne   return true;
1686df49d1bbSPeter Collingbourne }
1687