1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2df49d1bbSPeter Collingbourne // 3df49d1bbSPeter Collingbourne // The LLVM Compiler Infrastructure 4df49d1bbSPeter Collingbourne // 5df49d1bbSPeter Collingbourne // This file is distributed under the University of Illinois Open Source 6df49d1bbSPeter Collingbourne // License. See LICENSE.TXT for details. 7df49d1bbSPeter Collingbourne // 8df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===// 9df49d1bbSPeter Collingbourne // 10df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases 117efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This 12df49d1bbSPeter Collingbourne // includes the following: 13df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single 14df49d1bbSPeter Collingbourne // possible callee, replace all calls with a direct call to that callee. 15df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an 16df49d1bbSPeter Collingbourne // integer <=64 bits and all possible callees are readnone, for each class and 17df49d1bbSPeter Collingbourne // each list of constant arguments: evaluate the function, store the return 18df49d1bbSPeter Collingbourne // value alongside the virtual table, and rewrite each virtual call as a load 19df49d1bbSPeter Collingbourne // from the virtual table. 20df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant 21df49d1bbSPeter Collingbourne // propagation hold and each function returns the same constant value, replace 22df49d1bbSPeter Collingbourne // each virtual call with that constant. 23df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions 24df49d1bbSPeter Collingbourne // for virtual constant propagation hold and a single vtable's function 25df49d1bbSPeter Collingbourne // returns 0, or a single vtable's function returns 1, replace each virtual 26df49d1bbSPeter Collingbourne // call with a comparison of the vptr against that vtable's address. 27df49d1bbSPeter Collingbourne // 28df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===// 29df49d1bbSPeter Collingbourne 30df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 31b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h" 32cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h" 33cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h" 34df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h" 35cdc71612SEugene Zelenko #include "llvm/ADT/iterator_range.h" 36df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h" 37cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h" 3837317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h" 3937317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h" 407efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h" 41df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h" 42df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h" 43df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h" 44b05e06e4SIvan Krasin #include "llvm/IR/DebugInfoMetadata.h" 45cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h" 46cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h" 475474645dSIvan Krasin #include "llvm/IR/DiagnosticInfo.h" 48cdc71612SEugene Zelenko #include "llvm/IR/Function.h" 49cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h" 50cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h" 51df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h" 52cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h" 53cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h" 54df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h" 55df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h" 56cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h" 57cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h" 58df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h" 592b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h" 60df49d1bbSPeter Collingbourne #include "llvm/Pass.h" 61cdc71612SEugene Zelenko #include "llvm/PassRegistry.h" 62cdc71612SEugene Zelenko #include "llvm/PassSupport.h" 63cdc71612SEugene Zelenko #include "llvm/Support/Casting.h" 642b33f653SPeter Collingbourne #include "llvm/Support/Error.h" 652b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h" 66cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h" 67b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h" 6837317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h" 69df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h" 70cdc71612SEugene Zelenko #include <algorithm> 71cdc71612SEugene Zelenko #include <cstddef> 72cdc71612SEugene Zelenko #include <map> 73df49d1bbSPeter Collingbourne #include <set> 74cdc71612SEugene Zelenko #include <string> 75df49d1bbSPeter Collingbourne 76df49d1bbSPeter Collingbourne using namespace llvm; 77df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt; 78df49d1bbSPeter Collingbourne 79df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt" 80df49d1bbSPeter Collingbourne 812b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction( 822b33f653SPeter Collingbourne "wholeprogramdevirt-summary-action", 832b33f653SPeter Collingbourne cl::desc("What to do with the summary when running this pass"), 842b33f653SPeter Collingbourne cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 852b33f653SPeter Collingbourne clEnumValN(PassSummaryAction::Import, "import", 862b33f653SPeter Collingbourne "Import typeid resolutions from summary and globals"), 872b33f653SPeter Collingbourne clEnumValN(PassSummaryAction::Export, "export", 882b33f653SPeter Collingbourne "Export typeid resolutions to summary and globals")), 892b33f653SPeter Collingbourne cl::Hidden); 902b33f653SPeter Collingbourne 912b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary( 922b33f653SPeter Collingbourne "wholeprogramdevirt-read-summary", 932b33f653SPeter Collingbourne cl::desc("Read summary from given YAML file before running pass"), 942b33f653SPeter Collingbourne cl::Hidden); 952b33f653SPeter Collingbourne 962b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary( 972b33f653SPeter Collingbourne "wholeprogramdevirt-write-summary", 982b33f653SPeter Collingbourne cl::desc("Write summary to given YAML file after running pass"), 992b33f653SPeter Collingbourne cl::Hidden); 1002b33f653SPeter Collingbourne 101df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If 102df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an 103df49d1bbSPeter Collingbourne // offset after the object. 104df49d1bbSPeter Collingbourne uint64_t 105df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 106df49d1bbSPeter Collingbourne bool IsAfter, uint64_t Size) { 107df49d1bbSPeter Collingbourne // Find a minimum offset taking into account only vtable sizes. 108df49d1bbSPeter Collingbourne uint64_t MinByte = 0; 109df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : Targets) { 110df49d1bbSPeter Collingbourne if (IsAfter) 111df49d1bbSPeter Collingbourne MinByte = std::max(MinByte, Target.minAfterBytes()); 112df49d1bbSPeter Collingbourne else 113df49d1bbSPeter Collingbourne MinByte = std::max(MinByte, Target.minBeforeBytes()); 114df49d1bbSPeter Collingbourne } 115df49d1bbSPeter Collingbourne 116df49d1bbSPeter Collingbourne // Build a vector of arrays of bytes covering, for each target, a slice of the 117df49d1bbSPeter Collingbourne // used region (see AccumBitVector::BytesUsed in 118df49d1bbSPeter Collingbourne // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 119df49d1bbSPeter Collingbourne // this aligns the used regions to start at MinByte. 120df49d1bbSPeter Collingbourne // 121df49d1bbSPeter Collingbourne // In this example, A, B and C are vtables, # is a byte already allocated for 122df49d1bbSPeter Collingbourne // a virtual function pointer, AAAA... (etc.) are the used regions for the 123df49d1bbSPeter Collingbourne // vtables and Offset(X) is the value computed for the Offset variable below 124df49d1bbSPeter Collingbourne // for X. 125df49d1bbSPeter Collingbourne // 126df49d1bbSPeter Collingbourne // Offset(A) 127df49d1bbSPeter Collingbourne // | | 128df49d1bbSPeter Collingbourne // |MinByte 129df49d1bbSPeter Collingbourne // A: ################AAAAAAAA|AAAAAAAA 130df49d1bbSPeter Collingbourne // B: ########BBBBBBBBBBBBBBBB|BBBB 131df49d1bbSPeter Collingbourne // C: ########################|CCCCCCCCCCCCCCCC 132df49d1bbSPeter Collingbourne // | Offset(B) | 133df49d1bbSPeter Collingbourne // 134df49d1bbSPeter Collingbourne // This code produces the slices of A, B and C that appear after the divider 135df49d1bbSPeter Collingbourne // at MinByte. 136df49d1bbSPeter Collingbourne std::vector<ArrayRef<uint8_t>> Used; 137df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : Targets) { 1387efd7506SPeter Collingbourne ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 1397efd7506SPeter Collingbourne : Target.TM->Bits->Before.BytesUsed; 140df49d1bbSPeter Collingbourne uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 141df49d1bbSPeter Collingbourne : MinByte - Target.minBeforeBytes(); 142df49d1bbSPeter Collingbourne 143df49d1bbSPeter Collingbourne // Disregard used regions that are smaller than Offset. These are 144df49d1bbSPeter Collingbourne // effectively all-free regions that do not need to be checked. 145df49d1bbSPeter Collingbourne if (VTUsed.size() > Offset) 146df49d1bbSPeter Collingbourne Used.push_back(VTUsed.slice(Offset)); 147df49d1bbSPeter Collingbourne } 148df49d1bbSPeter Collingbourne 149df49d1bbSPeter Collingbourne if (Size == 1) { 150df49d1bbSPeter Collingbourne // Find a free bit in each member of Used. 151df49d1bbSPeter Collingbourne for (unsigned I = 0;; ++I) { 152df49d1bbSPeter Collingbourne uint8_t BitsUsed = 0; 153df49d1bbSPeter Collingbourne for (auto &&B : Used) 154df49d1bbSPeter Collingbourne if (I < B.size()) 155df49d1bbSPeter Collingbourne BitsUsed |= B[I]; 156df49d1bbSPeter Collingbourne if (BitsUsed != 0xff) 157df49d1bbSPeter Collingbourne return (MinByte + I) * 8 + 158df49d1bbSPeter Collingbourne countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 159df49d1bbSPeter Collingbourne } 160df49d1bbSPeter Collingbourne } else { 161df49d1bbSPeter Collingbourne // Find a free (Size/8) byte region in each member of Used. 162df49d1bbSPeter Collingbourne // FIXME: see if alignment helps. 163df49d1bbSPeter Collingbourne for (unsigned I = 0;; ++I) { 164df49d1bbSPeter Collingbourne for (auto &&B : Used) { 165df49d1bbSPeter Collingbourne unsigned Byte = 0; 166df49d1bbSPeter Collingbourne while ((I + Byte) < B.size() && Byte < (Size / 8)) { 167df49d1bbSPeter Collingbourne if (B[I + Byte]) 168df49d1bbSPeter Collingbourne goto NextI; 169df49d1bbSPeter Collingbourne ++Byte; 170df49d1bbSPeter Collingbourne } 171df49d1bbSPeter Collingbourne } 172df49d1bbSPeter Collingbourne return (MinByte + I) * 8; 173df49d1bbSPeter Collingbourne NextI:; 174df49d1bbSPeter Collingbourne } 175df49d1bbSPeter Collingbourne } 176df49d1bbSPeter Collingbourne } 177df49d1bbSPeter Collingbourne 178df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues( 179df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 180df49d1bbSPeter Collingbourne unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 181df49d1bbSPeter Collingbourne if (BitWidth == 1) 182df49d1bbSPeter Collingbourne OffsetByte = -(AllocBefore / 8 + 1); 183df49d1bbSPeter Collingbourne else 184df49d1bbSPeter Collingbourne OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 185df49d1bbSPeter Collingbourne OffsetBit = AllocBefore % 8; 186df49d1bbSPeter Collingbourne 187df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : Targets) { 188df49d1bbSPeter Collingbourne if (BitWidth == 1) 189df49d1bbSPeter Collingbourne Target.setBeforeBit(AllocBefore); 190df49d1bbSPeter Collingbourne else 191df49d1bbSPeter Collingbourne Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 192df49d1bbSPeter Collingbourne } 193df49d1bbSPeter Collingbourne } 194df49d1bbSPeter Collingbourne 195df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues( 196df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 197df49d1bbSPeter Collingbourne unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 198df49d1bbSPeter Collingbourne if (BitWidth == 1) 199df49d1bbSPeter Collingbourne OffsetByte = AllocAfter / 8; 200df49d1bbSPeter Collingbourne else 201df49d1bbSPeter Collingbourne OffsetByte = (AllocAfter + 7) / 8; 202df49d1bbSPeter Collingbourne OffsetBit = AllocAfter % 8; 203df49d1bbSPeter Collingbourne 204df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : Targets) { 205df49d1bbSPeter Collingbourne if (BitWidth == 1) 206df49d1bbSPeter Collingbourne Target.setAfterBit(AllocAfter); 207df49d1bbSPeter Collingbourne else 208df49d1bbSPeter Collingbourne Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 209df49d1bbSPeter Collingbourne } 210df49d1bbSPeter Collingbourne } 211df49d1bbSPeter Collingbourne 2127efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) 2137efd7506SPeter Collingbourne : Fn(Fn), TM(TM), 21489439a79SIvan Krasin IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} 215df49d1bbSPeter Collingbourne 216df49d1bbSPeter Collingbourne namespace { 217df49d1bbSPeter Collingbourne 2187efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual 219df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to 220df49d1bbSPeter Collingbourne // the virtual function pointer. 221df49d1bbSPeter Collingbourne struct VTableSlot { 2227efd7506SPeter Collingbourne Metadata *TypeID; 223df49d1bbSPeter Collingbourne uint64_t ByteOffset; 224df49d1bbSPeter Collingbourne }; 225df49d1bbSPeter Collingbourne 226cdc71612SEugene Zelenko } // end anonymous namespace 227df49d1bbSPeter Collingbourne 2289b656527SPeter Collingbourne namespace llvm { 2299b656527SPeter Collingbourne 230df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> { 231df49d1bbSPeter Collingbourne static VTableSlot getEmptyKey() { 232df49d1bbSPeter Collingbourne return {DenseMapInfo<Metadata *>::getEmptyKey(), 233df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getEmptyKey()}; 234df49d1bbSPeter Collingbourne } 235df49d1bbSPeter Collingbourne static VTableSlot getTombstoneKey() { 236df49d1bbSPeter Collingbourne return {DenseMapInfo<Metadata *>::getTombstoneKey(), 237df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getTombstoneKey()}; 238df49d1bbSPeter Collingbourne } 239df49d1bbSPeter Collingbourne static unsigned getHashValue(const VTableSlot &I) { 2407efd7506SPeter Collingbourne return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 241df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 242df49d1bbSPeter Collingbourne } 243df49d1bbSPeter Collingbourne static bool isEqual(const VTableSlot &LHS, 244df49d1bbSPeter Collingbourne const VTableSlot &RHS) { 2457efd7506SPeter Collingbourne return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 246df49d1bbSPeter Collingbourne } 247df49d1bbSPeter Collingbourne }; 248df49d1bbSPeter Collingbourne 249cdc71612SEugene Zelenko } // end namespace llvm 2509b656527SPeter Collingbourne 251df49d1bbSPeter Collingbourne namespace { 252df49d1bbSPeter Collingbourne 253df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is 254df49d1bbSPeter Collingbourne // the indirect virtual call. 255df49d1bbSPeter Collingbourne struct VirtualCallSite { 256df49d1bbSPeter Collingbourne Value *VTable; 257df49d1bbSPeter Collingbourne CallSite CS; 258df49d1bbSPeter Collingbourne 2590312f614SPeter Collingbourne // If non-null, this field points to the associated unsafe use count stored in 2600312f614SPeter Collingbourne // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 2610312f614SPeter Collingbourne // of that field for details. 2620312f614SPeter Collingbourne unsigned *NumUnsafeUses; 2630312f614SPeter Collingbourne 264f3403fd2SIvan Krasin void emitRemark(const Twine &OptName, const Twine &TargetName) { 2655474645dSIvan Krasin Function *F = CS.getCaller(); 266f3403fd2SIvan Krasin emitOptimizationRemark( 267f3403fd2SIvan Krasin F->getContext(), DEBUG_TYPE, *F, 2685474645dSIvan Krasin CS.getInstruction()->getDebugLoc(), 269f3403fd2SIvan Krasin OptName + ": devirtualized a call to " + TargetName); 2705474645dSIvan Krasin } 2715474645dSIvan Krasin 272f3403fd2SIvan Krasin void replaceAndErase(const Twine &OptName, const Twine &TargetName, 273f3403fd2SIvan Krasin bool RemarksEnabled, Value *New) { 274f3403fd2SIvan Krasin if (RemarksEnabled) 275f3403fd2SIvan Krasin emitRemark(OptName, TargetName); 276df49d1bbSPeter Collingbourne CS->replaceAllUsesWith(New); 277df49d1bbSPeter Collingbourne if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { 278df49d1bbSPeter Collingbourne BranchInst::Create(II->getNormalDest(), CS.getInstruction()); 279df49d1bbSPeter Collingbourne II->getUnwindDest()->removePredecessor(II->getParent()); 280df49d1bbSPeter Collingbourne } 281df49d1bbSPeter Collingbourne CS->eraseFromParent(); 2820312f614SPeter Collingbourne // This use is no longer unsafe. 2830312f614SPeter Collingbourne if (NumUnsafeUses) 2840312f614SPeter Collingbourne --*NumUnsafeUses; 285df49d1bbSPeter Collingbourne } 286df49d1bbSPeter Collingbourne }; 287df49d1bbSPeter Collingbourne 28850cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list 28950cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the 29050cbd7ccSPeter Collingbourne // VTableSlotInfo class. 29150cbd7ccSPeter Collingbourne struct CallSiteInfo { 29250cbd7ccSPeter Collingbourne std::vector<VirtualCallSite> CallSites; 29350cbd7ccSPeter Collingbourne }; 29450cbd7ccSPeter Collingbourne 29550cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot. 29650cbd7ccSPeter Collingbourne struct VTableSlotInfo { 29750cbd7ccSPeter Collingbourne // The set of call sites which do not have all constant integer arguments 29850cbd7ccSPeter Collingbourne // (excluding "this"). 29950cbd7ccSPeter Collingbourne CallSiteInfo CSInfo; 30050cbd7ccSPeter Collingbourne 30150cbd7ccSPeter Collingbourne // The set of call sites with all constant integer arguments (excluding 30250cbd7ccSPeter Collingbourne // "this"), grouped by argument list. 30350cbd7ccSPeter Collingbourne std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 30450cbd7ccSPeter Collingbourne 30550cbd7ccSPeter Collingbourne void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); 30650cbd7ccSPeter Collingbourne 30750cbd7ccSPeter Collingbourne private: 30850cbd7ccSPeter Collingbourne CallSiteInfo &findCallSiteInfo(CallSite CS); 30950cbd7ccSPeter Collingbourne }; 31050cbd7ccSPeter Collingbourne 31150cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { 31250cbd7ccSPeter Collingbourne std::vector<uint64_t> Args; 31350cbd7ccSPeter Collingbourne auto *CI = dyn_cast<IntegerType>(CS.getType()); 31450cbd7ccSPeter Collingbourne if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) 31550cbd7ccSPeter Collingbourne return CSInfo; 31650cbd7ccSPeter Collingbourne for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { 31750cbd7ccSPeter Collingbourne auto *CI = dyn_cast<ConstantInt>(Arg); 31850cbd7ccSPeter Collingbourne if (!CI || CI->getBitWidth() > 64) 31950cbd7ccSPeter Collingbourne return CSInfo; 32050cbd7ccSPeter Collingbourne Args.push_back(CI->getZExtValue()); 32150cbd7ccSPeter Collingbourne } 32250cbd7ccSPeter Collingbourne return ConstCSInfo[Args]; 32350cbd7ccSPeter Collingbourne } 32450cbd7ccSPeter Collingbourne 32550cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, 32650cbd7ccSPeter Collingbourne unsigned *NumUnsafeUses) { 32750cbd7ccSPeter Collingbourne findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); 32850cbd7ccSPeter Collingbourne } 32950cbd7ccSPeter Collingbourne 330df49d1bbSPeter Collingbourne struct DevirtModule { 331df49d1bbSPeter Collingbourne Module &M; 33237317f12SPeter Collingbourne function_ref<AAResults &(Function &)> AARGetter; 3332b33f653SPeter Collingbourne 3342b33f653SPeter Collingbourne PassSummaryAction Action; 3352b33f653SPeter Collingbourne ModuleSummaryIndex *Summary; 3362b33f653SPeter Collingbourne 337df49d1bbSPeter Collingbourne IntegerType *Int8Ty; 338df49d1bbSPeter Collingbourne PointerType *Int8PtrTy; 339df49d1bbSPeter Collingbourne IntegerType *Int32Ty; 34050cbd7ccSPeter Collingbourne IntegerType *Int64Ty; 341df49d1bbSPeter Collingbourne 342f3403fd2SIvan Krasin bool RemarksEnabled; 343f3403fd2SIvan Krasin 34450cbd7ccSPeter Collingbourne MapVector<VTableSlot, VTableSlotInfo> CallSlots; 345df49d1bbSPeter Collingbourne 3460312f614SPeter Collingbourne // This map keeps track of the number of "unsafe" uses of a loaded function 3470312f614SPeter Collingbourne // pointer. The key is the associated llvm.type.test intrinsic call generated 3480312f614SPeter Collingbourne // by this pass. An unsafe use is one that calls the loaded function pointer 3490312f614SPeter Collingbourne // directly. Every time we eliminate an unsafe use (for example, by 3500312f614SPeter Collingbourne // devirtualizing it or by applying virtual constant propagation), we 3510312f614SPeter Collingbourne // decrement the value stored in this map. If a value reaches zero, we can 3520312f614SPeter Collingbourne // eliminate the type check by RAUWing the associated llvm.type.test call with 3530312f614SPeter Collingbourne // true. 3540312f614SPeter Collingbourne std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 3550312f614SPeter Collingbourne 35637317f12SPeter Collingbourne DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 35737317f12SPeter Collingbourne PassSummaryAction Action, ModuleSummaryIndex *Summary) 35837317f12SPeter Collingbourne : M(M), AARGetter(AARGetter), Action(Action), Summary(Summary), 3592b33f653SPeter Collingbourne Int8Ty(Type::getInt8Ty(M.getContext())), 360df49d1bbSPeter Collingbourne Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 361f3403fd2SIvan Krasin Int32Ty(Type::getInt32Ty(M.getContext())), 36250cbd7ccSPeter Collingbourne Int64Ty(Type::getInt64Ty(M.getContext())), 363f3403fd2SIvan Krasin RemarksEnabled(areRemarksEnabled()) {} 364f3403fd2SIvan Krasin 365f3403fd2SIvan Krasin bool areRemarksEnabled(); 366df49d1bbSPeter Collingbourne 3670312f614SPeter Collingbourne void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); 3680312f614SPeter Collingbourne void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 3690312f614SPeter Collingbourne 3707efd7506SPeter Collingbourne void buildTypeIdentifierMap( 3717efd7506SPeter Collingbourne std::vector<VTableBits> &Bits, 3727efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 3738786754cSPeter Collingbourne Constant *getPointerAtOffset(Constant *I, uint64_t Offset); 3747efd7506SPeter Collingbourne bool 3757efd7506SPeter Collingbourne tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 3767efd7506SPeter Collingbourne const std::set<TypeMemberInfo> &TypeMemberInfos, 377df49d1bbSPeter Collingbourne uint64_t ByteOffset); 37850cbd7ccSPeter Collingbourne 37950cbd7ccSPeter Collingbourne void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn); 380f3403fd2SIvan Krasin bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 38150cbd7ccSPeter Collingbourne VTableSlotInfo &SlotInfo); 38250cbd7ccSPeter Collingbourne 383df49d1bbSPeter Collingbourne bool tryEvaluateFunctionsWithArgs( 384df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 38550cbd7ccSPeter Collingbourne ArrayRef<uint64_t> Args); 38650cbd7ccSPeter Collingbourne 38750cbd7ccSPeter Collingbourne void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 38850cbd7ccSPeter Collingbourne uint64_t TheRetVal); 38950cbd7ccSPeter Collingbourne bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 39050cbd7ccSPeter Collingbourne CallSiteInfo &CSInfo); 39150cbd7ccSPeter Collingbourne 39250cbd7ccSPeter Collingbourne void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 39350cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr); 394df49d1bbSPeter Collingbourne bool tryUniqueRetValOpt(unsigned BitWidth, 395f3403fd2SIvan Krasin MutableArrayRef<VirtualCallTarget> TargetsForSlot, 39650cbd7ccSPeter Collingbourne CallSiteInfo &CSInfo); 39750cbd7ccSPeter Collingbourne 39850cbd7ccSPeter Collingbourne void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 39950cbd7ccSPeter Collingbourne Constant *Byte, Constant *Bit); 400df49d1bbSPeter Collingbourne bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 40150cbd7ccSPeter Collingbourne VTableSlotInfo &SlotInfo); 402df49d1bbSPeter Collingbourne 403df49d1bbSPeter Collingbourne void rebuildGlobal(VTableBits &B); 404df49d1bbSPeter Collingbourne 405df49d1bbSPeter Collingbourne bool run(); 4062b33f653SPeter Collingbourne 4072b33f653SPeter Collingbourne // Lower the module using the action and summary passed as command line 4082b33f653SPeter Collingbourne // arguments. For testing purposes only. 40937317f12SPeter Collingbourne static bool runForTesting(Module &M, 41037317f12SPeter Collingbourne function_ref<AAResults &(Function &)> AARGetter); 411df49d1bbSPeter Collingbourne }; 412df49d1bbSPeter Collingbourne 413df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass { 414df49d1bbSPeter Collingbourne static char ID; 415cdc71612SEugene Zelenko 4162b33f653SPeter Collingbourne bool UseCommandLine = false; 4172b33f653SPeter Collingbourne 4182b33f653SPeter Collingbourne PassSummaryAction Action; 4192b33f653SPeter Collingbourne ModuleSummaryIndex *Summary; 4202b33f653SPeter Collingbourne 4212b33f653SPeter Collingbourne WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { 4222b33f653SPeter Collingbourne initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 4232b33f653SPeter Collingbourne } 4242b33f653SPeter Collingbourne 4252b33f653SPeter Collingbourne WholeProgramDevirt(PassSummaryAction Action, ModuleSummaryIndex *Summary) 4262b33f653SPeter Collingbourne : ModulePass(ID), Action(Action), Summary(Summary) { 427df49d1bbSPeter Collingbourne initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 428df49d1bbSPeter Collingbourne } 429cdc71612SEugene Zelenko 430cdc71612SEugene Zelenko bool runOnModule(Module &M) override { 431aa641a51SAndrew Kaylor if (skipModule(M)) 432aa641a51SAndrew Kaylor return false; 4332b33f653SPeter Collingbourne if (UseCommandLine) 43437317f12SPeter Collingbourne return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); 43537317f12SPeter Collingbourne return DevirtModule(M, LegacyAARGetter(*this), Action, Summary).run(); 43637317f12SPeter Collingbourne } 43737317f12SPeter Collingbourne 43837317f12SPeter Collingbourne void getAnalysisUsage(AnalysisUsage &AU) const override { 43937317f12SPeter Collingbourne AU.addRequired<AssumptionCacheTracker>(); 44037317f12SPeter Collingbourne AU.addRequired<TargetLibraryInfoWrapperPass>(); 441aa641a51SAndrew Kaylor } 442df49d1bbSPeter Collingbourne }; 443df49d1bbSPeter Collingbourne 444cdc71612SEugene Zelenko } // end anonymous namespace 445df49d1bbSPeter Collingbourne 44637317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", 44737317f12SPeter Collingbourne "Whole program devirtualization", false, false) 44837317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 44937317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 45037317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", 451df49d1bbSPeter Collingbourne "Whole program devirtualization", false, false) 452df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0; 453df49d1bbSPeter Collingbourne 4542b33f653SPeter Collingbourne ModulePass *llvm::createWholeProgramDevirtPass(PassSummaryAction Action, 4552b33f653SPeter Collingbourne ModuleSummaryIndex *Summary) { 4562b33f653SPeter Collingbourne return new WholeProgramDevirt(Action, Summary); 457df49d1bbSPeter Collingbourne } 458df49d1bbSPeter Collingbourne 459164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 46037317f12SPeter Collingbourne ModuleAnalysisManager &AM) { 46137317f12SPeter Collingbourne auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 46237317f12SPeter Collingbourne auto AARGetter = [&](Function &F) -> AAResults & { 46337317f12SPeter Collingbourne return FAM.getResult<AAManager>(F); 46437317f12SPeter Collingbourne }; 46537317f12SPeter Collingbourne if (!DevirtModule(M, AARGetter, PassSummaryAction::None, nullptr).run()) 466d737dd2eSDavide Italiano return PreservedAnalyses::all(); 467d737dd2eSDavide Italiano return PreservedAnalyses::none(); 468d737dd2eSDavide Italiano } 469d737dd2eSDavide Italiano 47037317f12SPeter Collingbourne bool DevirtModule::runForTesting( 47137317f12SPeter Collingbourne Module &M, function_ref<AAResults &(Function &)> AARGetter) { 4722b33f653SPeter Collingbourne ModuleSummaryIndex Summary; 4732b33f653SPeter Collingbourne 4742b33f653SPeter Collingbourne // Handle the command-line summary arguments. This code is for testing 4752b33f653SPeter Collingbourne // purposes only, so we handle errors directly. 4762b33f653SPeter Collingbourne if (!ClReadSummary.empty()) { 4772b33f653SPeter Collingbourne ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 4782b33f653SPeter Collingbourne ": "); 4792b33f653SPeter Collingbourne auto ReadSummaryFile = 4802b33f653SPeter Collingbourne ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 4812b33f653SPeter Collingbourne 4822b33f653SPeter Collingbourne yaml::Input In(ReadSummaryFile->getBuffer()); 4832b33f653SPeter Collingbourne In >> Summary; 4842b33f653SPeter Collingbourne ExitOnErr(errorCodeToError(In.error())); 4852b33f653SPeter Collingbourne } 4862b33f653SPeter Collingbourne 48737317f12SPeter Collingbourne bool Changed = DevirtModule(M, AARGetter, ClSummaryAction, &Summary).run(); 4882b33f653SPeter Collingbourne 4892b33f653SPeter Collingbourne if (!ClWriteSummary.empty()) { 4902b33f653SPeter Collingbourne ExitOnError ExitOnErr( 4912b33f653SPeter Collingbourne "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 4922b33f653SPeter Collingbourne std::error_code EC; 4932b33f653SPeter Collingbourne raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); 4942b33f653SPeter Collingbourne ExitOnErr(errorCodeToError(EC)); 4952b33f653SPeter Collingbourne 4962b33f653SPeter Collingbourne yaml::Output Out(OS); 4972b33f653SPeter Collingbourne Out << Summary; 4982b33f653SPeter Collingbourne } 4992b33f653SPeter Collingbourne 5002b33f653SPeter Collingbourne return Changed; 5012b33f653SPeter Collingbourne } 5022b33f653SPeter Collingbourne 5037efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap( 504df49d1bbSPeter Collingbourne std::vector<VTableBits> &Bits, 5057efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 506df49d1bbSPeter Collingbourne DenseMap<GlobalVariable *, VTableBits *> GVToBits; 5077efd7506SPeter Collingbourne Bits.reserve(M.getGlobalList().size()); 5087efd7506SPeter Collingbourne SmallVector<MDNode *, 2> Types; 5097efd7506SPeter Collingbourne for (GlobalVariable &GV : M.globals()) { 5107efd7506SPeter Collingbourne Types.clear(); 5117efd7506SPeter Collingbourne GV.getMetadata(LLVMContext::MD_type, Types); 5127efd7506SPeter Collingbourne if (Types.empty()) 513df49d1bbSPeter Collingbourne continue; 514df49d1bbSPeter Collingbourne 5157efd7506SPeter Collingbourne VTableBits *&BitsPtr = GVToBits[&GV]; 5167efd7506SPeter Collingbourne if (!BitsPtr) { 5177efd7506SPeter Collingbourne Bits.emplace_back(); 5187efd7506SPeter Collingbourne Bits.back().GV = &GV; 5197efd7506SPeter Collingbourne Bits.back().ObjectSize = 5207efd7506SPeter Collingbourne M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 5217efd7506SPeter Collingbourne BitsPtr = &Bits.back(); 5227efd7506SPeter Collingbourne } 5237efd7506SPeter Collingbourne 5247efd7506SPeter Collingbourne for (MDNode *Type : Types) { 5257efd7506SPeter Collingbourne auto TypeID = Type->getOperand(1).get(); 526df49d1bbSPeter Collingbourne 527df49d1bbSPeter Collingbourne uint64_t Offset = 528df49d1bbSPeter Collingbourne cast<ConstantInt>( 5297efd7506SPeter Collingbourne cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 530df49d1bbSPeter Collingbourne ->getZExtValue(); 531df49d1bbSPeter Collingbourne 5327efd7506SPeter Collingbourne TypeIdMap[TypeID].insert({BitsPtr, Offset}); 533df49d1bbSPeter Collingbourne } 534df49d1bbSPeter Collingbourne } 535df49d1bbSPeter Collingbourne } 536df49d1bbSPeter Collingbourne 5378786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) { 5388786754cSPeter Collingbourne if (I->getType()->isPointerTy()) { 5398786754cSPeter Collingbourne if (Offset == 0) 5408786754cSPeter Collingbourne return I; 5418786754cSPeter Collingbourne return nullptr; 5428786754cSPeter Collingbourne } 5438786754cSPeter Collingbourne 5447a1e5bbeSPeter Collingbourne const DataLayout &DL = M.getDataLayout(); 5457a1e5bbeSPeter Collingbourne 5467a1e5bbeSPeter Collingbourne if (auto *C = dyn_cast<ConstantStruct>(I)) { 5477a1e5bbeSPeter Collingbourne const StructLayout *SL = DL.getStructLayout(C->getType()); 5487a1e5bbeSPeter Collingbourne if (Offset >= SL->getSizeInBytes()) 5497a1e5bbeSPeter Collingbourne return nullptr; 5507a1e5bbeSPeter Collingbourne 5518786754cSPeter Collingbourne unsigned Op = SL->getElementContainingOffset(Offset); 5528786754cSPeter Collingbourne return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), 5538786754cSPeter Collingbourne Offset - SL->getElementOffset(Op)); 5548786754cSPeter Collingbourne } 5558786754cSPeter Collingbourne if (auto *C = dyn_cast<ConstantArray>(I)) { 5567a1e5bbeSPeter Collingbourne ArrayType *VTableTy = C->getType(); 5577a1e5bbeSPeter Collingbourne uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); 5587a1e5bbeSPeter Collingbourne 5598786754cSPeter Collingbourne unsigned Op = Offset / ElemSize; 5607a1e5bbeSPeter Collingbourne if (Op >= C->getNumOperands()) 5617a1e5bbeSPeter Collingbourne return nullptr; 5627a1e5bbeSPeter Collingbourne 5638786754cSPeter Collingbourne return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), 5648786754cSPeter Collingbourne Offset % ElemSize); 5658786754cSPeter Collingbourne } 5668786754cSPeter Collingbourne return nullptr; 5677a1e5bbeSPeter Collingbourne } 5687a1e5bbeSPeter Collingbourne 569df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets( 570df49d1bbSPeter Collingbourne std::vector<VirtualCallTarget> &TargetsForSlot, 5717efd7506SPeter Collingbourne const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { 5727efd7506SPeter Collingbourne for (const TypeMemberInfo &TM : TypeMemberInfos) { 5737efd7506SPeter Collingbourne if (!TM.Bits->GV->isConstant()) 574df49d1bbSPeter Collingbourne return false; 575df49d1bbSPeter Collingbourne 5768786754cSPeter Collingbourne Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), 5778786754cSPeter Collingbourne TM.Offset + ByteOffset); 5788786754cSPeter Collingbourne if (!Ptr) 579df49d1bbSPeter Collingbourne return false; 580df49d1bbSPeter Collingbourne 5818786754cSPeter Collingbourne auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); 582df49d1bbSPeter Collingbourne if (!Fn) 583df49d1bbSPeter Collingbourne return false; 584df49d1bbSPeter Collingbourne 585df49d1bbSPeter Collingbourne // We can disregard __cxa_pure_virtual as a possible call target, as 586df49d1bbSPeter Collingbourne // calls to pure virtuals are UB. 587df49d1bbSPeter Collingbourne if (Fn->getName() == "__cxa_pure_virtual") 588df49d1bbSPeter Collingbourne continue; 589df49d1bbSPeter Collingbourne 5907efd7506SPeter Collingbourne TargetsForSlot.push_back({Fn, &TM}); 591df49d1bbSPeter Collingbourne } 592df49d1bbSPeter Collingbourne 593df49d1bbSPeter Collingbourne // Give up if we couldn't find any targets. 594df49d1bbSPeter Collingbourne return !TargetsForSlot.empty(); 595df49d1bbSPeter Collingbourne } 596df49d1bbSPeter Collingbourne 59750cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 59850cbd7ccSPeter Collingbourne Constant *TheFn) { 59950cbd7ccSPeter Collingbourne auto Apply = [&](CallSiteInfo &CSInfo) { 60050cbd7ccSPeter Collingbourne for (auto &&VCallSite : CSInfo.CallSites) { 601f3403fd2SIvan Krasin if (RemarksEnabled) 602f3403fd2SIvan Krasin VCallSite.emitRemark("single-impl", TheFn->getName()); 603df49d1bbSPeter Collingbourne VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( 604df49d1bbSPeter Collingbourne TheFn, VCallSite.CS.getCalledValue()->getType())); 6050312f614SPeter Collingbourne // This use is no longer unsafe. 6060312f614SPeter Collingbourne if (VCallSite.NumUnsafeUses) 6070312f614SPeter Collingbourne --*VCallSite.NumUnsafeUses; 608df49d1bbSPeter Collingbourne } 60950cbd7ccSPeter Collingbourne }; 61050cbd7ccSPeter Collingbourne Apply(SlotInfo.CSInfo); 61150cbd7ccSPeter Collingbourne for (auto &P : SlotInfo.ConstCSInfo) 61250cbd7ccSPeter Collingbourne Apply(P.second); 61350cbd7ccSPeter Collingbourne } 61450cbd7ccSPeter Collingbourne 61550cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt( 61650cbd7ccSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 61750cbd7ccSPeter Collingbourne VTableSlotInfo &SlotInfo) { 61850cbd7ccSPeter Collingbourne // See if the program contains a single implementation of this virtual 61950cbd7ccSPeter Collingbourne // function. 62050cbd7ccSPeter Collingbourne Function *TheFn = TargetsForSlot[0].Fn; 62150cbd7ccSPeter Collingbourne for (auto &&Target : TargetsForSlot) 62250cbd7ccSPeter Collingbourne if (TheFn != Target.Fn) 62350cbd7ccSPeter Collingbourne return false; 62450cbd7ccSPeter Collingbourne 62550cbd7ccSPeter Collingbourne // If so, update each call site to call that implementation directly. 62650cbd7ccSPeter Collingbourne if (RemarksEnabled) 62750cbd7ccSPeter Collingbourne TargetsForSlot[0].WasDevirt = true; 62850cbd7ccSPeter Collingbourne applySingleImplDevirt(SlotInfo, TheFn); 629df49d1bbSPeter Collingbourne return true; 630df49d1bbSPeter Collingbourne } 631df49d1bbSPeter Collingbourne 632df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs( 633df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 63450cbd7ccSPeter Collingbourne ArrayRef<uint64_t> Args) { 635df49d1bbSPeter Collingbourne // Evaluate each function and store the result in each target's RetVal 636df49d1bbSPeter Collingbourne // field. 637df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : TargetsForSlot) { 638df49d1bbSPeter Collingbourne if (Target.Fn->arg_size() != Args.size() + 1) 639df49d1bbSPeter Collingbourne return false; 640df49d1bbSPeter Collingbourne 641df49d1bbSPeter Collingbourne Evaluator Eval(M.getDataLayout(), nullptr); 642df49d1bbSPeter Collingbourne SmallVector<Constant *, 2> EvalArgs; 643df49d1bbSPeter Collingbourne EvalArgs.push_back( 644df49d1bbSPeter Collingbourne Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 64550cbd7ccSPeter Collingbourne for (unsigned I = 0; I != Args.size(); ++I) { 64650cbd7ccSPeter Collingbourne auto *ArgTy = dyn_cast<IntegerType>( 64750cbd7ccSPeter Collingbourne Target.Fn->getFunctionType()->getParamType(I + 1)); 64850cbd7ccSPeter Collingbourne if (!ArgTy) 64950cbd7ccSPeter Collingbourne return false; 65050cbd7ccSPeter Collingbourne EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 65150cbd7ccSPeter Collingbourne } 65250cbd7ccSPeter Collingbourne 653df49d1bbSPeter Collingbourne Constant *RetVal; 654df49d1bbSPeter Collingbourne if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 655df49d1bbSPeter Collingbourne !isa<ConstantInt>(RetVal)) 656df49d1bbSPeter Collingbourne return false; 657df49d1bbSPeter Collingbourne Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 658df49d1bbSPeter Collingbourne } 659df49d1bbSPeter Collingbourne return true; 660df49d1bbSPeter Collingbourne } 661df49d1bbSPeter Collingbourne 66250cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 66350cbd7ccSPeter Collingbourne uint64_t TheRetVal) { 66450cbd7ccSPeter Collingbourne for (auto Call : CSInfo.CallSites) 66550cbd7ccSPeter Collingbourne Call.replaceAndErase( 66650cbd7ccSPeter Collingbourne "uniform-ret-val", FnName, RemarksEnabled, 66750cbd7ccSPeter Collingbourne ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); 66850cbd7ccSPeter Collingbourne } 66950cbd7ccSPeter Collingbourne 670df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt( 67150cbd7ccSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo) { 672df49d1bbSPeter Collingbourne // Uniform return value optimization. If all functions return the same 673df49d1bbSPeter Collingbourne // constant, replace all calls with that constant. 674df49d1bbSPeter Collingbourne uint64_t TheRetVal = TargetsForSlot[0].RetVal; 675df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : TargetsForSlot) 676df49d1bbSPeter Collingbourne if (Target.RetVal != TheRetVal) 677df49d1bbSPeter Collingbourne return false; 678df49d1bbSPeter Collingbourne 67950cbd7ccSPeter Collingbourne applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 680f3403fd2SIvan Krasin if (RemarksEnabled) 681f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 682f3403fd2SIvan Krasin Target.WasDevirt = true; 683df49d1bbSPeter Collingbourne return true; 684df49d1bbSPeter Collingbourne } 685df49d1bbSPeter Collingbourne 68650cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 68750cbd7ccSPeter Collingbourne bool IsOne, 68850cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr) { 68950cbd7ccSPeter Collingbourne for (auto &&Call : CSInfo.CallSites) { 69050cbd7ccSPeter Collingbourne IRBuilder<> B(Call.CS.getInstruction()); 69150cbd7ccSPeter Collingbourne Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, 69250cbd7ccSPeter Collingbourne Call.VTable, UniqueMemberAddr); 69350cbd7ccSPeter Collingbourne Cmp = B.CreateZExt(Cmp, Call.CS->getType()); 69450cbd7ccSPeter Collingbourne Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp); 69550cbd7ccSPeter Collingbourne } 69650cbd7ccSPeter Collingbourne } 69750cbd7ccSPeter Collingbourne 698df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt( 699f3403fd2SIvan Krasin unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 70050cbd7ccSPeter Collingbourne CallSiteInfo &CSInfo) { 701df49d1bbSPeter Collingbourne // IsOne controls whether we look for a 0 or a 1. 702df49d1bbSPeter Collingbourne auto tryUniqueRetValOptFor = [&](bool IsOne) { 703cdc71612SEugene Zelenko const TypeMemberInfo *UniqueMember = nullptr; 704df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : TargetsForSlot) { 7053866cc5fSPeter Collingbourne if (Target.RetVal == (IsOne ? 1 : 0)) { 7067efd7506SPeter Collingbourne if (UniqueMember) 707df49d1bbSPeter Collingbourne return false; 7087efd7506SPeter Collingbourne UniqueMember = Target.TM; 709df49d1bbSPeter Collingbourne } 710df49d1bbSPeter Collingbourne } 711df49d1bbSPeter Collingbourne 7127efd7506SPeter Collingbourne // We should have found a unique member or bailed out by now. We already 713df49d1bbSPeter Collingbourne // checked for a uniform return value in tryUniformRetValOpt. 7147efd7506SPeter Collingbourne assert(UniqueMember); 715df49d1bbSPeter Collingbourne 716df49d1bbSPeter Collingbourne // Replace each call with the comparison. 71750cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr = 71850cbd7ccSPeter Collingbourne ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); 71950cbd7ccSPeter Collingbourne UniqueMemberAddr = ConstantExpr::getGetElementPtr( 72050cbd7ccSPeter Collingbourne Int8Ty, UniqueMemberAddr, 72150cbd7ccSPeter Collingbourne ConstantInt::get(Int64Ty, UniqueMember->Offset)); 72250cbd7ccSPeter Collingbourne 72350cbd7ccSPeter Collingbourne applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 72450cbd7ccSPeter Collingbourne UniqueMemberAddr); 72550cbd7ccSPeter Collingbourne 726f3403fd2SIvan Krasin // Update devirtualization statistics for targets. 727f3403fd2SIvan Krasin if (RemarksEnabled) 728f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 729f3403fd2SIvan Krasin Target.WasDevirt = true; 730f3403fd2SIvan Krasin 731df49d1bbSPeter Collingbourne return true; 732df49d1bbSPeter Collingbourne }; 733df49d1bbSPeter Collingbourne 734df49d1bbSPeter Collingbourne if (BitWidth == 1) { 735df49d1bbSPeter Collingbourne if (tryUniqueRetValOptFor(true)) 736df49d1bbSPeter Collingbourne return true; 737df49d1bbSPeter Collingbourne if (tryUniqueRetValOptFor(false)) 738df49d1bbSPeter Collingbourne return true; 739df49d1bbSPeter Collingbourne } 740df49d1bbSPeter Collingbourne return false; 741df49d1bbSPeter Collingbourne } 742df49d1bbSPeter Collingbourne 74350cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 74450cbd7ccSPeter Collingbourne Constant *Byte, Constant *Bit) { 74550cbd7ccSPeter Collingbourne for (auto Call : CSInfo.CallSites) { 74650cbd7ccSPeter Collingbourne auto *RetType = cast<IntegerType>(Call.CS.getType()); 74750cbd7ccSPeter Collingbourne IRBuilder<> B(Call.CS.getInstruction()); 74850cbd7ccSPeter Collingbourne Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); 74950cbd7ccSPeter Collingbourne if (RetType->getBitWidth() == 1) { 75050cbd7ccSPeter Collingbourne Value *Bits = B.CreateLoad(Addr); 75150cbd7ccSPeter Collingbourne Value *BitsAndBit = B.CreateAnd(Bits, Bit); 75250cbd7ccSPeter Collingbourne auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 75350cbd7ccSPeter Collingbourne Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 75450cbd7ccSPeter Collingbourne IsBitSet); 75550cbd7ccSPeter Collingbourne } else { 75650cbd7ccSPeter Collingbourne Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 75750cbd7ccSPeter Collingbourne Value *Val = B.CreateLoad(RetType, ValAddr); 75850cbd7ccSPeter Collingbourne Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val); 75950cbd7ccSPeter Collingbourne } 76050cbd7ccSPeter Collingbourne } 76150cbd7ccSPeter Collingbourne } 76250cbd7ccSPeter Collingbourne 763df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp( 764df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 76550cbd7ccSPeter Collingbourne VTableSlotInfo &SlotInfo) { 766df49d1bbSPeter Collingbourne // This only works if the function returns an integer. 767df49d1bbSPeter Collingbourne auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 768df49d1bbSPeter Collingbourne if (!RetType) 769df49d1bbSPeter Collingbourne return false; 770df49d1bbSPeter Collingbourne unsigned BitWidth = RetType->getBitWidth(); 771df49d1bbSPeter Collingbourne if (BitWidth > 64) 772df49d1bbSPeter Collingbourne return false; 773df49d1bbSPeter Collingbourne 77417febdbbSPeter Collingbourne // Make sure that each function is defined, does not access memory, takes at 77517febdbbSPeter Collingbourne // least one argument, does not use its first argument (which we assume is 77617febdbbSPeter Collingbourne // 'this'), and has the same return type. 77737317f12SPeter Collingbourne // 77837317f12SPeter Collingbourne // Note that we test whether this copy of the function is readnone, rather 77937317f12SPeter Collingbourne // than testing function attributes, which must hold for any copy of the 78037317f12SPeter Collingbourne // function, even a less optimized version substituted at link time. This is 78137317f12SPeter Collingbourne // sound because the virtual constant propagation optimizations effectively 78237317f12SPeter Collingbourne // inline all implementations of the virtual function into each call site, 78337317f12SPeter Collingbourne // rather than using function attributes to perform local optimization. 784df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : TargetsForSlot) { 78537317f12SPeter Collingbourne if (Target.Fn->isDeclaration() || 78637317f12SPeter Collingbourne computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != 78737317f12SPeter Collingbourne MAK_ReadNone || 78817febdbbSPeter Collingbourne Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || 789df49d1bbSPeter Collingbourne Target.Fn->getReturnType() != RetType) 790df49d1bbSPeter Collingbourne return false; 791df49d1bbSPeter Collingbourne } 792df49d1bbSPeter Collingbourne 79350cbd7ccSPeter Collingbourne for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 794df49d1bbSPeter Collingbourne if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 795df49d1bbSPeter Collingbourne continue; 796df49d1bbSPeter Collingbourne 79750cbd7ccSPeter Collingbourne if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second)) 798df49d1bbSPeter Collingbourne continue; 799df49d1bbSPeter Collingbourne 800df49d1bbSPeter Collingbourne if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) 801df49d1bbSPeter Collingbourne continue; 802df49d1bbSPeter Collingbourne 8037efd7506SPeter Collingbourne // Find an allocation offset in bits in all vtables associated with the 8047efd7506SPeter Collingbourne // type. 805df49d1bbSPeter Collingbourne uint64_t AllocBefore = 806df49d1bbSPeter Collingbourne findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 807df49d1bbSPeter Collingbourne uint64_t AllocAfter = 808df49d1bbSPeter Collingbourne findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 809df49d1bbSPeter Collingbourne 810df49d1bbSPeter Collingbourne // Calculate the total amount of padding needed to store a value at both 811df49d1bbSPeter Collingbourne // ends of the object. 812df49d1bbSPeter Collingbourne uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 813df49d1bbSPeter Collingbourne for (auto &&Target : TargetsForSlot) { 814df49d1bbSPeter Collingbourne TotalPaddingBefore += std::max<int64_t>( 815df49d1bbSPeter Collingbourne (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 816df49d1bbSPeter Collingbourne TotalPaddingAfter += std::max<int64_t>( 817df49d1bbSPeter Collingbourne (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 818df49d1bbSPeter Collingbourne } 819df49d1bbSPeter Collingbourne 820df49d1bbSPeter Collingbourne // If the amount of padding is too large, give up. 821df49d1bbSPeter Collingbourne // FIXME: do something smarter here. 822df49d1bbSPeter Collingbourne if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 823df49d1bbSPeter Collingbourne continue; 824df49d1bbSPeter Collingbourne 825df49d1bbSPeter Collingbourne // Calculate the offset to the value as a (possibly negative) byte offset 826df49d1bbSPeter Collingbourne // and (if applicable) a bit offset, and store the values in the targets. 827df49d1bbSPeter Collingbourne int64_t OffsetByte; 828df49d1bbSPeter Collingbourne uint64_t OffsetBit; 829df49d1bbSPeter Collingbourne if (TotalPaddingBefore <= TotalPaddingAfter) 830df49d1bbSPeter Collingbourne setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 831df49d1bbSPeter Collingbourne OffsetBit); 832df49d1bbSPeter Collingbourne else 833df49d1bbSPeter Collingbourne setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 834df49d1bbSPeter Collingbourne OffsetBit); 835df49d1bbSPeter Collingbourne 836f3403fd2SIvan Krasin if (RemarksEnabled) 837f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 838f3403fd2SIvan Krasin Target.WasDevirt = true; 839f3403fd2SIvan Krasin 840df49d1bbSPeter Collingbourne // Rewrite each call to a load from OffsetByte/OffsetBit. 841184773d8SPeter Collingbourne Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 84250cbd7ccSPeter Collingbourne Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 84350cbd7ccSPeter Collingbourne applyVirtualConstProp(CSByConstantArg.second, 84450cbd7ccSPeter Collingbourne TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 845df49d1bbSPeter Collingbourne } 846df49d1bbSPeter Collingbourne return true; 847df49d1bbSPeter Collingbourne } 848df49d1bbSPeter Collingbourne 849df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) { 850df49d1bbSPeter Collingbourne if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 851df49d1bbSPeter Collingbourne return; 852df49d1bbSPeter Collingbourne 853df49d1bbSPeter Collingbourne // Align each byte array to pointer width. 854df49d1bbSPeter Collingbourne unsigned PointerSize = M.getDataLayout().getPointerSize(); 855df49d1bbSPeter Collingbourne B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); 856df49d1bbSPeter Collingbourne B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); 857df49d1bbSPeter Collingbourne 858df49d1bbSPeter Collingbourne // Before was stored in reverse order; flip it now. 859df49d1bbSPeter Collingbourne for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 860df49d1bbSPeter Collingbourne std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 861df49d1bbSPeter Collingbourne 862df49d1bbSPeter Collingbourne // Build an anonymous global containing the before bytes, followed by the 863df49d1bbSPeter Collingbourne // original initializer, followed by the after bytes. 864df49d1bbSPeter Collingbourne auto NewInit = ConstantStruct::getAnon( 865df49d1bbSPeter Collingbourne {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 866df49d1bbSPeter Collingbourne B.GV->getInitializer(), 867df49d1bbSPeter Collingbourne ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 868df49d1bbSPeter Collingbourne auto NewGV = 869df49d1bbSPeter Collingbourne new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 870df49d1bbSPeter Collingbourne GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 871df49d1bbSPeter Collingbourne NewGV->setSection(B.GV->getSection()); 872df49d1bbSPeter Collingbourne NewGV->setComdat(B.GV->getComdat()); 873df49d1bbSPeter Collingbourne 8740312f614SPeter Collingbourne // Copy the original vtable's metadata to the anonymous global, adjusting 8750312f614SPeter Collingbourne // offsets as required. 8760312f614SPeter Collingbourne NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 8770312f614SPeter Collingbourne 878df49d1bbSPeter Collingbourne // Build an alias named after the original global, pointing at the second 879df49d1bbSPeter Collingbourne // element (the original initializer). 880df49d1bbSPeter Collingbourne auto Alias = GlobalAlias::create( 881df49d1bbSPeter Collingbourne B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 882df49d1bbSPeter Collingbourne ConstantExpr::getGetElementPtr( 883df49d1bbSPeter Collingbourne NewInit->getType(), NewGV, 884df49d1bbSPeter Collingbourne ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 885df49d1bbSPeter Collingbourne ConstantInt::get(Int32Ty, 1)}), 886df49d1bbSPeter Collingbourne &M); 887df49d1bbSPeter Collingbourne Alias->setVisibility(B.GV->getVisibility()); 888df49d1bbSPeter Collingbourne Alias->takeName(B.GV); 889df49d1bbSPeter Collingbourne 890df49d1bbSPeter Collingbourne B.GV->replaceAllUsesWith(Alias); 891df49d1bbSPeter Collingbourne B.GV->eraseFromParent(); 892df49d1bbSPeter Collingbourne } 893df49d1bbSPeter Collingbourne 894f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() { 895f3403fd2SIvan Krasin const auto &FL = M.getFunctionList(); 896f3403fd2SIvan Krasin if (FL.empty()) 897f3403fd2SIvan Krasin return false; 898f3403fd2SIvan Krasin const Function &Fn = FL.front(); 89904758ba3SAdam Nemet auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), ""); 900f3403fd2SIvan Krasin return DI.isEnabled(); 901f3403fd2SIvan Krasin } 902f3403fd2SIvan Krasin 9030312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, 9040312f614SPeter Collingbourne Function *AssumeFunc) { 905df49d1bbSPeter Collingbourne // Find all virtual calls via a virtual table pointer %p under an assumption 9067efd7506SPeter Collingbourne // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 9077efd7506SPeter Collingbourne // points to a member of the type identifier %md. Group calls by (type ID, 9087efd7506SPeter Collingbourne // offset) pair (effectively the identity of the virtual function) and store 9097efd7506SPeter Collingbourne // to CallSlots. 910df49d1bbSPeter Collingbourne DenseSet<Value *> SeenPtrs; 9117efd7506SPeter Collingbourne for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); 912df49d1bbSPeter Collingbourne I != E;) { 913df49d1bbSPeter Collingbourne auto CI = dyn_cast<CallInst>(I->getUser()); 914df49d1bbSPeter Collingbourne ++I; 915df49d1bbSPeter Collingbourne if (!CI) 916df49d1bbSPeter Collingbourne continue; 917df49d1bbSPeter Collingbourne 918ccdc225cSPeter Collingbourne // Search for virtual calls based on %p and add them to DevirtCalls. 919ccdc225cSPeter Collingbourne SmallVector<DevirtCallSite, 1> DevirtCalls; 920df49d1bbSPeter Collingbourne SmallVector<CallInst *, 1> Assumes; 9210312f614SPeter Collingbourne findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); 922df49d1bbSPeter Collingbourne 923ccdc225cSPeter Collingbourne // If we found any, add them to CallSlots. Only do this if we haven't seen 924ccdc225cSPeter Collingbourne // the vtable pointer before, as it may have been CSE'd with pointers from 925ccdc225cSPeter Collingbourne // other call sites, and we don't want to process call sites multiple times. 926df49d1bbSPeter Collingbourne if (!Assumes.empty()) { 9277efd7506SPeter Collingbourne Metadata *TypeId = 928df49d1bbSPeter Collingbourne cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 929df49d1bbSPeter Collingbourne Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 930ccdc225cSPeter Collingbourne if (SeenPtrs.insert(Ptr).second) { 931ccdc225cSPeter Collingbourne for (DevirtCallSite Call : DevirtCalls) { 93250cbd7ccSPeter Collingbourne CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), 93350cbd7ccSPeter Collingbourne Call.CS, nullptr); 934ccdc225cSPeter Collingbourne } 935ccdc225cSPeter Collingbourne } 936df49d1bbSPeter Collingbourne } 937df49d1bbSPeter Collingbourne 9387efd7506SPeter Collingbourne // We no longer need the assumes or the type test. 939df49d1bbSPeter Collingbourne for (auto Assume : Assumes) 940df49d1bbSPeter Collingbourne Assume->eraseFromParent(); 941df49d1bbSPeter Collingbourne // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 942df49d1bbSPeter Collingbourne // may use the vtable argument later. 943df49d1bbSPeter Collingbourne if (CI->use_empty()) 944df49d1bbSPeter Collingbourne CI->eraseFromParent(); 945df49d1bbSPeter Collingbourne } 9460312f614SPeter Collingbourne } 9470312f614SPeter Collingbourne 9480312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 9490312f614SPeter Collingbourne Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 9500312f614SPeter Collingbourne 9510312f614SPeter Collingbourne for (auto I = TypeCheckedLoadFunc->use_begin(), 9520312f614SPeter Collingbourne E = TypeCheckedLoadFunc->use_end(); 9530312f614SPeter Collingbourne I != E;) { 9540312f614SPeter Collingbourne auto CI = dyn_cast<CallInst>(I->getUser()); 9550312f614SPeter Collingbourne ++I; 9560312f614SPeter Collingbourne if (!CI) 9570312f614SPeter Collingbourne continue; 9580312f614SPeter Collingbourne 9590312f614SPeter Collingbourne Value *Ptr = CI->getArgOperand(0); 9600312f614SPeter Collingbourne Value *Offset = CI->getArgOperand(1); 9610312f614SPeter Collingbourne Value *TypeIdValue = CI->getArgOperand(2); 9620312f614SPeter Collingbourne Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 9630312f614SPeter Collingbourne 9640312f614SPeter Collingbourne SmallVector<DevirtCallSite, 1> DevirtCalls; 9650312f614SPeter Collingbourne SmallVector<Instruction *, 1> LoadedPtrs; 9660312f614SPeter Collingbourne SmallVector<Instruction *, 1> Preds; 9670312f614SPeter Collingbourne bool HasNonCallUses = false; 9680312f614SPeter Collingbourne findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 9690312f614SPeter Collingbourne HasNonCallUses, CI); 9700312f614SPeter Collingbourne 9710312f614SPeter Collingbourne // Start by generating "pessimistic" code that explicitly loads the function 9720312f614SPeter Collingbourne // pointer from the vtable and performs the type check. If possible, we will 9730312f614SPeter Collingbourne // eliminate the load and the type check later. 9740312f614SPeter Collingbourne 9750312f614SPeter Collingbourne // If possible, only generate the load at the point where it is used. 9760312f614SPeter Collingbourne // This helps avoid unnecessary spills. 9770312f614SPeter Collingbourne IRBuilder<> LoadB( 9780312f614SPeter Collingbourne (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 9790312f614SPeter Collingbourne Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 9800312f614SPeter Collingbourne Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); 9810312f614SPeter Collingbourne Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); 9820312f614SPeter Collingbourne 9830312f614SPeter Collingbourne for (Instruction *LoadedPtr : LoadedPtrs) { 9840312f614SPeter Collingbourne LoadedPtr->replaceAllUsesWith(LoadedValue); 9850312f614SPeter Collingbourne LoadedPtr->eraseFromParent(); 9860312f614SPeter Collingbourne } 9870312f614SPeter Collingbourne 9880312f614SPeter Collingbourne // Likewise for the type test. 9890312f614SPeter Collingbourne IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 9900312f614SPeter Collingbourne CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 9910312f614SPeter Collingbourne 9920312f614SPeter Collingbourne for (Instruction *Pred : Preds) { 9930312f614SPeter Collingbourne Pred->replaceAllUsesWith(TypeTestCall); 9940312f614SPeter Collingbourne Pred->eraseFromParent(); 9950312f614SPeter Collingbourne } 9960312f614SPeter Collingbourne 9970312f614SPeter Collingbourne // We have already erased any extractvalue instructions that refer to the 9980312f614SPeter Collingbourne // intrinsic call, but the intrinsic may have other non-extractvalue uses 9990312f614SPeter Collingbourne // (although this is unlikely). In that case, explicitly build a pair and 10000312f614SPeter Collingbourne // RAUW it. 10010312f614SPeter Collingbourne if (!CI->use_empty()) { 10020312f614SPeter Collingbourne Value *Pair = UndefValue::get(CI->getType()); 10030312f614SPeter Collingbourne IRBuilder<> B(CI); 10040312f614SPeter Collingbourne Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 10050312f614SPeter Collingbourne Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 10060312f614SPeter Collingbourne CI->replaceAllUsesWith(Pair); 10070312f614SPeter Collingbourne } 10080312f614SPeter Collingbourne 10090312f614SPeter Collingbourne // The number of unsafe uses is initially the number of uses. 10100312f614SPeter Collingbourne auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 10110312f614SPeter Collingbourne NumUnsafeUses = DevirtCalls.size(); 10120312f614SPeter Collingbourne 10130312f614SPeter Collingbourne // If the function pointer has a non-call user, we cannot eliminate the type 10140312f614SPeter Collingbourne // check, as one of those users may eventually call the pointer. Increment 10150312f614SPeter Collingbourne // the unsafe use count to make sure it cannot reach zero. 10160312f614SPeter Collingbourne if (HasNonCallUses) 10170312f614SPeter Collingbourne ++NumUnsafeUses; 10180312f614SPeter Collingbourne for (DevirtCallSite Call : DevirtCalls) { 101950cbd7ccSPeter Collingbourne CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, 102050cbd7ccSPeter Collingbourne &NumUnsafeUses); 10210312f614SPeter Collingbourne } 10220312f614SPeter Collingbourne 10230312f614SPeter Collingbourne CI->eraseFromParent(); 10240312f614SPeter Collingbourne } 10250312f614SPeter Collingbourne } 10260312f614SPeter Collingbourne 10270312f614SPeter Collingbourne bool DevirtModule::run() { 10280312f614SPeter Collingbourne Function *TypeTestFunc = 10290312f614SPeter Collingbourne M.getFunction(Intrinsic::getName(Intrinsic::type_test)); 10300312f614SPeter Collingbourne Function *TypeCheckedLoadFunc = 10310312f614SPeter Collingbourne M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); 10320312f614SPeter Collingbourne Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 10330312f614SPeter Collingbourne 10340312f614SPeter Collingbourne if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 10350312f614SPeter Collingbourne AssumeFunc->use_empty()) && 10360312f614SPeter Collingbourne (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) 10370312f614SPeter Collingbourne return false; 10380312f614SPeter Collingbourne 10390312f614SPeter Collingbourne if (TypeTestFunc && AssumeFunc) 10400312f614SPeter Collingbourne scanTypeTestUsers(TypeTestFunc, AssumeFunc); 10410312f614SPeter Collingbourne 10420312f614SPeter Collingbourne if (TypeCheckedLoadFunc) 10430312f614SPeter Collingbourne scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 1044df49d1bbSPeter Collingbourne 10457efd7506SPeter Collingbourne // Rebuild type metadata into a map for easy lookup. 1046df49d1bbSPeter Collingbourne std::vector<VTableBits> Bits; 10477efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 10487efd7506SPeter Collingbourne buildTypeIdentifierMap(Bits, TypeIdMap); 10497efd7506SPeter Collingbourne if (TypeIdMap.empty()) 1050df49d1bbSPeter Collingbourne return true; 1051df49d1bbSPeter Collingbourne 10527efd7506SPeter Collingbourne // For each (type, offset) pair: 1053df49d1bbSPeter Collingbourne bool DidVirtualConstProp = false; 1054f3403fd2SIvan Krasin std::map<std::string, Function*> DevirtTargets; 1055df49d1bbSPeter Collingbourne for (auto &S : CallSlots) { 10567efd7506SPeter Collingbourne // Search each of the members of the type identifier for the virtual 10577efd7506SPeter Collingbourne // function implementation at offset S.first.ByteOffset, and add to 10587efd7506SPeter Collingbourne // TargetsForSlot. 1059df49d1bbSPeter Collingbourne std::vector<VirtualCallTarget> TargetsForSlot; 10607efd7506SPeter Collingbourne if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], 1061df49d1bbSPeter Collingbourne S.first.ByteOffset)) 1062df49d1bbSPeter Collingbourne continue; 1063df49d1bbSPeter Collingbourne 1064f3403fd2SIvan Krasin if (!trySingleImplDevirt(TargetsForSlot, S.second) && 1065f3403fd2SIvan Krasin tryVirtualConstProp(TargetsForSlot, S.second)) 1066f3403fd2SIvan Krasin DidVirtualConstProp = true; 1067f3403fd2SIvan Krasin 1068f3403fd2SIvan Krasin // Collect functions devirtualized at least for one call site for stats. 1069f3403fd2SIvan Krasin if (RemarksEnabled) 1070f3403fd2SIvan Krasin for (const auto &T : TargetsForSlot) 1071f3403fd2SIvan Krasin if (T.WasDevirt) 1072f3403fd2SIvan Krasin DevirtTargets[T.Fn->getName()] = T.Fn; 1073b05e06e4SIvan Krasin } 1074df49d1bbSPeter Collingbourne 1075f3403fd2SIvan Krasin if (RemarksEnabled) { 1076f3403fd2SIvan Krasin // Generate remarks for each devirtualized function. 1077f3403fd2SIvan Krasin for (const auto &DT : DevirtTargets) { 1078f3403fd2SIvan Krasin Function *F = DT.second; 1079f3403fd2SIvan Krasin DISubprogram *SP = F->getSubprogram(); 1080*7bc978b5SJustin Bogner emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP, 1081f3403fd2SIvan Krasin Twine("devirtualized ") + F->getName()); 1082b05e06e4SIvan Krasin } 1083df49d1bbSPeter Collingbourne } 1084df49d1bbSPeter Collingbourne 10850312f614SPeter Collingbourne // If we were able to eliminate all unsafe uses for a type checked load, 10860312f614SPeter Collingbourne // eliminate the type test by replacing it with true. 10870312f614SPeter Collingbourne if (TypeCheckedLoadFunc) { 10880312f614SPeter Collingbourne auto True = ConstantInt::getTrue(M.getContext()); 10890312f614SPeter Collingbourne for (auto &&U : NumUnsafeUsesForTypeTest) { 10900312f614SPeter Collingbourne if (U.second == 0) { 10910312f614SPeter Collingbourne U.first->replaceAllUsesWith(True); 10920312f614SPeter Collingbourne U.first->eraseFromParent(); 10930312f614SPeter Collingbourne } 10940312f614SPeter Collingbourne } 10950312f614SPeter Collingbourne } 10960312f614SPeter Collingbourne 1097df49d1bbSPeter Collingbourne // Rebuild each global we touched as part of virtual constant propagation to 1098df49d1bbSPeter Collingbourne // include the before and after bytes. 1099df49d1bbSPeter Collingbourne if (DidVirtualConstProp) 1100df49d1bbSPeter Collingbourne for (VTableBits &B : Bits) 1101df49d1bbSPeter Collingbourne rebuildGlobal(B); 1102df49d1bbSPeter Collingbourne 1103df49d1bbSPeter Collingbourne return true; 1104df49d1bbSPeter Collingbourne } 1105