1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2df49d1bbSPeter Collingbourne // 3df49d1bbSPeter Collingbourne // The LLVM Compiler Infrastructure 4df49d1bbSPeter Collingbourne // 5df49d1bbSPeter Collingbourne // This file is distributed under the University of Illinois Open Source 6df49d1bbSPeter Collingbourne // License. See LICENSE.TXT for details. 7df49d1bbSPeter Collingbourne // 8df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===// 9df49d1bbSPeter Collingbourne // 10df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases 117efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This 12df49d1bbSPeter Collingbourne // includes the following: 13df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single 14df49d1bbSPeter Collingbourne // possible callee, replace all calls with a direct call to that callee. 15df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an 16df49d1bbSPeter Collingbourne // integer <=64 bits and all possible callees are readnone, for each class and 17df49d1bbSPeter Collingbourne // each list of constant arguments: evaluate the function, store the return 18df49d1bbSPeter Collingbourne // value alongside the virtual table, and rewrite each virtual call as a load 19df49d1bbSPeter Collingbourne // from the virtual table. 20df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant 21df49d1bbSPeter Collingbourne // propagation hold and each function returns the same constant value, replace 22df49d1bbSPeter Collingbourne // each virtual call with that constant. 23df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions 24df49d1bbSPeter Collingbourne // for virtual constant propagation hold and a single vtable's function 25df49d1bbSPeter Collingbourne // returns 0, or a single vtable's function returns 1, replace each virtual 26df49d1bbSPeter Collingbourne // call with a comparison of the vptr against that vtable's address. 27df49d1bbSPeter Collingbourne // 28b406baaeSPeter Collingbourne // This pass is intended to be used during the regular and thin LTO pipelines. 29b406baaeSPeter Collingbourne // During regular LTO, the pass determines the best optimization for each 30b406baaeSPeter Collingbourne // virtual call and applies the resolutions directly to virtual calls that are 31b406baaeSPeter Collingbourne // eligible for virtual call optimization (i.e. calls that use either of the 32b406baaeSPeter Collingbourne // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During 33b406baaeSPeter Collingbourne // ThinLTO, the pass operates in two phases: 34b406baaeSPeter Collingbourne // - Export phase: this is run during the thin link over a single merged module 35b406baaeSPeter Collingbourne // that contains all vtables with !type metadata that participate in the link. 36b406baaeSPeter Collingbourne // The pass computes a resolution for each virtual call and stores it in the 37b406baaeSPeter Collingbourne // type identifier summary. 38b406baaeSPeter Collingbourne // - Import phase: this is run during the thin backends over the individual 39b406baaeSPeter Collingbourne // modules. The pass applies the resolutions previously computed during the 40b406baaeSPeter Collingbourne // import phase to each eligible virtual call. 41b406baaeSPeter Collingbourne // 42df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===// 43df49d1bbSPeter Collingbourne 44df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 45b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h" 46cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h" 47cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h" 48df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h" 49df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h" 50cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h" 516bda14b3SChandler Carruth #include "llvm/ADT/iterator_range.h" 5237317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h" 5337317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h" 540965da20SAdam Nemet #include "llvm/Analysis/OptimizationRemarkEmitter.h" 557efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h" 56df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h" 57df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h" 58df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h" 59cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h" 60cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h" 61cdc71612SEugene Zelenko #include "llvm/IR/Function.h" 62cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h" 63cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h" 64df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h" 65cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h" 66cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h" 67df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h" 68df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h" 69cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h" 70cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h" 71df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h" 722b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h" 73df49d1bbSPeter Collingbourne #include "llvm/Pass.h" 74cdc71612SEugene Zelenko #include "llvm/PassRegistry.h" 75cdc71612SEugene Zelenko #include "llvm/PassSupport.h" 76cdc71612SEugene Zelenko #include "llvm/Support/Casting.h" 772b33f653SPeter Collingbourne #include "llvm/Support/Error.h" 782b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h" 79cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h" 80b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h" 8137317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h" 82df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h" 83cdc71612SEugene Zelenko #include <algorithm> 84cdc71612SEugene Zelenko #include <cstddef> 85cdc71612SEugene Zelenko #include <map> 86df49d1bbSPeter Collingbourne #include <set> 87cdc71612SEugene Zelenko #include <string> 88df49d1bbSPeter Collingbourne 89df49d1bbSPeter Collingbourne using namespace llvm; 90df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt; 91df49d1bbSPeter Collingbourne 92df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt" 93df49d1bbSPeter Collingbourne 942b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction( 952b33f653SPeter Collingbourne "wholeprogramdevirt-summary-action", 962b33f653SPeter Collingbourne cl::desc("What to do with the summary when running this pass"), 972b33f653SPeter Collingbourne cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 982b33f653SPeter Collingbourne clEnumValN(PassSummaryAction::Import, "import", 992b33f653SPeter Collingbourne "Import typeid resolutions from summary and globals"), 1002b33f653SPeter Collingbourne clEnumValN(PassSummaryAction::Export, "export", 1012b33f653SPeter Collingbourne "Export typeid resolutions to summary and globals")), 1022b33f653SPeter Collingbourne cl::Hidden); 1032b33f653SPeter Collingbourne 1042b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary( 1052b33f653SPeter Collingbourne "wholeprogramdevirt-read-summary", 1062b33f653SPeter Collingbourne cl::desc("Read summary from given YAML file before running pass"), 1072b33f653SPeter Collingbourne cl::Hidden); 1082b33f653SPeter Collingbourne 1092b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary( 1102b33f653SPeter Collingbourne "wholeprogramdevirt-write-summary", 1112b33f653SPeter Collingbourne cl::desc("Write summary to given YAML file after running pass"), 1122b33f653SPeter Collingbourne cl::Hidden); 1132b33f653SPeter Collingbourne 1149cb59b92SVitaly Buka static cl::opt<unsigned> 1159cb59b92SVitaly Buka ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, 1169cb59b92SVitaly Buka cl::init(10), cl::ZeroOrMore, 11766f53d71SVitaly Buka cl::desc("Maximum number of call targets per " 11866f53d71SVitaly Buka "call site to enable branch funnels")); 11966f53d71SVitaly Buka 120df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If 121df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an 122df49d1bbSPeter Collingbourne // offset after the object. 123df49d1bbSPeter Collingbourne uint64_t 124df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 125df49d1bbSPeter Collingbourne bool IsAfter, uint64_t Size) { 126df49d1bbSPeter Collingbourne // Find a minimum offset taking into account only vtable sizes. 127df49d1bbSPeter Collingbourne uint64_t MinByte = 0; 128df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : Targets) { 129df49d1bbSPeter Collingbourne if (IsAfter) 130df49d1bbSPeter Collingbourne MinByte = std::max(MinByte, Target.minAfterBytes()); 131df49d1bbSPeter Collingbourne else 132df49d1bbSPeter Collingbourne MinByte = std::max(MinByte, Target.minBeforeBytes()); 133df49d1bbSPeter Collingbourne } 134df49d1bbSPeter Collingbourne 135df49d1bbSPeter Collingbourne // Build a vector of arrays of bytes covering, for each target, a slice of the 136df49d1bbSPeter Collingbourne // used region (see AccumBitVector::BytesUsed in 137df49d1bbSPeter Collingbourne // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 138df49d1bbSPeter Collingbourne // this aligns the used regions to start at MinByte. 139df49d1bbSPeter Collingbourne // 140df49d1bbSPeter Collingbourne // In this example, A, B and C are vtables, # is a byte already allocated for 141df49d1bbSPeter Collingbourne // a virtual function pointer, AAAA... (etc.) are the used regions for the 142df49d1bbSPeter Collingbourne // vtables and Offset(X) is the value computed for the Offset variable below 143df49d1bbSPeter Collingbourne // for X. 144df49d1bbSPeter Collingbourne // 145df49d1bbSPeter Collingbourne // Offset(A) 146df49d1bbSPeter Collingbourne // | | 147df49d1bbSPeter Collingbourne // |MinByte 148df49d1bbSPeter Collingbourne // A: ################AAAAAAAA|AAAAAAAA 149df49d1bbSPeter Collingbourne // B: ########BBBBBBBBBBBBBBBB|BBBB 150df49d1bbSPeter Collingbourne // C: ########################|CCCCCCCCCCCCCCCC 151df49d1bbSPeter Collingbourne // | Offset(B) | 152df49d1bbSPeter Collingbourne // 153df49d1bbSPeter Collingbourne // This code produces the slices of A, B and C that appear after the divider 154df49d1bbSPeter Collingbourne // at MinByte. 155df49d1bbSPeter Collingbourne std::vector<ArrayRef<uint8_t>> Used; 156df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : Targets) { 1577efd7506SPeter Collingbourne ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 1587efd7506SPeter Collingbourne : Target.TM->Bits->Before.BytesUsed; 159df49d1bbSPeter Collingbourne uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 160df49d1bbSPeter Collingbourne : MinByte - Target.minBeforeBytes(); 161df49d1bbSPeter Collingbourne 162df49d1bbSPeter Collingbourne // Disregard used regions that are smaller than Offset. These are 163df49d1bbSPeter Collingbourne // effectively all-free regions that do not need to be checked. 164df49d1bbSPeter Collingbourne if (VTUsed.size() > Offset) 165df49d1bbSPeter Collingbourne Used.push_back(VTUsed.slice(Offset)); 166df49d1bbSPeter Collingbourne } 167df49d1bbSPeter Collingbourne 168df49d1bbSPeter Collingbourne if (Size == 1) { 169df49d1bbSPeter Collingbourne // Find a free bit in each member of Used. 170df49d1bbSPeter Collingbourne for (unsigned I = 0;; ++I) { 171df49d1bbSPeter Collingbourne uint8_t BitsUsed = 0; 172df49d1bbSPeter Collingbourne for (auto &&B : Used) 173df49d1bbSPeter Collingbourne if (I < B.size()) 174df49d1bbSPeter Collingbourne BitsUsed |= B[I]; 175df49d1bbSPeter Collingbourne if (BitsUsed != 0xff) 176df49d1bbSPeter Collingbourne return (MinByte + I) * 8 + 177df49d1bbSPeter Collingbourne countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 178df49d1bbSPeter Collingbourne } 179df49d1bbSPeter Collingbourne } else { 180df49d1bbSPeter Collingbourne // Find a free (Size/8) byte region in each member of Used. 181df49d1bbSPeter Collingbourne // FIXME: see if alignment helps. 182df49d1bbSPeter Collingbourne for (unsigned I = 0;; ++I) { 183df49d1bbSPeter Collingbourne for (auto &&B : Used) { 184df49d1bbSPeter Collingbourne unsigned Byte = 0; 185df49d1bbSPeter Collingbourne while ((I + Byte) < B.size() && Byte < (Size / 8)) { 186df49d1bbSPeter Collingbourne if (B[I + Byte]) 187df49d1bbSPeter Collingbourne goto NextI; 188df49d1bbSPeter Collingbourne ++Byte; 189df49d1bbSPeter Collingbourne } 190df49d1bbSPeter Collingbourne } 191df49d1bbSPeter Collingbourne return (MinByte + I) * 8; 192df49d1bbSPeter Collingbourne NextI:; 193df49d1bbSPeter Collingbourne } 194df49d1bbSPeter Collingbourne } 195df49d1bbSPeter Collingbourne } 196df49d1bbSPeter Collingbourne 197df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues( 198df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 199df49d1bbSPeter Collingbourne unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 200df49d1bbSPeter Collingbourne if (BitWidth == 1) 201df49d1bbSPeter Collingbourne OffsetByte = -(AllocBefore / 8 + 1); 202df49d1bbSPeter Collingbourne else 203df49d1bbSPeter Collingbourne OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 204df49d1bbSPeter Collingbourne OffsetBit = AllocBefore % 8; 205df49d1bbSPeter Collingbourne 206df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : Targets) { 207df49d1bbSPeter Collingbourne if (BitWidth == 1) 208df49d1bbSPeter Collingbourne Target.setBeforeBit(AllocBefore); 209df49d1bbSPeter Collingbourne else 210df49d1bbSPeter Collingbourne Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 211df49d1bbSPeter Collingbourne } 212df49d1bbSPeter Collingbourne } 213df49d1bbSPeter Collingbourne 214df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues( 215df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 216df49d1bbSPeter Collingbourne unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 217df49d1bbSPeter Collingbourne if (BitWidth == 1) 218df49d1bbSPeter Collingbourne OffsetByte = AllocAfter / 8; 219df49d1bbSPeter Collingbourne else 220df49d1bbSPeter Collingbourne OffsetByte = (AllocAfter + 7) / 8; 221df49d1bbSPeter Collingbourne OffsetBit = AllocAfter % 8; 222df49d1bbSPeter Collingbourne 223df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : Targets) { 224df49d1bbSPeter Collingbourne if (BitWidth == 1) 225df49d1bbSPeter Collingbourne Target.setAfterBit(AllocAfter); 226df49d1bbSPeter Collingbourne else 227df49d1bbSPeter Collingbourne Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 228df49d1bbSPeter Collingbourne } 229df49d1bbSPeter Collingbourne } 230df49d1bbSPeter Collingbourne 2317efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) 2327efd7506SPeter Collingbourne : Fn(Fn), TM(TM), 23389439a79SIvan Krasin IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} 234df49d1bbSPeter Collingbourne 235df49d1bbSPeter Collingbourne namespace { 236df49d1bbSPeter Collingbourne 2377efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual 238df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to 239df49d1bbSPeter Collingbourne // the virtual function pointer. 240df49d1bbSPeter Collingbourne struct VTableSlot { 2417efd7506SPeter Collingbourne Metadata *TypeID; 242df49d1bbSPeter Collingbourne uint64_t ByteOffset; 243df49d1bbSPeter Collingbourne }; 244df49d1bbSPeter Collingbourne 245cdc71612SEugene Zelenko } // end anonymous namespace 246df49d1bbSPeter Collingbourne 2479b656527SPeter Collingbourne namespace llvm { 2489b656527SPeter Collingbourne 249df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> { 250df49d1bbSPeter Collingbourne static VTableSlot getEmptyKey() { 251df49d1bbSPeter Collingbourne return {DenseMapInfo<Metadata *>::getEmptyKey(), 252df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getEmptyKey()}; 253df49d1bbSPeter Collingbourne } 254df49d1bbSPeter Collingbourne static VTableSlot getTombstoneKey() { 255df49d1bbSPeter Collingbourne return {DenseMapInfo<Metadata *>::getTombstoneKey(), 256df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getTombstoneKey()}; 257df49d1bbSPeter Collingbourne } 258df49d1bbSPeter Collingbourne static unsigned getHashValue(const VTableSlot &I) { 2597efd7506SPeter Collingbourne return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 260df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 261df49d1bbSPeter Collingbourne } 262df49d1bbSPeter Collingbourne static bool isEqual(const VTableSlot &LHS, 263df49d1bbSPeter Collingbourne const VTableSlot &RHS) { 2647efd7506SPeter Collingbourne return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 265df49d1bbSPeter Collingbourne } 266df49d1bbSPeter Collingbourne }; 267df49d1bbSPeter Collingbourne 268cdc71612SEugene Zelenko } // end namespace llvm 2699b656527SPeter Collingbourne 270df49d1bbSPeter Collingbourne namespace { 271df49d1bbSPeter Collingbourne 272df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is 273df49d1bbSPeter Collingbourne // the indirect virtual call. 274df49d1bbSPeter Collingbourne struct VirtualCallSite { 275df49d1bbSPeter Collingbourne Value *VTable; 276df49d1bbSPeter Collingbourne CallSite CS; 277df49d1bbSPeter Collingbourne 2780312f614SPeter Collingbourne // If non-null, this field points to the associated unsafe use count stored in 2790312f614SPeter Collingbourne // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 2800312f614SPeter Collingbourne // of that field for details. 2810312f614SPeter Collingbourne unsigned *NumUnsafeUses; 2820312f614SPeter Collingbourne 283e963c89dSSam Elliott void 284e963c89dSSam Elliott emitRemark(const StringRef OptName, const StringRef TargetName, 285e963c89dSSam Elliott function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { 2865474645dSIvan Krasin Function *F = CS.getCaller(); 287e963c89dSSam Elliott DebugLoc DLoc = CS->getDebugLoc(); 288e963c89dSSam Elliott BasicBlock *Block = CS.getParent(); 289e963c89dSSam Elliott 290e963c89dSSam Elliott using namespace ore; 2919110cb45SPeter Collingbourne OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) 2929110cb45SPeter Collingbourne << NV("Optimization", OptName) 2939110cb45SPeter Collingbourne << ": devirtualized a call to " 294e963c89dSSam Elliott << NV("FunctionName", TargetName)); 295e963c89dSSam Elliott } 296e963c89dSSam Elliott 297e963c89dSSam Elliott void replaceAndErase( 298e963c89dSSam Elliott const StringRef OptName, const StringRef TargetName, bool RemarksEnabled, 299e963c89dSSam Elliott function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 300e963c89dSSam Elliott Value *New) { 301f3403fd2SIvan Krasin if (RemarksEnabled) 302e963c89dSSam Elliott emitRemark(OptName, TargetName, OREGetter); 303df49d1bbSPeter Collingbourne CS->replaceAllUsesWith(New); 304df49d1bbSPeter Collingbourne if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { 305df49d1bbSPeter Collingbourne BranchInst::Create(II->getNormalDest(), CS.getInstruction()); 306df49d1bbSPeter Collingbourne II->getUnwindDest()->removePredecessor(II->getParent()); 307df49d1bbSPeter Collingbourne } 308df49d1bbSPeter Collingbourne CS->eraseFromParent(); 3090312f614SPeter Collingbourne // This use is no longer unsafe. 3100312f614SPeter Collingbourne if (NumUnsafeUses) 3110312f614SPeter Collingbourne --*NumUnsafeUses; 312df49d1bbSPeter Collingbourne } 313df49d1bbSPeter Collingbourne }; 314df49d1bbSPeter Collingbourne 31550cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list 31650cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the 31750cbd7ccSPeter Collingbourne // VTableSlotInfo class. 31850cbd7ccSPeter Collingbourne struct CallSiteInfo { 319b406baaeSPeter Collingbourne /// The set of call sites for this slot. Used during regular LTO and the 320b406baaeSPeter Collingbourne /// import phase of ThinLTO (as well as the export phase of ThinLTO for any 321b406baaeSPeter Collingbourne /// call sites that appear in the merged module itself); in each of these 322b406baaeSPeter Collingbourne /// cases we are directly operating on the call sites at the IR level. 32350cbd7ccSPeter Collingbourne std::vector<VirtualCallSite> CallSites; 324b406baaeSPeter Collingbourne 3252974856aSPeter Collingbourne /// Whether all call sites represented by this CallSiteInfo, including those 3262974856aSPeter Collingbourne /// in summaries, have been devirtualized. This starts off as true because a 3272974856aSPeter Collingbourne /// default constructed CallSiteInfo represents no call sites. 3282974856aSPeter Collingbourne bool AllCallSitesDevirted = true; 3292974856aSPeter Collingbourne 330b406baaeSPeter Collingbourne // These fields are used during the export phase of ThinLTO and reflect 331b406baaeSPeter Collingbourne // information collected from function summaries. 332b406baaeSPeter Collingbourne 3332325bb34SPeter Collingbourne /// Whether any function summary contains an llvm.assume(llvm.type.test) for 3342325bb34SPeter Collingbourne /// this slot. 3352974856aSPeter Collingbourne bool SummaryHasTypeTestAssumeUsers = false; 3362325bb34SPeter Collingbourne 337b406baaeSPeter Collingbourne /// CFI-specific: a vector containing the list of function summaries that use 338b406baaeSPeter Collingbourne /// the llvm.type.checked.load intrinsic and therefore will require 339b406baaeSPeter Collingbourne /// resolutions for llvm.type.test in order to implement CFI checks if 340b406baaeSPeter Collingbourne /// devirtualization was unsuccessful. If devirtualization was successful, the 34159675ba0SPeter Collingbourne /// pass will clear this vector by calling markDevirt(). If at the end of the 34259675ba0SPeter Collingbourne /// pass the vector is non-empty, we will need to add a use of llvm.type.test 34359675ba0SPeter Collingbourne /// to each of the function summaries in the vector. 344b406baaeSPeter Collingbourne std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; 3452325bb34SPeter Collingbourne 3462325bb34SPeter Collingbourne bool isExported() const { 3472325bb34SPeter Collingbourne return SummaryHasTypeTestAssumeUsers || 3482325bb34SPeter Collingbourne !SummaryTypeCheckedLoadUsers.empty(); 3492325bb34SPeter Collingbourne } 35059675ba0SPeter Collingbourne 3512974856aSPeter Collingbourne void markSummaryHasTypeTestAssumeUsers() { 3522974856aSPeter Collingbourne SummaryHasTypeTestAssumeUsers = true; 3532974856aSPeter Collingbourne AllCallSitesDevirted = false; 3542974856aSPeter Collingbourne } 3552974856aSPeter Collingbourne 3562974856aSPeter Collingbourne void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { 3572974856aSPeter Collingbourne SummaryTypeCheckedLoadUsers.push_back(FS); 3582974856aSPeter Collingbourne AllCallSitesDevirted = false; 3592974856aSPeter Collingbourne } 3602974856aSPeter Collingbourne 3612974856aSPeter Collingbourne void markDevirt() { 3622974856aSPeter Collingbourne AllCallSitesDevirted = true; 3632974856aSPeter Collingbourne 3642974856aSPeter Collingbourne // As explained in the comment for SummaryTypeCheckedLoadUsers. 3652974856aSPeter Collingbourne SummaryTypeCheckedLoadUsers.clear(); 3662974856aSPeter Collingbourne } 36750cbd7ccSPeter Collingbourne }; 36850cbd7ccSPeter Collingbourne 36950cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot. 37050cbd7ccSPeter Collingbourne struct VTableSlotInfo { 37150cbd7ccSPeter Collingbourne // The set of call sites which do not have all constant integer arguments 37250cbd7ccSPeter Collingbourne // (excluding "this"). 37350cbd7ccSPeter Collingbourne CallSiteInfo CSInfo; 37450cbd7ccSPeter Collingbourne 37550cbd7ccSPeter Collingbourne // The set of call sites with all constant integer arguments (excluding 37650cbd7ccSPeter Collingbourne // "this"), grouped by argument list. 37750cbd7ccSPeter Collingbourne std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 37850cbd7ccSPeter Collingbourne 37950cbd7ccSPeter Collingbourne void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); 38050cbd7ccSPeter Collingbourne 38150cbd7ccSPeter Collingbourne private: 38250cbd7ccSPeter Collingbourne CallSiteInfo &findCallSiteInfo(CallSite CS); 38350cbd7ccSPeter Collingbourne }; 38450cbd7ccSPeter Collingbourne 38550cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { 38650cbd7ccSPeter Collingbourne std::vector<uint64_t> Args; 38750cbd7ccSPeter Collingbourne auto *CI = dyn_cast<IntegerType>(CS.getType()); 38850cbd7ccSPeter Collingbourne if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) 38950cbd7ccSPeter Collingbourne return CSInfo; 39050cbd7ccSPeter Collingbourne for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { 39150cbd7ccSPeter Collingbourne auto *CI = dyn_cast<ConstantInt>(Arg); 39250cbd7ccSPeter Collingbourne if (!CI || CI->getBitWidth() > 64) 39350cbd7ccSPeter Collingbourne return CSInfo; 39450cbd7ccSPeter Collingbourne Args.push_back(CI->getZExtValue()); 39550cbd7ccSPeter Collingbourne } 39650cbd7ccSPeter Collingbourne return ConstCSInfo[Args]; 39750cbd7ccSPeter Collingbourne } 39850cbd7ccSPeter Collingbourne 39950cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, 40050cbd7ccSPeter Collingbourne unsigned *NumUnsafeUses) { 4012974856aSPeter Collingbourne auto &CSI = findCallSiteInfo(CS); 4022974856aSPeter Collingbourne CSI.AllCallSitesDevirted = false; 4032974856aSPeter Collingbourne CSI.CallSites.push_back({VTable, CS, NumUnsafeUses}); 40450cbd7ccSPeter Collingbourne } 40550cbd7ccSPeter Collingbourne 406df49d1bbSPeter Collingbourne struct DevirtModule { 407df49d1bbSPeter Collingbourne Module &M; 40837317f12SPeter Collingbourne function_ref<AAResults &(Function &)> AARGetter; 4092b33f653SPeter Collingbourne 410f7691d8bSPeter Collingbourne ModuleSummaryIndex *ExportSummary; 411f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary; 4122b33f653SPeter Collingbourne 413df49d1bbSPeter Collingbourne IntegerType *Int8Ty; 414df49d1bbSPeter Collingbourne PointerType *Int8PtrTy; 415df49d1bbSPeter Collingbourne IntegerType *Int32Ty; 41650cbd7ccSPeter Collingbourne IntegerType *Int64Ty; 41714dcf02fSPeter Collingbourne IntegerType *IntPtrTy; 418df49d1bbSPeter Collingbourne 419f3403fd2SIvan Krasin bool RemarksEnabled; 420e963c89dSSam Elliott function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; 421f3403fd2SIvan Krasin 42250cbd7ccSPeter Collingbourne MapVector<VTableSlot, VTableSlotInfo> CallSlots; 423df49d1bbSPeter Collingbourne 4240312f614SPeter Collingbourne // This map keeps track of the number of "unsafe" uses of a loaded function 4250312f614SPeter Collingbourne // pointer. The key is the associated llvm.type.test intrinsic call generated 4260312f614SPeter Collingbourne // by this pass. An unsafe use is one that calls the loaded function pointer 4270312f614SPeter Collingbourne // directly. Every time we eliminate an unsafe use (for example, by 4280312f614SPeter Collingbourne // devirtualizing it or by applying virtual constant propagation), we 4290312f614SPeter Collingbourne // decrement the value stored in this map. If a value reaches zero, we can 4300312f614SPeter Collingbourne // eliminate the type check by RAUWing the associated llvm.type.test call with 4310312f614SPeter Collingbourne // true. 4320312f614SPeter Collingbourne std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 4330312f614SPeter Collingbourne 43437317f12SPeter Collingbourne DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 435e963c89dSSam Elliott function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 436f7691d8bSPeter Collingbourne ModuleSummaryIndex *ExportSummary, 437f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary) 438f7691d8bSPeter Collingbourne : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), 439f7691d8bSPeter Collingbourne ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), 440df49d1bbSPeter Collingbourne Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 441f3403fd2SIvan Krasin Int32Ty(Type::getInt32Ty(M.getContext())), 44250cbd7ccSPeter Collingbourne Int64Ty(Type::getInt64Ty(M.getContext())), 44314dcf02fSPeter Collingbourne IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), 444e963c89dSSam Elliott RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { 445f7691d8bSPeter Collingbourne assert(!(ExportSummary && ImportSummary)); 446f7691d8bSPeter Collingbourne } 447f3403fd2SIvan Krasin 448f3403fd2SIvan Krasin bool areRemarksEnabled(); 449df49d1bbSPeter Collingbourne 4500312f614SPeter Collingbourne void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); 4510312f614SPeter Collingbourne void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 4520312f614SPeter Collingbourne 4537efd7506SPeter Collingbourne void buildTypeIdentifierMap( 4547efd7506SPeter Collingbourne std::vector<VTableBits> &Bits, 4557efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 4568786754cSPeter Collingbourne Constant *getPointerAtOffset(Constant *I, uint64_t Offset); 4577efd7506SPeter Collingbourne bool 4587efd7506SPeter Collingbourne tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 4597efd7506SPeter Collingbourne const std::set<TypeMemberInfo> &TypeMemberInfos, 460df49d1bbSPeter Collingbourne uint64_t ByteOffset); 46150cbd7ccSPeter Collingbourne 4622325bb34SPeter Collingbourne void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, 4632325bb34SPeter Collingbourne bool &IsExported); 464f3403fd2SIvan Krasin bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 4652325bb34SPeter Collingbourne VTableSlotInfo &SlotInfo, 4662325bb34SPeter Collingbourne WholeProgramDevirtResolution *Res); 46750cbd7ccSPeter Collingbourne 4682974856aSPeter Collingbourne void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, 4692974856aSPeter Collingbourne bool &IsExported); 4702974856aSPeter Collingbourne void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 4712974856aSPeter Collingbourne VTableSlotInfo &SlotInfo, 4722974856aSPeter Collingbourne WholeProgramDevirtResolution *Res, VTableSlot Slot); 4732974856aSPeter Collingbourne 474df49d1bbSPeter Collingbourne bool tryEvaluateFunctionsWithArgs( 475df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 47650cbd7ccSPeter Collingbourne ArrayRef<uint64_t> Args); 47750cbd7ccSPeter Collingbourne 47850cbd7ccSPeter Collingbourne void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 47950cbd7ccSPeter Collingbourne uint64_t TheRetVal); 48050cbd7ccSPeter Collingbourne bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 48177a8d563SPeter Collingbourne CallSiteInfo &CSInfo, 48277a8d563SPeter Collingbourne WholeProgramDevirtResolution::ByArg *Res); 48350cbd7ccSPeter Collingbourne 48459675ba0SPeter Collingbourne // Returns the global symbol name that is used to export information about the 48559675ba0SPeter Collingbourne // given vtable slot and list of arguments. 48659675ba0SPeter Collingbourne std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, 48759675ba0SPeter Collingbourne StringRef Name); 48859675ba0SPeter Collingbourne 489b15a35e6SPeter Collingbourne bool shouldExportConstantsAsAbsoluteSymbols(); 490b15a35e6SPeter Collingbourne 49159675ba0SPeter Collingbourne // This function is called during the export phase to create a symbol 49259675ba0SPeter Collingbourne // definition containing information about the given vtable slot and list of 49359675ba0SPeter Collingbourne // arguments. 49459675ba0SPeter Collingbourne void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 49559675ba0SPeter Collingbourne Constant *C); 496b15a35e6SPeter Collingbourne void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 497b15a35e6SPeter Collingbourne uint32_t Const, uint32_t &Storage); 49859675ba0SPeter Collingbourne 49959675ba0SPeter Collingbourne // This function is called during the import phase to create a reference to 50059675ba0SPeter Collingbourne // the symbol definition created during the export phase. 50159675ba0SPeter Collingbourne Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 502b15a35e6SPeter Collingbourne StringRef Name); 503b15a35e6SPeter Collingbourne Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 504b15a35e6SPeter Collingbourne StringRef Name, IntegerType *IntTy, 505b15a35e6SPeter Collingbourne uint32_t Storage); 50659675ba0SPeter Collingbourne 5072974856aSPeter Collingbourne Constant *getMemberAddr(const TypeMemberInfo *M); 5082974856aSPeter Collingbourne 50950cbd7ccSPeter Collingbourne void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 51050cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr); 511df49d1bbSPeter Collingbourne bool tryUniqueRetValOpt(unsigned BitWidth, 512f3403fd2SIvan Krasin MutableArrayRef<VirtualCallTarget> TargetsForSlot, 51359675ba0SPeter Collingbourne CallSiteInfo &CSInfo, 51459675ba0SPeter Collingbourne WholeProgramDevirtResolution::ByArg *Res, 51559675ba0SPeter Collingbourne VTableSlot Slot, ArrayRef<uint64_t> Args); 51650cbd7ccSPeter Collingbourne 51750cbd7ccSPeter Collingbourne void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 51850cbd7ccSPeter Collingbourne Constant *Byte, Constant *Bit); 519df49d1bbSPeter Collingbourne bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 52077a8d563SPeter Collingbourne VTableSlotInfo &SlotInfo, 52159675ba0SPeter Collingbourne WholeProgramDevirtResolution *Res, VTableSlot Slot); 522df49d1bbSPeter Collingbourne 523df49d1bbSPeter Collingbourne void rebuildGlobal(VTableBits &B); 524df49d1bbSPeter Collingbourne 5256d284fabSPeter Collingbourne // Apply the summary resolution for Slot to all virtual calls in SlotInfo. 5266d284fabSPeter Collingbourne void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); 5276d284fabSPeter Collingbourne 5286d284fabSPeter Collingbourne // If we were able to eliminate all unsafe uses for a type checked load, 5296d284fabSPeter Collingbourne // eliminate the associated type tests by replacing them with true. 5306d284fabSPeter Collingbourne void removeRedundantTypeTests(); 5316d284fabSPeter Collingbourne 532df49d1bbSPeter Collingbourne bool run(); 5332b33f653SPeter Collingbourne 5342b33f653SPeter Collingbourne // Lower the module using the action and summary passed as command line 5352b33f653SPeter Collingbourne // arguments. For testing purposes only. 536e963c89dSSam Elliott static bool runForTesting( 537e963c89dSSam Elliott Module &M, function_ref<AAResults &(Function &)> AARGetter, 538e963c89dSSam Elliott function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter); 539df49d1bbSPeter Collingbourne }; 540df49d1bbSPeter Collingbourne 541df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass { 542df49d1bbSPeter Collingbourne static char ID; 543cdc71612SEugene Zelenko 5442b33f653SPeter Collingbourne bool UseCommandLine = false; 5452b33f653SPeter Collingbourne 546f7691d8bSPeter Collingbourne ModuleSummaryIndex *ExportSummary; 547f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary; 5482b33f653SPeter Collingbourne 5492b33f653SPeter Collingbourne WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { 5502b33f653SPeter Collingbourne initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 5512b33f653SPeter Collingbourne } 5522b33f653SPeter Collingbourne 553f7691d8bSPeter Collingbourne WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, 554f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary) 555f7691d8bSPeter Collingbourne : ModulePass(ID), ExportSummary(ExportSummary), 556f7691d8bSPeter Collingbourne ImportSummary(ImportSummary) { 557df49d1bbSPeter Collingbourne initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 558df49d1bbSPeter Collingbourne } 559cdc71612SEugene Zelenko 560cdc71612SEugene Zelenko bool runOnModule(Module &M) override { 561aa641a51SAndrew Kaylor if (skipModule(M)) 562aa641a51SAndrew Kaylor return false; 563e963c89dSSam Elliott 5649110cb45SPeter Collingbourne // In the new pass manager, we can request the optimization 5659110cb45SPeter Collingbourne // remark emitter pass on a per-function-basis, which the 5669110cb45SPeter Collingbourne // OREGetter will do for us. 5679110cb45SPeter Collingbourne // In the old pass manager, this is harder, so we just build 5689110cb45SPeter Collingbourne // an optimization remark emitter on the fly, when we need it. 5699110cb45SPeter Collingbourne std::unique_ptr<OptimizationRemarkEmitter> ORE; 5709110cb45SPeter Collingbourne auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 5719110cb45SPeter Collingbourne ORE = make_unique<OptimizationRemarkEmitter>(F); 5729110cb45SPeter Collingbourne return *ORE; 5739110cb45SPeter Collingbourne }; 574e963c89dSSam Elliott 5752b33f653SPeter Collingbourne if (UseCommandLine) 576e963c89dSSam Elliott return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter); 577e963c89dSSam Elliott 578e963c89dSSam Elliott return DevirtModule(M, LegacyAARGetter(*this), OREGetter, ExportSummary, 579e963c89dSSam Elliott ImportSummary) 580f7691d8bSPeter Collingbourne .run(); 58137317f12SPeter Collingbourne } 58237317f12SPeter Collingbourne 58337317f12SPeter Collingbourne void getAnalysisUsage(AnalysisUsage &AU) const override { 58437317f12SPeter Collingbourne AU.addRequired<AssumptionCacheTracker>(); 58537317f12SPeter Collingbourne AU.addRequired<TargetLibraryInfoWrapperPass>(); 586aa641a51SAndrew Kaylor } 587df49d1bbSPeter Collingbourne }; 588df49d1bbSPeter Collingbourne 589cdc71612SEugene Zelenko } // end anonymous namespace 590df49d1bbSPeter Collingbourne 59137317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", 59237317f12SPeter Collingbourne "Whole program devirtualization", false, false) 59337317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 59437317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 59537317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", 596df49d1bbSPeter Collingbourne "Whole program devirtualization", false, false) 597df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0; 598df49d1bbSPeter Collingbourne 599f7691d8bSPeter Collingbourne ModulePass * 600f7691d8bSPeter Collingbourne llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, 601f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary) { 602f7691d8bSPeter Collingbourne return new WholeProgramDevirt(ExportSummary, ImportSummary); 603df49d1bbSPeter Collingbourne } 604df49d1bbSPeter Collingbourne 605164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 60637317f12SPeter Collingbourne ModuleAnalysisManager &AM) { 60737317f12SPeter Collingbourne auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 60837317f12SPeter Collingbourne auto AARGetter = [&](Function &F) -> AAResults & { 60937317f12SPeter Collingbourne return FAM.getResult<AAManager>(F); 61037317f12SPeter Collingbourne }; 611e963c89dSSam Elliott auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 612e963c89dSSam Elliott return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 613e963c89dSSam Elliott }; 614e963c89dSSam Elliott if (!DevirtModule(M, AARGetter, OREGetter, nullptr, nullptr).run()) 615d737dd2eSDavide Italiano return PreservedAnalyses::all(); 616d737dd2eSDavide Italiano return PreservedAnalyses::none(); 617d737dd2eSDavide Italiano } 618d737dd2eSDavide Italiano 61937317f12SPeter Collingbourne bool DevirtModule::runForTesting( 620e963c89dSSam Elliott Module &M, function_ref<AAResults &(Function &)> AARGetter, 621e963c89dSSam Elliott function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { 622*4ffc3e78STeresa Johnson ModuleSummaryIndex Summary(/*HaveGVs=*/false); 6232b33f653SPeter Collingbourne 6242b33f653SPeter Collingbourne // Handle the command-line summary arguments. This code is for testing 6252b33f653SPeter Collingbourne // purposes only, so we handle errors directly. 6262b33f653SPeter Collingbourne if (!ClReadSummary.empty()) { 6272b33f653SPeter Collingbourne ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 6282b33f653SPeter Collingbourne ": "); 6292b33f653SPeter Collingbourne auto ReadSummaryFile = 6302b33f653SPeter Collingbourne ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 6312b33f653SPeter Collingbourne 6322b33f653SPeter Collingbourne yaml::Input In(ReadSummaryFile->getBuffer()); 6332b33f653SPeter Collingbourne In >> Summary; 6342b33f653SPeter Collingbourne ExitOnErr(errorCodeToError(In.error())); 6352b33f653SPeter Collingbourne } 6362b33f653SPeter Collingbourne 637f7691d8bSPeter Collingbourne bool Changed = 638f7691d8bSPeter Collingbourne DevirtModule( 639e963c89dSSam Elliott M, AARGetter, OREGetter, 640f7691d8bSPeter Collingbourne ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, 641f7691d8bSPeter Collingbourne ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) 642f7691d8bSPeter Collingbourne .run(); 6432b33f653SPeter Collingbourne 6442b33f653SPeter Collingbourne if (!ClWriteSummary.empty()) { 6452b33f653SPeter Collingbourne ExitOnError ExitOnErr( 6462b33f653SPeter Collingbourne "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 6472b33f653SPeter Collingbourne std::error_code EC; 6482b33f653SPeter Collingbourne raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); 6492b33f653SPeter Collingbourne ExitOnErr(errorCodeToError(EC)); 6502b33f653SPeter Collingbourne 6512b33f653SPeter Collingbourne yaml::Output Out(OS); 6522b33f653SPeter Collingbourne Out << Summary; 6532b33f653SPeter Collingbourne } 6542b33f653SPeter Collingbourne 6552b33f653SPeter Collingbourne return Changed; 6562b33f653SPeter Collingbourne } 6572b33f653SPeter Collingbourne 6587efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap( 659df49d1bbSPeter Collingbourne std::vector<VTableBits> &Bits, 6607efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 661df49d1bbSPeter Collingbourne DenseMap<GlobalVariable *, VTableBits *> GVToBits; 6627efd7506SPeter Collingbourne Bits.reserve(M.getGlobalList().size()); 6637efd7506SPeter Collingbourne SmallVector<MDNode *, 2> Types; 6647efd7506SPeter Collingbourne for (GlobalVariable &GV : M.globals()) { 6657efd7506SPeter Collingbourne Types.clear(); 6667efd7506SPeter Collingbourne GV.getMetadata(LLVMContext::MD_type, Types); 6677efd7506SPeter Collingbourne if (Types.empty()) 668df49d1bbSPeter Collingbourne continue; 669df49d1bbSPeter Collingbourne 6707efd7506SPeter Collingbourne VTableBits *&BitsPtr = GVToBits[&GV]; 6717efd7506SPeter Collingbourne if (!BitsPtr) { 6727efd7506SPeter Collingbourne Bits.emplace_back(); 6737efd7506SPeter Collingbourne Bits.back().GV = &GV; 6747efd7506SPeter Collingbourne Bits.back().ObjectSize = 6757efd7506SPeter Collingbourne M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 6767efd7506SPeter Collingbourne BitsPtr = &Bits.back(); 6777efd7506SPeter Collingbourne } 6787efd7506SPeter Collingbourne 6797efd7506SPeter Collingbourne for (MDNode *Type : Types) { 6807efd7506SPeter Collingbourne auto TypeID = Type->getOperand(1).get(); 681df49d1bbSPeter Collingbourne 682df49d1bbSPeter Collingbourne uint64_t Offset = 683df49d1bbSPeter Collingbourne cast<ConstantInt>( 6847efd7506SPeter Collingbourne cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 685df49d1bbSPeter Collingbourne ->getZExtValue(); 686df49d1bbSPeter Collingbourne 6877efd7506SPeter Collingbourne TypeIdMap[TypeID].insert({BitsPtr, Offset}); 688df49d1bbSPeter Collingbourne } 689df49d1bbSPeter Collingbourne } 690df49d1bbSPeter Collingbourne } 691df49d1bbSPeter Collingbourne 6928786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) { 6938786754cSPeter Collingbourne if (I->getType()->isPointerTy()) { 6948786754cSPeter Collingbourne if (Offset == 0) 6958786754cSPeter Collingbourne return I; 6968786754cSPeter Collingbourne return nullptr; 6978786754cSPeter Collingbourne } 6988786754cSPeter Collingbourne 6997a1e5bbeSPeter Collingbourne const DataLayout &DL = M.getDataLayout(); 7007a1e5bbeSPeter Collingbourne 7017a1e5bbeSPeter Collingbourne if (auto *C = dyn_cast<ConstantStruct>(I)) { 7027a1e5bbeSPeter Collingbourne const StructLayout *SL = DL.getStructLayout(C->getType()); 7037a1e5bbeSPeter Collingbourne if (Offset >= SL->getSizeInBytes()) 7047a1e5bbeSPeter Collingbourne return nullptr; 7057a1e5bbeSPeter Collingbourne 7068786754cSPeter Collingbourne unsigned Op = SL->getElementContainingOffset(Offset); 7078786754cSPeter Collingbourne return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), 7088786754cSPeter Collingbourne Offset - SL->getElementOffset(Op)); 7098786754cSPeter Collingbourne } 7108786754cSPeter Collingbourne if (auto *C = dyn_cast<ConstantArray>(I)) { 7117a1e5bbeSPeter Collingbourne ArrayType *VTableTy = C->getType(); 7127a1e5bbeSPeter Collingbourne uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); 7137a1e5bbeSPeter Collingbourne 7148786754cSPeter Collingbourne unsigned Op = Offset / ElemSize; 7157a1e5bbeSPeter Collingbourne if (Op >= C->getNumOperands()) 7167a1e5bbeSPeter Collingbourne return nullptr; 7177a1e5bbeSPeter Collingbourne 7188786754cSPeter Collingbourne return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), 7198786754cSPeter Collingbourne Offset % ElemSize); 7208786754cSPeter Collingbourne } 7218786754cSPeter Collingbourne return nullptr; 7227a1e5bbeSPeter Collingbourne } 7237a1e5bbeSPeter Collingbourne 724df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets( 725df49d1bbSPeter Collingbourne std::vector<VirtualCallTarget> &TargetsForSlot, 7267efd7506SPeter Collingbourne const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { 7277efd7506SPeter Collingbourne for (const TypeMemberInfo &TM : TypeMemberInfos) { 7287efd7506SPeter Collingbourne if (!TM.Bits->GV->isConstant()) 729df49d1bbSPeter Collingbourne return false; 730df49d1bbSPeter Collingbourne 7318786754cSPeter Collingbourne Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), 7328786754cSPeter Collingbourne TM.Offset + ByteOffset); 7338786754cSPeter Collingbourne if (!Ptr) 734df49d1bbSPeter Collingbourne return false; 735df49d1bbSPeter Collingbourne 7368786754cSPeter Collingbourne auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); 737df49d1bbSPeter Collingbourne if (!Fn) 738df49d1bbSPeter Collingbourne return false; 739df49d1bbSPeter Collingbourne 740df49d1bbSPeter Collingbourne // We can disregard __cxa_pure_virtual as a possible call target, as 741df49d1bbSPeter Collingbourne // calls to pure virtuals are UB. 742df49d1bbSPeter Collingbourne if (Fn->getName() == "__cxa_pure_virtual") 743df49d1bbSPeter Collingbourne continue; 744df49d1bbSPeter Collingbourne 7457efd7506SPeter Collingbourne TargetsForSlot.push_back({Fn, &TM}); 746df49d1bbSPeter Collingbourne } 747df49d1bbSPeter Collingbourne 748df49d1bbSPeter Collingbourne // Give up if we couldn't find any targets. 749df49d1bbSPeter Collingbourne return !TargetsForSlot.empty(); 750df49d1bbSPeter Collingbourne } 751df49d1bbSPeter Collingbourne 75250cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 7532325bb34SPeter Collingbourne Constant *TheFn, bool &IsExported) { 75450cbd7ccSPeter Collingbourne auto Apply = [&](CallSiteInfo &CSInfo) { 75550cbd7ccSPeter Collingbourne for (auto &&VCallSite : CSInfo.CallSites) { 756f3403fd2SIvan Krasin if (RemarksEnabled) 757e963c89dSSam Elliott VCallSite.emitRemark("single-impl", TheFn->getName(), OREGetter); 758df49d1bbSPeter Collingbourne VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( 759df49d1bbSPeter Collingbourne TheFn, VCallSite.CS.getCalledValue()->getType())); 7600312f614SPeter Collingbourne // This use is no longer unsafe. 7610312f614SPeter Collingbourne if (VCallSite.NumUnsafeUses) 7620312f614SPeter Collingbourne --*VCallSite.NumUnsafeUses; 763df49d1bbSPeter Collingbourne } 7642974856aSPeter Collingbourne if (CSInfo.isExported()) 7652325bb34SPeter Collingbourne IsExported = true; 76659675ba0SPeter Collingbourne CSInfo.markDevirt(); 76750cbd7ccSPeter Collingbourne }; 76850cbd7ccSPeter Collingbourne Apply(SlotInfo.CSInfo); 76950cbd7ccSPeter Collingbourne for (auto &P : SlotInfo.ConstCSInfo) 77050cbd7ccSPeter Collingbourne Apply(P.second); 77150cbd7ccSPeter Collingbourne } 77250cbd7ccSPeter Collingbourne 77350cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt( 77450cbd7ccSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 7752325bb34SPeter Collingbourne VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) { 77650cbd7ccSPeter Collingbourne // See if the program contains a single implementation of this virtual 77750cbd7ccSPeter Collingbourne // function. 77850cbd7ccSPeter Collingbourne Function *TheFn = TargetsForSlot[0].Fn; 77950cbd7ccSPeter Collingbourne for (auto &&Target : TargetsForSlot) 78050cbd7ccSPeter Collingbourne if (TheFn != Target.Fn) 78150cbd7ccSPeter Collingbourne return false; 78250cbd7ccSPeter Collingbourne 78350cbd7ccSPeter Collingbourne // If so, update each call site to call that implementation directly. 78450cbd7ccSPeter Collingbourne if (RemarksEnabled) 78550cbd7ccSPeter Collingbourne TargetsForSlot[0].WasDevirt = true; 7862325bb34SPeter Collingbourne 7872325bb34SPeter Collingbourne bool IsExported = false; 7882325bb34SPeter Collingbourne applySingleImplDevirt(SlotInfo, TheFn, IsExported); 7892325bb34SPeter Collingbourne if (!IsExported) 7902325bb34SPeter Collingbourne return false; 7912325bb34SPeter Collingbourne 7922325bb34SPeter Collingbourne // If the only implementation has local linkage, we must promote to external 7932325bb34SPeter Collingbourne // to make it visible to thin LTO objects. We can only get here during the 7942325bb34SPeter Collingbourne // ThinLTO export phase. 7952325bb34SPeter Collingbourne if (TheFn->hasLocalLinkage()) { 79688a58cf9SPeter Collingbourne std::string NewName = (TheFn->getName() + "$merged").str(); 79788a58cf9SPeter Collingbourne 79888a58cf9SPeter Collingbourne // Since we are renaming the function, any comdats with the same name must 79988a58cf9SPeter Collingbourne // also be renamed. This is required when targeting COFF, as the comdat name 80088a58cf9SPeter Collingbourne // must match one of the names of the symbols in the comdat. 80188a58cf9SPeter Collingbourne if (Comdat *C = TheFn->getComdat()) { 80288a58cf9SPeter Collingbourne if (C->getName() == TheFn->getName()) { 80388a58cf9SPeter Collingbourne Comdat *NewC = M.getOrInsertComdat(NewName); 80488a58cf9SPeter Collingbourne NewC->setSelectionKind(C->getSelectionKind()); 80588a58cf9SPeter Collingbourne for (GlobalObject &GO : M.global_objects()) 80688a58cf9SPeter Collingbourne if (GO.getComdat() == C) 80788a58cf9SPeter Collingbourne GO.setComdat(NewC); 80888a58cf9SPeter Collingbourne } 80988a58cf9SPeter Collingbourne } 81088a58cf9SPeter Collingbourne 8112325bb34SPeter Collingbourne TheFn->setLinkage(GlobalValue::ExternalLinkage); 8122325bb34SPeter Collingbourne TheFn->setVisibility(GlobalValue::HiddenVisibility); 81388a58cf9SPeter Collingbourne TheFn->setName(NewName); 8142325bb34SPeter Collingbourne } 8152325bb34SPeter Collingbourne 8162325bb34SPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 8172325bb34SPeter Collingbourne Res->SingleImplName = TheFn->getName(); 8182325bb34SPeter Collingbourne 819df49d1bbSPeter Collingbourne return true; 820df49d1bbSPeter Collingbourne } 821df49d1bbSPeter Collingbourne 8222974856aSPeter Collingbourne void DevirtModule::tryICallBranchFunnel( 8232974856aSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 8242974856aSPeter Collingbourne WholeProgramDevirtResolution *Res, VTableSlot Slot) { 8252974856aSPeter Collingbourne Triple T(M.getTargetTriple()); 8262974856aSPeter Collingbourne if (T.getArch() != Triple::x86_64) 8272974856aSPeter Collingbourne return; 8282974856aSPeter Collingbourne 82966f53d71SVitaly Buka if (TargetsForSlot.size() > ClThreshold) 8302974856aSPeter Collingbourne return; 8312974856aSPeter Collingbourne 8322974856aSPeter Collingbourne bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; 8332974856aSPeter Collingbourne if (!HasNonDevirt) 8342974856aSPeter Collingbourne for (auto &P : SlotInfo.ConstCSInfo) 8352974856aSPeter Collingbourne if (!P.second.AllCallSitesDevirted) { 8362974856aSPeter Collingbourne HasNonDevirt = true; 8372974856aSPeter Collingbourne break; 8382974856aSPeter Collingbourne } 8392974856aSPeter Collingbourne 8402974856aSPeter Collingbourne if (!HasNonDevirt) 8412974856aSPeter Collingbourne return; 8422974856aSPeter Collingbourne 8432974856aSPeter Collingbourne FunctionType *FT = 8442974856aSPeter Collingbourne FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true); 8452974856aSPeter Collingbourne Function *JT; 8462974856aSPeter Collingbourne if (isa<MDString>(Slot.TypeID)) { 8472974856aSPeter Collingbourne JT = Function::Create(FT, Function::ExternalLinkage, 8482974856aSPeter Collingbourne getGlobalName(Slot, {}, "branch_funnel"), &M); 8492974856aSPeter Collingbourne JT->setVisibility(GlobalValue::HiddenVisibility); 8502974856aSPeter Collingbourne } else { 8512974856aSPeter Collingbourne JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M); 8522974856aSPeter Collingbourne } 8532974856aSPeter Collingbourne JT->addAttribute(1, Attribute::Nest); 8542974856aSPeter Collingbourne 8552974856aSPeter Collingbourne std::vector<Value *> JTArgs; 8562974856aSPeter Collingbourne JTArgs.push_back(JT->arg_begin()); 8572974856aSPeter Collingbourne for (auto &T : TargetsForSlot) { 8582974856aSPeter Collingbourne JTArgs.push_back(getMemberAddr(T.TM)); 8592974856aSPeter Collingbourne JTArgs.push_back(T.Fn); 8602974856aSPeter Collingbourne } 8612974856aSPeter Collingbourne 8622974856aSPeter Collingbourne BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); 8632974856aSPeter Collingbourne Constant *Intr = 8642974856aSPeter Collingbourne Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {}); 8652974856aSPeter Collingbourne 8662974856aSPeter Collingbourne auto *CI = CallInst::Create(Intr, JTArgs, "", BB); 8672974856aSPeter Collingbourne CI->setTailCallKind(CallInst::TCK_MustTail); 8682974856aSPeter Collingbourne ReturnInst::Create(M.getContext(), nullptr, BB); 8692974856aSPeter Collingbourne 8702974856aSPeter Collingbourne bool IsExported = false; 8712974856aSPeter Collingbourne applyICallBranchFunnel(SlotInfo, JT, IsExported); 8722974856aSPeter Collingbourne if (IsExported) 8732974856aSPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::BranchFunnel; 8742974856aSPeter Collingbourne } 8752974856aSPeter Collingbourne 8762974856aSPeter Collingbourne void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, 8772974856aSPeter Collingbourne Constant *JT, bool &IsExported) { 8782974856aSPeter Collingbourne auto Apply = [&](CallSiteInfo &CSInfo) { 8792974856aSPeter Collingbourne if (CSInfo.isExported()) 8802974856aSPeter Collingbourne IsExported = true; 8812974856aSPeter Collingbourne if (CSInfo.AllCallSitesDevirted) 8822974856aSPeter Collingbourne return; 8832974856aSPeter Collingbourne for (auto &&VCallSite : CSInfo.CallSites) { 8842974856aSPeter Collingbourne CallSite CS = VCallSite.CS; 8852974856aSPeter Collingbourne 8862974856aSPeter Collingbourne // Jump tables are only profitable if the retpoline mitigation is enabled. 8872974856aSPeter Collingbourne Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features"); 8882974856aSPeter Collingbourne if (FSAttr.hasAttribute(Attribute::None) || 8892974856aSPeter Collingbourne !FSAttr.getValueAsString().contains("+retpoline")) 8902974856aSPeter Collingbourne continue; 8912974856aSPeter Collingbourne 8922974856aSPeter Collingbourne if (RemarksEnabled) 8932974856aSPeter Collingbourne VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter); 8942974856aSPeter Collingbourne 8952974856aSPeter Collingbourne // Pass the address of the vtable in the nest register, which is r10 on 8962974856aSPeter Collingbourne // x86_64. 8972974856aSPeter Collingbourne std::vector<Type *> NewArgs; 8982974856aSPeter Collingbourne NewArgs.push_back(Int8PtrTy); 8992974856aSPeter Collingbourne for (Type *T : CS.getFunctionType()->params()) 9002974856aSPeter Collingbourne NewArgs.push_back(T); 9012974856aSPeter Collingbourne PointerType *NewFT = PointerType::getUnqual( 9022974856aSPeter Collingbourne FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs, 9032974856aSPeter Collingbourne CS.getFunctionType()->isVarArg())); 9042974856aSPeter Collingbourne 9052974856aSPeter Collingbourne IRBuilder<> IRB(CS.getInstruction()); 9062974856aSPeter Collingbourne std::vector<Value *> Args; 9072974856aSPeter Collingbourne Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); 9082974856aSPeter Collingbourne for (unsigned I = 0; I != CS.getNumArgOperands(); ++I) 9092974856aSPeter Collingbourne Args.push_back(CS.getArgOperand(I)); 9102974856aSPeter Collingbourne 9112974856aSPeter Collingbourne CallSite NewCS; 9122974856aSPeter Collingbourne if (CS.isCall()) 9132974856aSPeter Collingbourne NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args); 9142974856aSPeter Collingbourne else 9152974856aSPeter Collingbourne NewCS = IRB.CreateInvoke( 9162974856aSPeter Collingbourne IRB.CreateBitCast(JT, NewFT), 9172974856aSPeter Collingbourne cast<InvokeInst>(CS.getInstruction())->getNormalDest(), 9182974856aSPeter Collingbourne cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args); 9192974856aSPeter Collingbourne NewCS.setCallingConv(CS.getCallingConv()); 9202974856aSPeter Collingbourne 9212974856aSPeter Collingbourne AttributeList Attrs = CS.getAttributes(); 9222974856aSPeter Collingbourne std::vector<AttributeSet> NewArgAttrs; 9232974856aSPeter Collingbourne NewArgAttrs.push_back(AttributeSet::get( 9242974856aSPeter Collingbourne M.getContext(), ArrayRef<Attribute>{Attribute::get( 9252974856aSPeter Collingbourne M.getContext(), Attribute::Nest)})); 9262974856aSPeter Collingbourne for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) 9272974856aSPeter Collingbourne NewArgAttrs.push_back(Attrs.getParamAttributes(I)); 9282974856aSPeter Collingbourne NewCS.setAttributes( 9292974856aSPeter Collingbourne AttributeList::get(M.getContext(), Attrs.getFnAttributes(), 9302974856aSPeter Collingbourne Attrs.getRetAttributes(), NewArgAttrs)); 9312974856aSPeter Collingbourne 9322974856aSPeter Collingbourne CS->replaceAllUsesWith(NewCS.getInstruction()); 9332974856aSPeter Collingbourne CS->eraseFromParent(); 9342974856aSPeter Collingbourne 9352974856aSPeter Collingbourne // This use is no longer unsafe. 9362974856aSPeter Collingbourne if (VCallSite.NumUnsafeUses) 9372974856aSPeter Collingbourne --*VCallSite.NumUnsafeUses; 9382974856aSPeter Collingbourne } 9392974856aSPeter Collingbourne // Don't mark as devirtualized because there may be callers compiled without 9402974856aSPeter Collingbourne // retpoline mitigation, which would mean that they are lowered to 9412974856aSPeter Collingbourne // llvm.type.test and therefore require an llvm.type.test resolution for the 9422974856aSPeter Collingbourne // type identifier. 9432974856aSPeter Collingbourne }; 9442974856aSPeter Collingbourne Apply(SlotInfo.CSInfo); 9452974856aSPeter Collingbourne for (auto &P : SlotInfo.ConstCSInfo) 9462974856aSPeter Collingbourne Apply(P.second); 9472974856aSPeter Collingbourne } 9482974856aSPeter Collingbourne 949df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs( 950df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 95150cbd7ccSPeter Collingbourne ArrayRef<uint64_t> Args) { 952df49d1bbSPeter Collingbourne // Evaluate each function and store the result in each target's RetVal 953df49d1bbSPeter Collingbourne // field. 954df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : TargetsForSlot) { 955df49d1bbSPeter Collingbourne if (Target.Fn->arg_size() != Args.size() + 1) 956df49d1bbSPeter Collingbourne return false; 957df49d1bbSPeter Collingbourne 958df49d1bbSPeter Collingbourne Evaluator Eval(M.getDataLayout(), nullptr); 959df49d1bbSPeter Collingbourne SmallVector<Constant *, 2> EvalArgs; 960df49d1bbSPeter Collingbourne EvalArgs.push_back( 961df49d1bbSPeter Collingbourne Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 96250cbd7ccSPeter Collingbourne for (unsigned I = 0; I != Args.size(); ++I) { 96350cbd7ccSPeter Collingbourne auto *ArgTy = dyn_cast<IntegerType>( 96450cbd7ccSPeter Collingbourne Target.Fn->getFunctionType()->getParamType(I + 1)); 96550cbd7ccSPeter Collingbourne if (!ArgTy) 96650cbd7ccSPeter Collingbourne return false; 96750cbd7ccSPeter Collingbourne EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 96850cbd7ccSPeter Collingbourne } 96950cbd7ccSPeter Collingbourne 970df49d1bbSPeter Collingbourne Constant *RetVal; 971df49d1bbSPeter Collingbourne if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 972df49d1bbSPeter Collingbourne !isa<ConstantInt>(RetVal)) 973df49d1bbSPeter Collingbourne return false; 974df49d1bbSPeter Collingbourne Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 975df49d1bbSPeter Collingbourne } 976df49d1bbSPeter Collingbourne return true; 977df49d1bbSPeter Collingbourne } 978df49d1bbSPeter Collingbourne 97950cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 98050cbd7ccSPeter Collingbourne uint64_t TheRetVal) { 98150cbd7ccSPeter Collingbourne for (auto Call : CSInfo.CallSites) 98250cbd7ccSPeter Collingbourne Call.replaceAndErase( 983e963c89dSSam Elliott "uniform-ret-val", FnName, RemarksEnabled, OREGetter, 98450cbd7ccSPeter Collingbourne ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); 98559675ba0SPeter Collingbourne CSInfo.markDevirt(); 98650cbd7ccSPeter Collingbourne } 98750cbd7ccSPeter Collingbourne 988df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt( 98977a8d563SPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, 99077a8d563SPeter Collingbourne WholeProgramDevirtResolution::ByArg *Res) { 991df49d1bbSPeter Collingbourne // Uniform return value optimization. If all functions return the same 992df49d1bbSPeter Collingbourne // constant, replace all calls with that constant. 993df49d1bbSPeter Collingbourne uint64_t TheRetVal = TargetsForSlot[0].RetVal; 994df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : TargetsForSlot) 995df49d1bbSPeter Collingbourne if (Target.RetVal != TheRetVal) 996df49d1bbSPeter Collingbourne return false; 997df49d1bbSPeter Collingbourne 99877a8d563SPeter Collingbourne if (CSInfo.isExported()) { 99977a8d563SPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; 100077a8d563SPeter Collingbourne Res->Info = TheRetVal; 100177a8d563SPeter Collingbourne } 100277a8d563SPeter Collingbourne 100350cbd7ccSPeter Collingbourne applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 1004f3403fd2SIvan Krasin if (RemarksEnabled) 1005f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 1006f3403fd2SIvan Krasin Target.WasDevirt = true; 1007df49d1bbSPeter Collingbourne return true; 1008df49d1bbSPeter Collingbourne } 1009df49d1bbSPeter Collingbourne 101059675ba0SPeter Collingbourne std::string DevirtModule::getGlobalName(VTableSlot Slot, 101159675ba0SPeter Collingbourne ArrayRef<uint64_t> Args, 101259675ba0SPeter Collingbourne StringRef Name) { 101359675ba0SPeter Collingbourne std::string FullName = "__typeid_"; 101459675ba0SPeter Collingbourne raw_string_ostream OS(FullName); 101559675ba0SPeter Collingbourne OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; 101659675ba0SPeter Collingbourne for (uint64_t Arg : Args) 101759675ba0SPeter Collingbourne OS << '_' << Arg; 101859675ba0SPeter Collingbourne OS << '_' << Name; 101959675ba0SPeter Collingbourne return OS.str(); 102059675ba0SPeter Collingbourne } 102159675ba0SPeter Collingbourne 1022b15a35e6SPeter Collingbourne bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { 1023b15a35e6SPeter Collingbourne Triple T(M.getTargetTriple()); 1024b15a35e6SPeter Collingbourne return (T.getArch() == Triple::x86 || T.getArch() == Triple::x86_64) && 1025b15a35e6SPeter Collingbourne T.getObjectFormat() == Triple::ELF; 1026b15a35e6SPeter Collingbourne } 1027b15a35e6SPeter Collingbourne 102859675ba0SPeter Collingbourne void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 102959675ba0SPeter Collingbourne StringRef Name, Constant *C) { 103059675ba0SPeter Collingbourne GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, 103159675ba0SPeter Collingbourne getGlobalName(Slot, Args, Name), C, &M); 103259675ba0SPeter Collingbourne GA->setVisibility(GlobalValue::HiddenVisibility); 103359675ba0SPeter Collingbourne } 103459675ba0SPeter Collingbourne 1035b15a35e6SPeter Collingbourne void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1036b15a35e6SPeter Collingbourne StringRef Name, uint32_t Const, 1037b15a35e6SPeter Collingbourne uint32_t &Storage) { 1038b15a35e6SPeter Collingbourne if (shouldExportConstantsAsAbsoluteSymbols()) { 1039b15a35e6SPeter Collingbourne exportGlobal( 1040b15a35e6SPeter Collingbourne Slot, Args, Name, 1041b15a35e6SPeter Collingbourne ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy)); 1042b15a35e6SPeter Collingbourne return; 1043b15a35e6SPeter Collingbourne } 1044b15a35e6SPeter Collingbourne 1045b15a35e6SPeter Collingbourne Storage = Const; 1046b15a35e6SPeter Collingbourne } 1047b15a35e6SPeter Collingbourne 104859675ba0SPeter Collingbourne Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1049b15a35e6SPeter Collingbourne StringRef Name) { 105059675ba0SPeter Collingbourne Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); 105159675ba0SPeter Collingbourne auto *GV = dyn_cast<GlobalVariable>(C); 1052b15a35e6SPeter Collingbourne if (GV) 1053b15a35e6SPeter Collingbourne GV->setVisibility(GlobalValue::HiddenVisibility); 1054b15a35e6SPeter Collingbourne return C; 1055b15a35e6SPeter Collingbourne } 1056b15a35e6SPeter Collingbourne 1057b15a35e6SPeter Collingbourne Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1058b15a35e6SPeter Collingbourne StringRef Name, IntegerType *IntTy, 1059b15a35e6SPeter Collingbourne uint32_t Storage) { 1060b15a35e6SPeter Collingbourne if (!shouldExportConstantsAsAbsoluteSymbols()) 1061b15a35e6SPeter Collingbourne return ConstantInt::get(IntTy, Storage); 1062b15a35e6SPeter Collingbourne 1063b15a35e6SPeter Collingbourne Constant *C = importGlobal(Slot, Args, Name); 1064b15a35e6SPeter Collingbourne auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); 1065b15a35e6SPeter Collingbourne C = ConstantExpr::getPtrToInt(C, IntTy); 1066b15a35e6SPeter Collingbourne 106714dcf02fSPeter Collingbourne // We only need to set metadata if the global is newly created, in which 106814dcf02fSPeter Collingbourne // case it would not have hidden visibility. 10690deb9a9aSBenjamin Kramer if (GV->hasMetadata(LLVMContext::MD_absolute_symbol)) 107059675ba0SPeter Collingbourne return C; 107114dcf02fSPeter Collingbourne 107214dcf02fSPeter Collingbourne auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { 107314dcf02fSPeter Collingbourne auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); 107414dcf02fSPeter Collingbourne auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); 107514dcf02fSPeter Collingbourne GV->setMetadata(LLVMContext::MD_absolute_symbol, 107614dcf02fSPeter Collingbourne MDNode::get(M.getContext(), {MinC, MaxC})); 107714dcf02fSPeter Collingbourne }; 1078b15a35e6SPeter Collingbourne unsigned AbsWidth = IntTy->getBitWidth(); 107914dcf02fSPeter Collingbourne if (AbsWidth == IntPtrTy->getBitWidth()) 108014dcf02fSPeter Collingbourne SetAbsRange(~0ull, ~0ull); // Full set. 1081b15a35e6SPeter Collingbourne else 108214dcf02fSPeter Collingbourne SetAbsRange(0, 1ull << AbsWidth); 1083b15a35e6SPeter Collingbourne return C; 108459675ba0SPeter Collingbourne } 108559675ba0SPeter Collingbourne 108650cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 108750cbd7ccSPeter Collingbourne bool IsOne, 108850cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr) { 108950cbd7ccSPeter Collingbourne for (auto &&Call : CSInfo.CallSites) { 109050cbd7ccSPeter Collingbourne IRBuilder<> B(Call.CS.getInstruction()); 1091001052a0SPeter Collingbourne Value *Cmp = 1092001052a0SPeter Collingbourne B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, 1093001052a0SPeter Collingbourne B.CreateBitCast(Call.VTable, Int8PtrTy), UniqueMemberAddr); 109450cbd7ccSPeter Collingbourne Cmp = B.CreateZExt(Cmp, Call.CS->getType()); 1095e963c89dSSam Elliott Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, 1096e963c89dSSam Elliott Cmp); 109750cbd7ccSPeter Collingbourne } 109859675ba0SPeter Collingbourne CSInfo.markDevirt(); 109950cbd7ccSPeter Collingbourne } 110050cbd7ccSPeter Collingbourne 11012974856aSPeter Collingbourne Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { 11022974856aSPeter Collingbourne Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); 11032974856aSPeter Collingbourne return ConstantExpr::getGetElementPtr(Int8Ty, C, 11042974856aSPeter Collingbourne ConstantInt::get(Int64Ty, M->Offset)); 11052974856aSPeter Collingbourne } 11062974856aSPeter Collingbourne 1107df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt( 1108f3403fd2SIvan Krasin unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 110959675ba0SPeter Collingbourne CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, 111059675ba0SPeter Collingbourne VTableSlot Slot, ArrayRef<uint64_t> Args) { 1111df49d1bbSPeter Collingbourne // IsOne controls whether we look for a 0 or a 1. 1112df49d1bbSPeter Collingbourne auto tryUniqueRetValOptFor = [&](bool IsOne) { 1113cdc71612SEugene Zelenko const TypeMemberInfo *UniqueMember = nullptr; 1114df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : TargetsForSlot) { 11153866cc5fSPeter Collingbourne if (Target.RetVal == (IsOne ? 1 : 0)) { 11167efd7506SPeter Collingbourne if (UniqueMember) 1117df49d1bbSPeter Collingbourne return false; 11187efd7506SPeter Collingbourne UniqueMember = Target.TM; 1119df49d1bbSPeter Collingbourne } 1120df49d1bbSPeter Collingbourne } 1121df49d1bbSPeter Collingbourne 11227efd7506SPeter Collingbourne // We should have found a unique member or bailed out by now. We already 1123df49d1bbSPeter Collingbourne // checked for a uniform return value in tryUniformRetValOpt. 11247efd7506SPeter Collingbourne assert(UniqueMember); 1125df49d1bbSPeter Collingbourne 11262974856aSPeter Collingbourne Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); 112759675ba0SPeter Collingbourne if (CSInfo.isExported()) { 112859675ba0SPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; 112959675ba0SPeter Collingbourne Res->Info = IsOne; 113059675ba0SPeter Collingbourne 113159675ba0SPeter Collingbourne exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); 113259675ba0SPeter Collingbourne } 113359675ba0SPeter Collingbourne 113459675ba0SPeter Collingbourne // Replace each call with the comparison. 113550cbd7ccSPeter Collingbourne applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 113650cbd7ccSPeter Collingbourne UniqueMemberAddr); 113750cbd7ccSPeter Collingbourne 1138f3403fd2SIvan Krasin // Update devirtualization statistics for targets. 1139f3403fd2SIvan Krasin if (RemarksEnabled) 1140f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 1141f3403fd2SIvan Krasin Target.WasDevirt = true; 1142f3403fd2SIvan Krasin 1143df49d1bbSPeter Collingbourne return true; 1144df49d1bbSPeter Collingbourne }; 1145df49d1bbSPeter Collingbourne 1146df49d1bbSPeter Collingbourne if (BitWidth == 1) { 1147df49d1bbSPeter Collingbourne if (tryUniqueRetValOptFor(true)) 1148df49d1bbSPeter Collingbourne return true; 1149df49d1bbSPeter Collingbourne if (tryUniqueRetValOptFor(false)) 1150df49d1bbSPeter Collingbourne return true; 1151df49d1bbSPeter Collingbourne } 1152df49d1bbSPeter Collingbourne return false; 1153df49d1bbSPeter Collingbourne } 1154df49d1bbSPeter Collingbourne 115550cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 115650cbd7ccSPeter Collingbourne Constant *Byte, Constant *Bit) { 115750cbd7ccSPeter Collingbourne for (auto Call : CSInfo.CallSites) { 115850cbd7ccSPeter Collingbourne auto *RetType = cast<IntegerType>(Call.CS.getType()); 115950cbd7ccSPeter Collingbourne IRBuilder<> B(Call.CS.getInstruction()); 1160001052a0SPeter Collingbourne Value *Addr = 1161001052a0SPeter Collingbourne B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); 116250cbd7ccSPeter Collingbourne if (RetType->getBitWidth() == 1) { 116350cbd7ccSPeter Collingbourne Value *Bits = B.CreateLoad(Addr); 116450cbd7ccSPeter Collingbourne Value *BitsAndBit = B.CreateAnd(Bits, Bit); 116550cbd7ccSPeter Collingbourne auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 116650cbd7ccSPeter Collingbourne Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 1167e963c89dSSam Elliott OREGetter, IsBitSet); 116850cbd7ccSPeter Collingbourne } else { 116950cbd7ccSPeter Collingbourne Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 117050cbd7ccSPeter Collingbourne Value *Val = B.CreateLoad(RetType, ValAddr); 1171e963c89dSSam Elliott Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, 1172e963c89dSSam Elliott OREGetter, Val); 117350cbd7ccSPeter Collingbourne } 117450cbd7ccSPeter Collingbourne } 117514dcf02fSPeter Collingbourne CSInfo.markDevirt(); 117650cbd7ccSPeter Collingbourne } 117750cbd7ccSPeter Collingbourne 1178df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp( 117959675ba0SPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 118059675ba0SPeter Collingbourne WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1181df49d1bbSPeter Collingbourne // This only works if the function returns an integer. 1182df49d1bbSPeter Collingbourne auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 1183df49d1bbSPeter Collingbourne if (!RetType) 1184df49d1bbSPeter Collingbourne return false; 1185df49d1bbSPeter Collingbourne unsigned BitWidth = RetType->getBitWidth(); 1186df49d1bbSPeter Collingbourne if (BitWidth > 64) 1187df49d1bbSPeter Collingbourne return false; 1188df49d1bbSPeter Collingbourne 118917febdbbSPeter Collingbourne // Make sure that each function is defined, does not access memory, takes at 119017febdbbSPeter Collingbourne // least one argument, does not use its first argument (which we assume is 119117febdbbSPeter Collingbourne // 'this'), and has the same return type. 119237317f12SPeter Collingbourne // 119337317f12SPeter Collingbourne // Note that we test whether this copy of the function is readnone, rather 119437317f12SPeter Collingbourne // than testing function attributes, which must hold for any copy of the 119537317f12SPeter Collingbourne // function, even a less optimized version substituted at link time. This is 119637317f12SPeter Collingbourne // sound because the virtual constant propagation optimizations effectively 119737317f12SPeter Collingbourne // inline all implementations of the virtual function into each call site, 119837317f12SPeter Collingbourne // rather than using function attributes to perform local optimization. 1199df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : TargetsForSlot) { 120037317f12SPeter Collingbourne if (Target.Fn->isDeclaration() || 120137317f12SPeter Collingbourne computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != 120237317f12SPeter Collingbourne MAK_ReadNone || 120317febdbbSPeter Collingbourne Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || 1204df49d1bbSPeter Collingbourne Target.Fn->getReturnType() != RetType) 1205df49d1bbSPeter Collingbourne return false; 1206df49d1bbSPeter Collingbourne } 1207df49d1bbSPeter Collingbourne 120850cbd7ccSPeter Collingbourne for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 1209df49d1bbSPeter Collingbourne if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 1210df49d1bbSPeter Collingbourne continue; 1211df49d1bbSPeter Collingbourne 121277a8d563SPeter Collingbourne WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; 121377a8d563SPeter Collingbourne if (Res) 121477a8d563SPeter Collingbourne ResByArg = &Res->ResByArg[CSByConstantArg.first]; 121577a8d563SPeter Collingbourne 121677a8d563SPeter Collingbourne if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) 1217df49d1bbSPeter Collingbourne continue; 1218df49d1bbSPeter Collingbourne 121959675ba0SPeter Collingbourne if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, 122059675ba0SPeter Collingbourne ResByArg, Slot, CSByConstantArg.first)) 1221df49d1bbSPeter Collingbourne continue; 1222df49d1bbSPeter Collingbourne 12237efd7506SPeter Collingbourne // Find an allocation offset in bits in all vtables associated with the 12247efd7506SPeter Collingbourne // type. 1225df49d1bbSPeter Collingbourne uint64_t AllocBefore = 1226df49d1bbSPeter Collingbourne findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 1227df49d1bbSPeter Collingbourne uint64_t AllocAfter = 1228df49d1bbSPeter Collingbourne findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 1229df49d1bbSPeter Collingbourne 1230df49d1bbSPeter Collingbourne // Calculate the total amount of padding needed to store a value at both 1231df49d1bbSPeter Collingbourne // ends of the object. 1232df49d1bbSPeter Collingbourne uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 1233df49d1bbSPeter Collingbourne for (auto &&Target : TargetsForSlot) { 1234df49d1bbSPeter Collingbourne TotalPaddingBefore += std::max<int64_t>( 1235df49d1bbSPeter Collingbourne (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 1236df49d1bbSPeter Collingbourne TotalPaddingAfter += std::max<int64_t>( 1237df49d1bbSPeter Collingbourne (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 1238df49d1bbSPeter Collingbourne } 1239df49d1bbSPeter Collingbourne 1240df49d1bbSPeter Collingbourne // If the amount of padding is too large, give up. 1241df49d1bbSPeter Collingbourne // FIXME: do something smarter here. 1242df49d1bbSPeter Collingbourne if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 1243df49d1bbSPeter Collingbourne continue; 1244df49d1bbSPeter Collingbourne 1245df49d1bbSPeter Collingbourne // Calculate the offset to the value as a (possibly negative) byte offset 1246df49d1bbSPeter Collingbourne // and (if applicable) a bit offset, and store the values in the targets. 1247df49d1bbSPeter Collingbourne int64_t OffsetByte; 1248df49d1bbSPeter Collingbourne uint64_t OffsetBit; 1249df49d1bbSPeter Collingbourne if (TotalPaddingBefore <= TotalPaddingAfter) 1250df49d1bbSPeter Collingbourne setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 1251df49d1bbSPeter Collingbourne OffsetBit); 1252df49d1bbSPeter Collingbourne else 1253df49d1bbSPeter Collingbourne setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 1254df49d1bbSPeter Collingbourne OffsetBit); 1255df49d1bbSPeter Collingbourne 1256f3403fd2SIvan Krasin if (RemarksEnabled) 1257f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 1258f3403fd2SIvan Krasin Target.WasDevirt = true; 1259f3403fd2SIvan Krasin 126014dcf02fSPeter Collingbourne 126114dcf02fSPeter Collingbourne if (CSByConstantArg.second.isExported()) { 126214dcf02fSPeter Collingbourne ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; 1263b15a35e6SPeter Collingbourne exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte, 1264b15a35e6SPeter Collingbourne ResByArg->Byte); 1265b15a35e6SPeter Collingbourne exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit, 1266b15a35e6SPeter Collingbourne ResByArg->Bit); 126714dcf02fSPeter Collingbourne } 126814dcf02fSPeter Collingbourne 126914dcf02fSPeter Collingbourne // Rewrite each call to a load from OffsetByte/OffsetBit. 1270b15a35e6SPeter Collingbourne Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 1271b15a35e6SPeter Collingbourne Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 127250cbd7ccSPeter Collingbourne applyVirtualConstProp(CSByConstantArg.second, 127350cbd7ccSPeter Collingbourne TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 1274df49d1bbSPeter Collingbourne } 1275df49d1bbSPeter Collingbourne return true; 1276df49d1bbSPeter Collingbourne } 1277df49d1bbSPeter Collingbourne 1278df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) { 1279df49d1bbSPeter Collingbourne if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 1280df49d1bbSPeter Collingbourne return; 1281df49d1bbSPeter Collingbourne 1282df49d1bbSPeter Collingbourne // Align each byte array to pointer width. 1283df49d1bbSPeter Collingbourne unsigned PointerSize = M.getDataLayout().getPointerSize(); 1284df49d1bbSPeter Collingbourne B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); 1285df49d1bbSPeter Collingbourne B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); 1286df49d1bbSPeter Collingbourne 1287df49d1bbSPeter Collingbourne // Before was stored in reverse order; flip it now. 1288df49d1bbSPeter Collingbourne for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 1289df49d1bbSPeter Collingbourne std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 1290df49d1bbSPeter Collingbourne 1291df49d1bbSPeter Collingbourne // Build an anonymous global containing the before bytes, followed by the 1292df49d1bbSPeter Collingbourne // original initializer, followed by the after bytes. 1293df49d1bbSPeter Collingbourne auto NewInit = ConstantStruct::getAnon( 1294df49d1bbSPeter Collingbourne {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 1295df49d1bbSPeter Collingbourne B.GV->getInitializer(), 1296df49d1bbSPeter Collingbourne ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 1297df49d1bbSPeter Collingbourne auto NewGV = 1298df49d1bbSPeter Collingbourne new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 1299df49d1bbSPeter Collingbourne GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 1300df49d1bbSPeter Collingbourne NewGV->setSection(B.GV->getSection()); 1301df49d1bbSPeter Collingbourne NewGV->setComdat(B.GV->getComdat()); 1302df49d1bbSPeter Collingbourne 13030312f614SPeter Collingbourne // Copy the original vtable's metadata to the anonymous global, adjusting 13040312f614SPeter Collingbourne // offsets as required. 13050312f614SPeter Collingbourne NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 13060312f614SPeter Collingbourne 1307df49d1bbSPeter Collingbourne // Build an alias named after the original global, pointing at the second 1308df49d1bbSPeter Collingbourne // element (the original initializer). 1309df49d1bbSPeter Collingbourne auto Alias = GlobalAlias::create( 1310df49d1bbSPeter Collingbourne B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 1311df49d1bbSPeter Collingbourne ConstantExpr::getGetElementPtr( 1312df49d1bbSPeter Collingbourne NewInit->getType(), NewGV, 1313df49d1bbSPeter Collingbourne ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 1314df49d1bbSPeter Collingbourne ConstantInt::get(Int32Ty, 1)}), 1315df49d1bbSPeter Collingbourne &M); 1316df49d1bbSPeter Collingbourne Alias->setVisibility(B.GV->getVisibility()); 1317df49d1bbSPeter Collingbourne Alias->takeName(B.GV); 1318df49d1bbSPeter Collingbourne 1319df49d1bbSPeter Collingbourne B.GV->replaceAllUsesWith(Alias); 1320df49d1bbSPeter Collingbourne B.GV->eraseFromParent(); 1321df49d1bbSPeter Collingbourne } 1322df49d1bbSPeter Collingbourne 1323f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() { 1324f3403fd2SIvan Krasin const auto &FL = M.getFunctionList(); 1325f3403fd2SIvan Krasin if (FL.empty()) 1326f3403fd2SIvan Krasin return false; 1327f3403fd2SIvan Krasin const Function &Fn = FL.front(); 1328de53bfb9SAdam Nemet 1329de53bfb9SAdam Nemet const auto &BBL = Fn.getBasicBlockList(); 1330de53bfb9SAdam Nemet if (BBL.empty()) 1331de53bfb9SAdam Nemet return false; 1332de53bfb9SAdam Nemet auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); 1333f3403fd2SIvan Krasin return DI.isEnabled(); 1334f3403fd2SIvan Krasin } 1335f3403fd2SIvan Krasin 13360312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, 13370312f614SPeter Collingbourne Function *AssumeFunc) { 1338df49d1bbSPeter Collingbourne // Find all virtual calls via a virtual table pointer %p under an assumption 13397efd7506SPeter Collingbourne // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 13407efd7506SPeter Collingbourne // points to a member of the type identifier %md. Group calls by (type ID, 13417efd7506SPeter Collingbourne // offset) pair (effectively the identity of the virtual function) and store 13427efd7506SPeter Collingbourne // to CallSlots. 1343df49d1bbSPeter Collingbourne DenseSet<Value *> SeenPtrs; 13447efd7506SPeter Collingbourne for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); 1345df49d1bbSPeter Collingbourne I != E;) { 1346df49d1bbSPeter Collingbourne auto CI = dyn_cast<CallInst>(I->getUser()); 1347df49d1bbSPeter Collingbourne ++I; 1348df49d1bbSPeter Collingbourne if (!CI) 1349df49d1bbSPeter Collingbourne continue; 1350df49d1bbSPeter Collingbourne 1351ccdc225cSPeter Collingbourne // Search for virtual calls based on %p and add them to DevirtCalls. 1352ccdc225cSPeter Collingbourne SmallVector<DevirtCallSite, 1> DevirtCalls; 1353df49d1bbSPeter Collingbourne SmallVector<CallInst *, 1> Assumes; 13540312f614SPeter Collingbourne findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); 1355df49d1bbSPeter Collingbourne 1356ccdc225cSPeter Collingbourne // If we found any, add them to CallSlots. Only do this if we haven't seen 1357ccdc225cSPeter Collingbourne // the vtable pointer before, as it may have been CSE'd with pointers from 1358ccdc225cSPeter Collingbourne // other call sites, and we don't want to process call sites multiple times. 1359df49d1bbSPeter Collingbourne if (!Assumes.empty()) { 13607efd7506SPeter Collingbourne Metadata *TypeId = 1361df49d1bbSPeter Collingbourne cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 1362df49d1bbSPeter Collingbourne Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 1363ccdc225cSPeter Collingbourne if (SeenPtrs.insert(Ptr).second) { 1364ccdc225cSPeter Collingbourne for (DevirtCallSite Call : DevirtCalls) { 1365001052a0SPeter Collingbourne CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr); 1366ccdc225cSPeter Collingbourne } 1367ccdc225cSPeter Collingbourne } 1368df49d1bbSPeter Collingbourne } 1369df49d1bbSPeter Collingbourne 13707efd7506SPeter Collingbourne // We no longer need the assumes or the type test. 1371df49d1bbSPeter Collingbourne for (auto Assume : Assumes) 1372df49d1bbSPeter Collingbourne Assume->eraseFromParent(); 1373df49d1bbSPeter Collingbourne // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 1374df49d1bbSPeter Collingbourne // may use the vtable argument later. 1375df49d1bbSPeter Collingbourne if (CI->use_empty()) 1376df49d1bbSPeter Collingbourne CI->eraseFromParent(); 1377df49d1bbSPeter Collingbourne } 13780312f614SPeter Collingbourne } 13790312f614SPeter Collingbourne 13800312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 13810312f614SPeter Collingbourne Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 13820312f614SPeter Collingbourne 13830312f614SPeter Collingbourne for (auto I = TypeCheckedLoadFunc->use_begin(), 13840312f614SPeter Collingbourne E = TypeCheckedLoadFunc->use_end(); 13850312f614SPeter Collingbourne I != E;) { 13860312f614SPeter Collingbourne auto CI = dyn_cast<CallInst>(I->getUser()); 13870312f614SPeter Collingbourne ++I; 13880312f614SPeter Collingbourne if (!CI) 13890312f614SPeter Collingbourne continue; 13900312f614SPeter Collingbourne 13910312f614SPeter Collingbourne Value *Ptr = CI->getArgOperand(0); 13920312f614SPeter Collingbourne Value *Offset = CI->getArgOperand(1); 13930312f614SPeter Collingbourne Value *TypeIdValue = CI->getArgOperand(2); 13940312f614SPeter Collingbourne Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 13950312f614SPeter Collingbourne 13960312f614SPeter Collingbourne SmallVector<DevirtCallSite, 1> DevirtCalls; 13970312f614SPeter Collingbourne SmallVector<Instruction *, 1> LoadedPtrs; 13980312f614SPeter Collingbourne SmallVector<Instruction *, 1> Preds; 13990312f614SPeter Collingbourne bool HasNonCallUses = false; 14000312f614SPeter Collingbourne findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 14010312f614SPeter Collingbourne HasNonCallUses, CI); 14020312f614SPeter Collingbourne 14030312f614SPeter Collingbourne // Start by generating "pessimistic" code that explicitly loads the function 14040312f614SPeter Collingbourne // pointer from the vtable and performs the type check. If possible, we will 14050312f614SPeter Collingbourne // eliminate the load and the type check later. 14060312f614SPeter Collingbourne 14070312f614SPeter Collingbourne // If possible, only generate the load at the point where it is used. 14080312f614SPeter Collingbourne // This helps avoid unnecessary spills. 14090312f614SPeter Collingbourne IRBuilder<> LoadB( 14100312f614SPeter Collingbourne (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 14110312f614SPeter Collingbourne Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 14120312f614SPeter Collingbourne Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); 14130312f614SPeter Collingbourne Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); 14140312f614SPeter Collingbourne 14150312f614SPeter Collingbourne for (Instruction *LoadedPtr : LoadedPtrs) { 14160312f614SPeter Collingbourne LoadedPtr->replaceAllUsesWith(LoadedValue); 14170312f614SPeter Collingbourne LoadedPtr->eraseFromParent(); 14180312f614SPeter Collingbourne } 14190312f614SPeter Collingbourne 14200312f614SPeter Collingbourne // Likewise for the type test. 14210312f614SPeter Collingbourne IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 14220312f614SPeter Collingbourne CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 14230312f614SPeter Collingbourne 14240312f614SPeter Collingbourne for (Instruction *Pred : Preds) { 14250312f614SPeter Collingbourne Pred->replaceAllUsesWith(TypeTestCall); 14260312f614SPeter Collingbourne Pred->eraseFromParent(); 14270312f614SPeter Collingbourne } 14280312f614SPeter Collingbourne 14290312f614SPeter Collingbourne // We have already erased any extractvalue instructions that refer to the 14300312f614SPeter Collingbourne // intrinsic call, but the intrinsic may have other non-extractvalue uses 14310312f614SPeter Collingbourne // (although this is unlikely). In that case, explicitly build a pair and 14320312f614SPeter Collingbourne // RAUW it. 14330312f614SPeter Collingbourne if (!CI->use_empty()) { 14340312f614SPeter Collingbourne Value *Pair = UndefValue::get(CI->getType()); 14350312f614SPeter Collingbourne IRBuilder<> B(CI); 14360312f614SPeter Collingbourne Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 14370312f614SPeter Collingbourne Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 14380312f614SPeter Collingbourne CI->replaceAllUsesWith(Pair); 14390312f614SPeter Collingbourne } 14400312f614SPeter Collingbourne 14410312f614SPeter Collingbourne // The number of unsafe uses is initially the number of uses. 14420312f614SPeter Collingbourne auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 14430312f614SPeter Collingbourne NumUnsafeUses = DevirtCalls.size(); 14440312f614SPeter Collingbourne 14450312f614SPeter Collingbourne // If the function pointer has a non-call user, we cannot eliminate the type 14460312f614SPeter Collingbourne // check, as one of those users may eventually call the pointer. Increment 14470312f614SPeter Collingbourne // the unsafe use count to make sure it cannot reach zero. 14480312f614SPeter Collingbourne if (HasNonCallUses) 14490312f614SPeter Collingbourne ++NumUnsafeUses; 14500312f614SPeter Collingbourne for (DevirtCallSite Call : DevirtCalls) { 145150cbd7ccSPeter Collingbourne CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, 145250cbd7ccSPeter Collingbourne &NumUnsafeUses); 14530312f614SPeter Collingbourne } 14540312f614SPeter Collingbourne 14550312f614SPeter Collingbourne CI->eraseFromParent(); 14560312f614SPeter Collingbourne } 14570312f614SPeter Collingbourne } 14580312f614SPeter Collingbourne 14596d284fabSPeter Collingbourne void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { 14609a3f9797SPeter Collingbourne const TypeIdSummary *TidSummary = 1461f7691d8bSPeter Collingbourne ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString()); 14629a3f9797SPeter Collingbourne if (!TidSummary) 14639a3f9797SPeter Collingbourne return; 14649a3f9797SPeter Collingbourne auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); 14659a3f9797SPeter Collingbourne if (ResI == TidSummary->WPDRes.end()) 14669a3f9797SPeter Collingbourne return; 14679a3f9797SPeter Collingbourne const WholeProgramDevirtResolution &Res = ResI->second; 14686d284fabSPeter Collingbourne 14696d284fabSPeter Collingbourne if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { 14706d284fabSPeter Collingbourne // The type of the function in the declaration is irrelevant because every 14716d284fabSPeter Collingbourne // call site will cast it to the correct type. 1472db11fdfdSMehdi Amini auto *SingleImpl = M.getOrInsertFunction( 147359a2d7b9SSerge Guelton Res.SingleImplName, Type::getVoidTy(M.getContext())); 14746d284fabSPeter Collingbourne 14756d284fabSPeter Collingbourne // This is the import phase so we should not be exporting anything. 14766d284fabSPeter Collingbourne bool IsExported = false; 14776d284fabSPeter Collingbourne applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); 14786d284fabSPeter Collingbourne assert(!IsExported); 14796d284fabSPeter Collingbourne } 14800152c815SPeter Collingbourne 14810152c815SPeter Collingbourne for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { 14820152c815SPeter Collingbourne auto I = Res.ResByArg.find(CSByConstantArg.first); 14830152c815SPeter Collingbourne if (I == Res.ResByArg.end()) 14840152c815SPeter Collingbourne continue; 14850152c815SPeter Collingbourne auto &ResByArg = I->second; 14860152c815SPeter Collingbourne // FIXME: We should figure out what to do about the "function name" argument 14870152c815SPeter Collingbourne // to the apply* functions, as the function names are unavailable during the 14880152c815SPeter Collingbourne // importing phase. For now we just pass the empty string. This does not 14890152c815SPeter Collingbourne // impact correctness because the function names are just used for remarks. 14900152c815SPeter Collingbourne switch (ResByArg.TheKind) { 14910152c815SPeter Collingbourne case WholeProgramDevirtResolution::ByArg::UniformRetVal: 14920152c815SPeter Collingbourne applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); 14930152c815SPeter Collingbourne break; 149459675ba0SPeter Collingbourne case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { 149559675ba0SPeter Collingbourne Constant *UniqueMemberAddr = 149659675ba0SPeter Collingbourne importGlobal(Slot, CSByConstantArg.first, "unique_member"); 149759675ba0SPeter Collingbourne applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, 149859675ba0SPeter Collingbourne UniqueMemberAddr); 149959675ba0SPeter Collingbourne break; 150059675ba0SPeter Collingbourne } 150114dcf02fSPeter Collingbourne case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { 1502b15a35e6SPeter Collingbourne Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte", 1503b15a35e6SPeter Collingbourne Int32Ty, ResByArg.Byte); 1504b15a35e6SPeter Collingbourne Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty, 1505b15a35e6SPeter Collingbourne ResByArg.Bit); 150614dcf02fSPeter Collingbourne applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); 15070e6694d1SAdrian Prantl break; 150814dcf02fSPeter Collingbourne } 15090152c815SPeter Collingbourne default: 15100152c815SPeter Collingbourne break; 15110152c815SPeter Collingbourne } 15120152c815SPeter Collingbourne } 15132974856aSPeter Collingbourne 15142974856aSPeter Collingbourne if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { 15152974856aSPeter Collingbourne auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), 15162974856aSPeter Collingbourne Type::getVoidTy(M.getContext())); 15172974856aSPeter Collingbourne bool IsExported = false; 15182974856aSPeter Collingbourne applyICallBranchFunnel(SlotInfo, JT, IsExported); 15192974856aSPeter Collingbourne assert(!IsExported); 15202974856aSPeter Collingbourne } 15216d284fabSPeter Collingbourne } 15226d284fabSPeter Collingbourne 15236d284fabSPeter Collingbourne void DevirtModule::removeRedundantTypeTests() { 15246d284fabSPeter Collingbourne auto True = ConstantInt::getTrue(M.getContext()); 15256d284fabSPeter Collingbourne for (auto &&U : NumUnsafeUsesForTypeTest) { 15266d284fabSPeter Collingbourne if (U.second == 0) { 15276d284fabSPeter Collingbourne U.first->replaceAllUsesWith(True); 15286d284fabSPeter Collingbourne U.first->eraseFromParent(); 15296d284fabSPeter Collingbourne } 15306d284fabSPeter Collingbourne } 15316d284fabSPeter Collingbourne } 15326d284fabSPeter Collingbourne 15330312f614SPeter Collingbourne bool DevirtModule::run() { 15340312f614SPeter Collingbourne Function *TypeTestFunc = 15350312f614SPeter Collingbourne M.getFunction(Intrinsic::getName(Intrinsic::type_test)); 15360312f614SPeter Collingbourne Function *TypeCheckedLoadFunc = 15370312f614SPeter Collingbourne M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); 15380312f614SPeter Collingbourne Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 15390312f614SPeter Collingbourne 1540b406baaeSPeter Collingbourne // Normally if there are no users of the devirtualization intrinsics in the 1541b406baaeSPeter Collingbourne // module, this pass has nothing to do. But if we are exporting, we also need 1542b406baaeSPeter Collingbourne // to handle any users that appear only in the function summaries. 1543f7691d8bSPeter Collingbourne if (!ExportSummary && 1544b406baaeSPeter Collingbourne (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 15450312f614SPeter Collingbourne AssumeFunc->use_empty()) && 15460312f614SPeter Collingbourne (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) 15470312f614SPeter Collingbourne return false; 15480312f614SPeter Collingbourne 15490312f614SPeter Collingbourne if (TypeTestFunc && AssumeFunc) 15500312f614SPeter Collingbourne scanTypeTestUsers(TypeTestFunc, AssumeFunc); 15510312f614SPeter Collingbourne 15520312f614SPeter Collingbourne if (TypeCheckedLoadFunc) 15530312f614SPeter Collingbourne scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 1554df49d1bbSPeter Collingbourne 1555f7691d8bSPeter Collingbourne if (ImportSummary) { 15566d284fabSPeter Collingbourne for (auto &S : CallSlots) 15576d284fabSPeter Collingbourne importResolution(S.first, S.second); 15586d284fabSPeter Collingbourne 15596d284fabSPeter Collingbourne removeRedundantTypeTests(); 15606d284fabSPeter Collingbourne 15616d284fabSPeter Collingbourne // The rest of the code is only necessary when exporting or during regular 15626d284fabSPeter Collingbourne // LTO, so we are done. 15636d284fabSPeter Collingbourne return true; 15646d284fabSPeter Collingbourne } 15656d284fabSPeter Collingbourne 15667efd7506SPeter Collingbourne // Rebuild type metadata into a map for easy lookup. 1567df49d1bbSPeter Collingbourne std::vector<VTableBits> Bits; 15687efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 15697efd7506SPeter Collingbourne buildTypeIdentifierMap(Bits, TypeIdMap); 15707efd7506SPeter Collingbourne if (TypeIdMap.empty()) 1571df49d1bbSPeter Collingbourne return true; 1572df49d1bbSPeter Collingbourne 1573b406baaeSPeter Collingbourne // Collect information from summary about which calls to try to devirtualize. 1574f7691d8bSPeter Collingbourne if (ExportSummary) { 1575b406baaeSPeter Collingbourne DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; 1576b406baaeSPeter Collingbourne for (auto &P : TypeIdMap) { 1577b406baaeSPeter Collingbourne if (auto *TypeId = dyn_cast<MDString>(P.first)) 1578b406baaeSPeter Collingbourne MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( 1579b406baaeSPeter Collingbourne TypeId); 1580b406baaeSPeter Collingbourne } 1581b406baaeSPeter Collingbourne 1582f7691d8bSPeter Collingbourne for (auto &P : *ExportSummary) { 15839667b91bSPeter Collingbourne for (auto &S : P.second.SummaryList) { 1584b406baaeSPeter Collingbourne auto *FS = dyn_cast<FunctionSummary>(S.get()); 1585b406baaeSPeter Collingbourne if (!FS) 1586b406baaeSPeter Collingbourne continue; 1587b406baaeSPeter Collingbourne // FIXME: Only add live functions. 15885d8aea10SGeorge Rimar for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 15895d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VF.GUID]) { 15902974856aSPeter Collingbourne CallSlots[{MD, VF.Offset}] 15912974856aSPeter Collingbourne .CSInfo.markSummaryHasTypeTestAssumeUsers(); 15925d8aea10SGeorge Rimar } 15935d8aea10SGeorge Rimar } 15945d8aea10SGeorge Rimar for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 15955d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VF.GUID]) { 15962974856aSPeter Collingbourne CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 15975d8aea10SGeorge Rimar } 15985d8aea10SGeorge Rimar } 1599b406baaeSPeter Collingbourne for (const FunctionSummary::ConstVCall &VC : 16005d8aea10SGeorge Rimar FS->type_test_assume_const_vcalls()) { 16015d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 16022325bb34SPeter Collingbourne CallSlots[{MD, VC.VFunc.Offset}] 16035d8aea10SGeorge Rimar .ConstCSInfo[VC.Args] 16042974856aSPeter Collingbourne .markSummaryHasTypeTestAssumeUsers(); 16055d8aea10SGeorge Rimar } 16065d8aea10SGeorge Rimar } 16072325bb34SPeter Collingbourne for (const FunctionSummary::ConstVCall &VC : 16085d8aea10SGeorge Rimar FS->type_checked_load_const_vcalls()) { 16095d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 1610b406baaeSPeter Collingbourne CallSlots[{MD, VC.VFunc.Offset}] 1611b406baaeSPeter Collingbourne .ConstCSInfo[VC.Args] 16122974856aSPeter Collingbourne .addSummaryTypeCheckedLoadUser(FS); 1613b406baaeSPeter Collingbourne } 1614b406baaeSPeter Collingbourne } 1615b406baaeSPeter Collingbourne } 16165d8aea10SGeorge Rimar } 16175d8aea10SGeorge Rimar } 1618b406baaeSPeter Collingbourne 16197efd7506SPeter Collingbourne // For each (type, offset) pair: 1620df49d1bbSPeter Collingbourne bool DidVirtualConstProp = false; 1621f3403fd2SIvan Krasin std::map<std::string, Function*> DevirtTargets; 1622df49d1bbSPeter Collingbourne for (auto &S : CallSlots) { 16237efd7506SPeter Collingbourne // Search each of the members of the type identifier for the virtual 16247efd7506SPeter Collingbourne // function implementation at offset S.first.ByteOffset, and add to 16257efd7506SPeter Collingbourne // TargetsForSlot. 1626df49d1bbSPeter Collingbourne std::vector<VirtualCallTarget> TargetsForSlot; 1627b406baaeSPeter Collingbourne if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], 1628b406baaeSPeter Collingbourne S.first.ByteOffset)) { 16292325bb34SPeter Collingbourne WholeProgramDevirtResolution *Res = nullptr; 1630f7691d8bSPeter Collingbourne if (ExportSummary && isa<MDString>(S.first.TypeID)) 1631f7691d8bSPeter Collingbourne Res = &ExportSummary 16329a3f9797SPeter Collingbourne ->getOrInsertTypeIdSummary( 16339a3f9797SPeter Collingbourne cast<MDString>(S.first.TypeID)->getString()) 16342325bb34SPeter Collingbourne .WPDRes[S.first.ByteOffset]; 16352325bb34SPeter Collingbourne 16362974856aSPeter Collingbourne if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) { 16372974856aSPeter Collingbourne DidVirtualConstProp |= 16382974856aSPeter Collingbourne tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); 16392974856aSPeter Collingbourne 16402974856aSPeter Collingbourne tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first); 16412974856aSPeter Collingbourne } 1642f3403fd2SIvan Krasin 1643f3403fd2SIvan Krasin // Collect functions devirtualized at least for one call site for stats. 1644f3403fd2SIvan Krasin if (RemarksEnabled) 1645f3403fd2SIvan Krasin for (const auto &T : TargetsForSlot) 1646f3403fd2SIvan Krasin if (T.WasDevirt) 1647f3403fd2SIvan Krasin DevirtTargets[T.Fn->getName()] = T.Fn; 1648b05e06e4SIvan Krasin } 1649df49d1bbSPeter Collingbourne 1650b406baaeSPeter Collingbourne // CFI-specific: if we are exporting and any llvm.type.checked.load 1651b406baaeSPeter Collingbourne // intrinsics were *not* devirtualized, we need to add the resulting 1652b406baaeSPeter Collingbourne // llvm.type.test intrinsics to the function summaries so that the 1653b406baaeSPeter Collingbourne // LowerTypeTests pass will export them. 1654f7691d8bSPeter Collingbourne if (ExportSummary && isa<MDString>(S.first.TypeID)) { 1655b406baaeSPeter Collingbourne auto GUID = 1656b406baaeSPeter Collingbourne GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); 1657b406baaeSPeter Collingbourne for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) 1658b406baaeSPeter Collingbourne FS->addTypeTest(GUID); 1659b406baaeSPeter Collingbourne for (auto &CCS : S.second.ConstCSInfo) 1660b406baaeSPeter Collingbourne for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) 1661b406baaeSPeter Collingbourne FS->addTypeTest(GUID); 1662b406baaeSPeter Collingbourne } 1663b406baaeSPeter Collingbourne } 1664b406baaeSPeter Collingbourne 1665f3403fd2SIvan Krasin if (RemarksEnabled) { 1666f3403fd2SIvan Krasin // Generate remarks for each devirtualized function. 1667f3403fd2SIvan Krasin for (const auto &DT : DevirtTargets) { 1668f3403fd2SIvan Krasin Function *F = DT.second; 1669e963c89dSSam Elliott 1670e963c89dSSam Elliott using namespace ore; 16719110cb45SPeter Collingbourne OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) 16729110cb45SPeter Collingbourne << "devirtualized " 16739110cb45SPeter Collingbourne << NV("FunctionName", F->getName())); 1674b05e06e4SIvan Krasin } 1675df49d1bbSPeter Collingbourne } 1676df49d1bbSPeter Collingbourne 16776d284fabSPeter Collingbourne removeRedundantTypeTests(); 16780312f614SPeter Collingbourne 1679df49d1bbSPeter Collingbourne // Rebuild each global we touched as part of virtual constant propagation to 1680df49d1bbSPeter Collingbourne // include the before and after bytes. 1681df49d1bbSPeter Collingbourne if (DidVirtualConstProp) 1682df49d1bbSPeter Collingbourne for (VTableBits &B : Bits) 1683df49d1bbSPeter Collingbourne rebuildGlobal(B); 1684df49d1bbSPeter Collingbourne 1685df49d1bbSPeter Collingbourne return true; 1686df49d1bbSPeter Collingbourne } 1687