1df49d1bbSPeter Collingbourne //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2df49d1bbSPeter Collingbourne // 3df49d1bbSPeter Collingbourne // The LLVM Compiler Infrastructure 4df49d1bbSPeter Collingbourne // 5df49d1bbSPeter Collingbourne // This file is distributed under the University of Illinois Open Source 6df49d1bbSPeter Collingbourne // License. See LICENSE.TXT for details. 7df49d1bbSPeter Collingbourne // 8df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===// 9df49d1bbSPeter Collingbourne // 10df49d1bbSPeter Collingbourne // This pass implements whole program optimization of virtual calls in cases 117efd7506SPeter Collingbourne // where we know (via !type metadata) that the list of callees is fixed. This 12df49d1bbSPeter Collingbourne // includes the following: 13df49d1bbSPeter Collingbourne // - Single implementation devirtualization: if a virtual call has a single 14df49d1bbSPeter Collingbourne // possible callee, replace all calls with a direct call to that callee. 15df49d1bbSPeter Collingbourne // - Virtual constant propagation: if the virtual function's return type is an 16df49d1bbSPeter Collingbourne // integer <=64 bits and all possible callees are readnone, for each class and 17df49d1bbSPeter Collingbourne // each list of constant arguments: evaluate the function, store the return 18df49d1bbSPeter Collingbourne // value alongside the virtual table, and rewrite each virtual call as a load 19df49d1bbSPeter Collingbourne // from the virtual table. 20df49d1bbSPeter Collingbourne // - Uniform return value optimization: if the conditions for virtual constant 21df49d1bbSPeter Collingbourne // propagation hold and each function returns the same constant value, replace 22df49d1bbSPeter Collingbourne // each virtual call with that constant. 23df49d1bbSPeter Collingbourne // - Unique return value optimization for i1 return values: if the conditions 24df49d1bbSPeter Collingbourne // for virtual constant propagation hold and a single vtable's function 25df49d1bbSPeter Collingbourne // returns 0, or a single vtable's function returns 1, replace each virtual 26df49d1bbSPeter Collingbourne // call with a comparison of the vptr against that vtable's address. 27df49d1bbSPeter Collingbourne // 28b406baaeSPeter Collingbourne // This pass is intended to be used during the regular and thin LTO pipelines. 29b406baaeSPeter Collingbourne // During regular LTO, the pass determines the best optimization for each 30b406baaeSPeter Collingbourne // virtual call and applies the resolutions directly to virtual calls that are 31b406baaeSPeter Collingbourne // eligible for virtual call optimization (i.e. calls that use either of the 32b406baaeSPeter Collingbourne // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During 33b406baaeSPeter Collingbourne // ThinLTO, the pass operates in two phases: 34b406baaeSPeter Collingbourne // - Export phase: this is run during the thin link over a single merged module 35b406baaeSPeter Collingbourne // that contains all vtables with !type metadata that participate in the link. 36b406baaeSPeter Collingbourne // The pass computes a resolution for each virtual call and stores it in the 37b406baaeSPeter Collingbourne // type identifier summary. 38b406baaeSPeter Collingbourne // - Import phase: this is run during the thin backends over the individual 39b406baaeSPeter Collingbourne // modules. The pass applies the resolutions previously computed during the 40b406baaeSPeter Collingbourne // import phase to each eligible virtual call. 41b406baaeSPeter Collingbourne // 42df49d1bbSPeter Collingbourne //===----------------------------------------------------------------------===// 43df49d1bbSPeter Collingbourne 44df49d1bbSPeter Collingbourne #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 45b550cb17SMehdi Amini #include "llvm/ADT/ArrayRef.h" 46cdc71612SEugene Zelenko #include "llvm/ADT/DenseMap.h" 47cdc71612SEugene Zelenko #include "llvm/ADT/DenseMapInfo.h" 48df49d1bbSPeter Collingbourne #include "llvm/ADT/DenseSet.h" 49cdc71612SEugene Zelenko #include "llvm/ADT/iterator_range.h" 50df49d1bbSPeter Collingbourne #include "llvm/ADT/MapVector.h" 51cdc71612SEugene Zelenko #include "llvm/ADT/SmallVector.h" 5237317f12SPeter Collingbourne #include "llvm/Analysis/AliasAnalysis.h" 5337317f12SPeter Collingbourne #include "llvm/Analysis/BasicAliasAnalysis.h" 547efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h" 55df49d1bbSPeter Collingbourne #include "llvm/IR/CallSite.h" 56df49d1bbSPeter Collingbourne #include "llvm/IR/Constants.h" 57df49d1bbSPeter Collingbourne #include "llvm/IR/DataLayout.h" 58b05e06e4SIvan Krasin #include "llvm/IR/DebugInfoMetadata.h" 59cdc71612SEugene Zelenko #include "llvm/IR/DebugLoc.h" 60cdc71612SEugene Zelenko #include "llvm/IR/DerivedTypes.h" 615474645dSIvan Krasin #include "llvm/IR/DiagnosticInfo.h" 62cdc71612SEugene Zelenko #include "llvm/IR/Function.h" 63cdc71612SEugene Zelenko #include "llvm/IR/GlobalAlias.h" 64cdc71612SEugene Zelenko #include "llvm/IR/GlobalVariable.h" 65df49d1bbSPeter Collingbourne #include "llvm/IR/IRBuilder.h" 66cdc71612SEugene Zelenko #include "llvm/IR/InstrTypes.h" 67cdc71612SEugene Zelenko #include "llvm/IR/Instruction.h" 68df49d1bbSPeter Collingbourne #include "llvm/IR/Instructions.h" 69df49d1bbSPeter Collingbourne #include "llvm/IR/Intrinsics.h" 70cdc71612SEugene Zelenko #include "llvm/IR/LLVMContext.h" 71cdc71612SEugene Zelenko #include "llvm/IR/Metadata.h" 72df49d1bbSPeter Collingbourne #include "llvm/IR/Module.h" 732b33f653SPeter Collingbourne #include "llvm/IR/ModuleSummaryIndexYAML.h" 74df49d1bbSPeter Collingbourne #include "llvm/Pass.h" 75cdc71612SEugene Zelenko #include "llvm/PassRegistry.h" 76cdc71612SEugene Zelenko #include "llvm/PassSupport.h" 77cdc71612SEugene Zelenko #include "llvm/Support/Casting.h" 782b33f653SPeter Collingbourne #include "llvm/Support/Error.h" 792b33f653SPeter Collingbourne #include "llvm/Support/FileSystem.h" 80cdc71612SEugene Zelenko #include "llvm/Support/MathExtras.h" 81b550cb17SMehdi Amini #include "llvm/Transforms/IPO.h" 8237317f12SPeter Collingbourne #include "llvm/Transforms/IPO/FunctionAttrs.h" 83df49d1bbSPeter Collingbourne #include "llvm/Transforms/Utils/Evaluator.h" 84cdc71612SEugene Zelenko #include <algorithm> 85cdc71612SEugene Zelenko #include <cstddef> 86cdc71612SEugene Zelenko #include <map> 87df49d1bbSPeter Collingbourne #include <set> 88cdc71612SEugene Zelenko #include <string> 89df49d1bbSPeter Collingbourne 90df49d1bbSPeter Collingbourne using namespace llvm; 91df49d1bbSPeter Collingbourne using namespace wholeprogramdevirt; 92df49d1bbSPeter Collingbourne 93df49d1bbSPeter Collingbourne #define DEBUG_TYPE "wholeprogramdevirt" 94df49d1bbSPeter Collingbourne 952b33f653SPeter Collingbourne static cl::opt<PassSummaryAction> ClSummaryAction( 962b33f653SPeter Collingbourne "wholeprogramdevirt-summary-action", 972b33f653SPeter Collingbourne cl::desc("What to do with the summary when running this pass"), 982b33f653SPeter Collingbourne cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 992b33f653SPeter Collingbourne clEnumValN(PassSummaryAction::Import, "import", 1002b33f653SPeter Collingbourne "Import typeid resolutions from summary and globals"), 1012b33f653SPeter Collingbourne clEnumValN(PassSummaryAction::Export, "export", 1022b33f653SPeter Collingbourne "Export typeid resolutions to summary and globals")), 1032b33f653SPeter Collingbourne cl::Hidden); 1042b33f653SPeter Collingbourne 1052b33f653SPeter Collingbourne static cl::opt<std::string> ClReadSummary( 1062b33f653SPeter Collingbourne "wholeprogramdevirt-read-summary", 1072b33f653SPeter Collingbourne cl::desc("Read summary from given YAML file before running pass"), 1082b33f653SPeter Collingbourne cl::Hidden); 1092b33f653SPeter Collingbourne 1102b33f653SPeter Collingbourne static cl::opt<std::string> ClWriteSummary( 1112b33f653SPeter Collingbourne "wholeprogramdevirt-write-summary", 1122b33f653SPeter Collingbourne cl::desc("Write summary to given YAML file after running pass"), 1132b33f653SPeter Collingbourne cl::Hidden); 1142b33f653SPeter Collingbourne 115df49d1bbSPeter Collingbourne // Find the minimum offset that we may store a value of size Size bits at. If 116df49d1bbSPeter Collingbourne // IsAfter is set, look for an offset before the object, otherwise look for an 117df49d1bbSPeter Collingbourne // offset after the object. 118df49d1bbSPeter Collingbourne uint64_t 119df49d1bbSPeter Collingbourne wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 120df49d1bbSPeter Collingbourne bool IsAfter, uint64_t Size) { 121df49d1bbSPeter Collingbourne // Find a minimum offset taking into account only vtable sizes. 122df49d1bbSPeter Collingbourne uint64_t MinByte = 0; 123df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : Targets) { 124df49d1bbSPeter Collingbourne if (IsAfter) 125df49d1bbSPeter Collingbourne MinByte = std::max(MinByte, Target.minAfterBytes()); 126df49d1bbSPeter Collingbourne else 127df49d1bbSPeter Collingbourne MinByte = std::max(MinByte, Target.minBeforeBytes()); 128df49d1bbSPeter Collingbourne } 129df49d1bbSPeter Collingbourne 130df49d1bbSPeter Collingbourne // Build a vector of arrays of bytes covering, for each target, a slice of the 131df49d1bbSPeter Collingbourne // used region (see AccumBitVector::BytesUsed in 132df49d1bbSPeter Collingbourne // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 133df49d1bbSPeter Collingbourne // this aligns the used regions to start at MinByte. 134df49d1bbSPeter Collingbourne // 135df49d1bbSPeter Collingbourne // In this example, A, B and C are vtables, # is a byte already allocated for 136df49d1bbSPeter Collingbourne // a virtual function pointer, AAAA... (etc.) are the used regions for the 137df49d1bbSPeter Collingbourne // vtables and Offset(X) is the value computed for the Offset variable below 138df49d1bbSPeter Collingbourne // for X. 139df49d1bbSPeter Collingbourne // 140df49d1bbSPeter Collingbourne // Offset(A) 141df49d1bbSPeter Collingbourne // | | 142df49d1bbSPeter Collingbourne // |MinByte 143df49d1bbSPeter Collingbourne // A: ################AAAAAAAA|AAAAAAAA 144df49d1bbSPeter Collingbourne // B: ########BBBBBBBBBBBBBBBB|BBBB 145df49d1bbSPeter Collingbourne // C: ########################|CCCCCCCCCCCCCCCC 146df49d1bbSPeter Collingbourne // | Offset(B) | 147df49d1bbSPeter Collingbourne // 148df49d1bbSPeter Collingbourne // This code produces the slices of A, B and C that appear after the divider 149df49d1bbSPeter Collingbourne // at MinByte. 150df49d1bbSPeter Collingbourne std::vector<ArrayRef<uint8_t>> Used; 151df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : Targets) { 1527efd7506SPeter Collingbourne ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 1537efd7506SPeter Collingbourne : Target.TM->Bits->Before.BytesUsed; 154df49d1bbSPeter Collingbourne uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 155df49d1bbSPeter Collingbourne : MinByte - Target.minBeforeBytes(); 156df49d1bbSPeter Collingbourne 157df49d1bbSPeter Collingbourne // Disregard used regions that are smaller than Offset. These are 158df49d1bbSPeter Collingbourne // effectively all-free regions that do not need to be checked. 159df49d1bbSPeter Collingbourne if (VTUsed.size() > Offset) 160df49d1bbSPeter Collingbourne Used.push_back(VTUsed.slice(Offset)); 161df49d1bbSPeter Collingbourne } 162df49d1bbSPeter Collingbourne 163df49d1bbSPeter Collingbourne if (Size == 1) { 164df49d1bbSPeter Collingbourne // Find a free bit in each member of Used. 165df49d1bbSPeter Collingbourne for (unsigned I = 0;; ++I) { 166df49d1bbSPeter Collingbourne uint8_t BitsUsed = 0; 167df49d1bbSPeter Collingbourne for (auto &&B : Used) 168df49d1bbSPeter Collingbourne if (I < B.size()) 169df49d1bbSPeter Collingbourne BitsUsed |= B[I]; 170df49d1bbSPeter Collingbourne if (BitsUsed != 0xff) 171df49d1bbSPeter Collingbourne return (MinByte + I) * 8 + 172df49d1bbSPeter Collingbourne countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 173df49d1bbSPeter Collingbourne } 174df49d1bbSPeter Collingbourne } else { 175df49d1bbSPeter Collingbourne // Find a free (Size/8) byte region in each member of Used. 176df49d1bbSPeter Collingbourne // FIXME: see if alignment helps. 177df49d1bbSPeter Collingbourne for (unsigned I = 0;; ++I) { 178df49d1bbSPeter Collingbourne for (auto &&B : Used) { 179df49d1bbSPeter Collingbourne unsigned Byte = 0; 180df49d1bbSPeter Collingbourne while ((I + Byte) < B.size() && Byte < (Size / 8)) { 181df49d1bbSPeter Collingbourne if (B[I + Byte]) 182df49d1bbSPeter Collingbourne goto NextI; 183df49d1bbSPeter Collingbourne ++Byte; 184df49d1bbSPeter Collingbourne } 185df49d1bbSPeter Collingbourne } 186df49d1bbSPeter Collingbourne return (MinByte + I) * 8; 187df49d1bbSPeter Collingbourne NextI:; 188df49d1bbSPeter Collingbourne } 189df49d1bbSPeter Collingbourne } 190df49d1bbSPeter Collingbourne } 191df49d1bbSPeter Collingbourne 192df49d1bbSPeter Collingbourne void wholeprogramdevirt::setBeforeReturnValues( 193df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 194df49d1bbSPeter Collingbourne unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 195df49d1bbSPeter Collingbourne if (BitWidth == 1) 196df49d1bbSPeter Collingbourne OffsetByte = -(AllocBefore / 8 + 1); 197df49d1bbSPeter Collingbourne else 198df49d1bbSPeter Collingbourne OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 199df49d1bbSPeter Collingbourne OffsetBit = AllocBefore % 8; 200df49d1bbSPeter Collingbourne 201df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : Targets) { 202df49d1bbSPeter Collingbourne if (BitWidth == 1) 203df49d1bbSPeter Collingbourne Target.setBeforeBit(AllocBefore); 204df49d1bbSPeter Collingbourne else 205df49d1bbSPeter Collingbourne Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 206df49d1bbSPeter Collingbourne } 207df49d1bbSPeter Collingbourne } 208df49d1bbSPeter Collingbourne 209df49d1bbSPeter Collingbourne void wholeprogramdevirt::setAfterReturnValues( 210df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 211df49d1bbSPeter Collingbourne unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 212df49d1bbSPeter Collingbourne if (BitWidth == 1) 213df49d1bbSPeter Collingbourne OffsetByte = AllocAfter / 8; 214df49d1bbSPeter Collingbourne else 215df49d1bbSPeter Collingbourne OffsetByte = (AllocAfter + 7) / 8; 216df49d1bbSPeter Collingbourne OffsetBit = AllocAfter % 8; 217df49d1bbSPeter Collingbourne 218df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : Targets) { 219df49d1bbSPeter Collingbourne if (BitWidth == 1) 220df49d1bbSPeter Collingbourne Target.setAfterBit(AllocAfter); 221df49d1bbSPeter Collingbourne else 222df49d1bbSPeter Collingbourne Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 223df49d1bbSPeter Collingbourne } 224df49d1bbSPeter Collingbourne } 225df49d1bbSPeter Collingbourne 2267efd7506SPeter Collingbourne VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) 2277efd7506SPeter Collingbourne : Fn(Fn), TM(TM), 22889439a79SIvan Krasin IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} 229df49d1bbSPeter Collingbourne 230df49d1bbSPeter Collingbourne namespace { 231df49d1bbSPeter Collingbourne 2327efd7506SPeter Collingbourne // A slot in a set of virtual tables. The TypeID identifies the set of virtual 233df49d1bbSPeter Collingbourne // tables, and the ByteOffset is the offset in bytes from the address point to 234df49d1bbSPeter Collingbourne // the virtual function pointer. 235df49d1bbSPeter Collingbourne struct VTableSlot { 2367efd7506SPeter Collingbourne Metadata *TypeID; 237df49d1bbSPeter Collingbourne uint64_t ByteOffset; 238df49d1bbSPeter Collingbourne }; 239df49d1bbSPeter Collingbourne 240cdc71612SEugene Zelenko } // end anonymous namespace 241df49d1bbSPeter Collingbourne 2429b656527SPeter Collingbourne namespace llvm { 2439b656527SPeter Collingbourne 244df49d1bbSPeter Collingbourne template <> struct DenseMapInfo<VTableSlot> { 245df49d1bbSPeter Collingbourne static VTableSlot getEmptyKey() { 246df49d1bbSPeter Collingbourne return {DenseMapInfo<Metadata *>::getEmptyKey(), 247df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getEmptyKey()}; 248df49d1bbSPeter Collingbourne } 249df49d1bbSPeter Collingbourne static VTableSlot getTombstoneKey() { 250df49d1bbSPeter Collingbourne return {DenseMapInfo<Metadata *>::getTombstoneKey(), 251df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getTombstoneKey()}; 252df49d1bbSPeter Collingbourne } 253df49d1bbSPeter Collingbourne static unsigned getHashValue(const VTableSlot &I) { 2547efd7506SPeter Collingbourne return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 255df49d1bbSPeter Collingbourne DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 256df49d1bbSPeter Collingbourne } 257df49d1bbSPeter Collingbourne static bool isEqual(const VTableSlot &LHS, 258df49d1bbSPeter Collingbourne const VTableSlot &RHS) { 2597efd7506SPeter Collingbourne return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 260df49d1bbSPeter Collingbourne } 261df49d1bbSPeter Collingbourne }; 262df49d1bbSPeter Collingbourne 263cdc71612SEugene Zelenko } // end namespace llvm 2649b656527SPeter Collingbourne 265df49d1bbSPeter Collingbourne namespace { 266df49d1bbSPeter Collingbourne 267df49d1bbSPeter Collingbourne // A virtual call site. VTable is the loaded virtual table pointer, and CS is 268df49d1bbSPeter Collingbourne // the indirect virtual call. 269df49d1bbSPeter Collingbourne struct VirtualCallSite { 270df49d1bbSPeter Collingbourne Value *VTable; 271df49d1bbSPeter Collingbourne CallSite CS; 272df49d1bbSPeter Collingbourne 2730312f614SPeter Collingbourne // If non-null, this field points to the associated unsafe use count stored in 2740312f614SPeter Collingbourne // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 2750312f614SPeter Collingbourne // of that field for details. 2760312f614SPeter Collingbourne unsigned *NumUnsafeUses; 2770312f614SPeter Collingbourne 278f3403fd2SIvan Krasin void emitRemark(const Twine &OptName, const Twine &TargetName) { 2795474645dSIvan Krasin Function *F = CS.getCaller(); 280f3403fd2SIvan Krasin emitOptimizationRemark( 281f3403fd2SIvan Krasin F->getContext(), DEBUG_TYPE, *F, 2825474645dSIvan Krasin CS.getInstruction()->getDebugLoc(), 283f3403fd2SIvan Krasin OptName + ": devirtualized a call to " + TargetName); 2845474645dSIvan Krasin } 2855474645dSIvan Krasin 286f3403fd2SIvan Krasin void replaceAndErase(const Twine &OptName, const Twine &TargetName, 287f3403fd2SIvan Krasin bool RemarksEnabled, Value *New) { 288f3403fd2SIvan Krasin if (RemarksEnabled) 289f3403fd2SIvan Krasin emitRemark(OptName, TargetName); 290df49d1bbSPeter Collingbourne CS->replaceAllUsesWith(New); 291df49d1bbSPeter Collingbourne if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { 292df49d1bbSPeter Collingbourne BranchInst::Create(II->getNormalDest(), CS.getInstruction()); 293df49d1bbSPeter Collingbourne II->getUnwindDest()->removePredecessor(II->getParent()); 294df49d1bbSPeter Collingbourne } 295df49d1bbSPeter Collingbourne CS->eraseFromParent(); 2960312f614SPeter Collingbourne // This use is no longer unsafe. 2970312f614SPeter Collingbourne if (NumUnsafeUses) 2980312f614SPeter Collingbourne --*NumUnsafeUses; 299df49d1bbSPeter Collingbourne } 300df49d1bbSPeter Collingbourne }; 301df49d1bbSPeter Collingbourne 30250cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot and possibly a list 30350cbd7ccSPeter Collingbourne // of constant integer arguments. The grouping by arguments is handled by the 30450cbd7ccSPeter Collingbourne // VTableSlotInfo class. 30550cbd7ccSPeter Collingbourne struct CallSiteInfo { 306b406baaeSPeter Collingbourne /// The set of call sites for this slot. Used during regular LTO and the 307b406baaeSPeter Collingbourne /// import phase of ThinLTO (as well as the export phase of ThinLTO for any 308b406baaeSPeter Collingbourne /// call sites that appear in the merged module itself); in each of these 309b406baaeSPeter Collingbourne /// cases we are directly operating on the call sites at the IR level. 31050cbd7ccSPeter Collingbourne std::vector<VirtualCallSite> CallSites; 311b406baaeSPeter Collingbourne 312b406baaeSPeter Collingbourne // These fields are used during the export phase of ThinLTO and reflect 313b406baaeSPeter Collingbourne // information collected from function summaries. 314b406baaeSPeter Collingbourne 3152325bb34SPeter Collingbourne /// Whether any function summary contains an llvm.assume(llvm.type.test) for 3162325bb34SPeter Collingbourne /// this slot. 3172325bb34SPeter Collingbourne bool SummaryHasTypeTestAssumeUsers; 3182325bb34SPeter Collingbourne 319b406baaeSPeter Collingbourne /// CFI-specific: a vector containing the list of function summaries that use 320b406baaeSPeter Collingbourne /// the llvm.type.checked.load intrinsic and therefore will require 321b406baaeSPeter Collingbourne /// resolutions for llvm.type.test in order to implement CFI checks if 322b406baaeSPeter Collingbourne /// devirtualization was unsuccessful. If devirtualization was successful, the 32359675ba0SPeter Collingbourne /// pass will clear this vector by calling markDevirt(). If at the end of the 32459675ba0SPeter Collingbourne /// pass the vector is non-empty, we will need to add a use of llvm.type.test 32559675ba0SPeter Collingbourne /// to each of the function summaries in the vector. 326b406baaeSPeter Collingbourne std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; 3272325bb34SPeter Collingbourne 3282325bb34SPeter Collingbourne bool isExported() const { 3292325bb34SPeter Collingbourne return SummaryHasTypeTestAssumeUsers || 3302325bb34SPeter Collingbourne !SummaryTypeCheckedLoadUsers.empty(); 3312325bb34SPeter Collingbourne } 33259675ba0SPeter Collingbourne 33359675ba0SPeter Collingbourne /// As explained in the comment for SummaryTypeCheckedLoadUsers. 33459675ba0SPeter Collingbourne void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); } 33550cbd7ccSPeter Collingbourne }; 33650cbd7ccSPeter Collingbourne 33750cbd7ccSPeter Collingbourne // Call site information collected for a specific VTableSlot. 33850cbd7ccSPeter Collingbourne struct VTableSlotInfo { 33950cbd7ccSPeter Collingbourne // The set of call sites which do not have all constant integer arguments 34050cbd7ccSPeter Collingbourne // (excluding "this"). 34150cbd7ccSPeter Collingbourne CallSiteInfo CSInfo; 34250cbd7ccSPeter Collingbourne 34350cbd7ccSPeter Collingbourne // The set of call sites with all constant integer arguments (excluding 34450cbd7ccSPeter Collingbourne // "this"), grouped by argument list. 34550cbd7ccSPeter Collingbourne std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 34650cbd7ccSPeter Collingbourne 34750cbd7ccSPeter Collingbourne void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); 34850cbd7ccSPeter Collingbourne 34950cbd7ccSPeter Collingbourne private: 35050cbd7ccSPeter Collingbourne CallSiteInfo &findCallSiteInfo(CallSite CS); 35150cbd7ccSPeter Collingbourne }; 35250cbd7ccSPeter Collingbourne 35350cbd7ccSPeter Collingbourne CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { 35450cbd7ccSPeter Collingbourne std::vector<uint64_t> Args; 35550cbd7ccSPeter Collingbourne auto *CI = dyn_cast<IntegerType>(CS.getType()); 35650cbd7ccSPeter Collingbourne if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) 35750cbd7ccSPeter Collingbourne return CSInfo; 35850cbd7ccSPeter Collingbourne for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { 35950cbd7ccSPeter Collingbourne auto *CI = dyn_cast<ConstantInt>(Arg); 36050cbd7ccSPeter Collingbourne if (!CI || CI->getBitWidth() > 64) 36150cbd7ccSPeter Collingbourne return CSInfo; 36250cbd7ccSPeter Collingbourne Args.push_back(CI->getZExtValue()); 36350cbd7ccSPeter Collingbourne } 36450cbd7ccSPeter Collingbourne return ConstCSInfo[Args]; 36550cbd7ccSPeter Collingbourne } 36650cbd7ccSPeter Collingbourne 36750cbd7ccSPeter Collingbourne void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, 36850cbd7ccSPeter Collingbourne unsigned *NumUnsafeUses) { 36950cbd7ccSPeter Collingbourne findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); 37050cbd7ccSPeter Collingbourne } 37150cbd7ccSPeter Collingbourne 372df49d1bbSPeter Collingbourne struct DevirtModule { 373df49d1bbSPeter Collingbourne Module &M; 37437317f12SPeter Collingbourne function_ref<AAResults &(Function &)> AARGetter; 3752b33f653SPeter Collingbourne 376f7691d8bSPeter Collingbourne ModuleSummaryIndex *ExportSummary; 377f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary; 3782b33f653SPeter Collingbourne 379df49d1bbSPeter Collingbourne IntegerType *Int8Ty; 380df49d1bbSPeter Collingbourne PointerType *Int8PtrTy; 381df49d1bbSPeter Collingbourne IntegerType *Int32Ty; 38250cbd7ccSPeter Collingbourne IntegerType *Int64Ty; 38314dcf02fSPeter Collingbourne IntegerType *IntPtrTy; 384df49d1bbSPeter Collingbourne 385f3403fd2SIvan Krasin bool RemarksEnabled; 386f3403fd2SIvan Krasin 38750cbd7ccSPeter Collingbourne MapVector<VTableSlot, VTableSlotInfo> CallSlots; 388df49d1bbSPeter Collingbourne 3890312f614SPeter Collingbourne // This map keeps track of the number of "unsafe" uses of a loaded function 3900312f614SPeter Collingbourne // pointer. The key is the associated llvm.type.test intrinsic call generated 3910312f614SPeter Collingbourne // by this pass. An unsafe use is one that calls the loaded function pointer 3920312f614SPeter Collingbourne // directly. Every time we eliminate an unsafe use (for example, by 3930312f614SPeter Collingbourne // devirtualizing it or by applying virtual constant propagation), we 3940312f614SPeter Collingbourne // decrement the value stored in this map. If a value reaches zero, we can 3950312f614SPeter Collingbourne // eliminate the type check by RAUWing the associated llvm.type.test call with 3960312f614SPeter Collingbourne // true. 3970312f614SPeter Collingbourne std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 3980312f614SPeter Collingbourne 39937317f12SPeter Collingbourne DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 400f7691d8bSPeter Collingbourne ModuleSummaryIndex *ExportSummary, 401f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary) 402f7691d8bSPeter Collingbourne : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), 403f7691d8bSPeter Collingbourne ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), 404df49d1bbSPeter Collingbourne Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 405f3403fd2SIvan Krasin Int32Ty(Type::getInt32Ty(M.getContext())), 40650cbd7ccSPeter Collingbourne Int64Ty(Type::getInt64Ty(M.getContext())), 40714dcf02fSPeter Collingbourne IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), 408f7691d8bSPeter Collingbourne RemarksEnabled(areRemarksEnabled()) { 409f7691d8bSPeter Collingbourne assert(!(ExportSummary && ImportSummary)); 410f7691d8bSPeter Collingbourne } 411f3403fd2SIvan Krasin 412f3403fd2SIvan Krasin bool areRemarksEnabled(); 413df49d1bbSPeter Collingbourne 4140312f614SPeter Collingbourne void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); 4150312f614SPeter Collingbourne void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 4160312f614SPeter Collingbourne 4177efd7506SPeter Collingbourne void buildTypeIdentifierMap( 4187efd7506SPeter Collingbourne std::vector<VTableBits> &Bits, 4197efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 4208786754cSPeter Collingbourne Constant *getPointerAtOffset(Constant *I, uint64_t Offset); 4217efd7506SPeter Collingbourne bool 4227efd7506SPeter Collingbourne tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 4237efd7506SPeter Collingbourne const std::set<TypeMemberInfo> &TypeMemberInfos, 424df49d1bbSPeter Collingbourne uint64_t ByteOffset); 42550cbd7ccSPeter Collingbourne 4262325bb34SPeter Collingbourne void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, 4272325bb34SPeter Collingbourne bool &IsExported); 428f3403fd2SIvan Krasin bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 4292325bb34SPeter Collingbourne VTableSlotInfo &SlotInfo, 4302325bb34SPeter Collingbourne WholeProgramDevirtResolution *Res); 43150cbd7ccSPeter Collingbourne 432df49d1bbSPeter Collingbourne bool tryEvaluateFunctionsWithArgs( 433df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 43450cbd7ccSPeter Collingbourne ArrayRef<uint64_t> Args); 43550cbd7ccSPeter Collingbourne 43650cbd7ccSPeter Collingbourne void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 43750cbd7ccSPeter Collingbourne uint64_t TheRetVal); 43850cbd7ccSPeter Collingbourne bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 43977a8d563SPeter Collingbourne CallSiteInfo &CSInfo, 44077a8d563SPeter Collingbourne WholeProgramDevirtResolution::ByArg *Res); 44150cbd7ccSPeter Collingbourne 44259675ba0SPeter Collingbourne // Returns the global symbol name that is used to export information about the 44359675ba0SPeter Collingbourne // given vtable slot and list of arguments. 44459675ba0SPeter Collingbourne std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, 44559675ba0SPeter Collingbourne StringRef Name); 44659675ba0SPeter Collingbourne 44759675ba0SPeter Collingbourne // This function is called during the export phase to create a symbol 44859675ba0SPeter Collingbourne // definition containing information about the given vtable slot and list of 44959675ba0SPeter Collingbourne // arguments. 45059675ba0SPeter Collingbourne void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 45159675ba0SPeter Collingbourne Constant *C); 45259675ba0SPeter Collingbourne 45359675ba0SPeter Collingbourne // This function is called during the import phase to create a reference to 45459675ba0SPeter Collingbourne // the symbol definition created during the export phase. 45559675ba0SPeter Collingbourne Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 45614dcf02fSPeter Collingbourne StringRef Name, unsigned AbsWidth = 0); 45759675ba0SPeter Collingbourne 45850cbd7ccSPeter Collingbourne void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 45950cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr); 460df49d1bbSPeter Collingbourne bool tryUniqueRetValOpt(unsigned BitWidth, 461f3403fd2SIvan Krasin MutableArrayRef<VirtualCallTarget> TargetsForSlot, 46259675ba0SPeter Collingbourne CallSiteInfo &CSInfo, 46359675ba0SPeter Collingbourne WholeProgramDevirtResolution::ByArg *Res, 46459675ba0SPeter Collingbourne VTableSlot Slot, ArrayRef<uint64_t> Args); 46550cbd7ccSPeter Collingbourne 46650cbd7ccSPeter Collingbourne void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 46750cbd7ccSPeter Collingbourne Constant *Byte, Constant *Bit); 468df49d1bbSPeter Collingbourne bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 46977a8d563SPeter Collingbourne VTableSlotInfo &SlotInfo, 47059675ba0SPeter Collingbourne WholeProgramDevirtResolution *Res, VTableSlot Slot); 471df49d1bbSPeter Collingbourne 472df49d1bbSPeter Collingbourne void rebuildGlobal(VTableBits &B); 473df49d1bbSPeter Collingbourne 4746d284fabSPeter Collingbourne // Apply the summary resolution for Slot to all virtual calls in SlotInfo. 4756d284fabSPeter Collingbourne void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); 4766d284fabSPeter Collingbourne 4776d284fabSPeter Collingbourne // If we were able to eliminate all unsafe uses for a type checked load, 4786d284fabSPeter Collingbourne // eliminate the associated type tests by replacing them with true. 4796d284fabSPeter Collingbourne void removeRedundantTypeTests(); 4806d284fabSPeter Collingbourne 481df49d1bbSPeter Collingbourne bool run(); 4822b33f653SPeter Collingbourne 4832b33f653SPeter Collingbourne // Lower the module using the action and summary passed as command line 4842b33f653SPeter Collingbourne // arguments. For testing purposes only. 48537317f12SPeter Collingbourne static bool runForTesting(Module &M, 48637317f12SPeter Collingbourne function_ref<AAResults &(Function &)> AARGetter); 487df49d1bbSPeter Collingbourne }; 488df49d1bbSPeter Collingbourne 489df49d1bbSPeter Collingbourne struct WholeProgramDevirt : public ModulePass { 490df49d1bbSPeter Collingbourne static char ID; 491cdc71612SEugene Zelenko 4922b33f653SPeter Collingbourne bool UseCommandLine = false; 4932b33f653SPeter Collingbourne 494f7691d8bSPeter Collingbourne ModuleSummaryIndex *ExportSummary; 495f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary; 4962b33f653SPeter Collingbourne 4972b33f653SPeter Collingbourne WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { 4982b33f653SPeter Collingbourne initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 4992b33f653SPeter Collingbourne } 5002b33f653SPeter Collingbourne 501f7691d8bSPeter Collingbourne WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, 502f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary) 503f7691d8bSPeter Collingbourne : ModulePass(ID), ExportSummary(ExportSummary), 504f7691d8bSPeter Collingbourne ImportSummary(ImportSummary) { 505df49d1bbSPeter Collingbourne initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 506df49d1bbSPeter Collingbourne } 507cdc71612SEugene Zelenko 508cdc71612SEugene Zelenko bool runOnModule(Module &M) override { 509aa641a51SAndrew Kaylor if (skipModule(M)) 510aa641a51SAndrew Kaylor return false; 5112b33f653SPeter Collingbourne if (UseCommandLine) 51237317f12SPeter Collingbourne return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); 513f7691d8bSPeter Collingbourne return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary) 514f7691d8bSPeter Collingbourne .run(); 51537317f12SPeter Collingbourne } 51637317f12SPeter Collingbourne 51737317f12SPeter Collingbourne void getAnalysisUsage(AnalysisUsage &AU) const override { 51837317f12SPeter Collingbourne AU.addRequired<AssumptionCacheTracker>(); 51937317f12SPeter Collingbourne AU.addRequired<TargetLibraryInfoWrapperPass>(); 520aa641a51SAndrew Kaylor } 521df49d1bbSPeter Collingbourne }; 522df49d1bbSPeter Collingbourne 523cdc71612SEugene Zelenko } // end anonymous namespace 524df49d1bbSPeter Collingbourne 52537317f12SPeter Collingbourne INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", 52637317f12SPeter Collingbourne "Whole program devirtualization", false, false) 52737317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 52837317f12SPeter Collingbourne INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 52937317f12SPeter Collingbourne INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", 530df49d1bbSPeter Collingbourne "Whole program devirtualization", false, false) 531df49d1bbSPeter Collingbourne char WholeProgramDevirt::ID = 0; 532df49d1bbSPeter Collingbourne 533f7691d8bSPeter Collingbourne ModulePass * 534f7691d8bSPeter Collingbourne llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, 535f7691d8bSPeter Collingbourne const ModuleSummaryIndex *ImportSummary) { 536f7691d8bSPeter Collingbourne return new WholeProgramDevirt(ExportSummary, ImportSummary); 537df49d1bbSPeter Collingbourne } 538df49d1bbSPeter Collingbourne 539164a2aa6SChandler Carruth PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 54037317f12SPeter Collingbourne ModuleAnalysisManager &AM) { 54137317f12SPeter Collingbourne auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 54237317f12SPeter Collingbourne auto AARGetter = [&](Function &F) -> AAResults & { 54337317f12SPeter Collingbourne return FAM.getResult<AAManager>(F); 54437317f12SPeter Collingbourne }; 545f7691d8bSPeter Collingbourne if (!DevirtModule(M, AARGetter, nullptr, nullptr).run()) 546d737dd2eSDavide Italiano return PreservedAnalyses::all(); 547d737dd2eSDavide Italiano return PreservedAnalyses::none(); 548d737dd2eSDavide Italiano } 549d737dd2eSDavide Italiano 55037317f12SPeter Collingbourne bool DevirtModule::runForTesting( 55137317f12SPeter Collingbourne Module &M, function_ref<AAResults &(Function &)> AARGetter) { 5522b33f653SPeter Collingbourne ModuleSummaryIndex Summary; 5532b33f653SPeter Collingbourne 5542b33f653SPeter Collingbourne // Handle the command-line summary arguments. This code is for testing 5552b33f653SPeter Collingbourne // purposes only, so we handle errors directly. 5562b33f653SPeter Collingbourne if (!ClReadSummary.empty()) { 5572b33f653SPeter Collingbourne ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 5582b33f653SPeter Collingbourne ": "); 5592b33f653SPeter Collingbourne auto ReadSummaryFile = 5602b33f653SPeter Collingbourne ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 5612b33f653SPeter Collingbourne 5622b33f653SPeter Collingbourne yaml::Input In(ReadSummaryFile->getBuffer()); 5632b33f653SPeter Collingbourne In >> Summary; 5642b33f653SPeter Collingbourne ExitOnErr(errorCodeToError(In.error())); 5652b33f653SPeter Collingbourne } 5662b33f653SPeter Collingbourne 567f7691d8bSPeter Collingbourne bool Changed = 568f7691d8bSPeter Collingbourne DevirtModule( 569f7691d8bSPeter Collingbourne M, AARGetter, 570f7691d8bSPeter Collingbourne ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, 571f7691d8bSPeter Collingbourne ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) 572f7691d8bSPeter Collingbourne .run(); 5732b33f653SPeter Collingbourne 5742b33f653SPeter Collingbourne if (!ClWriteSummary.empty()) { 5752b33f653SPeter Collingbourne ExitOnError ExitOnErr( 5762b33f653SPeter Collingbourne "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 5772b33f653SPeter Collingbourne std::error_code EC; 5782b33f653SPeter Collingbourne raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); 5792b33f653SPeter Collingbourne ExitOnErr(errorCodeToError(EC)); 5802b33f653SPeter Collingbourne 5812b33f653SPeter Collingbourne yaml::Output Out(OS); 5822b33f653SPeter Collingbourne Out << Summary; 5832b33f653SPeter Collingbourne } 5842b33f653SPeter Collingbourne 5852b33f653SPeter Collingbourne return Changed; 5862b33f653SPeter Collingbourne } 5872b33f653SPeter Collingbourne 5887efd7506SPeter Collingbourne void DevirtModule::buildTypeIdentifierMap( 589df49d1bbSPeter Collingbourne std::vector<VTableBits> &Bits, 5907efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 591df49d1bbSPeter Collingbourne DenseMap<GlobalVariable *, VTableBits *> GVToBits; 5927efd7506SPeter Collingbourne Bits.reserve(M.getGlobalList().size()); 5937efd7506SPeter Collingbourne SmallVector<MDNode *, 2> Types; 5947efd7506SPeter Collingbourne for (GlobalVariable &GV : M.globals()) { 5957efd7506SPeter Collingbourne Types.clear(); 5967efd7506SPeter Collingbourne GV.getMetadata(LLVMContext::MD_type, Types); 5977efd7506SPeter Collingbourne if (Types.empty()) 598df49d1bbSPeter Collingbourne continue; 599df49d1bbSPeter Collingbourne 6007efd7506SPeter Collingbourne VTableBits *&BitsPtr = GVToBits[&GV]; 6017efd7506SPeter Collingbourne if (!BitsPtr) { 6027efd7506SPeter Collingbourne Bits.emplace_back(); 6037efd7506SPeter Collingbourne Bits.back().GV = &GV; 6047efd7506SPeter Collingbourne Bits.back().ObjectSize = 6057efd7506SPeter Collingbourne M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 6067efd7506SPeter Collingbourne BitsPtr = &Bits.back(); 6077efd7506SPeter Collingbourne } 6087efd7506SPeter Collingbourne 6097efd7506SPeter Collingbourne for (MDNode *Type : Types) { 6107efd7506SPeter Collingbourne auto TypeID = Type->getOperand(1).get(); 611df49d1bbSPeter Collingbourne 612df49d1bbSPeter Collingbourne uint64_t Offset = 613df49d1bbSPeter Collingbourne cast<ConstantInt>( 6147efd7506SPeter Collingbourne cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 615df49d1bbSPeter Collingbourne ->getZExtValue(); 616df49d1bbSPeter Collingbourne 6177efd7506SPeter Collingbourne TypeIdMap[TypeID].insert({BitsPtr, Offset}); 618df49d1bbSPeter Collingbourne } 619df49d1bbSPeter Collingbourne } 620df49d1bbSPeter Collingbourne } 621df49d1bbSPeter Collingbourne 6228786754cSPeter Collingbourne Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) { 6238786754cSPeter Collingbourne if (I->getType()->isPointerTy()) { 6248786754cSPeter Collingbourne if (Offset == 0) 6258786754cSPeter Collingbourne return I; 6268786754cSPeter Collingbourne return nullptr; 6278786754cSPeter Collingbourne } 6288786754cSPeter Collingbourne 6297a1e5bbeSPeter Collingbourne const DataLayout &DL = M.getDataLayout(); 6307a1e5bbeSPeter Collingbourne 6317a1e5bbeSPeter Collingbourne if (auto *C = dyn_cast<ConstantStruct>(I)) { 6327a1e5bbeSPeter Collingbourne const StructLayout *SL = DL.getStructLayout(C->getType()); 6337a1e5bbeSPeter Collingbourne if (Offset >= SL->getSizeInBytes()) 6347a1e5bbeSPeter Collingbourne return nullptr; 6357a1e5bbeSPeter Collingbourne 6368786754cSPeter Collingbourne unsigned Op = SL->getElementContainingOffset(Offset); 6378786754cSPeter Collingbourne return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), 6388786754cSPeter Collingbourne Offset - SL->getElementOffset(Op)); 6398786754cSPeter Collingbourne } 6408786754cSPeter Collingbourne if (auto *C = dyn_cast<ConstantArray>(I)) { 6417a1e5bbeSPeter Collingbourne ArrayType *VTableTy = C->getType(); 6427a1e5bbeSPeter Collingbourne uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); 6437a1e5bbeSPeter Collingbourne 6448786754cSPeter Collingbourne unsigned Op = Offset / ElemSize; 6457a1e5bbeSPeter Collingbourne if (Op >= C->getNumOperands()) 6467a1e5bbeSPeter Collingbourne return nullptr; 6477a1e5bbeSPeter Collingbourne 6488786754cSPeter Collingbourne return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), 6498786754cSPeter Collingbourne Offset % ElemSize); 6508786754cSPeter Collingbourne } 6518786754cSPeter Collingbourne return nullptr; 6527a1e5bbeSPeter Collingbourne } 6537a1e5bbeSPeter Collingbourne 654df49d1bbSPeter Collingbourne bool DevirtModule::tryFindVirtualCallTargets( 655df49d1bbSPeter Collingbourne std::vector<VirtualCallTarget> &TargetsForSlot, 6567efd7506SPeter Collingbourne const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { 6577efd7506SPeter Collingbourne for (const TypeMemberInfo &TM : TypeMemberInfos) { 6587efd7506SPeter Collingbourne if (!TM.Bits->GV->isConstant()) 659df49d1bbSPeter Collingbourne return false; 660df49d1bbSPeter Collingbourne 6618786754cSPeter Collingbourne Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), 6628786754cSPeter Collingbourne TM.Offset + ByteOffset); 6638786754cSPeter Collingbourne if (!Ptr) 664df49d1bbSPeter Collingbourne return false; 665df49d1bbSPeter Collingbourne 6668786754cSPeter Collingbourne auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); 667df49d1bbSPeter Collingbourne if (!Fn) 668df49d1bbSPeter Collingbourne return false; 669df49d1bbSPeter Collingbourne 670df49d1bbSPeter Collingbourne // We can disregard __cxa_pure_virtual as a possible call target, as 671df49d1bbSPeter Collingbourne // calls to pure virtuals are UB. 672df49d1bbSPeter Collingbourne if (Fn->getName() == "__cxa_pure_virtual") 673df49d1bbSPeter Collingbourne continue; 674df49d1bbSPeter Collingbourne 6757efd7506SPeter Collingbourne TargetsForSlot.push_back({Fn, &TM}); 676df49d1bbSPeter Collingbourne } 677df49d1bbSPeter Collingbourne 678df49d1bbSPeter Collingbourne // Give up if we couldn't find any targets. 679df49d1bbSPeter Collingbourne return !TargetsForSlot.empty(); 680df49d1bbSPeter Collingbourne } 681df49d1bbSPeter Collingbourne 68250cbd7ccSPeter Collingbourne void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 6832325bb34SPeter Collingbourne Constant *TheFn, bool &IsExported) { 68450cbd7ccSPeter Collingbourne auto Apply = [&](CallSiteInfo &CSInfo) { 68550cbd7ccSPeter Collingbourne for (auto &&VCallSite : CSInfo.CallSites) { 686f3403fd2SIvan Krasin if (RemarksEnabled) 687f3403fd2SIvan Krasin VCallSite.emitRemark("single-impl", TheFn->getName()); 688df49d1bbSPeter Collingbourne VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( 689df49d1bbSPeter Collingbourne TheFn, VCallSite.CS.getCalledValue()->getType())); 6900312f614SPeter Collingbourne // This use is no longer unsafe. 6910312f614SPeter Collingbourne if (VCallSite.NumUnsafeUses) 6920312f614SPeter Collingbourne --*VCallSite.NumUnsafeUses; 693df49d1bbSPeter Collingbourne } 6942325bb34SPeter Collingbourne if (CSInfo.isExported()) { 6952325bb34SPeter Collingbourne IsExported = true; 69659675ba0SPeter Collingbourne CSInfo.markDevirt(); 6972325bb34SPeter Collingbourne } 69850cbd7ccSPeter Collingbourne }; 69950cbd7ccSPeter Collingbourne Apply(SlotInfo.CSInfo); 70050cbd7ccSPeter Collingbourne for (auto &P : SlotInfo.ConstCSInfo) 70150cbd7ccSPeter Collingbourne Apply(P.second); 70250cbd7ccSPeter Collingbourne } 70350cbd7ccSPeter Collingbourne 70450cbd7ccSPeter Collingbourne bool DevirtModule::trySingleImplDevirt( 70550cbd7ccSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 7062325bb34SPeter Collingbourne VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) { 70750cbd7ccSPeter Collingbourne // See if the program contains a single implementation of this virtual 70850cbd7ccSPeter Collingbourne // function. 70950cbd7ccSPeter Collingbourne Function *TheFn = TargetsForSlot[0].Fn; 71050cbd7ccSPeter Collingbourne for (auto &&Target : TargetsForSlot) 71150cbd7ccSPeter Collingbourne if (TheFn != Target.Fn) 71250cbd7ccSPeter Collingbourne return false; 71350cbd7ccSPeter Collingbourne 71450cbd7ccSPeter Collingbourne // If so, update each call site to call that implementation directly. 71550cbd7ccSPeter Collingbourne if (RemarksEnabled) 71650cbd7ccSPeter Collingbourne TargetsForSlot[0].WasDevirt = true; 7172325bb34SPeter Collingbourne 7182325bb34SPeter Collingbourne bool IsExported = false; 7192325bb34SPeter Collingbourne applySingleImplDevirt(SlotInfo, TheFn, IsExported); 7202325bb34SPeter Collingbourne if (!IsExported) 7212325bb34SPeter Collingbourne return false; 7222325bb34SPeter Collingbourne 7232325bb34SPeter Collingbourne // If the only implementation has local linkage, we must promote to external 7242325bb34SPeter Collingbourne // to make it visible to thin LTO objects. We can only get here during the 7252325bb34SPeter Collingbourne // ThinLTO export phase. 7262325bb34SPeter Collingbourne if (TheFn->hasLocalLinkage()) { 7272325bb34SPeter Collingbourne TheFn->setLinkage(GlobalValue::ExternalLinkage); 7282325bb34SPeter Collingbourne TheFn->setVisibility(GlobalValue::HiddenVisibility); 7292325bb34SPeter Collingbourne TheFn->setName(TheFn->getName() + "$merged"); 7302325bb34SPeter Collingbourne } 7312325bb34SPeter Collingbourne 7322325bb34SPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 7332325bb34SPeter Collingbourne Res->SingleImplName = TheFn->getName(); 7342325bb34SPeter Collingbourne 735df49d1bbSPeter Collingbourne return true; 736df49d1bbSPeter Collingbourne } 737df49d1bbSPeter Collingbourne 738df49d1bbSPeter Collingbourne bool DevirtModule::tryEvaluateFunctionsWithArgs( 739df49d1bbSPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, 74050cbd7ccSPeter Collingbourne ArrayRef<uint64_t> Args) { 741df49d1bbSPeter Collingbourne // Evaluate each function and store the result in each target's RetVal 742df49d1bbSPeter Collingbourne // field. 743df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : TargetsForSlot) { 744df49d1bbSPeter Collingbourne if (Target.Fn->arg_size() != Args.size() + 1) 745df49d1bbSPeter Collingbourne return false; 746df49d1bbSPeter Collingbourne 747df49d1bbSPeter Collingbourne Evaluator Eval(M.getDataLayout(), nullptr); 748df49d1bbSPeter Collingbourne SmallVector<Constant *, 2> EvalArgs; 749df49d1bbSPeter Collingbourne EvalArgs.push_back( 750df49d1bbSPeter Collingbourne Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 75150cbd7ccSPeter Collingbourne for (unsigned I = 0; I != Args.size(); ++I) { 75250cbd7ccSPeter Collingbourne auto *ArgTy = dyn_cast<IntegerType>( 75350cbd7ccSPeter Collingbourne Target.Fn->getFunctionType()->getParamType(I + 1)); 75450cbd7ccSPeter Collingbourne if (!ArgTy) 75550cbd7ccSPeter Collingbourne return false; 75650cbd7ccSPeter Collingbourne EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 75750cbd7ccSPeter Collingbourne } 75850cbd7ccSPeter Collingbourne 759df49d1bbSPeter Collingbourne Constant *RetVal; 760df49d1bbSPeter Collingbourne if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 761df49d1bbSPeter Collingbourne !isa<ConstantInt>(RetVal)) 762df49d1bbSPeter Collingbourne return false; 763df49d1bbSPeter Collingbourne Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 764df49d1bbSPeter Collingbourne } 765df49d1bbSPeter Collingbourne return true; 766df49d1bbSPeter Collingbourne } 767df49d1bbSPeter Collingbourne 76850cbd7ccSPeter Collingbourne void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 76950cbd7ccSPeter Collingbourne uint64_t TheRetVal) { 77050cbd7ccSPeter Collingbourne for (auto Call : CSInfo.CallSites) 77150cbd7ccSPeter Collingbourne Call.replaceAndErase( 77250cbd7ccSPeter Collingbourne "uniform-ret-val", FnName, RemarksEnabled, 77350cbd7ccSPeter Collingbourne ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); 77459675ba0SPeter Collingbourne CSInfo.markDevirt(); 77550cbd7ccSPeter Collingbourne } 77650cbd7ccSPeter Collingbourne 777df49d1bbSPeter Collingbourne bool DevirtModule::tryUniformRetValOpt( 77877a8d563SPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, 77977a8d563SPeter Collingbourne WholeProgramDevirtResolution::ByArg *Res) { 780df49d1bbSPeter Collingbourne // Uniform return value optimization. If all functions return the same 781df49d1bbSPeter Collingbourne // constant, replace all calls with that constant. 782df49d1bbSPeter Collingbourne uint64_t TheRetVal = TargetsForSlot[0].RetVal; 783df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : TargetsForSlot) 784df49d1bbSPeter Collingbourne if (Target.RetVal != TheRetVal) 785df49d1bbSPeter Collingbourne return false; 786df49d1bbSPeter Collingbourne 78777a8d563SPeter Collingbourne if (CSInfo.isExported()) { 78877a8d563SPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; 78977a8d563SPeter Collingbourne Res->Info = TheRetVal; 79077a8d563SPeter Collingbourne } 79177a8d563SPeter Collingbourne 79250cbd7ccSPeter Collingbourne applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 793f3403fd2SIvan Krasin if (RemarksEnabled) 794f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 795f3403fd2SIvan Krasin Target.WasDevirt = true; 796df49d1bbSPeter Collingbourne return true; 797df49d1bbSPeter Collingbourne } 798df49d1bbSPeter Collingbourne 79959675ba0SPeter Collingbourne std::string DevirtModule::getGlobalName(VTableSlot Slot, 80059675ba0SPeter Collingbourne ArrayRef<uint64_t> Args, 80159675ba0SPeter Collingbourne StringRef Name) { 80259675ba0SPeter Collingbourne std::string FullName = "__typeid_"; 80359675ba0SPeter Collingbourne raw_string_ostream OS(FullName); 80459675ba0SPeter Collingbourne OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; 80559675ba0SPeter Collingbourne for (uint64_t Arg : Args) 80659675ba0SPeter Collingbourne OS << '_' << Arg; 80759675ba0SPeter Collingbourne OS << '_' << Name; 80859675ba0SPeter Collingbourne return OS.str(); 80959675ba0SPeter Collingbourne } 81059675ba0SPeter Collingbourne 81159675ba0SPeter Collingbourne void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 81259675ba0SPeter Collingbourne StringRef Name, Constant *C) { 81359675ba0SPeter Collingbourne GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, 81459675ba0SPeter Collingbourne getGlobalName(Slot, Args, Name), C, &M); 81559675ba0SPeter Collingbourne GA->setVisibility(GlobalValue::HiddenVisibility); 81659675ba0SPeter Collingbourne } 81759675ba0SPeter Collingbourne 81859675ba0SPeter Collingbourne Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 81914dcf02fSPeter Collingbourne StringRef Name, unsigned AbsWidth) { 82059675ba0SPeter Collingbourne Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); 82159675ba0SPeter Collingbourne auto *GV = dyn_cast<GlobalVariable>(C); 82214dcf02fSPeter Collingbourne // We only need to set metadata if the global is newly created, in which 82314dcf02fSPeter Collingbourne // case it would not have hidden visibility. 82414dcf02fSPeter Collingbourne if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) 82559675ba0SPeter Collingbourne return C; 82614dcf02fSPeter Collingbourne 82759675ba0SPeter Collingbourne GV->setVisibility(GlobalValue::HiddenVisibility); 82814dcf02fSPeter Collingbourne auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { 82914dcf02fSPeter Collingbourne auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); 83014dcf02fSPeter Collingbourne auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); 83114dcf02fSPeter Collingbourne GV->setMetadata(LLVMContext::MD_absolute_symbol, 83214dcf02fSPeter Collingbourne MDNode::get(M.getContext(), {MinC, MaxC})); 83314dcf02fSPeter Collingbourne }; 83414dcf02fSPeter Collingbourne if (AbsWidth == IntPtrTy->getBitWidth()) 83514dcf02fSPeter Collingbourne SetAbsRange(~0ull, ~0ull); // Full set. 83614dcf02fSPeter Collingbourne else if (AbsWidth) 83714dcf02fSPeter Collingbourne SetAbsRange(0, 1ull << AbsWidth); 83859675ba0SPeter Collingbourne return GV; 83959675ba0SPeter Collingbourne } 84059675ba0SPeter Collingbourne 84150cbd7ccSPeter Collingbourne void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 84250cbd7ccSPeter Collingbourne bool IsOne, 84350cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr) { 84450cbd7ccSPeter Collingbourne for (auto &&Call : CSInfo.CallSites) { 84550cbd7ccSPeter Collingbourne IRBuilder<> B(Call.CS.getInstruction()); 84650cbd7ccSPeter Collingbourne Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, 84750cbd7ccSPeter Collingbourne Call.VTable, UniqueMemberAddr); 84850cbd7ccSPeter Collingbourne Cmp = B.CreateZExt(Cmp, Call.CS->getType()); 84950cbd7ccSPeter Collingbourne Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp); 85050cbd7ccSPeter Collingbourne } 85159675ba0SPeter Collingbourne CSInfo.markDevirt(); 85250cbd7ccSPeter Collingbourne } 85350cbd7ccSPeter Collingbourne 854df49d1bbSPeter Collingbourne bool DevirtModule::tryUniqueRetValOpt( 855f3403fd2SIvan Krasin unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 85659675ba0SPeter Collingbourne CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, 85759675ba0SPeter Collingbourne VTableSlot Slot, ArrayRef<uint64_t> Args) { 858df49d1bbSPeter Collingbourne // IsOne controls whether we look for a 0 or a 1. 859df49d1bbSPeter Collingbourne auto tryUniqueRetValOptFor = [&](bool IsOne) { 860cdc71612SEugene Zelenko const TypeMemberInfo *UniqueMember = nullptr; 861df49d1bbSPeter Collingbourne for (const VirtualCallTarget &Target : TargetsForSlot) { 8623866cc5fSPeter Collingbourne if (Target.RetVal == (IsOne ? 1 : 0)) { 8637efd7506SPeter Collingbourne if (UniqueMember) 864df49d1bbSPeter Collingbourne return false; 8657efd7506SPeter Collingbourne UniqueMember = Target.TM; 866df49d1bbSPeter Collingbourne } 867df49d1bbSPeter Collingbourne } 868df49d1bbSPeter Collingbourne 8697efd7506SPeter Collingbourne // We should have found a unique member or bailed out by now. We already 870df49d1bbSPeter Collingbourne // checked for a uniform return value in tryUniformRetValOpt. 8717efd7506SPeter Collingbourne assert(UniqueMember); 872df49d1bbSPeter Collingbourne 87350cbd7ccSPeter Collingbourne Constant *UniqueMemberAddr = 87450cbd7ccSPeter Collingbourne ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); 87550cbd7ccSPeter Collingbourne UniqueMemberAddr = ConstantExpr::getGetElementPtr( 87650cbd7ccSPeter Collingbourne Int8Ty, UniqueMemberAddr, 87750cbd7ccSPeter Collingbourne ConstantInt::get(Int64Ty, UniqueMember->Offset)); 87850cbd7ccSPeter Collingbourne 87959675ba0SPeter Collingbourne if (CSInfo.isExported()) { 88059675ba0SPeter Collingbourne Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; 88159675ba0SPeter Collingbourne Res->Info = IsOne; 88259675ba0SPeter Collingbourne 88359675ba0SPeter Collingbourne exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); 88459675ba0SPeter Collingbourne } 88559675ba0SPeter Collingbourne 88659675ba0SPeter Collingbourne // Replace each call with the comparison. 88750cbd7ccSPeter Collingbourne applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 88850cbd7ccSPeter Collingbourne UniqueMemberAddr); 88950cbd7ccSPeter Collingbourne 890f3403fd2SIvan Krasin // Update devirtualization statistics for targets. 891f3403fd2SIvan Krasin if (RemarksEnabled) 892f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 893f3403fd2SIvan Krasin Target.WasDevirt = true; 894f3403fd2SIvan Krasin 895df49d1bbSPeter Collingbourne return true; 896df49d1bbSPeter Collingbourne }; 897df49d1bbSPeter Collingbourne 898df49d1bbSPeter Collingbourne if (BitWidth == 1) { 899df49d1bbSPeter Collingbourne if (tryUniqueRetValOptFor(true)) 900df49d1bbSPeter Collingbourne return true; 901df49d1bbSPeter Collingbourne if (tryUniqueRetValOptFor(false)) 902df49d1bbSPeter Collingbourne return true; 903df49d1bbSPeter Collingbourne } 904df49d1bbSPeter Collingbourne return false; 905df49d1bbSPeter Collingbourne } 906df49d1bbSPeter Collingbourne 90750cbd7ccSPeter Collingbourne void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 90850cbd7ccSPeter Collingbourne Constant *Byte, Constant *Bit) { 90950cbd7ccSPeter Collingbourne for (auto Call : CSInfo.CallSites) { 91050cbd7ccSPeter Collingbourne auto *RetType = cast<IntegerType>(Call.CS.getType()); 91150cbd7ccSPeter Collingbourne IRBuilder<> B(Call.CS.getInstruction()); 91250cbd7ccSPeter Collingbourne Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); 91350cbd7ccSPeter Collingbourne if (RetType->getBitWidth() == 1) { 91450cbd7ccSPeter Collingbourne Value *Bits = B.CreateLoad(Addr); 91550cbd7ccSPeter Collingbourne Value *BitsAndBit = B.CreateAnd(Bits, Bit); 91650cbd7ccSPeter Collingbourne auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 91750cbd7ccSPeter Collingbourne Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 91850cbd7ccSPeter Collingbourne IsBitSet); 91950cbd7ccSPeter Collingbourne } else { 92050cbd7ccSPeter Collingbourne Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 92150cbd7ccSPeter Collingbourne Value *Val = B.CreateLoad(RetType, ValAddr); 92250cbd7ccSPeter Collingbourne Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val); 92350cbd7ccSPeter Collingbourne } 92450cbd7ccSPeter Collingbourne } 92514dcf02fSPeter Collingbourne CSInfo.markDevirt(); 92650cbd7ccSPeter Collingbourne } 92750cbd7ccSPeter Collingbourne 928df49d1bbSPeter Collingbourne bool DevirtModule::tryVirtualConstProp( 92959675ba0SPeter Collingbourne MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 93059675ba0SPeter Collingbourne WholeProgramDevirtResolution *Res, VTableSlot Slot) { 931df49d1bbSPeter Collingbourne // This only works if the function returns an integer. 932df49d1bbSPeter Collingbourne auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 933df49d1bbSPeter Collingbourne if (!RetType) 934df49d1bbSPeter Collingbourne return false; 935df49d1bbSPeter Collingbourne unsigned BitWidth = RetType->getBitWidth(); 936df49d1bbSPeter Collingbourne if (BitWidth > 64) 937df49d1bbSPeter Collingbourne return false; 938df49d1bbSPeter Collingbourne 93917febdbbSPeter Collingbourne // Make sure that each function is defined, does not access memory, takes at 94017febdbbSPeter Collingbourne // least one argument, does not use its first argument (which we assume is 94117febdbbSPeter Collingbourne // 'this'), and has the same return type. 94237317f12SPeter Collingbourne // 94337317f12SPeter Collingbourne // Note that we test whether this copy of the function is readnone, rather 94437317f12SPeter Collingbourne // than testing function attributes, which must hold for any copy of the 94537317f12SPeter Collingbourne // function, even a less optimized version substituted at link time. This is 94637317f12SPeter Collingbourne // sound because the virtual constant propagation optimizations effectively 94737317f12SPeter Collingbourne // inline all implementations of the virtual function into each call site, 94837317f12SPeter Collingbourne // rather than using function attributes to perform local optimization. 949df49d1bbSPeter Collingbourne for (VirtualCallTarget &Target : TargetsForSlot) { 95037317f12SPeter Collingbourne if (Target.Fn->isDeclaration() || 95137317f12SPeter Collingbourne computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != 95237317f12SPeter Collingbourne MAK_ReadNone || 95317febdbbSPeter Collingbourne Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || 954df49d1bbSPeter Collingbourne Target.Fn->getReturnType() != RetType) 955df49d1bbSPeter Collingbourne return false; 956df49d1bbSPeter Collingbourne } 957df49d1bbSPeter Collingbourne 95850cbd7ccSPeter Collingbourne for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 959df49d1bbSPeter Collingbourne if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 960df49d1bbSPeter Collingbourne continue; 961df49d1bbSPeter Collingbourne 96277a8d563SPeter Collingbourne WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; 96377a8d563SPeter Collingbourne if (Res) 96477a8d563SPeter Collingbourne ResByArg = &Res->ResByArg[CSByConstantArg.first]; 96577a8d563SPeter Collingbourne 96677a8d563SPeter Collingbourne if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) 967df49d1bbSPeter Collingbourne continue; 968df49d1bbSPeter Collingbourne 96959675ba0SPeter Collingbourne if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, 97059675ba0SPeter Collingbourne ResByArg, Slot, CSByConstantArg.first)) 971df49d1bbSPeter Collingbourne continue; 972df49d1bbSPeter Collingbourne 9737efd7506SPeter Collingbourne // Find an allocation offset in bits in all vtables associated with the 9747efd7506SPeter Collingbourne // type. 975df49d1bbSPeter Collingbourne uint64_t AllocBefore = 976df49d1bbSPeter Collingbourne findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 977df49d1bbSPeter Collingbourne uint64_t AllocAfter = 978df49d1bbSPeter Collingbourne findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 979df49d1bbSPeter Collingbourne 980df49d1bbSPeter Collingbourne // Calculate the total amount of padding needed to store a value at both 981df49d1bbSPeter Collingbourne // ends of the object. 982df49d1bbSPeter Collingbourne uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 983df49d1bbSPeter Collingbourne for (auto &&Target : TargetsForSlot) { 984df49d1bbSPeter Collingbourne TotalPaddingBefore += std::max<int64_t>( 985df49d1bbSPeter Collingbourne (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 986df49d1bbSPeter Collingbourne TotalPaddingAfter += std::max<int64_t>( 987df49d1bbSPeter Collingbourne (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 988df49d1bbSPeter Collingbourne } 989df49d1bbSPeter Collingbourne 990df49d1bbSPeter Collingbourne // If the amount of padding is too large, give up. 991df49d1bbSPeter Collingbourne // FIXME: do something smarter here. 992df49d1bbSPeter Collingbourne if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 993df49d1bbSPeter Collingbourne continue; 994df49d1bbSPeter Collingbourne 995df49d1bbSPeter Collingbourne // Calculate the offset to the value as a (possibly negative) byte offset 996df49d1bbSPeter Collingbourne // and (if applicable) a bit offset, and store the values in the targets. 997df49d1bbSPeter Collingbourne int64_t OffsetByte; 998df49d1bbSPeter Collingbourne uint64_t OffsetBit; 999df49d1bbSPeter Collingbourne if (TotalPaddingBefore <= TotalPaddingAfter) 1000df49d1bbSPeter Collingbourne setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 1001df49d1bbSPeter Collingbourne OffsetBit); 1002df49d1bbSPeter Collingbourne else 1003df49d1bbSPeter Collingbourne setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 1004df49d1bbSPeter Collingbourne OffsetBit); 1005df49d1bbSPeter Collingbourne 1006f3403fd2SIvan Krasin if (RemarksEnabled) 1007f3403fd2SIvan Krasin for (auto &&Target : TargetsForSlot) 1008f3403fd2SIvan Krasin Target.WasDevirt = true; 1009f3403fd2SIvan Krasin 1010184773d8SPeter Collingbourne Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 101150cbd7ccSPeter Collingbourne Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 101214dcf02fSPeter Collingbourne 101314dcf02fSPeter Collingbourne if (CSByConstantArg.second.isExported()) { 101414dcf02fSPeter Collingbourne ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; 101514dcf02fSPeter Collingbourne exportGlobal(Slot, CSByConstantArg.first, "byte", 101614dcf02fSPeter Collingbourne ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy)); 101714dcf02fSPeter Collingbourne exportGlobal(Slot, CSByConstantArg.first, "bit", 101814dcf02fSPeter Collingbourne ConstantExpr::getIntToPtr(BitConst, Int8PtrTy)); 101914dcf02fSPeter Collingbourne } 102014dcf02fSPeter Collingbourne 102114dcf02fSPeter Collingbourne // Rewrite each call to a load from OffsetByte/OffsetBit. 102250cbd7ccSPeter Collingbourne applyVirtualConstProp(CSByConstantArg.second, 102350cbd7ccSPeter Collingbourne TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 1024df49d1bbSPeter Collingbourne } 1025df49d1bbSPeter Collingbourne return true; 1026df49d1bbSPeter Collingbourne } 1027df49d1bbSPeter Collingbourne 1028df49d1bbSPeter Collingbourne void DevirtModule::rebuildGlobal(VTableBits &B) { 1029df49d1bbSPeter Collingbourne if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 1030df49d1bbSPeter Collingbourne return; 1031df49d1bbSPeter Collingbourne 1032df49d1bbSPeter Collingbourne // Align each byte array to pointer width. 1033df49d1bbSPeter Collingbourne unsigned PointerSize = M.getDataLayout().getPointerSize(); 1034df49d1bbSPeter Collingbourne B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); 1035df49d1bbSPeter Collingbourne B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); 1036df49d1bbSPeter Collingbourne 1037df49d1bbSPeter Collingbourne // Before was stored in reverse order; flip it now. 1038df49d1bbSPeter Collingbourne for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 1039df49d1bbSPeter Collingbourne std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 1040df49d1bbSPeter Collingbourne 1041df49d1bbSPeter Collingbourne // Build an anonymous global containing the before bytes, followed by the 1042df49d1bbSPeter Collingbourne // original initializer, followed by the after bytes. 1043df49d1bbSPeter Collingbourne auto NewInit = ConstantStruct::getAnon( 1044df49d1bbSPeter Collingbourne {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 1045df49d1bbSPeter Collingbourne B.GV->getInitializer(), 1046df49d1bbSPeter Collingbourne ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 1047df49d1bbSPeter Collingbourne auto NewGV = 1048df49d1bbSPeter Collingbourne new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 1049df49d1bbSPeter Collingbourne GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 1050df49d1bbSPeter Collingbourne NewGV->setSection(B.GV->getSection()); 1051df49d1bbSPeter Collingbourne NewGV->setComdat(B.GV->getComdat()); 1052df49d1bbSPeter Collingbourne 10530312f614SPeter Collingbourne // Copy the original vtable's metadata to the anonymous global, adjusting 10540312f614SPeter Collingbourne // offsets as required. 10550312f614SPeter Collingbourne NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 10560312f614SPeter Collingbourne 1057df49d1bbSPeter Collingbourne // Build an alias named after the original global, pointing at the second 1058df49d1bbSPeter Collingbourne // element (the original initializer). 1059df49d1bbSPeter Collingbourne auto Alias = GlobalAlias::create( 1060df49d1bbSPeter Collingbourne B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 1061df49d1bbSPeter Collingbourne ConstantExpr::getGetElementPtr( 1062df49d1bbSPeter Collingbourne NewInit->getType(), NewGV, 1063df49d1bbSPeter Collingbourne ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 1064df49d1bbSPeter Collingbourne ConstantInt::get(Int32Ty, 1)}), 1065df49d1bbSPeter Collingbourne &M); 1066df49d1bbSPeter Collingbourne Alias->setVisibility(B.GV->getVisibility()); 1067df49d1bbSPeter Collingbourne Alias->takeName(B.GV); 1068df49d1bbSPeter Collingbourne 1069df49d1bbSPeter Collingbourne B.GV->replaceAllUsesWith(Alias); 1070df49d1bbSPeter Collingbourne B.GV->eraseFromParent(); 1071df49d1bbSPeter Collingbourne } 1072df49d1bbSPeter Collingbourne 1073f3403fd2SIvan Krasin bool DevirtModule::areRemarksEnabled() { 1074f3403fd2SIvan Krasin const auto &FL = M.getFunctionList(); 1075f3403fd2SIvan Krasin if (FL.empty()) 1076f3403fd2SIvan Krasin return false; 1077f3403fd2SIvan Krasin const Function &Fn = FL.front(); 1078de53bfb9SAdam Nemet 1079de53bfb9SAdam Nemet const auto &BBL = Fn.getBasicBlockList(); 1080de53bfb9SAdam Nemet if (BBL.empty()) 1081de53bfb9SAdam Nemet return false; 1082de53bfb9SAdam Nemet auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); 1083f3403fd2SIvan Krasin return DI.isEnabled(); 1084f3403fd2SIvan Krasin } 1085f3403fd2SIvan Krasin 10860312f614SPeter Collingbourne void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, 10870312f614SPeter Collingbourne Function *AssumeFunc) { 1088df49d1bbSPeter Collingbourne // Find all virtual calls via a virtual table pointer %p under an assumption 10897efd7506SPeter Collingbourne // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 10907efd7506SPeter Collingbourne // points to a member of the type identifier %md. Group calls by (type ID, 10917efd7506SPeter Collingbourne // offset) pair (effectively the identity of the virtual function) and store 10927efd7506SPeter Collingbourne // to CallSlots. 1093df49d1bbSPeter Collingbourne DenseSet<Value *> SeenPtrs; 10947efd7506SPeter Collingbourne for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); 1095df49d1bbSPeter Collingbourne I != E;) { 1096df49d1bbSPeter Collingbourne auto CI = dyn_cast<CallInst>(I->getUser()); 1097df49d1bbSPeter Collingbourne ++I; 1098df49d1bbSPeter Collingbourne if (!CI) 1099df49d1bbSPeter Collingbourne continue; 1100df49d1bbSPeter Collingbourne 1101ccdc225cSPeter Collingbourne // Search for virtual calls based on %p and add them to DevirtCalls. 1102ccdc225cSPeter Collingbourne SmallVector<DevirtCallSite, 1> DevirtCalls; 1103df49d1bbSPeter Collingbourne SmallVector<CallInst *, 1> Assumes; 11040312f614SPeter Collingbourne findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); 1105df49d1bbSPeter Collingbourne 1106ccdc225cSPeter Collingbourne // If we found any, add them to CallSlots. Only do this if we haven't seen 1107ccdc225cSPeter Collingbourne // the vtable pointer before, as it may have been CSE'd with pointers from 1108ccdc225cSPeter Collingbourne // other call sites, and we don't want to process call sites multiple times. 1109df49d1bbSPeter Collingbourne if (!Assumes.empty()) { 11107efd7506SPeter Collingbourne Metadata *TypeId = 1111df49d1bbSPeter Collingbourne cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 1112df49d1bbSPeter Collingbourne Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 1113ccdc225cSPeter Collingbourne if (SeenPtrs.insert(Ptr).second) { 1114ccdc225cSPeter Collingbourne for (DevirtCallSite Call : DevirtCalls) { 111550cbd7ccSPeter Collingbourne CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), 111650cbd7ccSPeter Collingbourne Call.CS, nullptr); 1117ccdc225cSPeter Collingbourne } 1118ccdc225cSPeter Collingbourne } 1119df49d1bbSPeter Collingbourne } 1120df49d1bbSPeter Collingbourne 11217efd7506SPeter Collingbourne // We no longer need the assumes or the type test. 1122df49d1bbSPeter Collingbourne for (auto Assume : Assumes) 1123df49d1bbSPeter Collingbourne Assume->eraseFromParent(); 1124df49d1bbSPeter Collingbourne // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 1125df49d1bbSPeter Collingbourne // may use the vtable argument later. 1126df49d1bbSPeter Collingbourne if (CI->use_empty()) 1127df49d1bbSPeter Collingbourne CI->eraseFromParent(); 1128df49d1bbSPeter Collingbourne } 11290312f614SPeter Collingbourne } 11300312f614SPeter Collingbourne 11310312f614SPeter Collingbourne void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 11320312f614SPeter Collingbourne Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 11330312f614SPeter Collingbourne 11340312f614SPeter Collingbourne for (auto I = TypeCheckedLoadFunc->use_begin(), 11350312f614SPeter Collingbourne E = TypeCheckedLoadFunc->use_end(); 11360312f614SPeter Collingbourne I != E;) { 11370312f614SPeter Collingbourne auto CI = dyn_cast<CallInst>(I->getUser()); 11380312f614SPeter Collingbourne ++I; 11390312f614SPeter Collingbourne if (!CI) 11400312f614SPeter Collingbourne continue; 11410312f614SPeter Collingbourne 11420312f614SPeter Collingbourne Value *Ptr = CI->getArgOperand(0); 11430312f614SPeter Collingbourne Value *Offset = CI->getArgOperand(1); 11440312f614SPeter Collingbourne Value *TypeIdValue = CI->getArgOperand(2); 11450312f614SPeter Collingbourne Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 11460312f614SPeter Collingbourne 11470312f614SPeter Collingbourne SmallVector<DevirtCallSite, 1> DevirtCalls; 11480312f614SPeter Collingbourne SmallVector<Instruction *, 1> LoadedPtrs; 11490312f614SPeter Collingbourne SmallVector<Instruction *, 1> Preds; 11500312f614SPeter Collingbourne bool HasNonCallUses = false; 11510312f614SPeter Collingbourne findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 11520312f614SPeter Collingbourne HasNonCallUses, CI); 11530312f614SPeter Collingbourne 11540312f614SPeter Collingbourne // Start by generating "pessimistic" code that explicitly loads the function 11550312f614SPeter Collingbourne // pointer from the vtable and performs the type check. If possible, we will 11560312f614SPeter Collingbourne // eliminate the load and the type check later. 11570312f614SPeter Collingbourne 11580312f614SPeter Collingbourne // If possible, only generate the load at the point where it is used. 11590312f614SPeter Collingbourne // This helps avoid unnecessary spills. 11600312f614SPeter Collingbourne IRBuilder<> LoadB( 11610312f614SPeter Collingbourne (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 11620312f614SPeter Collingbourne Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 11630312f614SPeter Collingbourne Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); 11640312f614SPeter Collingbourne Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); 11650312f614SPeter Collingbourne 11660312f614SPeter Collingbourne for (Instruction *LoadedPtr : LoadedPtrs) { 11670312f614SPeter Collingbourne LoadedPtr->replaceAllUsesWith(LoadedValue); 11680312f614SPeter Collingbourne LoadedPtr->eraseFromParent(); 11690312f614SPeter Collingbourne } 11700312f614SPeter Collingbourne 11710312f614SPeter Collingbourne // Likewise for the type test. 11720312f614SPeter Collingbourne IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 11730312f614SPeter Collingbourne CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 11740312f614SPeter Collingbourne 11750312f614SPeter Collingbourne for (Instruction *Pred : Preds) { 11760312f614SPeter Collingbourne Pred->replaceAllUsesWith(TypeTestCall); 11770312f614SPeter Collingbourne Pred->eraseFromParent(); 11780312f614SPeter Collingbourne } 11790312f614SPeter Collingbourne 11800312f614SPeter Collingbourne // We have already erased any extractvalue instructions that refer to the 11810312f614SPeter Collingbourne // intrinsic call, but the intrinsic may have other non-extractvalue uses 11820312f614SPeter Collingbourne // (although this is unlikely). In that case, explicitly build a pair and 11830312f614SPeter Collingbourne // RAUW it. 11840312f614SPeter Collingbourne if (!CI->use_empty()) { 11850312f614SPeter Collingbourne Value *Pair = UndefValue::get(CI->getType()); 11860312f614SPeter Collingbourne IRBuilder<> B(CI); 11870312f614SPeter Collingbourne Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 11880312f614SPeter Collingbourne Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 11890312f614SPeter Collingbourne CI->replaceAllUsesWith(Pair); 11900312f614SPeter Collingbourne } 11910312f614SPeter Collingbourne 11920312f614SPeter Collingbourne // The number of unsafe uses is initially the number of uses. 11930312f614SPeter Collingbourne auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 11940312f614SPeter Collingbourne NumUnsafeUses = DevirtCalls.size(); 11950312f614SPeter Collingbourne 11960312f614SPeter Collingbourne // If the function pointer has a non-call user, we cannot eliminate the type 11970312f614SPeter Collingbourne // check, as one of those users may eventually call the pointer. Increment 11980312f614SPeter Collingbourne // the unsafe use count to make sure it cannot reach zero. 11990312f614SPeter Collingbourne if (HasNonCallUses) 12000312f614SPeter Collingbourne ++NumUnsafeUses; 12010312f614SPeter Collingbourne for (DevirtCallSite Call : DevirtCalls) { 120250cbd7ccSPeter Collingbourne CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, 120350cbd7ccSPeter Collingbourne &NumUnsafeUses); 12040312f614SPeter Collingbourne } 12050312f614SPeter Collingbourne 12060312f614SPeter Collingbourne CI->eraseFromParent(); 12070312f614SPeter Collingbourne } 12080312f614SPeter Collingbourne } 12090312f614SPeter Collingbourne 12106d284fabSPeter Collingbourne void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { 12119a3f9797SPeter Collingbourne const TypeIdSummary *TidSummary = 1212f7691d8bSPeter Collingbourne ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString()); 12139a3f9797SPeter Collingbourne if (!TidSummary) 12149a3f9797SPeter Collingbourne return; 12159a3f9797SPeter Collingbourne auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); 12169a3f9797SPeter Collingbourne if (ResI == TidSummary->WPDRes.end()) 12179a3f9797SPeter Collingbourne return; 12189a3f9797SPeter Collingbourne const WholeProgramDevirtResolution &Res = ResI->second; 12196d284fabSPeter Collingbourne 12206d284fabSPeter Collingbourne if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { 12216d284fabSPeter Collingbourne // The type of the function in the declaration is irrelevant because every 12226d284fabSPeter Collingbourne // call site will cast it to the correct type. 1223*db11fdfdSMehdi Amini auto *SingleImpl = M.getOrInsertFunction( 1224*db11fdfdSMehdi Amini Res.SingleImplName, Type::getVoidTy(M.getContext()), nullptr); 12256d284fabSPeter Collingbourne 12266d284fabSPeter Collingbourne // This is the import phase so we should not be exporting anything. 12276d284fabSPeter Collingbourne bool IsExported = false; 12286d284fabSPeter Collingbourne applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); 12296d284fabSPeter Collingbourne assert(!IsExported); 12306d284fabSPeter Collingbourne } 12310152c815SPeter Collingbourne 12320152c815SPeter Collingbourne for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { 12330152c815SPeter Collingbourne auto I = Res.ResByArg.find(CSByConstantArg.first); 12340152c815SPeter Collingbourne if (I == Res.ResByArg.end()) 12350152c815SPeter Collingbourne continue; 12360152c815SPeter Collingbourne auto &ResByArg = I->second; 12370152c815SPeter Collingbourne // FIXME: We should figure out what to do about the "function name" argument 12380152c815SPeter Collingbourne // to the apply* functions, as the function names are unavailable during the 12390152c815SPeter Collingbourne // importing phase. For now we just pass the empty string. This does not 12400152c815SPeter Collingbourne // impact correctness because the function names are just used for remarks. 12410152c815SPeter Collingbourne switch (ResByArg.TheKind) { 12420152c815SPeter Collingbourne case WholeProgramDevirtResolution::ByArg::UniformRetVal: 12430152c815SPeter Collingbourne applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); 12440152c815SPeter Collingbourne break; 124559675ba0SPeter Collingbourne case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { 124659675ba0SPeter Collingbourne Constant *UniqueMemberAddr = 124759675ba0SPeter Collingbourne importGlobal(Slot, CSByConstantArg.first, "unique_member"); 124859675ba0SPeter Collingbourne applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, 124959675ba0SPeter Collingbourne UniqueMemberAddr); 125059675ba0SPeter Collingbourne break; 125159675ba0SPeter Collingbourne } 125214dcf02fSPeter Collingbourne case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { 125314dcf02fSPeter Collingbourne Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32); 125414dcf02fSPeter Collingbourne Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty); 125514dcf02fSPeter Collingbourne Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8); 125614dcf02fSPeter Collingbourne Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty); 125714dcf02fSPeter Collingbourne applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); 125814dcf02fSPeter Collingbourne } 12590152c815SPeter Collingbourne default: 12600152c815SPeter Collingbourne break; 12610152c815SPeter Collingbourne } 12620152c815SPeter Collingbourne } 12636d284fabSPeter Collingbourne } 12646d284fabSPeter Collingbourne 12656d284fabSPeter Collingbourne void DevirtModule::removeRedundantTypeTests() { 12666d284fabSPeter Collingbourne auto True = ConstantInt::getTrue(M.getContext()); 12676d284fabSPeter Collingbourne for (auto &&U : NumUnsafeUsesForTypeTest) { 12686d284fabSPeter Collingbourne if (U.second == 0) { 12696d284fabSPeter Collingbourne U.first->replaceAllUsesWith(True); 12706d284fabSPeter Collingbourne U.first->eraseFromParent(); 12716d284fabSPeter Collingbourne } 12726d284fabSPeter Collingbourne } 12736d284fabSPeter Collingbourne } 12746d284fabSPeter Collingbourne 12750312f614SPeter Collingbourne bool DevirtModule::run() { 12760312f614SPeter Collingbourne Function *TypeTestFunc = 12770312f614SPeter Collingbourne M.getFunction(Intrinsic::getName(Intrinsic::type_test)); 12780312f614SPeter Collingbourne Function *TypeCheckedLoadFunc = 12790312f614SPeter Collingbourne M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); 12800312f614SPeter Collingbourne Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 12810312f614SPeter Collingbourne 1282b406baaeSPeter Collingbourne // Normally if there are no users of the devirtualization intrinsics in the 1283b406baaeSPeter Collingbourne // module, this pass has nothing to do. But if we are exporting, we also need 1284b406baaeSPeter Collingbourne // to handle any users that appear only in the function summaries. 1285f7691d8bSPeter Collingbourne if (!ExportSummary && 1286b406baaeSPeter Collingbourne (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 12870312f614SPeter Collingbourne AssumeFunc->use_empty()) && 12880312f614SPeter Collingbourne (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) 12890312f614SPeter Collingbourne return false; 12900312f614SPeter Collingbourne 12910312f614SPeter Collingbourne if (TypeTestFunc && AssumeFunc) 12920312f614SPeter Collingbourne scanTypeTestUsers(TypeTestFunc, AssumeFunc); 12930312f614SPeter Collingbourne 12940312f614SPeter Collingbourne if (TypeCheckedLoadFunc) 12950312f614SPeter Collingbourne scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 1296df49d1bbSPeter Collingbourne 1297f7691d8bSPeter Collingbourne if (ImportSummary) { 12986d284fabSPeter Collingbourne for (auto &S : CallSlots) 12996d284fabSPeter Collingbourne importResolution(S.first, S.second); 13006d284fabSPeter Collingbourne 13016d284fabSPeter Collingbourne removeRedundantTypeTests(); 13026d284fabSPeter Collingbourne 13036d284fabSPeter Collingbourne // The rest of the code is only necessary when exporting or during regular 13046d284fabSPeter Collingbourne // LTO, so we are done. 13056d284fabSPeter Collingbourne return true; 13066d284fabSPeter Collingbourne } 13076d284fabSPeter Collingbourne 13087efd7506SPeter Collingbourne // Rebuild type metadata into a map for easy lookup. 1309df49d1bbSPeter Collingbourne std::vector<VTableBits> Bits; 13107efd7506SPeter Collingbourne DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 13117efd7506SPeter Collingbourne buildTypeIdentifierMap(Bits, TypeIdMap); 13127efd7506SPeter Collingbourne if (TypeIdMap.empty()) 1313df49d1bbSPeter Collingbourne return true; 1314df49d1bbSPeter Collingbourne 1315b406baaeSPeter Collingbourne // Collect information from summary about which calls to try to devirtualize. 1316f7691d8bSPeter Collingbourne if (ExportSummary) { 1317b406baaeSPeter Collingbourne DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; 1318b406baaeSPeter Collingbourne for (auto &P : TypeIdMap) { 1319b406baaeSPeter Collingbourne if (auto *TypeId = dyn_cast<MDString>(P.first)) 1320b406baaeSPeter Collingbourne MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( 1321b406baaeSPeter Collingbourne TypeId); 1322b406baaeSPeter Collingbourne } 1323b406baaeSPeter Collingbourne 1324f7691d8bSPeter Collingbourne for (auto &P : *ExportSummary) { 1325b406baaeSPeter Collingbourne for (auto &S : P.second) { 1326b406baaeSPeter Collingbourne auto *FS = dyn_cast<FunctionSummary>(S.get()); 1327b406baaeSPeter Collingbourne if (!FS) 1328b406baaeSPeter Collingbourne continue; 1329b406baaeSPeter Collingbourne // FIXME: Only add live functions. 13305d8aea10SGeorge Rimar for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 13315d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VF.GUID]) { 13322325bb34SPeter Collingbourne CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers = 13332325bb34SPeter Collingbourne true; 13345d8aea10SGeorge Rimar } 13355d8aea10SGeorge Rimar } 13365d8aea10SGeorge Rimar for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 13375d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VF.GUID]) { 1338b406baaeSPeter Collingbourne CallSlots[{MD, VF.Offset}] 1339b406baaeSPeter Collingbourne .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); 13405d8aea10SGeorge Rimar } 13415d8aea10SGeorge Rimar } 1342b406baaeSPeter Collingbourne for (const FunctionSummary::ConstVCall &VC : 13435d8aea10SGeorge Rimar FS->type_test_assume_const_vcalls()) { 13445d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 13452325bb34SPeter Collingbourne CallSlots[{MD, VC.VFunc.Offset}] 13465d8aea10SGeorge Rimar .ConstCSInfo[VC.Args] 13475d8aea10SGeorge Rimar .SummaryHasTypeTestAssumeUsers = true; 13485d8aea10SGeorge Rimar } 13495d8aea10SGeorge Rimar } 13502325bb34SPeter Collingbourne for (const FunctionSummary::ConstVCall &VC : 13515d8aea10SGeorge Rimar FS->type_checked_load_const_vcalls()) { 13525d8aea10SGeorge Rimar for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 1353b406baaeSPeter Collingbourne CallSlots[{MD, VC.VFunc.Offset}] 1354b406baaeSPeter Collingbourne .ConstCSInfo[VC.Args] 1355b406baaeSPeter Collingbourne .SummaryTypeCheckedLoadUsers.push_back(FS); 1356b406baaeSPeter Collingbourne } 1357b406baaeSPeter Collingbourne } 1358b406baaeSPeter Collingbourne } 13595d8aea10SGeorge Rimar } 13605d8aea10SGeorge Rimar } 1361b406baaeSPeter Collingbourne 13627efd7506SPeter Collingbourne // For each (type, offset) pair: 1363df49d1bbSPeter Collingbourne bool DidVirtualConstProp = false; 1364f3403fd2SIvan Krasin std::map<std::string, Function*> DevirtTargets; 1365df49d1bbSPeter Collingbourne for (auto &S : CallSlots) { 13667efd7506SPeter Collingbourne // Search each of the members of the type identifier for the virtual 13677efd7506SPeter Collingbourne // function implementation at offset S.first.ByteOffset, and add to 13687efd7506SPeter Collingbourne // TargetsForSlot. 1369df49d1bbSPeter Collingbourne std::vector<VirtualCallTarget> TargetsForSlot; 1370b406baaeSPeter Collingbourne if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], 1371b406baaeSPeter Collingbourne S.first.ByteOffset)) { 13722325bb34SPeter Collingbourne WholeProgramDevirtResolution *Res = nullptr; 1373f7691d8bSPeter Collingbourne if (ExportSummary && isa<MDString>(S.first.TypeID)) 1374f7691d8bSPeter Collingbourne Res = &ExportSummary 13759a3f9797SPeter Collingbourne ->getOrInsertTypeIdSummary( 13769a3f9797SPeter Collingbourne cast<MDString>(S.first.TypeID)->getString()) 13772325bb34SPeter Collingbourne .WPDRes[S.first.ByteOffset]; 13782325bb34SPeter Collingbourne 13792325bb34SPeter Collingbourne if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) && 138059675ba0SPeter Collingbourne tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first)) 1381f3403fd2SIvan Krasin DidVirtualConstProp = true; 1382f3403fd2SIvan Krasin 1383f3403fd2SIvan Krasin // Collect functions devirtualized at least for one call site for stats. 1384f3403fd2SIvan Krasin if (RemarksEnabled) 1385f3403fd2SIvan Krasin for (const auto &T : TargetsForSlot) 1386f3403fd2SIvan Krasin if (T.WasDevirt) 1387f3403fd2SIvan Krasin DevirtTargets[T.Fn->getName()] = T.Fn; 1388b05e06e4SIvan Krasin } 1389df49d1bbSPeter Collingbourne 1390b406baaeSPeter Collingbourne // CFI-specific: if we are exporting and any llvm.type.checked.load 1391b406baaeSPeter Collingbourne // intrinsics were *not* devirtualized, we need to add the resulting 1392b406baaeSPeter Collingbourne // llvm.type.test intrinsics to the function summaries so that the 1393b406baaeSPeter Collingbourne // LowerTypeTests pass will export them. 1394f7691d8bSPeter Collingbourne if (ExportSummary && isa<MDString>(S.first.TypeID)) { 1395b406baaeSPeter Collingbourne auto GUID = 1396b406baaeSPeter Collingbourne GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); 1397b406baaeSPeter Collingbourne for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) 1398b406baaeSPeter Collingbourne FS->addTypeTest(GUID); 1399b406baaeSPeter Collingbourne for (auto &CCS : S.second.ConstCSInfo) 1400b406baaeSPeter Collingbourne for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) 1401b406baaeSPeter Collingbourne FS->addTypeTest(GUID); 1402b406baaeSPeter Collingbourne } 1403b406baaeSPeter Collingbourne } 1404b406baaeSPeter Collingbourne 1405f3403fd2SIvan Krasin if (RemarksEnabled) { 1406f3403fd2SIvan Krasin // Generate remarks for each devirtualized function. 1407f3403fd2SIvan Krasin for (const auto &DT : DevirtTargets) { 1408f3403fd2SIvan Krasin Function *F = DT.second; 1409f3403fd2SIvan Krasin DISubprogram *SP = F->getSubprogram(); 14107bc978b5SJustin Bogner emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP, 1411f3403fd2SIvan Krasin Twine("devirtualized ") + F->getName()); 1412b05e06e4SIvan Krasin } 1413df49d1bbSPeter Collingbourne } 1414df49d1bbSPeter Collingbourne 14156d284fabSPeter Collingbourne removeRedundantTypeTests(); 14160312f614SPeter Collingbourne 1417df49d1bbSPeter Collingbourne // Rebuild each global we touched as part of virtual constant propagation to 1418df49d1bbSPeter Collingbourne // include the before and after bytes. 1419df49d1bbSPeter Collingbourne if (DidVirtualConstProp) 1420df49d1bbSPeter Collingbourne for (VTableBits &B : Bits) 1421df49d1bbSPeter Collingbourne rebuildGlobal(B); 1422df49d1bbSPeter Collingbourne 1423df49d1bbSPeter Collingbourne return true; 1424df49d1bbSPeter Collingbourne } 1425