1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass implements whole program optimization of virtual calls in cases
11 // where we know (via !type metadata) that the list of callees is fixed. This
12 // includes the following:
13 // - Single implementation devirtualization: if a virtual call has a single
14 //   possible callee, replace all calls with a direct call to that callee.
15 // - Virtual constant propagation: if the virtual function's return type is an
16 //   integer <=64 bits and all possible callees are readnone, for each class and
17 //   each list of constant arguments: evaluate the function, store the return
18 //   value alongside the virtual table, and rewrite each virtual call as a load
19 //   from the virtual table.
20 // - Uniform return value optimization: if the conditions for virtual constant
21 //   propagation hold and each function returns the same constant value, replace
22 //   each virtual call with that constant.
23 // - Unique return value optimization for i1 return values: if the conditions
24 //   for virtual constant propagation hold and a single vtable's function
25 //   returns 0, or a single vtable's function returns 1, replace each virtual
26 //   call with a comparison of the vptr against that vtable's address.
27 //
28 //===----------------------------------------------------------------------===//
29 
30 #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/DenseMap.h"
33 #include "llvm/ADT/DenseMapInfo.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/iterator_range.h"
36 #include "llvm/ADT/MapVector.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/Analysis/AliasAnalysis.h"
39 #include "llvm/Analysis/BasicAliasAnalysis.h"
40 #include "llvm/Analysis/TypeMetadataUtils.h"
41 #include "llvm/IR/CallSite.h"
42 #include "llvm/IR/Constants.h"
43 #include "llvm/IR/DataLayout.h"
44 #include "llvm/IR/DebugInfoMetadata.h"
45 #include "llvm/IR/DebugLoc.h"
46 #include "llvm/IR/DerivedTypes.h"
47 #include "llvm/IR/DiagnosticInfo.h"
48 #include "llvm/IR/Function.h"
49 #include "llvm/IR/GlobalAlias.h"
50 #include "llvm/IR/GlobalVariable.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/InstrTypes.h"
53 #include "llvm/IR/Instruction.h"
54 #include "llvm/IR/Instructions.h"
55 #include "llvm/IR/Intrinsics.h"
56 #include "llvm/IR/LLVMContext.h"
57 #include "llvm/IR/Metadata.h"
58 #include "llvm/IR/Module.h"
59 #include "llvm/IR/ModuleSummaryIndexYAML.h"
60 #include "llvm/Pass.h"
61 #include "llvm/PassRegistry.h"
62 #include "llvm/PassSupport.h"
63 #include "llvm/Support/Casting.h"
64 #include "llvm/Support/Error.h"
65 #include "llvm/Support/FileSystem.h"
66 #include "llvm/Support/MathExtras.h"
67 #include "llvm/Transforms/IPO.h"
68 #include "llvm/Transforms/IPO/FunctionAttrs.h"
69 #include "llvm/Transforms/Utils/Evaluator.h"
70 #include <algorithm>
71 #include <cstddef>
72 #include <map>
73 #include <set>
74 #include <string>
75 
76 using namespace llvm;
77 using namespace wholeprogramdevirt;
78 
79 #define DEBUG_TYPE "wholeprogramdevirt"
80 
81 static cl::opt<PassSummaryAction> ClSummaryAction(
82     "wholeprogramdevirt-summary-action",
83     cl::desc("What to do with the summary when running this pass"),
84     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
85                clEnumValN(PassSummaryAction::Import, "import",
86                           "Import typeid resolutions from summary and globals"),
87                clEnumValN(PassSummaryAction::Export, "export",
88                           "Export typeid resolutions to summary and globals")),
89     cl::Hidden);
90 
91 static cl::opt<std::string> ClReadSummary(
92     "wholeprogramdevirt-read-summary",
93     cl::desc("Read summary from given YAML file before running pass"),
94     cl::Hidden);
95 
96 static cl::opt<std::string> ClWriteSummary(
97     "wholeprogramdevirt-write-summary",
98     cl::desc("Write summary to given YAML file after running pass"),
99     cl::Hidden);
100 
101 // Find the minimum offset that we may store a value of size Size bits at. If
102 // IsAfter is set, look for an offset before the object, otherwise look for an
103 // offset after the object.
104 uint64_t
105 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
106                                      bool IsAfter, uint64_t Size) {
107   // Find a minimum offset taking into account only vtable sizes.
108   uint64_t MinByte = 0;
109   for (const VirtualCallTarget &Target : Targets) {
110     if (IsAfter)
111       MinByte = std::max(MinByte, Target.minAfterBytes());
112     else
113       MinByte = std::max(MinByte, Target.minBeforeBytes());
114   }
115 
116   // Build a vector of arrays of bytes covering, for each target, a slice of the
117   // used region (see AccumBitVector::BytesUsed in
118   // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
119   // this aligns the used regions to start at MinByte.
120   //
121   // In this example, A, B and C are vtables, # is a byte already allocated for
122   // a virtual function pointer, AAAA... (etc.) are the used regions for the
123   // vtables and Offset(X) is the value computed for the Offset variable below
124   // for X.
125   //
126   //                    Offset(A)
127   //                    |       |
128   //                            |MinByte
129   // A: ################AAAAAAAA|AAAAAAAA
130   // B: ########BBBBBBBBBBBBBBBB|BBBB
131   // C: ########################|CCCCCCCCCCCCCCCC
132   //            |   Offset(B)   |
133   //
134   // This code produces the slices of A, B and C that appear after the divider
135   // at MinByte.
136   std::vector<ArrayRef<uint8_t>> Used;
137   for (const VirtualCallTarget &Target : Targets) {
138     ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
139                                        : Target.TM->Bits->Before.BytesUsed;
140     uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
141                               : MinByte - Target.minBeforeBytes();
142 
143     // Disregard used regions that are smaller than Offset. These are
144     // effectively all-free regions that do not need to be checked.
145     if (VTUsed.size() > Offset)
146       Used.push_back(VTUsed.slice(Offset));
147   }
148 
149   if (Size == 1) {
150     // Find a free bit in each member of Used.
151     for (unsigned I = 0;; ++I) {
152       uint8_t BitsUsed = 0;
153       for (auto &&B : Used)
154         if (I < B.size())
155           BitsUsed |= B[I];
156       if (BitsUsed != 0xff)
157         return (MinByte + I) * 8 +
158                countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
159     }
160   } else {
161     // Find a free (Size/8) byte region in each member of Used.
162     // FIXME: see if alignment helps.
163     for (unsigned I = 0;; ++I) {
164       for (auto &&B : Used) {
165         unsigned Byte = 0;
166         while ((I + Byte) < B.size() && Byte < (Size / 8)) {
167           if (B[I + Byte])
168             goto NextI;
169           ++Byte;
170         }
171       }
172       return (MinByte + I) * 8;
173     NextI:;
174     }
175   }
176 }
177 
178 void wholeprogramdevirt::setBeforeReturnValues(
179     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
180     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
181   if (BitWidth == 1)
182     OffsetByte = -(AllocBefore / 8 + 1);
183   else
184     OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
185   OffsetBit = AllocBefore % 8;
186 
187   for (VirtualCallTarget &Target : Targets) {
188     if (BitWidth == 1)
189       Target.setBeforeBit(AllocBefore);
190     else
191       Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
192   }
193 }
194 
195 void wholeprogramdevirt::setAfterReturnValues(
196     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
197     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
198   if (BitWidth == 1)
199     OffsetByte = AllocAfter / 8;
200   else
201     OffsetByte = (AllocAfter + 7) / 8;
202   OffsetBit = AllocAfter % 8;
203 
204   for (VirtualCallTarget &Target : Targets) {
205     if (BitWidth == 1)
206       Target.setAfterBit(AllocAfter);
207     else
208       Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
209   }
210 }
211 
212 VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
213     : Fn(Fn), TM(TM),
214       IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
215 
216 namespace {
217 
218 // A slot in a set of virtual tables. The TypeID identifies the set of virtual
219 // tables, and the ByteOffset is the offset in bytes from the address point to
220 // the virtual function pointer.
221 struct VTableSlot {
222   Metadata *TypeID;
223   uint64_t ByteOffset;
224 };
225 
226 } // end anonymous namespace
227 
228 namespace llvm {
229 
230 template <> struct DenseMapInfo<VTableSlot> {
231   static VTableSlot getEmptyKey() {
232     return {DenseMapInfo<Metadata *>::getEmptyKey(),
233             DenseMapInfo<uint64_t>::getEmptyKey()};
234   }
235   static VTableSlot getTombstoneKey() {
236     return {DenseMapInfo<Metadata *>::getTombstoneKey(),
237             DenseMapInfo<uint64_t>::getTombstoneKey()};
238   }
239   static unsigned getHashValue(const VTableSlot &I) {
240     return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
241            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
242   }
243   static bool isEqual(const VTableSlot &LHS,
244                       const VTableSlot &RHS) {
245     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
246   }
247 };
248 
249 } // end namespace llvm
250 
251 namespace {
252 
253 // A virtual call site. VTable is the loaded virtual table pointer, and CS is
254 // the indirect virtual call.
255 struct VirtualCallSite {
256   Value *VTable;
257   CallSite CS;
258 
259   // If non-null, this field points to the associated unsafe use count stored in
260   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
261   // of that field for details.
262   unsigned *NumUnsafeUses;
263 
264   void emitRemark(const Twine &OptName, const Twine &TargetName) {
265     Function *F = CS.getCaller();
266     emitOptimizationRemark(
267         F->getContext(), DEBUG_TYPE, *F,
268         CS.getInstruction()->getDebugLoc(),
269         OptName + ": devirtualized a call to " + TargetName);
270   }
271 
272   void replaceAndErase(const Twine &OptName, const Twine &TargetName,
273                        bool RemarksEnabled, Value *New) {
274     if (RemarksEnabled)
275       emitRemark(OptName, TargetName);
276     CS->replaceAllUsesWith(New);
277     if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
278       BranchInst::Create(II->getNormalDest(), CS.getInstruction());
279       II->getUnwindDest()->removePredecessor(II->getParent());
280     }
281     CS->eraseFromParent();
282     // This use is no longer unsafe.
283     if (NumUnsafeUses)
284       --*NumUnsafeUses;
285   }
286 };
287 
288 // Call site information collected for a specific VTableSlot and possibly a list
289 // of constant integer arguments. The grouping by arguments is handled by the
290 // VTableSlotInfo class.
291 struct CallSiteInfo {
292   std::vector<VirtualCallSite> CallSites;
293 };
294 
295 // Call site information collected for a specific VTableSlot.
296 struct VTableSlotInfo {
297   // The set of call sites which do not have all constant integer arguments
298   // (excluding "this").
299   CallSiteInfo CSInfo;
300 
301   // The set of call sites with all constant integer arguments (excluding
302   // "this"), grouped by argument list.
303   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
304 
305   void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
306 
307 private:
308   CallSiteInfo &findCallSiteInfo(CallSite CS);
309 };
310 
311 CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
312   std::vector<uint64_t> Args;
313   auto *CI = dyn_cast<IntegerType>(CS.getType());
314   if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
315     return CSInfo;
316   for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
317     auto *CI = dyn_cast<ConstantInt>(Arg);
318     if (!CI || CI->getBitWidth() > 64)
319       return CSInfo;
320     Args.push_back(CI->getZExtValue());
321   }
322   return ConstCSInfo[Args];
323 }
324 
325 void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
326                                  unsigned *NumUnsafeUses) {
327   findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
328 }
329 
330 struct DevirtModule {
331   Module &M;
332   function_ref<AAResults &(Function &)> AARGetter;
333 
334   PassSummaryAction Action;
335   ModuleSummaryIndex *Summary;
336 
337   IntegerType *Int8Ty;
338   PointerType *Int8PtrTy;
339   IntegerType *Int32Ty;
340   IntegerType *Int64Ty;
341 
342   bool RemarksEnabled;
343 
344   MapVector<VTableSlot, VTableSlotInfo> CallSlots;
345 
346   // This map keeps track of the number of "unsafe" uses of a loaded function
347   // pointer. The key is the associated llvm.type.test intrinsic call generated
348   // by this pass. An unsafe use is one that calls the loaded function pointer
349   // directly. Every time we eliminate an unsafe use (for example, by
350   // devirtualizing it or by applying virtual constant propagation), we
351   // decrement the value stored in this map. If a value reaches zero, we can
352   // eliminate the type check by RAUWing the associated llvm.type.test call with
353   // true.
354   std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
355 
356   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
357                PassSummaryAction Action, ModuleSummaryIndex *Summary)
358       : M(M), AARGetter(AARGetter), Action(Action), Summary(Summary),
359         Int8Ty(Type::getInt8Ty(M.getContext())),
360         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
361         Int32Ty(Type::getInt32Ty(M.getContext())),
362         Int64Ty(Type::getInt64Ty(M.getContext())),
363         RemarksEnabled(areRemarksEnabled()) {}
364 
365   bool areRemarksEnabled();
366 
367   void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
368   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
369 
370   void buildTypeIdentifierMap(
371       std::vector<VTableBits> &Bits,
372       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
373   Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
374   bool
375   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
376                             const std::set<TypeMemberInfo> &TypeMemberInfos,
377                             uint64_t ByteOffset);
378 
379   void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn);
380   bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
381                            VTableSlotInfo &SlotInfo);
382 
383   bool tryEvaluateFunctionsWithArgs(
384       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
385       ArrayRef<uint64_t> Args);
386 
387   void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
388                              uint64_t TheRetVal);
389   bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
390                            CallSiteInfo &CSInfo);
391 
392   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
393                             Constant *UniqueMemberAddr);
394   bool tryUniqueRetValOpt(unsigned BitWidth,
395                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
396                           CallSiteInfo &CSInfo);
397 
398   void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
399                              Constant *Byte, Constant *Bit);
400   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
401                            VTableSlotInfo &SlotInfo);
402 
403   void rebuildGlobal(VTableBits &B);
404 
405   bool run();
406 
407   // Lower the module using the action and summary passed as command line
408   // arguments. For testing purposes only.
409   static bool runForTesting(Module &M,
410                             function_ref<AAResults &(Function &)> AARGetter);
411 };
412 
413 struct WholeProgramDevirt : public ModulePass {
414   static char ID;
415 
416   bool UseCommandLine = false;
417 
418   PassSummaryAction Action;
419   ModuleSummaryIndex *Summary;
420 
421   WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
422     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
423   }
424 
425   WholeProgramDevirt(PassSummaryAction Action, ModuleSummaryIndex *Summary)
426       : ModulePass(ID), Action(Action), Summary(Summary) {
427     initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
428   }
429 
430   bool runOnModule(Module &M) override {
431     if (skipModule(M))
432       return false;
433     if (UseCommandLine)
434       return DevirtModule::runForTesting(M, LegacyAARGetter(*this));
435     return DevirtModule(M, LegacyAARGetter(*this), Action, Summary).run();
436   }
437 
438   void getAnalysisUsage(AnalysisUsage &AU) const override {
439     AU.addRequired<AssumptionCacheTracker>();
440     AU.addRequired<TargetLibraryInfoWrapperPass>();
441   }
442 };
443 
444 } // end anonymous namespace
445 
446 INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
447                       "Whole program devirtualization", false, false)
448 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
449 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
450 INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
451                     "Whole program devirtualization", false, false)
452 char WholeProgramDevirt::ID = 0;
453 
454 ModulePass *llvm::createWholeProgramDevirtPass(PassSummaryAction Action,
455                                                ModuleSummaryIndex *Summary) {
456   return new WholeProgramDevirt(Action, Summary);
457 }
458 
459 PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
460                                               ModuleAnalysisManager &AM) {
461   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
462   auto AARGetter = [&](Function &F) -> AAResults & {
463     return FAM.getResult<AAManager>(F);
464   };
465   if (!DevirtModule(M, AARGetter, PassSummaryAction::None, nullptr).run())
466     return PreservedAnalyses::all();
467   return PreservedAnalyses::none();
468 }
469 
470 bool DevirtModule::runForTesting(
471     Module &M, function_ref<AAResults &(Function &)> AARGetter) {
472   ModuleSummaryIndex Summary;
473 
474   // Handle the command-line summary arguments. This code is for testing
475   // purposes only, so we handle errors directly.
476   if (!ClReadSummary.empty()) {
477     ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
478                           ": ");
479     auto ReadSummaryFile =
480         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
481 
482     yaml::Input In(ReadSummaryFile->getBuffer());
483     In >> Summary;
484     ExitOnErr(errorCodeToError(In.error()));
485   }
486 
487   bool Changed = DevirtModule(M, AARGetter, ClSummaryAction, &Summary).run();
488 
489   if (!ClWriteSummary.empty()) {
490     ExitOnError ExitOnErr(
491         "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
492     std::error_code EC;
493     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
494     ExitOnErr(errorCodeToError(EC));
495 
496     yaml::Output Out(OS);
497     Out << Summary;
498   }
499 
500   return Changed;
501 }
502 
503 void DevirtModule::buildTypeIdentifierMap(
504     std::vector<VTableBits> &Bits,
505     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
506   DenseMap<GlobalVariable *, VTableBits *> GVToBits;
507   Bits.reserve(M.getGlobalList().size());
508   SmallVector<MDNode *, 2> Types;
509   for (GlobalVariable &GV : M.globals()) {
510     Types.clear();
511     GV.getMetadata(LLVMContext::MD_type, Types);
512     if (Types.empty())
513       continue;
514 
515     VTableBits *&BitsPtr = GVToBits[&GV];
516     if (!BitsPtr) {
517       Bits.emplace_back();
518       Bits.back().GV = &GV;
519       Bits.back().ObjectSize =
520           M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
521       BitsPtr = &Bits.back();
522     }
523 
524     for (MDNode *Type : Types) {
525       auto TypeID = Type->getOperand(1).get();
526 
527       uint64_t Offset =
528           cast<ConstantInt>(
529               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
530               ->getZExtValue();
531 
532       TypeIdMap[TypeID].insert({BitsPtr, Offset});
533     }
534   }
535 }
536 
537 Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
538   if (I->getType()->isPointerTy()) {
539     if (Offset == 0)
540       return I;
541     return nullptr;
542   }
543 
544   const DataLayout &DL = M.getDataLayout();
545 
546   if (auto *C = dyn_cast<ConstantStruct>(I)) {
547     const StructLayout *SL = DL.getStructLayout(C->getType());
548     if (Offset >= SL->getSizeInBytes())
549       return nullptr;
550 
551     unsigned Op = SL->getElementContainingOffset(Offset);
552     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
553                               Offset - SL->getElementOffset(Op));
554   }
555   if (auto *C = dyn_cast<ConstantArray>(I)) {
556     ArrayType *VTableTy = C->getType();
557     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
558 
559     unsigned Op = Offset / ElemSize;
560     if (Op >= C->getNumOperands())
561       return nullptr;
562 
563     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
564                               Offset % ElemSize);
565   }
566   return nullptr;
567 }
568 
569 bool DevirtModule::tryFindVirtualCallTargets(
570     std::vector<VirtualCallTarget> &TargetsForSlot,
571     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
572   for (const TypeMemberInfo &TM : TypeMemberInfos) {
573     if (!TM.Bits->GV->isConstant())
574       return false;
575 
576     Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
577                                        TM.Offset + ByteOffset);
578     if (!Ptr)
579       return false;
580 
581     auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
582     if (!Fn)
583       return false;
584 
585     // We can disregard __cxa_pure_virtual as a possible call target, as
586     // calls to pure virtuals are UB.
587     if (Fn->getName() == "__cxa_pure_virtual")
588       continue;
589 
590     TargetsForSlot.push_back({Fn, &TM});
591   }
592 
593   // Give up if we couldn't find any targets.
594   return !TargetsForSlot.empty();
595 }
596 
597 void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
598                                          Constant *TheFn) {
599   auto Apply = [&](CallSiteInfo &CSInfo) {
600     for (auto &&VCallSite : CSInfo.CallSites) {
601       if (RemarksEnabled)
602         VCallSite.emitRemark("single-impl", TheFn->getName());
603       VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
604           TheFn, VCallSite.CS.getCalledValue()->getType()));
605       // This use is no longer unsafe.
606       if (VCallSite.NumUnsafeUses)
607         --*VCallSite.NumUnsafeUses;
608     }
609   };
610   Apply(SlotInfo.CSInfo);
611   for (auto &P : SlotInfo.ConstCSInfo)
612     Apply(P.second);
613 }
614 
615 bool DevirtModule::trySingleImplDevirt(
616     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
617     VTableSlotInfo &SlotInfo) {
618   // See if the program contains a single implementation of this virtual
619   // function.
620   Function *TheFn = TargetsForSlot[0].Fn;
621   for (auto &&Target : TargetsForSlot)
622     if (TheFn != Target.Fn)
623       return false;
624 
625   // If so, update each call site to call that implementation directly.
626   if (RemarksEnabled)
627     TargetsForSlot[0].WasDevirt = true;
628   applySingleImplDevirt(SlotInfo, TheFn);
629   return true;
630 }
631 
632 bool DevirtModule::tryEvaluateFunctionsWithArgs(
633     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
634     ArrayRef<uint64_t> Args) {
635   // Evaluate each function and store the result in each target's RetVal
636   // field.
637   for (VirtualCallTarget &Target : TargetsForSlot) {
638     if (Target.Fn->arg_size() != Args.size() + 1)
639       return false;
640 
641     Evaluator Eval(M.getDataLayout(), nullptr);
642     SmallVector<Constant *, 2> EvalArgs;
643     EvalArgs.push_back(
644         Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
645     for (unsigned I = 0; I != Args.size(); ++I) {
646       auto *ArgTy = dyn_cast<IntegerType>(
647           Target.Fn->getFunctionType()->getParamType(I + 1));
648       if (!ArgTy)
649         return false;
650       EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
651     }
652 
653     Constant *RetVal;
654     if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
655         !isa<ConstantInt>(RetVal))
656       return false;
657     Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
658   }
659   return true;
660 }
661 
662 void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
663                                          uint64_t TheRetVal) {
664   for (auto Call : CSInfo.CallSites)
665     Call.replaceAndErase(
666         "uniform-ret-val", FnName, RemarksEnabled,
667         ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
668 }
669 
670 bool DevirtModule::tryUniformRetValOpt(
671     MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo) {
672   // Uniform return value optimization. If all functions return the same
673   // constant, replace all calls with that constant.
674   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
675   for (const VirtualCallTarget &Target : TargetsForSlot)
676     if (Target.RetVal != TheRetVal)
677       return false;
678 
679   applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
680   if (RemarksEnabled)
681     for (auto &&Target : TargetsForSlot)
682       Target.WasDevirt = true;
683   return true;
684 }
685 
686 void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
687                                         bool IsOne,
688                                         Constant *UniqueMemberAddr) {
689   for (auto &&Call : CSInfo.CallSites) {
690     IRBuilder<> B(Call.CS.getInstruction());
691     Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
692                               Call.VTable, UniqueMemberAddr);
693     Cmp = B.CreateZExt(Cmp, Call.CS->getType());
694     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp);
695   }
696 }
697 
698 bool DevirtModule::tryUniqueRetValOpt(
699     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
700     CallSiteInfo &CSInfo) {
701   // IsOne controls whether we look for a 0 or a 1.
702   auto tryUniqueRetValOptFor = [&](bool IsOne) {
703     const TypeMemberInfo *UniqueMember = nullptr;
704     for (const VirtualCallTarget &Target : TargetsForSlot) {
705       if (Target.RetVal == (IsOne ? 1 : 0)) {
706         if (UniqueMember)
707           return false;
708         UniqueMember = Target.TM;
709       }
710     }
711 
712     // We should have found a unique member or bailed out by now. We already
713     // checked for a uniform return value in tryUniformRetValOpt.
714     assert(UniqueMember);
715 
716     // Replace each call with the comparison.
717     Constant *UniqueMemberAddr =
718         ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
719     UniqueMemberAddr = ConstantExpr::getGetElementPtr(
720         Int8Ty, UniqueMemberAddr,
721         ConstantInt::get(Int64Ty, UniqueMember->Offset));
722 
723     applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
724                          UniqueMemberAddr);
725 
726     // Update devirtualization statistics for targets.
727     if (RemarksEnabled)
728       for (auto &&Target : TargetsForSlot)
729         Target.WasDevirt = true;
730 
731     return true;
732   };
733 
734   if (BitWidth == 1) {
735     if (tryUniqueRetValOptFor(true))
736       return true;
737     if (tryUniqueRetValOptFor(false))
738       return true;
739   }
740   return false;
741 }
742 
743 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
744                                          Constant *Byte, Constant *Bit) {
745   for (auto Call : CSInfo.CallSites) {
746     auto *RetType = cast<IntegerType>(Call.CS.getType());
747     IRBuilder<> B(Call.CS.getInstruction());
748     Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
749     if (RetType->getBitWidth() == 1) {
750       Value *Bits = B.CreateLoad(Addr);
751       Value *BitsAndBit = B.CreateAnd(Bits, Bit);
752       auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
753       Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
754                            IsBitSet);
755     } else {
756       Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
757       Value *Val = B.CreateLoad(RetType, ValAddr);
758       Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val);
759     }
760   }
761 }
762 
763 bool DevirtModule::tryVirtualConstProp(
764     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
765     VTableSlotInfo &SlotInfo) {
766   // This only works if the function returns an integer.
767   auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
768   if (!RetType)
769     return false;
770   unsigned BitWidth = RetType->getBitWidth();
771   if (BitWidth > 64)
772     return false;
773 
774   // Make sure that each function is defined, does not access memory, takes at
775   // least one argument, does not use its first argument (which we assume is
776   // 'this'), and has the same return type.
777   //
778   // Note that we test whether this copy of the function is readnone, rather
779   // than testing function attributes, which must hold for any copy of the
780   // function, even a less optimized version substituted at link time. This is
781   // sound because the virtual constant propagation optimizations effectively
782   // inline all implementations of the virtual function into each call site,
783   // rather than using function attributes to perform local optimization.
784   for (VirtualCallTarget &Target : TargetsForSlot) {
785     if (Target.Fn->isDeclaration() ||
786         computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
787             MAK_ReadNone ||
788         Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
789         Target.Fn->getReturnType() != RetType)
790       return false;
791   }
792 
793   for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
794     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
795       continue;
796 
797     if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second))
798       continue;
799 
800     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second))
801       continue;
802 
803     // Find an allocation offset in bits in all vtables associated with the
804     // type.
805     uint64_t AllocBefore =
806         findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
807     uint64_t AllocAfter =
808         findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
809 
810     // Calculate the total amount of padding needed to store a value at both
811     // ends of the object.
812     uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
813     for (auto &&Target : TargetsForSlot) {
814       TotalPaddingBefore += std::max<int64_t>(
815           (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
816       TotalPaddingAfter += std::max<int64_t>(
817           (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
818     }
819 
820     // If the amount of padding is too large, give up.
821     // FIXME: do something smarter here.
822     if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
823       continue;
824 
825     // Calculate the offset to the value as a (possibly negative) byte offset
826     // and (if applicable) a bit offset, and store the values in the targets.
827     int64_t OffsetByte;
828     uint64_t OffsetBit;
829     if (TotalPaddingBefore <= TotalPaddingAfter)
830       setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
831                             OffsetBit);
832     else
833       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
834                            OffsetBit);
835 
836     if (RemarksEnabled)
837       for (auto &&Target : TargetsForSlot)
838         Target.WasDevirt = true;
839 
840     // Rewrite each call to a load from OffsetByte/OffsetBit.
841     Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
842     Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
843     applyVirtualConstProp(CSByConstantArg.second,
844                           TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
845   }
846   return true;
847 }
848 
849 void DevirtModule::rebuildGlobal(VTableBits &B) {
850   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
851     return;
852 
853   // Align each byte array to pointer width.
854   unsigned PointerSize = M.getDataLayout().getPointerSize();
855   B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
856   B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
857 
858   // Before was stored in reverse order; flip it now.
859   for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
860     std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
861 
862   // Build an anonymous global containing the before bytes, followed by the
863   // original initializer, followed by the after bytes.
864   auto NewInit = ConstantStruct::getAnon(
865       {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
866        B.GV->getInitializer(),
867        ConstantDataArray::get(M.getContext(), B.After.Bytes)});
868   auto NewGV =
869       new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
870                          GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
871   NewGV->setSection(B.GV->getSection());
872   NewGV->setComdat(B.GV->getComdat());
873 
874   // Copy the original vtable's metadata to the anonymous global, adjusting
875   // offsets as required.
876   NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
877 
878   // Build an alias named after the original global, pointing at the second
879   // element (the original initializer).
880   auto Alias = GlobalAlias::create(
881       B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
882       ConstantExpr::getGetElementPtr(
883           NewInit->getType(), NewGV,
884           ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
885                                ConstantInt::get(Int32Ty, 1)}),
886       &M);
887   Alias->setVisibility(B.GV->getVisibility());
888   Alias->takeName(B.GV);
889 
890   B.GV->replaceAllUsesWith(Alias);
891   B.GV->eraseFromParent();
892 }
893 
894 bool DevirtModule::areRemarksEnabled() {
895   const auto &FL = M.getFunctionList();
896   if (FL.empty())
897     return false;
898   const Function &Fn = FL.front();
899 
900   const auto &BBL = Fn.getBasicBlockList();
901   if (BBL.empty())
902     return false;
903   auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());
904   return DI.isEnabled();
905 }
906 
907 void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
908                                      Function *AssumeFunc) {
909   // Find all virtual calls via a virtual table pointer %p under an assumption
910   // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
911   // points to a member of the type identifier %md. Group calls by (type ID,
912   // offset) pair (effectively the identity of the virtual function) and store
913   // to CallSlots.
914   DenseSet<Value *> SeenPtrs;
915   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
916        I != E;) {
917     auto CI = dyn_cast<CallInst>(I->getUser());
918     ++I;
919     if (!CI)
920       continue;
921 
922     // Search for virtual calls based on %p and add them to DevirtCalls.
923     SmallVector<DevirtCallSite, 1> DevirtCalls;
924     SmallVector<CallInst *, 1> Assumes;
925     findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
926 
927     // If we found any, add them to CallSlots. Only do this if we haven't seen
928     // the vtable pointer before, as it may have been CSE'd with pointers from
929     // other call sites, and we don't want to process call sites multiple times.
930     if (!Assumes.empty()) {
931       Metadata *TypeId =
932           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
933       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
934       if (SeenPtrs.insert(Ptr).second) {
935         for (DevirtCallSite Call : DevirtCalls) {
936           CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
937                                                        Call.CS, nullptr);
938         }
939       }
940     }
941 
942     // We no longer need the assumes or the type test.
943     for (auto Assume : Assumes)
944       Assume->eraseFromParent();
945     // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
946     // may use the vtable argument later.
947     if (CI->use_empty())
948       CI->eraseFromParent();
949   }
950 }
951 
952 void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
953   Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
954 
955   for (auto I = TypeCheckedLoadFunc->use_begin(),
956             E = TypeCheckedLoadFunc->use_end();
957        I != E;) {
958     auto CI = dyn_cast<CallInst>(I->getUser());
959     ++I;
960     if (!CI)
961       continue;
962 
963     Value *Ptr = CI->getArgOperand(0);
964     Value *Offset = CI->getArgOperand(1);
965     Value *TypeIdValue = CI->getArgOperand(2);
966     Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
967 
968     SmallVector<DevirtCallSite, 1> DevirtCalls;
969     SmallVector<Instruction *, 1> LoadedPtrs;
970     SmallVector<Instruction *, 1> Preds;
971     bool HasNonCallUses = false;
972     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
973                                                HasNonCallUses, CI);
974 
975     // Start by generating "pessimistic" code that explicitly loads the function
976     // pointer from the vtable and performs the type check. If possible, we will
977     // eliminate the load and the type check later.
978 
979     // If possible, only generate the load at the point where it is used.
980     // This helps avoid unnecessary spills.
981     IRBuilder<> LoadB(
982         (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
983     Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
984     Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
985     Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
986 
987     for (Instruction *LoadedPtr : LoadedPtrs) {
988       LoadedPtr->replaceAllUsesWith(LoadedValue);
989       LoadedPtr->eraseFromParent();
990     }
991 
992     // Likewise for the type test.
993     IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
994     CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
995 
996     for (Instruction *Pred : Preds) {
997       Pred->replaceAllUsesWith(TypeTestCall);
998       Pred->eraseFromParent();
999     }
1000 
1001     // We have already erased any extractvalue instructions that refer to the
1002     // intrinsic call, but the intrinsic may have other non-extractvalue uses
1003     // (although this is unlikely). In that case, explicitly build a pair and
1004     // RAUW it.
1005     if (!CI->use_empty()) {
1006       Value *Pair = UndefValue::get(CI->getType());
1007       IRBuilder<> B(CI);
1008       Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
1009       Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
1010       CI->replaceAllUsesWith(Pair);
1011     }
1012 
1013     // The number of unsafe uses is initially the number of uses.
1014     auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
1015     NumUnsafeUses = DevirtCalls.size();
1016 
1017     // If the function pointer has a non-call user, we cannot eliminate the type
1018     // check, as one of those users may eventually call the pointer. Increment
1019     // the unsafe use count to make sure it cannot reach zero.
1020     if (HasNonCallUses)
1021       ++NumUnsafeUses;
1022     for (DevirtCallSite Call : DevirtCalls) {
1023       CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
1024                                                    &NumUnsafeUses);
1025     }
1026 
1027     CI->eraseFromParent();
1028   }
1029 }
1030 
1031 bool DevirtModule::run() {
1032   Function *TypeTestFunc =
1033       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
1034   Function *TypeCheckedLoadFunc =
1035       M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
1036   Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
1037 
1038   if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
1039        AssumeFunc->use_empty()) &&
1040       (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
1041     return false;
1042 
1043   if (TypeTestFunc && AssumeFunc)
1044     scanTypeTestUsers(TypeTestFunc, AssumeFunc);
1045 
1046   if (TypeCheckedLoadFunc)
1047     scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
1048 
1049   // Rebuild type metadata into a map for easy lookup.
1050   std::vector<VTableBits> Bits;
1051   DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
1052   buildTypeIdentifierMap(Bits, TypeIdMap);
1053   if (TypeIdMap.empty())
1054     return true;
1055 
1056   // For each (type, offset) pair:
1057   bool DidVirtualConstProp = false;
1058   std::map<std::string, Function*> DevirtTargets;
1059   for (auto &S : CallSlots) {
1060     // Search each of the members of the type identifier for the virtual
1061     // function implementation at offset S.first.ByteOffset, and add to
1062     // TargetsForSlot.
1063     std::vector<VirtualCallTarget> TargetsForSlot;
1064     if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
1065                                    S.first.ByteOffset))
1066       continue;
1067 
1068     if (!trySingleImplDevirt(TargetsForSlot, S.second) &&
1069         tryVirtualConstProp(TargetsForSlot, S.second))
1070         DidVirtualConstProp = true;
1071 
1072     // Collect functions devirtualized at least for one call site for stats.
1073     if (RemarksEnabled)
1074       for (const auto &T : TargetsForSlot)
1075         if (T.WasDevirt)
1076           DevirtTargets[T.Fn->getName()] = T.Fn;
1077   }
1078 
1079   if (RemarksEnabled) {
1080     // Generate remarks for each devirtualized function.
1081     for (const auto &DT : DevirtTargets) {
1082       Function *F = DT.second;
1083       DISubprogram *SP = F->getSubprogram();
1084       emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP,
1085                              Twine("devirtualized ") + F->getName());
1086     }
1087   }
1088 
1089   // If we were able to eliminate all unsafe uses for a type checked load,
1090   // eliminate the type test by replacing it with true.
1091   if (TypeCheckedLoadFunc) {
1092     auto True = ConstantInt::getTrue(M.getContext());
1093     for (auto &&U : NumUnsafeUsesForTypeTest) {
1094       if (U.second == 0) {
1095         U.first->replaceAllUsesWith(True);
1096         U.first->eraseFromParent();
1097       }
1098     }
1099   }
1100 
1101   // Rebuild each global we touched as part of virtual constant propagation to
1102   // include the before and after bytes.
1103   if (DidVirtualConstProp)
1104     for (VTableBits &B : Bits)
1105       rebuildGlobal(B);
1106 
1107   return true;
1108 }
1109