1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements whole program optimization of virtual calls in cases 10 // where we know (via !type metadata) that the list of callees is fixed. This 11 // includes the following: 12 // - Single implementation devirtualization: if a virtual call has a single 13 // possible callee, replace all calls with a direct call to that callee. 14 // - Virtual constant propagation: if the virtual function's return type is an 15 // integer <=64 bits and all possible callees are readnone, for each class and 16 // each list of constant arguments: evaluate the function, store the return 17 // value alongside the virtual table, and rewrite each virtual call as a load 18 // from the virtual table. 19 // - Uniform return value optimization: if the conditions for virtual constant 20 // propagation hold and each function returns the same constant value, replace 21 // each virtual call with that constant. 22 // - Unique return value optimization for i1 return values: if the conditions 23 // for virtual constant propagation hold and a single vtable's function 24 // returns 0, or a single vtable's function returns 1, replace each virtual 25 // call with a comparison of the vptr against that vtable's address. 26 // 27 // This pass is intended to be used during the regular and thin LTO pipelines: 28 // 29 // During regular LTO, the pass determines the best optimization for each 30 // virtual call and applies the resolutions directly to virtual calls that are 31 // eligible for virtual call optimization (i.e. calls that use either of the 32 // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). 33 // 34 // During hybrid Regular/ThinLTO, the pass operates in two phases: 35 // - Export phase: this is run during the thin link over a single merged module 36 // that contains all vtables with !type metadata that participate in the link. 37 // The pass computes a resolution for each virtual call and stores it in the 38 // type identifier summary. 39 // - Import phase: this is run during the thin backends over the individual 40 // modules. The pass applies the resolutions previously computed during the 41 // import phase to each eligible virtual call. 42 // 43 // During ThinLTO, the pass operates in two phases: 44 // - Export phase: this is run during the thin link over the index which 45 // contains a summary of all vtables with !type metadata that participate in 46 // the link. It computes a resolution for each virtual call and stores it in 47 // the type identifier summary. Only single implementation devirtualization 48 // is supported. 49 // - Import phase: (same as with hybrid case above). 50 // 51 //===----------------------------------------------------------------------===// 52 53 #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 54 #include "llvm/ADT/ArrayRef.h" 55 #include "llvm/ADT/DenseMap.h" 56 #include "llvm/ADT/DenseMapInfo.h" 57 #include "llvm/ADT/DenseSet.h" 58 #include "llvm/ADT/MapVector.h" 59 #include "llvm/ADT/SmallVector.h" 60 #include "llvm/ADT/Triple.h" 61 #include "llvm/ADT/iterator_range.h" 62 #include "llvm/Analysis/AssumptionCache.h" 63 #include "llvm/Analysis/BasicAliasAnalysis.h" 64 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 65 #include "llvm/Analysis/TypeMetadataUtils.h" 66 #include "llvm/Bitcode/BitcodeReader.h" 67 #include "llvm/Bitcode/BitcodeWriter.h" 68 #include "llvm/IR/Constants.h" 69 #include "llvm/IR/DataLayout.h" 70 #include "llvm/IR/DebugLoc.h" 71 #include "llvm/IR/DerivedTypes.h" 72 #include "llvm/IR/Dominators.h" 73 #include "llvm/IR/Function.h" 74 #include "llvm/IR/GlobalAlias.h" 75 #include "llvm/IR/GlobalVariable.h" 76 #include "llvm/IR/IRBuilder.h" 77 #include "llvm/IR/InstrTypes.h" 78 #include "llvm/IR/Instruction.h" 79 #include "llvm/IR/Instructions.h" 80 #include "llvm/IR/Intrinsics.h" 81 #include "llvm/IR/LLVMContext.h" 82 #include "llvm/IR/MDBuilder.h" 83 #include "llvm/IR/Metadata.h" 84 #include "llvm/IR/Module.h" 85 #include "llvm/IR/ModuleSummaryIndexYAML.h" 86 #include "llvm/InitializePasses.h" 87 #include "llvm/Pass.h" 88 #include "llvm/PassRegistry.h" 89 #include "llvm/Support/Casting.h" 90 #include "llvm/Support/CommandLine.h" 91 #include "llvm/Support/Errc.h" 92 #include "llvm/Support/Error.h" 93 #include "llvm/Support/FileSystem.h" 94 #include "llvm/Support/GlobPattern.h" 95 #include "llvm/Support/MathExtras.h" 96 #include "llvm/Transforms/IPO.h" 97 #include "llvm/Transforms/IPO/FunctionAttrs.h" 98 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 99 #include "llvm/Transforms/Utils/CallPromotionUtils.h" 100 #include "llvm/Transforms/Utils/Evaluator.h" 101 #include <algorithm> 102 #include <cstddef> 103 #include <map> 104 #include <set> 105 #include <string> 106 107 using namespace llvm; 108 using namespace wholeprogramdevirt; 109 110 #define DEBUG_TYPE "wholeprogramdevirt" 111 112 static cl::opt<PassSummaryAction> ClSummaryAction( 113 "wholeprogramdevirt-summary-action", 114 cl::desc("What to do with the summary when running this pass"), 115 cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 116 clEnumValN(PassSummaryAction::Import, "import", 117 "Import typeid resolutions from summary and globals"), 118 clEnumValN(PassSummaryAction::Export, "export", 119 "Export typeid resolutions to summary and globals")), 120 cl::Hidden); 121 122 static cl::opt<std::string> ClReadSummary( 123 "wholeprogramdevirt-read-summary", 124 cl::desc( 125 "Read summary from given bitcode or YAML file before running pass"), 126 cl::Hidden); 127 128 static cl::opt<std::string> ClWriteSummary( 129 "wholeprogramdevirt-write-summary", 130 cl::desc("Write summary to given bitcode or YAML file after running pass. " 131 "Output file format is deduced from extension: *.bc means writing " 132 "bitcode, otherwise YAML"), 133 cl::Hidden); 134 135 static cl::opt<unsigned> 136 ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, 137 cl::init(10), cl::ZeroOrMore, 138 cl::desc("Maximum number of call targets per " 139 "call site to enable branch funnels")); 140 141 static cl::opt<bool> 142 PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden, 143 cl::init(false), cl::ZeroOrMore, 144 cl::desc("Print index-based devirtualization messages")); 145 146 /// Provide a way to force enable whole program visibility in tests. 147 /// This is needed to support legacy tests that don't contain 148 /// !vcall_visibility metadata (the mere presense of type tests 149 /// previously implied hidden visibility). 150 static cl::opt<bool> 151 WholeProgramVisibility("whole-program-visibility", cl::init(false), 152 cl::Hidden, cl::ZeroOrMore, 153 cl::desc("Enable whole program visibility")); 154 155 /// Provide a way to force disable whole program for debugging or workarounds, 156 /// when enabled via the linker. 157 static cl::opt<bool> DisableWholeProgramVisibility( 158 "disable-whole-program-visibility", cl::init(false), cl::Hidden, 159 cl::ZeroOrMore, 160 cl::desc("Disable whole program visibility (overrides enabling options)")); 161 162 /// Provide way to prevent certain function from being devirtualized 163 static cl::list<std::string> 164 SkipFunctionNames("wholeprogramdevirt-skip", 165 cl::desc("Prevent function(s) from being devirtualized"), 166 cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated); 167 168 /// Mechanism to add runtime checking of devirtualization decisions, optionally 169 /// trapping or falling back to indirect call on any that are not correct. 170 /// Trapping mode is useful for debugging undefined behavior leading to failures 171 /// with WPD. Fallback mode is useful for ensuring safety when whole program 172 /// visibility may be compromised. 173 enum WPDCheckMode { None, Trap, Fallback }; 174 static cl::opt<WPDCheckMode> DevirtCheckMode( 175 "wholeprogramdevirt-check", cl::Hidden, cl::ZeroOrMore, 176 cl::desc("Type of checking for incorrect devirtualizations"), 177 cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"), 178 clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"), 179 clEnumValN(WPDCheckMode::Fallback, "fallback", 180 "Fallback to indirect when incorrect"))); 181 182 namespace { 183 struct PatternList { 184 std::vector<GlobPattern> Patterns; 185 template <class T> void init(const T &StringList) { 186 for (const auto &S : StringList) 187 if (Expected<GlobPattern> Pat = GlobPattern::create(S)) 188 Patterns.push_back(std::move(*Pat)); 189 } 190 bool match(StringRef S) { 191 for (const GlobPattern &P : Patterns) 192 if (P.match(S)) 193 return true; 194 return false; 195 } 196 }; 197 } // namespace 198 199 // Find the minimum offset that we may store a value of size Size bits at. If 200 // IsAfter is set, look for an offset before the object, otherwise look for an 201 // offset after the object. 202 uint64_t 203 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 204 bool IsAfter, uint64_t Size) { 205 // Find a minimum offset taking into account only vtable sizes. 206 uint64_t MinByte = 0; 207 for (const VirtualCallTarget &Target : Targets) { 208 if (IsAfter) 209 MinByte = std::max(MinByte, Target.minAfterBytes()); 210 else 211 MinByte = std::max(MinByte, Target.minBeforeBytes()); 212 } 213 214 // Build a vector of arrays of bytes covering, for each target, a slice of the 215 // used region (see AccumBitVector::BytesUsed in 216 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 217 // this aligns the used regions to start at MinByte. 218 // 219 // In this example, A, B and C are vtables, # is a byte already allocated for 220 // a virtual function pointer, AAAA... (etc.) are the used regions for the 221 // vtables and Offset(X) is the value computed for the Offset variable below 222 // for X. 223 // 224 // Offset(A) 225 // | | 226 // |MinByte 227 // A: ################AAAAAAAA|AAAAAAAA 228 // B: ########BBBBBBBBBBBBBBBB|BBBB 229 // C: ########################|CCCCCCCCCCCCCCCC 230 // | Offset(B) | 231 // 232 // This code produces the slices of A, B and C that appear after the divider 233 // at MinByte. 234 std::vector<ArrayRef<uint8_t>> Used; 235 for (const VirtualCallTarget &Target : Targets) { 236 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 237 : Target.TM->Bits->Before.BytesUsed; 238 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 239 : MinByte - Target.minBeforeBytes(); 240 241 // Disregard used regions that are smaller than Offset. These are 242 // effectively all-free regions that do not need to be checked. 243 if (VTUsed.size() > Offset) 244 Used.push_back(VTUsed.slice(Offset)); 245 } 246 247 if (Size == 1) { 248 // Find a free bit in each member of Used. 249 for (unsigned I = 0;; ++I) { 250 uint8_t BitsUsed = 0; 251 for (auto &&B : Used) 252 if (I < B.size()) 253 BitsUsed |= B[I]; 254 if (BitsUsed != 0xff) 255 return (MinByte + I) * 8 + 256 countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 257 } 258 } else { 259 // Find a free (Size/8) byte region in each member of Used. 260 // FIXME: see if alignment helps. 261 for (unsigned I = 0;; ++I) { 262 for (auto &&B : Used) { 263 unsigned Byte = 0; 264 while ((I + Byte) < B.size() && Byte < (Size / 8)) { 265 if (B[I + Byte]) 266 goto NextI; 267 ++Byte; 268 } 269 } 270 return (MinByte + I) * 8; 271 NextI:; 272 } 273 } 274 } 275 276 void wholeprogramdevirt::setBeforeReturnValues( 277 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 278 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 279 if (BitWidth == 1) 280 OffsetByte = -(AllocBefore / 8 + 1); 281 else 282 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 283 OffsetBit = AllocBefore % 8; 284 285 for (VirtualCallTarget &Target : Targets) { 286 if (BitWidth == 1) 287 Target.setBeforeBit(AllocBefore); 288 else 289 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 290 } 291 } 292 293 void wholeprogramdevirt::setAfterReturnValues( 294 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 295 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 296 if (BitWidth == 1) 297 OffsetByte = AllocAfter / 8; 298 else 299 OffsetByte = (AllocAfter + 7) / 8; 300 OffsetBit = AllocAfter % 8; 301 302 for (VirtualCallTarget &Target : Targets) { 303 if (BitWidth == 1) 304 Target.setAfterBit(AllocAfter); 305 else 306 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 307 } 308 } 309 310 VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) 311 : Fn(Fn), TM(TM), 312 IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} 313 314 namespace { 315 316 // A slot in a set of virtual tables. The TypeID identifies the set of virtual 317 // tables, and the ByteOffset is the offset in bytes from the address point to 318 // the virtual function pointer. 319 struct VTableSlot { 320 Metadata *TypeID; 321 uint64_t ByteOffset; 322 }; 323 324 } // end anonymous namespace 325 326 namespace llvm { 327 328 template <> struct DenseMapInfo<VTableSlot> { 329 static VTableSlot getEmptyKey() { 330 return {DenseMapInfo<Metadata *>::getEmptyKey(), 331 DenseMapInfo<uint64_t>::getEmptyKey()}; 332 } 333 static VTableSlot getTombstoneKey() { 334 return {DenseMapInfo<Metadata *>::getTombstoneKey(), 335 DenseMapInfo<uint64_t>::getTombstoneKey()}; 336 } 337 static unsigned getHashValue(const VTableSlot &I) { 338 return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 339 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 340 } 341 static bool isEqual(const VTableSlot &LHS, 342 const VTableSlot &RHS) { 343 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 344 } 345 }; 346 347 template <> struct DenseMapInfo<VTableSlotSummary> { 348 static VTableSlotSummary getEmptyKey() { 349 return {DenseMapInfo<StringRef>::getEmptyKey(), 350 DenseMapInfo<uint64_t>::getEmptyKey()}; 351 } 352 static VTableSlotSummary getTombstoneKey() { 353 return {DenseMapInfo<StringRef>::getTombstoneKey(), 354 DenseMapInfo<uint64_t>::getTombstoneKey()}; 355 } 356 static unsigned getHashValue(const VTableSlotSummary &I) { 357 return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^ 358 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 359 } 360 static bool isEqual(const VTableSlotSummary &LHS, 361 const VTableSlotSummary &RHS) { 362 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 363 } 364 }; 365 366 } // end namespace llvm 367 368 namespace { 369 370 // Returns true if the function must be unreachable based on ValueInfo. 371 // 372 // In particular, identifies a function as unreachable in the following 373 // conditions 374 // 1) All summaries are live. 375 // 2) All function summaries indicate it's unreachable 376 bool mustBeUnreachableFunction(ValueInfo TheFnVI) { 377 if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { 378 // Returns false if ValueInfo is absent, or the summary list is empty 379 // (e.g., function declarations). 380 return false; 381 } 382 383 for (auto &Summary : TheFnVI.getSummaryList()) { 384 // Conservatively returns false if any non-live functions are seen. 385 // In general either all summaries should be live or all should be dead. 386 if (!Summary->isLive()) 387 return false; 388 if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) { 389 if (!FS->fflags().MustBeUnreachable) 390 return false; 391 } 392 // Do nothing if a non-function has the same GUID (which is rare). 393 // This is correct since non-function summaries are not relevant. 394 } 395 // All function summaries are live and all of them agree that the function is 396 // unreachble. 397 return true; 398 } 399 400 // A virtual call site. VTable is the loaded virtual table pointer, and CS is 401 // the indirect virtual call. 402 struct VirtualCallSite { 403 Value *VTable = nullptr; 404 CallBase &CB; 405 406 // If non-null, this field points to the associated unsafe use count stored in 407 // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 408 // of that field for details. 409 unsigned *NumUnsafeUses = nullptr; 410 411 void 412 emitRemark(const StringRef OptName, const StringRef TargetName, 413 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { 414 Function *F = CB.getCaller(); 415 DebugLoc DLoc = CB.getDebugLoc(); 416 BasicBlock *Block = CB.getParent(); 417 418 using namespace ore; 419 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) 420 << NV("Optimization", OptName) 421 << ": devirtualized a call to " 422 << NV("FunctionName", TargetName)); 423 } 424 425 void replaceAndErase( 426 const StringRef OptName, const StringRef TargetName, bool RemarksEnabled, 427 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 428 Value *New) { 429 if (RemarksEnabled) 430 emitRemark(OptName, TargetName, OREGetter); 431 CB.replaceAllUsesWith(New); 432 if (auto *II = dyn_cast<InvokeInst>(&CB)) { 433 BranchInst::Create(II->getNormalDest(), &CB); 434 II->getUnwindDest()->removePredecessor(II->getParent()); 435 } 436 CB.eraseFromParent(); 437 // This use is no longer unsafe. 438 if (NumUnsafeUses) 439 --*NumUnsafeUses; 440 } 441 }; 442 443 // Call site information collected for a specific VTableSlot and possibly a list 444 // of constant integer arguments. The grouping by arguments is handled by the 445 // VTableSlotInfo class. 446 struct CallSiteInfo { 447 /// The set of call sites for this slot. Used during regular LTO and the 448 /// import phase of ThinLTO (as well as the export phase of ThinLTO for any 449 /// call sites that appear in the merged module itself); in each of these 450 /// cases we are directly operating on the call sites at the IR level. 451 std::vector<VirtualCallSite> CallSites; 452 453 /// Whether all call sites represented by this CallSiteInfo, including those 454 /// in summaries, have been devirtualized. This starts off as true because a 455 /// default constructed CallSiteInfo represents no call sites. 456 bool AllCallSitesDevirted = true; 457 458 // These fields are used during the export phase of ThinLTO and reflect 459 // information collected from function summaries. 460 461 /// Whether any function summary contains an llvm.assume(llvm.type.test) for 462 /// this slot. 463 bool SummaryHasTypeTestAssumeUsers = false; 464 465 /// CFI-specific: a vector containing the list of function summaries that use 466 /// the llvm.type.checked.load intrinsic and therefore will require 467 /// resolutions for llvm.type.test in order to implement CFI checks if 468 /// devirtualization was unsuccessful. If devirtualization was successful, the 469 /// pass will clear this vector by calling markDevirt(). If at the end of the 470 /// pass the vector is non-empty, we will need to add a use of llvm.type.test 471 /// to each of the function summaries in the vector. 472 std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; 473 std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers; 474 475 bool isExported() const { 476 return SummaryHasTypeTestAssumeUsers || 477 !SummaryTypeCheckedLoadUsers.empty(); 478 } 479 480 void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { 481 SummaryTypeCheckedLoadUsers.push_back(FS); 482 AllCallSitesDevirted = false; 483 } 484 485 void addSummaryTypeTestAssumeUser(FunctionSummary *FS) { 486 SummaryTypeTestAssumeUsers.push_back(FS); 487 SummaryHasTypeTestAssumeUsers = true; 488 AllCallSitesDevirted = false; 489 } 490 491 void markDevirt() { 492 AllCallSitesDevirted = true; 493 494 // As explained in the comment for SummaryTypeCheckedLoadUsers. 495 SummaryTypeCheckedLoadUsers.clear(); 496 } 497 }; 498 499 // Call site information collected for a specific VTableSlot. 500 struct VTableSlotInfo { 501 // The set of call sites which do not have all constant integer arguments 502 // (excluding "this"). 503 CallSiteInfo CSInfo; 504 505 // The set of call sites with all constant integer arguments (excluding 506 // "this"), grouped by argument list. 507 std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 508 509 void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses); 510 511 private: 512 CallSiteInfo &findCallSiteInfo(CallBase &CB); 513 }; 514 515 CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { 516 std::vector<uint64_t> Args; 517 auto *CBType = dyn_cast<IntegerType>(CB.getType()); 518 if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) 519 return CSInfo; 520 for (auto &&Arg : drop_begin(CB.args())) { 521 auto *CI = dyn_cast<ConstantInt>(Arg); 522 if (!CI || CI->getBitWidth() > 64) 523 return CSInfo; 524 Args.push_back(CI->getZExtValue()); 525 } 526 return ConstCSInfo[Args]; 527 } 528 529 void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, 530 unsigned *NumUnsafeUses) { 531 auto &CSI = findCallSiteInfo(CB); 532 CSI.AllCallSitesDevirted = false; 533 CSI.CallSites.push_back({VTable, CB, NumUnsafeUses}); 534 } 535 536 struct DevirtModule { 537 Module &M; 538 function_ref<AAResults &(Function &)> AARGetter; 539 function_ref<DominatorTree &(Function &)> LookupDomTree; 540 541 ModuleSummaryIndex *ExportSummary; 542 const ModuleSummaryIndex *ImportSummary; 543 544 IntegerType *Int8Ty; 545 PointerType *Int8PtrTy; 546 IntegerType *Int32Ty; 547 IntegerType *Int64Ty; 548 IntegerType *IntPtrTy; 549 /// Sizeless array type, used for imported vtables. This provides a signal 550 /// to analyzers that these imports may alias, as they do for example 551 /// when multiple unique return values occur in the same vtable. 552 ArrayType *Int8Arr0Ty; 553 554 bool RemarksEnabled; 555 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; 556 557 MapVector<VTableSlot, VTableSlotInfo> CallSlots; 558 559 // Calls that have already been optimized. We may add a call to multiple 560 // VTableSlotInfos if vtable loads are coalesced and need to make sure not to 561 // optimize a call more than once. 562 SmallPtrSet<CallBase *, 8> OptimizedCalls; 563 564 // This map keeps track of the number of "unsafe" uses of a loaded function 565 // pointer. The key is the associated llvm.type.test intrinsic call generated 566 // by this pass. An unsafe use is one that calls the loaded function pointer 567 // directly. Every time we eliminate an unsafe use (for example, by 568 // devirtualizing it or by applying virtual constant propagation), we 569 // decrement the value stored in this map. If a value reaches zero, we can 570 // eliminate the type check by RAUWing the associated llvm.type.test call with 571 // true. 572 std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 573 PatternList FunctionsToSkip; 574 575 DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 576 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 577 function_ref<DominatorTree &(Function &)> LookupDomTree, 578 ModuleSummaryIndex *ExportSummary, 579 const ModuleSummaryIndex *ImportSummary) 580 : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), 581 ExportSummary(ExportSummary), ImportSummary(ImportSummary), 582 Int8Ty(Type::getInt8Ty(M.getContext())), 583 Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 584 Int32Ty(Type::getInt32Ty(M.getContext())), 585 Int64Ty(Type::getInt64Ty(M.getContext())), 586 IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), 587 Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)), 588 RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { 589 assert(!(ExportSummary && ImportSummary)); 590 FunctionsToSkip.init(SkipFunctionNames); 591 } 592 593 bool areRemarksEnabled(); 594 595 void 596 scanTypeTestUsers(Function *TypeTestFunc, 597 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 598 void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 599 600 void buildTypeIdentifierMap( 601 std::vector<VTableBits> &Bits, 602 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 603 604 bool 605 tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 606 const std::set<TypeMemberInfo> &TypeMemberInfos, 607 uint64_t ByteOffset, 608 ModuleSummaryIndex *ExportSummary); 609 610 void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, 611 bool &IsExported); 612 bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary, 613 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 614 VTableSlotInfo &SlotInfo, 615 WholeProgramDevirtResolution *Res); 616 617 void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, 618 bool &IsExported); 619 void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 620 VTableSlotInfo &SlotInfo, 621 WholeProgramDevirtResolution *Res, VTableSlot Slot); 622 623 bool tryEvaluateFunctionsWithArgs( 624 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 625 ArrayRef<uint64_t> Args); 626 627 void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 628 uint64_t TheRetVal); 629 bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 630 CallSiteInfo &CSInfo, 631 WholeProgramDevirtResolution::ByArg *Res); 632 633 // Returns the global symbol name that is used to export information about the 634 // given vtable slot and list of arguments. 635 std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, 636 StringRef Name); 637 638 bool shouldExportConstantsAsAbsoluteSymbols(); 639 640 // This function is called during the export phase to create a symbol 641 // definition containing information about the given vtable slot and list of 642 // arguments. 643 void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 644 Constant *C); 645 void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 646 uint32_t Const, uint32_t &Storage); 647 648 // This function is called during the import phase to create a reference to 649 // the symbol definition created during the export phase. 650 Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 651 StringRef Name); 652 Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 653 StringRef Name, IntegerType *IntTy, 654 uint32_t Storage); 655 656 Constant *getMemberAddr(const TypeMemberInfo *M); 657 658 void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 659 Constant *UniqueMemberAddr); 660 bool tryUniqueRetValOpt(unsigned BitWidth, 661 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 662 CallSiteInfo &CSInfo, 663 WholeProgramDevirtResolution::ByArg *Res, 664 VTableSlot Slot, ArrayRef<uint64_t> Args); 665 666 void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 667 Constant *Byte, Constant *Bit); 668 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 669 VTableSlotInfo &SlotInfo, 670 WholeProgramDevirtResolution *Res, VTableSlot Slot); 671 672 void rebuildGlobal(VTableBits &B); 673 674 // Apply the summary resolution for Slot to all virtual calls in SlotInfo. 675 void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); 676 677 // If we were able to eliminate all unsafe uses for a type checked load, 678 // eliminate the associated type tests by replacing them with true. 679 void removeRedundantTypeTests(); 680 681 bool run(); 682 683 // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`. 684 // 685 // Caller guarantees that `ExportSummary` is not nullptr. 686 static ValueInfo lookUpFunctionValueInfo(Function *TheFn, 687 ModuleSummaryIndex *ExportSummary); 688 689 // Returns true if the function definition must be unreachable. 690 // 691 // Note if this helper function returns true, `F` is guaranteed 692 // to be unreachable; if it returns false, `F` might still 693 // be unreachable but not covered by this helper function. 694 // 695 // Implementation-wise, if function definition is present, IR is analyzed; if 696 // not, look up function flags from ExportSummary as a fallback. 697 static bool mustBeUnreachableFunction(Function *const F, 698 ModuleSummaryIndex *ExportSummary); 699 700 // Lower the module using the action and summary passed as command line 701 // arguments. For testing purposes only. 702 static bool 703 runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter, 704 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 705 function_ref<DominatorTree &(Function &)> LookupDomTree); 706 }; 707 708 struct DevirtIndex { 709 ModuleSummaryIndex &ExportSummary; 710 // The set in which to record GUIDs exported from their module by 711 // devirtualization, used by client to ensure they are not internalized. 712 std::set<GlobalValue::GUID> &ExportedGUIDs; 713 // A map in which to record the information necessary to locate the WPD 714 // resolution for local targets in case they are exported by cross module 715 // importing. 716 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap; 717 718 MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots; 719 720 PatternList FunctionsToSkip; 721 722 DevirtIndex( 723 ModuleSummaryIndex &ExportSummary, 724 std::set<GlobalValue::GUID> &ExportedGUIDs, 725 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) 726 : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs), 727 LocalWPDTargetsMap(LocalWPDTargetsMap) { 728 FunctionsToSkip.init(SkipFunctionNames); 729 } 730 731 bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot, 732 const TypeIdCompatibleVtableInfo TIdInfo, 733 uint64_t ByteOffset); 734 735 bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 736 VTableSlotSummary &SlotSummary, 737 VTableSlotInfo &SlotInfo, 738 WholeProgramDevirtResolution *Res, 739 std::set<ValueInfo> &DevirtTargets); 740 741 void run(); 742 }; 743 744 struct WholeProgramDevirt : public ModulePass { 745 static char ID; 746 747 bool UseCommandLine = false; 748 749 ModuleSummaryIndex *ExportSummary = nullptr; 750 const ModuleSummaryIndex *ImportSummary = nullptr; 751 752 WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { 753 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 754 } 755 756 WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, 757 const ModuleSummaryIndex *ImportSummary) 758 : ModulePass(ID), ExportSummary(ExportSummary), 759 ImportSummary(ImportSummary) { 760 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 761 } 762 763 bool runOnModule(Module &M) override { 764 if (skipModule(M)) 765 return false; 766 767 // In the new pass manager, we can request the optimization 768 // remark emitter pass on a per-function-basis, which the 769 // OREGetter will do for us. 770 // In the old pass manager, this is harder, so we just build 771 // an optimization remark emitter on the fly, when we need it. 772 std::unique_ptr<OptimizationRemarkEmitter> ORE; 773 auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 774 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 775 return *ORE; 776 }; 777 778 auto LookupDomTree = [this](Function &F) -> DominatorTree & { 779 return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); 780 }; 781 782 if (UseCommandLine) 783 return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter, 784 LookupDomTree); 785 786 return DevirtModule(M, LegacyAARGetter(*this), OREGetter, LookupDomTree, 787 ExportSummary, ImportSummary) 788 .run(); 789 } 790 791 void getAnalysisUsage(AnalysisUsage &AU) const override { 792 AU.addRequired<AssumptionCacheTracker>(); 793 AU.addRequired<TargetLibraryInfoWrapperPass>(); 794 AU.addRequired<DominatorTreeWrapperPass>(); 795 } 796 }; 797 798 } // end anonymous namespace 799 800 INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", 801 "Whole program devirtualization", false, false) 802 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 803 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 804 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 805 INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", 806 "Whole program devirtualization", false, false) 807 char WholeProgramDevirt::ID = 0; 808 809 ModulePass * 810 llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, 811 const ModuleSummaryIndex *ImportSummary) { 812 return new WholeProgramDevirt(ExportSummary, ImportSummary); 813 } 814 815 PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 816 ModuleAnalysisManager &AM) { 817 auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 818 auto AARGetter = [&](Function &F) -> AAResults & { 819 return FAM.getResult<AAManager>(F); 820 }; 821 auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 822 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 823 }; 824 auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { 825 return FAM.getResult<DominatorTreeAnalysis>(F); 826 }; 827 if (UseCommandLine) { 828 if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) 829 return PreservedAnalyses::all(); 830 return PreservedAnalyses::none(); 831 } 832 if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, 833 ImportSummary) 834 .run()) 835 return PreservedAnalyses::all(); 836 return PreservedAnalyses::none(); 837 } 838 839 // Enable whole program visibility if enabled by client (e.g. linker) or 840 // internal option, and not force disabled. 841 static bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { 842 return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && 843 !DisableWholeProgramVisibility; 844 } 845 846 namespace llvm { 847 848 /// If whole program visibility asserted, then upgrade all public vcall 849 /// visibility metadata on vtable definitions to linkage unit visibility in 850 /// Module IR (for regular or hybrid LTO). 851 void updateVCallVisibilityInModule( 852 Module &M, bool WholeProgramVisibilityEnabledInLTO, 853 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { 854 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 855 return; 856 for (GlobalVariable &GV : M.globals()) 857 // Add linkage unit visibility to any variable with type metadata, which are 858 // the vtable definitions. We won't have an existing vcall_visibility 859 // metadata on vtable definitions with public visibility. 860 if (GV.hasMetadata(LLVMContext::MD_type) && 861 GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && 862 // Don't upgrade the visibility for symbols exported to the dynamic 863 // linker, as we have no information on their eventual use. 864 !DynamicExportSymbols.count(GV.getGUID())) 865 GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); 866 } 867 868 /// If whole program visibility asserted, then upgrade all public vcall 869 /// visibility metadata on vtable definition summaries to linkage unit 870 /// visibility in Module summary index (for ThinLTO). 871 void updateVCallVisibilityInIndex( 872 ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, 873 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { 874 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 875 return; 876 for (auto &P : Index) { 877 // Don't upgrade the visibility for symbols exported to the dynamic 878 // linker, as we have no information on their eventual use. 879 if (DynamicExportSymbols.count(P.first)) 880 continue; 881 for (auto &S : P.second.SummaryList) { 882 auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); 883 if (!GVar || 884 GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) 885 continue; 886 GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); 887 } 888 } 889 } 890 891 void runWholeProgramDevirtOnIndex( 892 ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, 893 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 894 DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); 895 } 896 897 void updateIndexWPDForExports( 898 ModuleSummaryIndex &Summary, 899 function_ref<bool(StringRef, ValueInfo)> isExported, 900 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 901 for (auto &T : LocalWPDTargetsMap) { 902 auto &VI = T.first; 903 // This was enforced earlier during trySingleImplDevirt. 904 assert(VI.getSummaryList().size() == 1 && 905 "Devirt of local target has more than one copy"); 906 auto &S = VI.getSummaryList()[0]; 907 if (!isExported(S->modulePath(), VI)) 908 continue; 909 910 // It's been exported by a cross module import. 911 for (auto &SlotSummary : T.second) { 912 auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID); 913 assert(TIdSum); 914 auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset); 915 assert(WPDRes != TIdSum->WPDRes.end()); 916 WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 917 WPDRes->second.SingleImplName, 918 Summary.getModuleHash(S->modulePath())); 919 } 920 } 921 } 922 923 } // end namespace llvm 924 925 static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { 926 // Check that summary index contains regular LTO module when performing 927 // export to prevent occasional use of index from pure ThinLTO compilation 928 // (-fno-split-lto-module). This kind of summary index is passed to 929 // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. 930 const auto &ModPaths = Summary->modulePaths(); 931 if (ClSummaryAction != PassSummaryAction::Import && 932 ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) == 933 ModPaths.end()) 934 return createStringError( 935 errc::invalid_argument, 936 "combined summary should contain Regular LTO module"); 937 return ErrorSuccess(); 938 } 939 940 bool DevirtModule::runForTesting( 941 Module &M, function_ref<AAResults &(Function &)> AARGetter, 942 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 943 function_ref<DominatorTree &(Function &)> LookupDomTree) { 944 std::unique_ptr<ModuleSummaryIndex> Summary = 945 std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false); 946 947 // Handle the command-line summary arguments. This code is for testing 948 // purposes only, so we handle errors directly. 949 if (!ClReadSummary.empty()) { 950 ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 951 ": "); 952 auto ReadSummaryFile = 953 ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 954 if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr = 955 getModuleSummaryIndex(*ReadSummaryFile)) { 956 Summary = std::move(*SummaryOrErr); 957 ExitOnErr(checkCombinedSummaryForTesting(Summary.get())); 958 } else { 959 // Try YAML if we've failed with bitcode. 960 consumeError(SummaryOrErr.takeError()); 961 yaml::Input In(ReadSummaryFile->getBuffer()); 962 In >> *Summary; 963 ExitOnErr(errorCodeToError(In.error())); 964 } 965 } 966 967 bool Changed = 968 DevirtModule(M, AARGetter, OREGetter, LookupDomTree, 969 ClSummaryAction == PassSummaryAction::Export ? Summary.get() 970 : nullptr, 971 ClSummaryAction == PassSummaryAction::Import ? Summary.get() 972 : nullptr) 973 .run(); 974 975 if (!ClWriteSummary.empty()) { 976 ExitOnError ExitOnErr( 977 "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 978 std::error_code EC; 979 if (StringRef(ClWriteSummary).endswith(".bc")) { 980 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); 981 ExitOnErr(errorCodeToError(EC)); 982 writeIndexToFile(*Summary, OS); 983 } else { 984 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF); 985 ExitOnErr(errorCodeToError(EC)); 986 yaml::Output Out(OS); 987 Out << *Summary; 988 } 989 } 990 991 return Changed; 992 } 993 994 void DevirtModule::buildTypeIdentifierMap( 995 std::vector<VTableBits> &Bits, 996 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 997 DenseMap<GlobalVariable *, VTableBits *> GVToBits; 998 Bits.reserve(M.getGlobalList().size()); 999 SmallVector<MDNode *, 2> Types; 1000 for (GlobalVariable &GV : M.globals()) { 1001 Types.clear(); 1002 GV.getMetadata(LLVMContext::MD_type, Types); 1003 if (GV.isDeclaration() || Types.empty()) 1004 continue; 1005 1006 VTableBits *&BitsPtr = GVToBits[&GV]; 1007 if (!BitsPtr) { 1008 Bits.emplace_back(); 1009 Bits.back().GV = &GV; 1010 Bits.back().ObjectSize = 1011 M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 1012 BitsPtr = &Bits.back(); 1013 } 1014 1015 for (MDNode *Type : Types) { 1016 auto TypeID = Type->getOperand(1).get(); 1017 1018 uint64_t Offset = 1019 cast<ConstantInt>( 1020 cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 1021 ->getZExtValue(); 1022 1023 TypeIdMap[TypeID].insert({BitsPtr, Offset}); 1024 } 1025 } 1026 } 1027 1028 bool DevirtModule::tryFindVirtualCallTargets( 1029 std::vector<VirtualCallTarget> &TargetsForSlot, 1030 const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset, 1031 ModuleSummaryIndex *ExportSummary) { 1032 for (const TypeMemberInfo &TM : TypeMemberInfos) { 1033 if (!TM.Bits->GV->isConstant()) 1034 return false; 1035 1036 // We cannot perform whole program devirtualization analysis on a vtable 1037 // with public LTO visibility. 1038 if (TM.Bits->GV->getVCallVisibility() == 1039 GlobalObject::VCallVisibilityPublic) 1040 return false; 1041 1042 Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), 1043 TM.Offset + ByteOffset, M); 1044 if (!Ptr) 1045 return false; 1046 1047 auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); 1048 if (!Fn) 1049 return false; 1050 1051 if (FunctionsToSkip.match(Fn->getName())) 1052 return false; 1053 1054 // We can disregard __cxa_pure_virtual as a possible call target, as 1055 // calls to pure virtuals are UB. 1056 if (Fn->getName() == "__cxa_pure_virtual") 1057 continue; 1058 1059 // We can disregard unreachable functions as possible call targets, as 1060 // unreachable functions shouldn't be called. 1061 if (mustBeUnreachableFunction(Fn, ExportSummary)) 1062 continue; 1063 1064 TargetsForSlot.push_back({Fn, &TM}); 1065 } 1066 1067 // Give up if we couldn't find any targets. 1068 return !TargetsForSlot.empty(); 1069 } 1070 1071 bool DevirtIndex::tryFindVirtualCallTargets( 1072 std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, 1073 uint64_t ByteOffset) { 1074 for (const TypeIdOffsetVtableInfo &P : TIdInfo) { 1075 // Find a representative copy of the vtable initializer. 1076 // We can have multiple available_externally, linkonce_odr and weak_odr 1077 // vtable initializers. We can also have multiple external vtable 1078 // initializers in the case of comdats, which we cannot check here. 1079 // The linker should give an error in this case. 1080 // 1081 // Also, handle the case of same-named local Vtables with the same path 1082 // and therefore the same GUID. This can happen if there isn't enough 1083 // distinguishing path when compiling the source file. In that case we 1084 // conservatively return false early. 1085 const GlobalVarSummary *VS = nullptr; 1086 bool LocalFound = false; 1087 for (auto &S : P.VTableVI.getSummaryList()) { 1088 if (GlobalValue::isLocalLinkage(S->linkage())) { 1089 if (LocalFound) 1090 return false; 1091 LocalFound = true; 1092 } 1093 auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); 1094 if (!CurVS->vTableFuncs().empty() || 1095 // Previously clang did not attach the necessary type metadata to 1096 // available_externally vtables, in which case there would not 1097 // be any vtable functions listed in the summary and we need 1098 // to treat this case conservatively (in case the bitcode is old). 1099 // However, we will also not have any vtable functions in the 1100 // case of a pure virtual base class. In that case we do want 1101 // to set VS to avoid treating it conservatively. 1102 !GlobalValue::isAvailableExternallyLinkage(S->linkage())) { 1103 VS = CurVS; 1104 // We cannot perform whole program devirtualization analysis on a vtable 1105 // with public LTO visibility. 1106 if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) 1107 return false; 1108 } 1109 } 1110 // There will be no VS if all copies are available_externally having no 1111 // type metadata. In that case we can't safely perform WPD. 1112 if (!VS) 1113 return false; 1114 if (!VS->isLive()) 1115 continue; 1116 for (auto VTP : VS->vTableFuncs()) { 1117 if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset) 1118 continue; 1119 1120 if (mustBeUnreachableFunction(VTP.FuncVI)) 1121 continue; 1122 1123 TargetsForSlot.push_back(VTP.FuncVI); 1124 } 1125 } 1126 1127 // Give up if we couldn't find any targets. 1128 return !TargetsForSlot.empty(); 1129 } 1130 1131 void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 1132 Constant *TheFn, bool &IsExported) { 1133 // Don't devirtualize function if we're told to skip it 1134 // in -wholeprogramdevirt-skip. 1135 if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName())) 1136 return; 1137 auto Apply = [&](CallSiteInfo &CSInfo) { 1138 for (auto &&VCallSite : CSInfo.CallSites) { 1139 if (!OptimizedCalls.insert(&VCallSite.CB).second) 1140 continue; 1141 1142 if (RemarksEnabled) 1143 VCallSite.emitRemark("single-impl", 1144 TheFn->stripPointerCasts()->getName(), OREGetter); 1145 auto &CB = VCallSite.CB; 1146 assert(!CB.getCalledFunction() && "devirtualizing direct call?"); 1147 IRBuilder<> Builder(&CB); 1148 Value *Callee = 1149 Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType()); 1150 1151 // If trap checking is enabled, add support to compare the virtual 1152 // function pointer to the devirtualized target. In case of a mismatch, 1153 // perform a debug trap. 1154 if (DevirtCheckMode == WPDCheckMode::Trap) { 1155 auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee); 1156 Instruction *ThenTerm = 1157 SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false); 1158 Builder.SetInsertPoint(ThenTerm); 1159 Function *TrapFn = Intrinsic::getDeclaration(&M, Intrinsic::debugtrap); 1160 auto *CallTrap = Builder.CreateCall(TrapFn); 1161 CallTrap->setDebugLoc(CB.getDebugLoc()); 1162 } 1163 1164 // If fallback checking is enabled, add support to compare the virtual 1165 // function pointer to the devirtualized target. In case of a mismatch, 1166 // fall back to indirect call. 1167 if (DevirtCheckMode == WPDCheckMode::Fallback) { 1168 MDNode *Weights = 1169 MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); 1170 // Version the indirect call site. If the called value is equal to the 1171 // given callee, 'NewInst' will be executed, otherwise the original call 1172 // site will be executed. 1173 CallBase &NewInst = versionCallSite(CB, Callee, Weights); 1174 NewInst.setCalledOperand(Callee); 1175 // Since the new call site is direct, we must clear metadata that 1176 // is only appropriate for indirect calls. This includes !prof and 1177 // !callees metadata. 1178 NewInst.setMetadata(LLVMContext::MD_prof, nullptr); 1179 NewInst.setMetadata(LLVMContext::MD_callees, nullptr); 1180 // Additionally, we should remove them from the fallback indirect call, 1181 // so that we don't attempt to perform indirect call promotion later. 1182 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1183 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1184 } 1185 1186 // In either trapping or non-checking mode, devirtualize original call. 1187 else { 1188 // Devirtualize unconditionally. 1189 CB.setCalledOperand(Callee); 1190 // Since the call site is now direct, we must clear metadata that 1191 // is only appropriate for indirect calls. This includes !prof and 1192 // !callees metadata. 1193 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1194 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1195 } 1196 1197 // This use is no longer unsafe. 1198 if (VCallSite.NumUnsafeUses) 1199 --*VCallSite.NumUnsafeUses; 1200 } 1201 if (CSInfo.isExported()) 1202 IsExported = true; 1203 CSInfo.markDevirt(); 1204 }; 1205 Apply(SlotInfo.CSInfo); 1206 for (auto &P : SlotInfo.ConstCSInfo) 1207 Apply(P.second); 1208 } 1209 1210 static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { 1211 // We can't add calls if we haven't seen a definition 1212 if (Callee.getSummaryList().empty()) 1213 return false; 1214 1215 // Insert calls into the summary index so that the devirtualized targets 1216 // are eligible for import. 1217 // FIXME: Annotate type tests with hotness. For now, mark these as hot 1218 // to better ensure we have the opportunity to inline them. 1219 bool IsExported = false; 1220 auto &S = Callee.getSummaryList()[0]; 1221 CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); 1222 auto AddCalls = [&](CallSiteInfo &CSInfo) { 1223 for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { 1224 FS->addCall({Callee, CI}); 1225 IsExported |= S->modulePath() != FS->modulePath(); 1226 } 1227 for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) { 1228 FS->addCall({Callee, CI}); 1229 IsExported |= S->modulePath() != FS->modulePath(); 1230 } 1231 }; 1232 AddCalls(SlotInfo.CSInfo); 1233 for (auto &P : SlotInfo.ConstCSInfo) 1234 AddCalls(P.second); 1235 return IsExported; 1236 } 1237 1238 bool DevirtModule::trySingleImplDevirt( 1239 ModuleSummaryIndex *ExportSummary, 1240 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1241 WholeProgramDevirtResolution *Res) { 1242 // See if the program contains a single implementation of this virtual 1243 // function. 1244 Function *TheFn = TargetsForSlot[0].Fn; 1245 for (auto &&Target : TargetsForSlot) 1246 if (TheFn != Target.Fn) 1247 return false; 1248 1249 // If so, update each call site to call that implementation directly. 1250 if (RemarksEnabled) 1251 TargetsForSlot[0].WasDevirt = true; 1252 1253 bool IsExported = false; 1254 applySingleImplDevirt(SlotInfo, TheFn, IsExported); 1255 if (!IsExported) 1256 return false; 1257 1258 // If the only implementation has local linkage, we must promote to external 1259 // to make it visible to thin LTO objects. We can only get here during the 1260 // ThinLTO export phase. 1261 if (TheFn->hasLocalLinkage()) { 1262 std::string NewName = (TheFn->getName() + ".llvm.merged").str(); 1263 1264 // Since we are renaming the function, any comdats with the same name must 1265 // also be renamed. This is required when targeting COFF, as the comdat name 1266 // must match one of the names of the symbols in the comdat. 1267 if (Comdat *C = TheFn->getComdat()) { 1268 if (C->getName() == TheFn->getName()) { 1269 Comdat *NewC = M.getOrInsertComdat(NewName); 1270 NewC->setSelectionKind(C->getSelectionKind()); 1271 for (GlobalObject &GO : M.global_objects()) 1272 if (GO.getComdat() == C) 1273 GO.setComdat(NewC); 1274 } 1275 } 1276 1277 TheFn->setLinkage(GlobalValue::ExternalLinkage); 1278 TheFn->setVisibility(GlobalValue::HiddenVisibility); 1279 TheFn->setName(NewName); 1280 } 1281 if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID())) 1282 // Any needed promotion of 'TheFn' has already been done during 1283 // LTO unit split, so we can ignore return value of AddCalls. 1284 AddCalls(SlotInfo, TheFnVI); 1285 1286 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1287 Res->SingleImplName = std::string(TheFn->getName()); 1288 1289 return true; 1290 } 1291 1292 bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 1293 VTableSlotSummary &SlotSummary, 1294 VTableSlotInfo &SlotInfo, 1295 WholeProgramDevirtResolution *Res, 1296 std::set<ValueInfo> &DevirtTargets) { 1297 // See if the program contains a single implementation of this virtual 1298 // function. 1299 auto TheFn = TargetsForSlot[0]; 1300 for (auto &&Target : TargetsForSlot) 1301 if (TheFn != Target) 1302 return false; 1303 1304 // Don't devirtualize if we don't have target definition. 1305 auto Size = TheFn.getSummaryList().size(); 1306 if (!Size) 1307 return false; 1308 1309 // Don't devirtualize function if we're told to skip it 1310 // in -wholeprogramdevirt-skip. 1311 if (FunctionsToSkip.match(TheFn.name())) 1312 return false; 1313 1314 // If the summary list contains multiple summaries where at least one is 1315 // a local, give up, as we won't know which (possibly promoted) name to use. 1316 for (auto &S : TheFn.getSummaryList()) 1317 if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) 1318 return false; 1319 1320 // Collect functions devirtualized at least for one call site for stats. 1321 if (PrintSummaryDevirt) 1322 DevirtTargets.insert(TheFn); 1323 1324 auto &S = TheFn.getSummaryList()[0]; 1325 bool IsExported = AddCalls(SlotInfo, TheFn); 1326 if (IsExported) 1327 ExportedGUIDs.insert(TheFn.getGUID()); 1328 1329 // Record in summary for use in devirtualization during the ThinLTO import 1330 // step. 1331 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1332 if (GlobalValue::isLocalLinkage(S->linkage())) { 1333 if (IsExported) 1334 // If target is a local function and we are exporting it by 1335 // devirtualizing a call in another module, we need to record the 1336 // promoted name. 1337 Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 1338 TheFn.name(), ExportSummary.getModuleHash(S->modulePath())); 1339 else { 1340 LocalWPDTargetsMap[TheFn].push_back(SlotSummary); 1341 Res->SingleImplName = std::string(TheFn.name()); 1342 } 1343 } else 1344 Res->SingleImplName = std::string(TheFn.name()); 1345 1346 // Name will be empty if this thin link driven off of serialized combined 1347 // index (e.g. llvm-lto). However, WPD is not supported/invoked for the 1348 // legacy LTO API anyway. 1349 assert(!Res->SingleImplName.empty()); 1350 1351 return true; 1352 } 1353 1354 void DevirtModule::tryICallBranchFunnel( 1355 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1356 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1357 Triple T(M.getTargetTriple()); 1358 if (T.getArch() != Triple::x86_64) 1359 return; 1360 1361 if (TargetsForSlot.size() > ClThreshold) 1362 return; 1363 1364 bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; 1365 if (!HasNonDevirt) 1366 for (auto &P : SlotInfo.ConstCSInfo) 1367 if (!P.second.AllCallSitesDevirted) { 1368 HasNonDevirt = true; 1369 break; 1370 } 1371 1372 if (!HasNonDevirt) 1373 return; 1374 1375 FunctionType *FT = 1376 FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true); 1377 Function *JT; 1378 if (isa<MDString>(Slot.TypeID)) { 1379 JT = Function::Create(FT, Function::ExternalLinkage, 1380 M.getDataLayout().getProgramAddressSpace(), 1381 getGlobalName(Slot, {}, "branch_funnel"), &M); 1382 JT->setVisibility(GlobalValue::HiddenVisibility); 1383 } else { 1384 JT = Function::Create(FT, Function::InternalLinkage, 1385 M.getDataLayout().getProgramAddressSpace(), 1386 "branch_funnel", &M); 1387 } 1388 JT->addParamAttr(0, Attribute::Nest); 1389 1390 std::vector<Value *> JTArgs; 1391 JTArgs.push_back(JT->arg_begin()); 1392 for (auto &T : TargetsForSlot) { 1393 JTArgs.push_back(getMemberAddr(T.TM)); 1394 JTArgs.push_back(T.Fn); 1395 } 1396 1397 BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); 1398 Function *Intr = 1399 Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {}); 1400 1401 auto *CI = CallInst::Create(Intr, JTArgs, "", BB); 1402 CI->setTailCallKind(CallInst::TCK_MustTail); 1403 ReturnInst::Create(M.getContext(), nullptr, BB); 1404 1405 bool IsExported = false; 1406 applyICallBranchFunnel(SlotInfo, JT, IsExported); 1407 if (IsExported) 1408 Res->TheKind = WholeProgramDevirtResolution::BranchFunnel; 1409 } 1410 1411 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, 1412 Constant *JT, bool &IsExported) { 1413 auto Apply = [&](CallSiteInfo &CSInfo) { 1414 if (CSInfo.isExported()) 1415 IsExported = true; 1416 if (CSInfo.AllCallSitesDevirted) 1417 return; 1418 for (auto &&VCallSite : CSInfo.CallSites) { 1419 CallBase &CB = VCallSite.CB; 1420 1421 // Jump tables are only profitable if the retpoline mitigation is enabled. 1422 Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); 1423 if (!FSAttr.isValid() || 1424 !FSAttr.getValueAsString().contains("+retpoline")) 1425 continue; 1426 1427 if (RemarksEnabled) 1428 VCallSite.emitRemark("branch-funnel", 1429 JT->stripPointerCasts()->getName(), OREGetter); 1430 1431 // Pass the address of the vtable in the nest register, which is r10 on 1432 // x86_64. 1433 std::vector<Type *> NewArgs; 1434 NewArgs.push_back(Int8PtrTy); 1435 append_range(NewArgs, CB.getFunctionType()->params()); 1436 FunctionType *NewFT = 1437 FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, 1438 CB.getFunctionType()->isVarArg()); 1439 PointerType *NewFTPtr = PointerType::getUnqual(NewFT); 1440 1441 IRBuilder<> IRB(&CB); 1442 std::vector<Value *> Args; 1443 Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); 1444 llvm::append_range(Args, CB.args()); 1445 1446 CallBase *NewCS = nullptr; 1447 if (isa<CallInst>(CB)) 1448 NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args); 1449 else 1450 NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr), 1451 cast<InvokeInst>(CB).getNormalDest(), 1452 cast<InvokeInst>(CB).getUnwindDest(), Args); 1453 NewCS->setCallingConv(CB.getCallingConv()); 1454 1455 AttributeList Attrs = CB.getAttributes(); 1456 std::vector<AttributeSet> NewArgAttrs; 1457 NewArgAttrs.push_back(AttributeSet::get( 1458 M.getContext(), ArrayRef<Attribute>{Attribute::get( 1459 M.getContext(), Attribute::Nest)})); 1460 for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) 1461 NewArgAttrs.push_back(Attrs.getParamAttrs(I)); 1462 NewCS->setAttributes( 1463 AttributeList::get(M.getContext(), Attrs.getFnAttrs(), 1464 Attrs.getRetAttrs(), NewArgAttrs)); 1465 1466 CB.replaceAllUsesWith(NewCS); 1467 CB.eraseFromParent(); 1468 1469 // This use is no longer unsafe. 1470 if (VCallSite.NumUnsafeUses) 1471 --*VCallSite.NumUnsafeUses; 1472 } 1473 // Don't mark as devirtualized because there may be callers compiled without 1474 // retpoline mitigation, which would mean that they are lowered to 1475 // llvm.type.test and therefore require an llvm.type.test resolution for the 1476 // type identifier. 1477 }; 1478 Apply(SlotInfo.CSInfo); 1479 for (auto &P : SlotInfo.ConstCSInfo) 1480 Apply(P.second); 1481 } 1482 1483 bool DevirtModule::tryEvaluateFunctionsWithArgs( 1484 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1485 ArrayRef<uint64_t> Args) { 1486 // Evaluate each function and store the result in each target's RetVal 1487 // field. 1488 for (VirtualCallTarget &Target : TargetsForSlot) { 1489 if (Target.Fn->arg_size() != Args.size() + 1) 1490 return false; 1491 1492 Evaluator Eval(M.getDataLayout(), nullptr); 1493 SmallVector<Constant *, 2> EvalArgs; 1494 EvalArgs.push_back( 1495 Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 1496 for (unsigned I = 0; I != Args.size(); ++I) { 1497 auto *ArgTy = dyn_cast<IntegerType>( 1498 Target.Fn->getFunctionType()->getParamType(I + 1)); 1499 if (!ArgTy) 1500 return false; 1501 EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 1502 } 1503 1504 Constant *RetVal; 1505 if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 1506 !isa<ConstantInt>(RetVal)) 1507 return false; 1508 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 1509 } 1510 return true; 1511 } 1512 1513 void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1514 uint64_t TheRetVal) { 1515 for (auto Call : CSInfo.CallSites) { 1516 if (!OptimizedCalls.insert(&Call.CB).second) 1517 continue; 1518 Call.replaceAndErase( 1519 "uniform-ret-val", FnName, RemarksEnabled, OREGetter, 1520 ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal)); 1521 } 1522 CSInfo.markDevirt(); 1523 } 1524 1525 bool DevirtModule::tryUniformRetValOpt( 1526 MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, 1527 WholeProgramDevirtResolution::ByArg *Res) { 1528 // Uniform return value optimization. If all functions return the same 1529 // constant, replace all calls with that constant. 1530 uint64_t TheRetVal = TargetsForSlot[0].RetVal; 1531 for (const VirtualCallTarget &Target : TargetsForSlot) 1532 if (Target.RetVal != TheRetVal) 1533 return false; 1534 1535 if (CSInfo.isExported()) { 1536 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; 1537 Res->Info = TheRetVal; 1538 } 1539 1540 applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 1541 if (RemarksEnabled) 1542 for (auto &&Target : TargetsForSlot) 1543 Target.WasDevirt = true; 1544 return true; 1545 } 1546 1547 std::string DevirtModule::getGlobalName(VTableSlot Slot, 1548 ArrayRef<uint64_t> Args, 1549 StringRef Name) { 1550 std::string FullName = "__typeid_"; 1551 raw_string_ostream OS(FullName); 1552 OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; 1553 for (uint64_t Arg : Args) 1554 OS << '_' << Arg; 1555 OS << '_' << Name; 1556 return OS.str(); 1557 } 1558 1559 bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { 1560 Triple T(M.getTargetTriple()); 1561 return T.isX86() && T.getObjectFormat() == Triple::ELF; 1562 } 1563 1564 void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1565 StringRef Name, Constant *C) { 1566 GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, 1567 getGlobalName(Slot, Args, Name), C, &M); 1568 GA->setVisibility(GlobalValue::HiddenVisibility); 1569 } 1570 1571 void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1572 StringRef Name, uint32_t Const, 1573 uint32_t &Storage) { 1574 if (shouldExportConstantsAsAbsoluteSymbols()) { 1575 exportGlobal( 1576 Slot, Args, Name, 1577 ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy)); 1578 return; 1579 } 1580 1581 Storage = Const; 1582 } 1583 1584 Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1585 StringRef Name) { 1586 Constant *C = 1587 M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty); 1588 auto *GV = dyn_cast<GlobalVariable>(C); 1589 if (GV) 1590 GV->setVisibility(GlobalValue::HiddenVisibility); 1591 return C; 1592 } 1593 1594 Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1595 StringRef Name, IntegerType *IntTy, 1596 uint32_t Storage) { 1597 if (!shouldExportConstantsAsAbsoluteSymbols()) 1598 return ConstantInt::get(IntTy, Storage); 1599 1600 Constant *C = importGlobal(Slot, Args, Name); 1601 auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); 1602 C = ConstantExpr::getPtrToInt(C, IntTy); 1603 1604 // We only need to set metadata if the global is newly created, in which 1605 // case it would not have hidden visibility. 1606 if (GV->hasMetadata(LLVMContext::MD_absolute_symbol)) 1607 return C; 1608 1609 auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { 1610 auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); 1611 auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); 1612 GV->setMetadata(LLVMContext::MD_absolute_symbol, 1613 MDNode::get(M.getContext(), {MinC, MaxC})); 1614 }; 1615 unsigned AbsWidth = IntTy->getBitWidth(); 1616 if (AbsWidth == IntPtrTy->getBitWidth()) 1617 SetAbsRange(~0ull, ~0ull); // Full set. 1618 else 1619 SetAbsRange(0, 1ull << AbsWidth); 1620 return C; 1621 } 1622 1623 void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1624 bool IsOne, 1625 Constant *UniqueMemberAddr) { 1626 for (auto &&Call : CSInfo.CallSites) { 1627 if (!OptimizedCalls.insert(&Call.CB).second) 1628 continue; 1629 IRBuilder<> B(&Call.CB); 1630 Value *Cmp = 1631 B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, 1632 B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); 1633 Cmp = B.CreateZExt(Cmp, Call.CB.getType()); 1634 Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, 1635 Cmp); 1636 } 1637 CSInfo.markDevirt(); 1638 } 1639 1640 Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { 1641 Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); 1642 return ConstantExpr::getGetElementPtr(Int8Ty, C, 1643 ConstantInt::get(Int64Ty, M->Offset)); 1644 } 1645 1646 bool DevirtModule::tryUniqueRetValOpt( 1647 unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1648 CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, 1649 VTableSlot Slot, ArrayRef<uint64_t> Args) { 1650 // IsOne controls whether we look for a 0 or a 1. 1651 auto tryUniqueRetValOptFor = [&](bool IsOne) { 1652 const TypeMemberInfo *UniqueMember = nullptr; 1653 for (const VirtualCallTarget &Target : TargetsForSlot) { 1654 if (Target.RetVal == (IsOne ? 1 : 0)) { 1655 if (UniqueMember) 1656 return false; 1657 UniqueMember = Target.TM; 1658 } 1659 } 1660 1661 // We should have found a unique member or bailed out by now. We already 1662 // checked for a uniform return value in tryUniformRetValOpt. 1663 assert(UniqueMember); 1664 1665 Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); 1666 if (CSInfo.isExported()) { 1667 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; 1668 Res->Info = IsOne; 1669 1670 exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); 1671 } 1672 1673 // Replace each call with the comparison. 1674 applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 1675 UniqueMemberAddr); 1676 1677 // Update devirtualization statistics for targets. 1678 if (RemarksEnabled) 1679 for (auto &&Target : TargetsForSlot) 1680 Target.WasDevirt = true; 1681 1682 return true; 1683 }; 1684 1685 if (BitWidth == 1) { 1686 if (tryUniqueRetValOptFor(true)) 1687 return true; 1688 if (tryUniqueRetValOptFor(false)) 1689 return true; 1690 } 1691 return false; 1692 } 1693 1694 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 1695 Constant *Byte, Constant *Bit) { 1696 for (auto Call : CSInfo.CallSites) { 1697 if (!OptimizedCalls.insert(&Call.CB).second) 1698 continue; 1699 auto *RetType = cast<IntegerType>(Call.CB.getType()); 1700 IRBuilder<> B(&Call.CB); 1701 Value *Addr = 1702 B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); 1703 if (RetType->getBitWidth() == 1) { 1704 Value *Bits = B.CreateLoad(Int8Ty, Addr); 1705 Value *BitsAndBit = B.CreateAnd(Bits, Bit); 1706 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 1707 Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 1708 OREGetter, IsBitSet); 1709 } else { 1710 Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 1711 Value *Val = B.CreateLoad(RetType, ValAddr); 1712 Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, 1713 OREGetter, Val); 1714 } 1715 } 1716 CSInfo.markDevirt(); 1717 } 1718 1719 bool DevirtModule::tryVirtualConstProp( 1720 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1721 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1722 // This only works if the function returns an integer. 1723 auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 1724 if (!RetType) 1725 return false; 1726 unsigned BitWidth = RetType->getBitWidth(); 1727 if (BitWidth > 64) 1728 return false; 1729 1730 // Make sure that each function is defined, does not access memory, takes at 1731 // least one argument, does not use its first argument (which we assume is 1732 // 'this'), and has the same return type. 1733 // 1734 // Note that we test whether this copy of the function is readnone, rather 1735 // than testing function attributes, which must hold for any copy of the 1736 // function, even a less optimized version substituted at link time. This is 1737 // sound because the virtual constant propagation optimizations effectively 1738 // inline all implementations of the virtual function into each call site, 1739 // rather than using function attributes to perform local optimization. 1740 for (VirtualCallTarget &Target : TargetsForSlot) { 1741 if (Target.Fn->isDeclaration() || 1742 computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != 1743 FMRB_DoesNotAccessMemory || 1744 Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || 1745 Target.Fn->getReturnType() != RetType) 1746 return false; 1747 } 1748 1749 for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 1750 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 1751 continue; 1752 1753 WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; 1754 if (Res) 1755 ResByArg = &Res->ResByArg[CSByConstantArg.first]; 1756 1757 if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) 1758 continue; 1759 1760 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, 1761 ResByArg, Slot, CSByConstantArg.first)) 1762 continue; 1763 1764 // Find an allocation offset in bits in all vtables associated with the 1765 // type. 1766 uint64_t AllocBefore = 1767 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 1768 uint64_t AllocAfter = 1769 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 1770 1771 // Calculate the total amount of padding needed to store a value at both 1772 // ends of the object. 1773 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 1774 for (auto &&Target : TargetsForSlot) { 1775 TotalPaddingBefore += std::max<int64_t>( 1776 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 1777 TotalPaddingAfter += std::max<int64_t>( 1778 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 1779 } 1780 1781 // If the amount of padding is too large, give up. 1782 // FIXME: do something smarter here. 1783 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 1784 continue; 1785 1786 // Calculate the offset to the value as a (possibly negative) byte offset 1787 // and (if applicable) a bit offset, and store the values in the targets. 1788 int64_t OffsetByte; 1789 uint64_t OffsetBit; 1790 if (TotalPaddingBefore <= TotalPaddingAfter) 1791 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 1792 OffsetBit); 1793 else 1794 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 1795 OffsetBit); 1796 1797 if (RemarksEnabled) 1798 for (auto &&Target : TargetsForSlot) 1799 Target.WasDevirt = true; 1800 1801 1802 if (CSByConstantArg.second.isExported()) { 1803 ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; 1804 exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte, 1805 ResByArg->Byte); 1806 exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit, 1807 ResByArg->Bit); 1808 } 1809 1810 // Rewrite each call to a load from OffsetByte/OffsetBit. 1811 Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 1812 Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 1813 applyVirtualConstProp(CSByConstantArg.second, 1814 TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 1815 } 1816 return true; 1817 } 1818 1819 void DevirtModule::rebuildGlobal(VTableBits &B) { 1820 if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 1821 return; 1822 1823 // Align the before byte array to the global's minimum alignment so that we 1824 // don't break any alignment requirements on the global. 1825 Align Alignment = M.getDataLayout().getValueOrABITypeAlignment( 1826 B.GV->getAlign(), B.GV->getValueType()); 1827 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment)); 1828 1829 // Before was stored in reverse order; flip it now. 1830 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 1831 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 1832 1833 // Build an anonymous global containing the before bytes, followed by the 1834 // original initializer, followed by the after bytes. 1835 auto NewInit = ConstantStruct::getAnon( 1836 {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 1837 B.GV->getInitializer(), 1838 ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 1839 auto NewGV = 1840 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 1841 GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 1842 NewGV->setSection(B.GV->getSection()); 1843 NewGV->setComdat(B.GV->getComdat()); 1844 NewGV->setAlignment(B.GV->getAlign()); 1845 1846 // Copy the original vtable's metadata to the anonymous global, adjusting 1847 // offsets as required. 1848 NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 1849 1850 // Build an alias named after the original global, pointing at the second 1851 // element (the original initializer). 1852 auto Alias = GlobalAlias::create( 1853 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 1854 ConstantExpr::getGetElementPtr( 1855 NewInit->getType(), NewGV, 1856 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 1857 ConstantInt::get(Int32Ty, 1)}), 1858 &M); 1859 Alias->setVisibility(B.GV->getVisibility()); 1860 Alias->takeName(B.GV); 1861 1862 B.GV->replaceAllUsesWith(Alias); 1863 B.GV->eraseFromParent(); 1864 } 1865 1866 bool DevirtModule::areRemarksEnabled() { 1867 const auto &FL = M.getFunctionList(); 1868 for (const Function &Fn : FL) { 1869 const auto &BBL = Fn.getBasicBlockList(); 1870 if (BBL.empty()) 1871 continue; 1872 auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); 1873 return DI.isEnabled(); 1874 } 1875 return false; 1876 } 1877 1878 void DevirtModule::scanTypeTestUsers( 1879 Function *TypeTestFunc, 1880 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 1881 // Find all virtual calls via a virtual table pointer %p under an assumption 1882 // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 1883 // points to a member of the type identifier %md. Group calls by (type ID, 1884 // offset) pair (effectively the identity of the virtual function) and store 1885 // to CallSlots. 1886 for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { 1887 auto *CI = dyn_cast<CallInst>(U.getUser()); 1888 if (!CI) 1889 continue; 1890 1891 // Search for virtual calls based on %p and add them to DevirtCalls. 1892 SmallVector<DevirtCallSite, 1> DevirtCalls; 1893 SmallVector<CallInst *, 1> Assumes; 1894 auto &DT = LookupDomTree(*CI->getFunction()); 1895 findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); 1896 1897 Metadata *TypeId = 1898 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 1899 // If we found any, add them to CallSlots. 1900 if (!Assumes.empty()) { 1901 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 1902 for (DevirtCallSite Call : DevirtCalls) 1903 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr); 1904 } 1905 1906 auto RemoveTypeTestAssumes = [&]() { 1907 // We no longer need the assumes or the type test. 1908 for (auto Assume : Assumes) 1909 Assume->eraseFromParent(); 1910 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 1911 // may use the vtable argument later. 1912 if (CI->use_empty()) 1913 CI->eraseFromParent(); 1914 }; 1915 1916 // At this point we could remove all type test assume sequences, as they 1917 // were originally inserted for WPD. However, we can keep these in the 1918 // code stream for later analysis (e.g. to help drive more efficient ICP 1919 // sequences). They will eventually be removed by a second LowerTypeTests 1920 // invocation that cleans them up. In order to do this correctly, the first 1921 // LowerTypeTests invocation needs to know that they have "Unknown" type 1922 // test resolution, so that they aren't treated as Unsat and lowered to 1923 // False, which will break any uses on assumes. Below we remove any type 1924 // test assumes that will not be treated as Unknown by LTT. 1925 1926 // The type test assumes will be treated by LTT as Unsat if the type id is 1927 // not used on a global (in which case it has no entry in the TypeIdMap). 1928 if (!TypeIdMap.count(TypeId)) 1929 RemoveTypeTestAssumes(); 1930 1931 // For ThinLTO importing, we need to remove the type test assumes if this is 1932 // an MDString type id without a corresponding TypeIdSummary. Any 1933 // non-MDString type ids are ignored and treated as Unknown by LTT, so their 1934 // type test assumes can be kept. If the MDString type id is missing a 1935 // TypeIdSummary (e.g. because there was no use on a vcall, preventing the 1936 // exporting phase of WPD from analyzing it), then it would be treated as 1937 // Unsat by LTT and we need to remove its type test assumes here. If not 1938 // used on a vcall we don't need them for later optimization use in any 1939 // case. 1940 else if (ImportSummary && isa<MDString>(TypeId)) { 1941 const TypeIdSummary *TidSummary = 1942 ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString()); 1943 if (!TidSummary) 1944 RemoveTypeTestAssumes(); 1945 else 1946 // If one was created it should not be Unsat, because if we reached here 1947 // the type id was used on a global. 1948 assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat); 1949 } 1950 } 1951 } 1952 1953 void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 1954 Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 1955 1956 for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) { 1957 auto *CI = dyn_cast<CallInst>(U.getUser()); 1958 if (!CI) 1959 continue; 1960 1961 Value *Ptr = CI->getArgOperand(0); 1962 Value *Offset = CI->getArgOperand(1); 1963 Value *TypeIdValue = CI->getArgOperand(2); 1964 Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 1965 1966 SmallVector<DevirtCallSite, 1> DevirtCalls; 1967 SmallVector<Instruction *, 1> LoadedPtrs; 1968 SmallVector<Instruction *, 1> Preds; 1969 bool HasNonCallUses = false; 1970 auto &DT = LookupDomTree(*CI->getFunction()); 1971 findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 1972 HasNonCallUses, CI, DT); 1973 1974 // Start by generating "pessimistic" code that explicitly loads the function 1975 // pointer from the vtable and performs the type check. If possible, we will 1976 // eliminate the load and the type check later. 1977 1978 // If possible, only generate the load at the point where it is used. 1979 // This helps avoid unnecessary spills. 1980 IRBuilder<> LoadB( 1981 (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 1982 Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 1983 Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); 1984 Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); 1985 1986 for (Instruction *LoadedPtr : LoadedPtrs) { 1987 LoadedPtr->replaceAllUsesWith(LoadedValue); 1988 LoadedPtr->eraseFromParent(); 1989 } 1990 1991 // Likewise for the type test. 1992 IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 1993 CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 1994 1995 for (Instruction *Pred : Preds) { 1996 Pred->replaceAllUsesWith(TypeTestCall); 1997 Pred->eraseFromParent(); 1998 } 1999 2000 // We have already erased any extractvalue instructions that refer to the 2001 // intrinsic call, but the intrinsic may have other non-extractvalue uses 2002 // (although this is unlikely). In that case, explicitly build a pair and 2003 // RAUW it. 2004 if (!CI->use_empty()) { 2005 Value *Pair = UndefValue::get(CI->getType()); 2006 IRBuilder<> B(CI); 2007 Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 2008 Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 2009 CI->replaceAllUsesWith(Pair); 2010 } 2011 2012 // The number of unsafe uses is initially the number of uses. 2013 auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 2014 NumUnsafeUses = DevirtCalls.size(); 2015 2016 // If the function pointer has a non-call user, we cannot eliminate the type 2017 // check, as one of those users may eventually call the pointer. Increment 2018 // the unsafe use count to make sure it cannot reach zero. 2019 if (HasNonCallUses) 2020 ++NumUnsafeUses; 2021 for (DevirtCallSite Call : DevirtCalls) { 2022 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, 2023 &NumUnsafeUses); 2024 } 2025 2026 CI->eraseFromParent(); 2027 } 2028 } 2029 2030 void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { 2031 auto *TypeId = dyn_cast<MDString>(Slot.TypeID); 2032 if (!TypeId) 2033 return; 2034 const TypeIdSummary *TidSummary = 2035 ImportSummary->getTypeIdSummary(TypeId->getString()); 2036 if (!TidSummary) 2037 return; 2038 auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); 2039 if (ResI == TidSummary->WPDRes.end()) 2040 return; 2041 const WholeProgramDevirtResolution &Res = ResI->second; 2042 2043 if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { 2044 assert(!Res.SingleImplName.empty()); 2045 // The type of the function in the declaration is irrelevant because every 2046 // call site will cast it to the correct type. 2047 Constant *SingleImpl = 2048 cast<Constant>(M.getOrInsertFunction(Res.SingleImplName, 2049 Type::getVoidTy(M.getContext())) 2050 .getCallee()); 2051 2052 // This is the import phase so we should not be exporting anything. 2053 bool IsExported = false; 2054 applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); 2055 assert(!IsExported); 2056 } 2057 2058 for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { 2059 auto I = Res.ResByArg.find(CSByConstantArg.first); 2060 if (I == Res.ResByArg.end()) 2061 continue; 2062 auto &ResByArg = I->second; 2063 // FIXME: We should figure out what to do about the "function name" argument 2064 // to the apply* functions, as the function names are unavailable during the 2065 // importing phase. For now we just pass the empty string. This does not 2066 // impact correctness because the function names are just used for remarks. 2067 switch (ResByArg.TheKind) { 2068 case WholeProgramDevirtResolution::ByArg::UniformRetVal: 2069 applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); 2070 break; 2071 case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { 2072 Constant *UniqueMemberAddr = 2073 importGlobal(Slot, CSByConstantArg.first, "unique_member"); 2074 applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, 2075 UniqueMemberAddr); 2076 break; 2077 } 2078 case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { 2079 Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte", 2080 Int32Ty, ResByArg.Byte); 2081 Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty, 2082 ResByArg.Bit); 2083 applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); 2084 break; 2085 } 2086 default: 2087 break; 2088 } 2089 } 2090 2091 if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { 2092 // The type of the function is irrelevant, because it's bitcast at calls 2093 // anyhow. 2094 Constant *JT = cast<Constant>( 2095 M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), 2096 Type::getVoidTy(M.getContext())) 2097 .getCallee()); 2098 bool IsExported = false; 2099 applyICallBranchFunnel(SlotInfo, JT, IsExported); 2100 assert(!IsExported); 2101 } 2102 } 2103 2104 void DevirtModule::removeRedundantTypeTests() { 2105 auto True = ConstantInt::getTrue(M.getContext()); 2106 for (auto &&U : NumUnsafeUsesForTypeTest) { 2107 if (U.second == 0) { 2108 U.first->replaceAllUsesWith(True); 2109 U.first->eraseFromParent(); 2110 } 2111 } 2112 } 2113 2114 ValueInfo 2115 DevirtModule::lookUpFunctionValueInfo(Function *TheFn, 2116 ModuleSummaryIndex *ExportSummary) { 2117 assert((ExportSummary != nullptr) && 2118 "Caller guarantees ExportSummary is not nullptr"); 2119 2120 const auto TheFnGUID = TheFn->getGUID(); 2121 const auto TheFnGUIDWithExportedName = GlobalValue::getGUID(TheFn->getName()); 2122 // Look up ValueInfo with the GUID in the current linkage. 2123 ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFnGUID); 2124 // If no entry is found and GUID is different from GUID computed using 2125 // exported name, look up ValueInfo with the exported name unconditionally. 2126 // This is a fallback. 2127 // 2128 // The reason to have a fallback: 2129 // 1. LTO could enable global value internalization via 2130 // `enable-lto-internalization`. 2131 // 2. The GUID in ExportedSummary is computed using exported name. 2132 if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) { 2133 TheFnVI = ExportSummary->getValueInfo(TheFnGUIDWithExportedName); 2134 } 2135 return TheFnVI; 2136 } 2137 2138 bool DevirtModule::mustBeUnreachableFunction( 2139 Function *const F, ModuleSummaryIndex *ExportSummary) { 2140 // First, learn unreachability by analyzing function IR. 2141 if (!F->isDeclaration()) { 2142 // A function must be unreachable if its entry block ends with an 2143 // 'unreachable'. 2144 return isa<UnreachableInst>(F->getEntryBlock().getTerminator()); 2145 } 2146 // Learn unreachability from ExportSummary if ExportSummary is present. 2147 return ExportSummary && 2148 ::mustBeUnreachableFunction( 2149 DevirtModule::lookUpFunctionValueInfo(F, ExportSummary)); 2150 } 2151 2152 bool DevirtModule::run() { 2153 // If only some of the modules were split, we cannot correctly perform 2154 // this transformation. We already checked for the presense of type tests 2155 // with partially split modules during the thin link, and would have emitted 2156 // an error if any were found, so here we can simply return. 2157 if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || 2158 (ImportSummary && ImportSummary->partiallySplitLTOUnits())) 2159 return false; 2160 2161 Function *TypeTestFunc = 2162 M.getFunction(Intrinsic::getName(Intrinsic::type_test)); 2163 Function *TypeCheckedLoadFunc = 2164 M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); 2165 Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 2166 2167 // Normally if there are no users of the devirtualization intrinsics in the 2168 // module, this pass has nothing to do. But if we are exporting, we also need 2169 // to handle any users that appear only in the function summaries. 2170 if (!ExportSummary && 2171 (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 2172 AssumeFunc->use_empty()) && 2173 (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) 2174 return false; 2175 2176 // Rebuild type metadata into a map for easy lookup. 2177 std::vector<VTableBits> Bits; 2178 DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 2179 buildTypeIdentifierMap(Bits, TypeIdMap); 2180 2181 if (TypeTestFunc && AssumeFunc) 2182 scanTypeTestUsers(TypeTestFunc, TypeIdMap); 2183 2184 if (TypeCheckedLoadFunc) 2185 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 2186 2187 if (ImportSummary) { 2188 for (auto &S : CallSlots) 2189 importResolution(S.first, S.second); 2190 2191 removeRedundantTypeTests(); 2192 2193 // We have lowered or deleted the type instrinsics, so we will no 2194 // longer have enough information to reason about the liveness of virtual 2195 // function pointers in GlobalDCE. 2196 for (GlobalVariable &GV : M.globals()) 2197 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2198 2199 // The rest of the code is only necessary when exporting or during regular 2200 // LTO, so we are done. 2201 return true; 2202 } 2203 2204 if (TypeIdMap.empty()) 2205 return true; 2206 2207 // Collect information from summary about which calls to try to devirtualize. 2208 if (ExportSummary) { 2209 DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; 2210 for (auto &P : TypeIdMap) { 2211 if (auto *TypeId = dyn_cast<MDString>(P.first)) 2212 MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( 2213 TypeId); 2214 } 2215 2216 for (auto &P : *ExportSummary) { 2217 for (auto &S : P.second.SummaryList) { 2218 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2219 if (!FS) 2220 continue; 2221 // FIXME: Only add live functions. 2222 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2223 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2224 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2225 } 2226 } 2227 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2228 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2229 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2230 } 2231 } 2232 for (const FunctionSummary::ConstVCall &VC : 2233 FS->type_test_assume_const_vcalls()) { 2234 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2235 CallSlots[{MD, VC.VFunc.Offset}] 2236 .ConstCSInfo[VC.Args] 2237 .addSummaryTypeTestAssumeUser(FS); 2238 } 2239 } 2240 for (const FunctionSummary::ConstVCall &VC : 2241 FS->type_checked_load_const_vcalls()) { 2242 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2243 CallSlots[{MD, VC.VFunc.Offset}] 2244 .ConstCSInfo[VC.Args] 2245 .addSummaryTypeCheckedLoadUser(FS); 2246 } 2247 } 2248 } 2249 } 2250 } 2251 2252 // For each (type, offset) pair: 2253 bool DidVirtualConstProp = false; 2254 std::map<std::string, Function*> DevirtTargets; 2255 for (auto &S : CallSlots) { 2256 // Search each of the members of the type identifier for the virtual 2257 // function implementation at offset S.first.ByteOffset, and add to 2258 // TargetsForSlot. 2259 std::vector<VirtualCallTarget> TargetsForSlot; 2260 WholeProgramDevirtResolution *Res = nullptr; 2261 const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID]; 2262 if (ExportSummary && isa<MDString>(S.first.TypeID) && 2263 TypeMemberInfos.size()) 2264 // For any type id used on a global's type metadata, create the type id 2265 // summary resolution regardless of whether we can devirtualize, so that 2266 // lower type tests knows the type id is not Unsat. If it was not used on 2267 // a global's type metadata, the TypeIdMap entry set will be empty, and 2268 // we don't want to create an entry (with the default Unknown type 2269 // resolution), which can prevent detection of the Unsat. 2270 Res = &ExportSummary 2271 ->getOrInsertTypeIdSummary( 2272 cast<MDString>(S.first.TypeID)->getString()) 2273 .WPDRes[S.first.ByteOffset]; 2274 if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, 2275 S.first.ByteOffset, ExportSummary)) { 2276 2277 if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { 2278 DidVirtualConstProp |= 2279 tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); 2280 2281 tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first); 2282 } 2283 2284 // Collect functions devirtualized at least for one call site for stats. 2285 if (RemarksEnabled) 2286 for (const auto &T : TargetsForSlot) 2287 if (T.WasDevirt) 2288 DevirtTargets[std::string(T.Fn->getName())] = T.Fn; 2289 } 2290 2291 // CFI-specific: if we are exporting and any llvm.type.checked.load 2292 // intrinsics were *not* devirtualized, we need to add the resulting 2293 // llvm.type.test intrinsics to the function summaries so that the 2294 // LowerTypeTests pass will export them. 2295 if (ExportSummary && isa<MDString>(S.first.TypeID)) { 2296 auto GUID = 2297 GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); 2298 for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) 2299 FS->addTypeTest(GUID); 2300 for (auto &CCS : S.second.ConstCSInfo) 2301 for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) 2302 FS->addTypeTest(GUID); 2303 } 2304 } 2305 2306 if (RemarksEnabled) { 2307 // Generate remarks for each devirtualized function. 2308 for (const auto &DT : DevirtTargets) { 2309 Function *F = DT.second; 2310 2311 using namespace ore; 2312 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) 2313 << "devirtualized " 2314 << NV("FunctionName", DT.first)); 2315 } 2316 } 2317 2318 removeRedundantTypeTests(); 2319 2320 // Rebuild each global we touched as part of virtual constant propagation to 2321 // include the before and after bytes. 2322 if (DidVirtualConstProp) 2323 for (VTableBits &B : Bits) 2324 rebuildGlobal(B); 2325 2326 // We have lowered or deleted the type instrinsics, so we will no 2327 // longer have enough information to reason about the liveness of virtual 2328 // function pointers in GlobalDCE. 2329 for (GlobalVariable &GV : M.globals()) 2330 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2331 2332 return true; 2333 } 2334 2335 void DevirtIndex::run() { 2336 if (ExportSummary.typeIdCompatibleVtableMap().empty()) 2337 return; 2338 2339 DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; 2340 for (auto &P : ExportSummary.typeIdCompatibleVtableMap()) { 2341 NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first); 2342 } 2343 2344 // Collect information from summary about which calls to try to devirtualize. 2345 for (auto &P : ExportSummary) { 2346 for (auto &S : P.second.SummaryList) { 2347 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2348 if (!FS) 2349 continue; 2350 // FIXME: Only add live functions. 2351 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2352 for (StringRef Name : NameByGUID[VF.GUID]) { 2353 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2354 } 2355 } 2356 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2357 for (StringRef Name : NameByGUID[VF.GUID]) { 2358 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2359 } 2360 } 2361 for (const FunctionSummary::ConstVCall &VC : 2362 FS->type_test_assume_const_vcalls()) { 2363 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2364 CallSlots[{Name, VC.VFunc.Offset}] 2365 .ConstCSInfo[VC.Args] 2366 .addSummaryTypeTestAssumeUser(FS); 2367 } 2368 } 2369 for (const FunctionSummary::ConstVCall &VC : 2370 FS->type_checked_load_const_vcalls()) { 2371 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2372 CallSlots[{Name, VC.VFunc.Offset}] 2373 .ConstCSInfo[VC.Args] 2374 .addSummaryTypeCheckedLoadUser(FS); 2375 } 2376 } 2377 } 2378 } 2379 2380 std::set<ValueInfo> DevirtTargets; 2381 // For each (type, offset) pair: 2382 for (auto &S : CallSlots) { 2383 // Search each of the members of the type identifier for the virtual 2384 // function implementation at offset S.first.ByteOffset, and add to 2385 // TargetsForSlot. 2386 std::vector<ValueInfo> TargetsForSlot; 2387 auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); 2388 assert(TidSummary); 2389 // Create the type id summary resolution regardlness of whether we can 2390 // devirtualize, so that lower type tests knows the type id is used on 2391 // a global and not Unsat. 2392 WholeProgramDevirtResolution *Res = 2393 &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID) 2394 .WPDRes[S.first.ByteOffset]; 2395 if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, 2396 S.first.ByteOffset)) { 2397 2398 if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res, 2399 DevirtTargets)) 2400 continue; 2401 } 2402 } 2403 2404 // Optionally have the thin link print message for each devirtualized 2405 // function. 2406 if (PrintSummaryDevirt) 2407 for (const auto &DT : DevirtTargets) 2408 errs() << "Devirtualized call to " << DT << "\n"; 2409 } 2410