1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements whole program optimization of virtual calls in cases 10 // where we know (via !type metadata) that the list of callees is fixed. This 11 // includes the following: 12 // - Single implementation devirtualization: if a virtual call has a single 13 // possible callee, replace all calls with a direct call to that callee. 14 // - Virtual constant propagation: if the virtual function's return type is an 15 // integer <=64 bits and all possible callees are readnone, for each class and 16 // each list of constant arguments: evaluate the function, store the return 17 // value alongside the virtual table, and rewrite each virtual call as a load 18 // from the virtual table. 19 // - Uniform return value optimization: if the conditions for virtual constant 20 // propagation hold and each function returns the same constant value, replace 21 // each virtual call with that constant. 22 // - Unique return value optimization for i1 return values: if the conditions 23 // for virtual constant propagation hold and a single vtable's function 24 // returns 0, or a single vtable's function returns 1, replace each virtual 25 // call with a comparison of the vptr against that vtable's address. 26 // 27 // This pass is intended to be used during the regular and thin LTO pipelines: 28 // 29 // During regular LTO, the pass determines the best optimization for each 30 // virtual call and applies the resolutions directly to virtual calls that are 31 // eligible for virtual call optimization (i.e. calls that use either of the 32 // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). 33 // 34 // During hybrid Regular/ThinLTO, the pass operates in two phases: 35 // - Export phase: this is run during the thin link over a single merged module 36 // that contains all vtables with !type metadata that participate in the link. 37 // The pass computes a resolution for each virtual call and stores it in the 38 // type identifier summary. 39 // - Import phase: this is run during the thin backends over the individual 40 // modules. The pass applies the resolutions previously computed during the 41 // import phase to each eligible virtual call. 42 // 43 // During ThinLTO, the pass operates in two phases: 44 // - Export phase: this is run during the thin link over the index which 45 // contains a summary of all vtables with !type metadata that participate in 46 // the link. It computes a resolution for each virtual call and stores it in 47 // the type identifier summary. Only single implementation devirtualization 48 // is supported. 49 // - Import phase: (same as with hybrid case above). 50 // 51 //===----------------------------------------------------------------------===// 52 53 #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 54 #include "llvm/ADT/ArrayRef.h" 55 #include "llvm/ADT/DenseMap.h" 56 #include "llvm/ADT/DenseMapInfo.h" 57 #include "llvm/ADT/DenseSet.h" 58 #include "llvm/ADT/MapVector.h" 59 #include "llvm/ADT/SmallVector.h" 60 #include "llvm/ADT/Statistic.h" 61 #include "llvm/ADT/Triple.h" 62 #include "llvm/ADT/iterator_range.h" 63 #include "llvm/Analysis/AssumptionCache.h" 64 #include "llvm/Analysis/BasicAliasAnalysis.h" 65 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 66 #include "llvm/Analysis/TypeMetadataUtils.h" 67 #include "llvm/Bitcode/BitcodeReader.h" 68 #include "llvm/Bitcode/BitcodeWriter.h" 69 #include "llvm/IR/Constants.h" 70 #include "llvm/IR/DataLayout.h" 71 #include "llvm/IR/DebugLoc.h" 72 #include "llvm/IR/DerivedTypes.h" 73 #include "llvm/IR/Dominators.h" 74 #include "llvm/IR/Function.h" 75 #include "llvm/IR/GlobalAlias.h" 76 #include "llvm/IR/GlobalVariable.h" 77 #include "llvm/IR/IRBuilder.h" 78 #include "llvm/IR/InstrTypes.h" 79 #include "llvm/IR/Instruction.h" 80 #include "llvm/IR/Instructions.h" 81 #include "llvm/IR/Intrinsics.h" 82 #include "llvm/IR/LLVMContext.h" 83 #include "llvm/IR/MDBuilder.h" 84 #include "llvm/IR/Metadata.h" 85 #include "llvm/IR/Module.h" 86 #include "llvm/IR/ModuleSummaryIndexYAML.h" 87 #include "llvm/InitializePasses.h" 88 #include "llvm/Pass.h" 89 #include "llvm/PassRegistry.h" 90 #include "llvm/Support/Casting.h" 91 #include "llvm/Support/CommandLine.h" 92 #include "llvm/Support/Errc.h" 93 #include "llvm/Support/Error.h" 94 #include "llvm/Support/FileSystem.h" 95 #include "llvm/Support/GlobPattern.h" 96 #include "llvm/Support/MathExtras.h" 97 #include "llvm/Transforms/IPO.h" 98 #include "llvm/Transforms/IPO/FunctionAttrs.h" 99 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 100 #include "llvm/Transforms/Utils/CallPromotionUtils.h" 101 #include "llvm/Transforms/Utils/Evaluator.h" 102 #include <algorithm> 103 #include <cstddef> 104 #include <map> 105 #include <set> 106 #include <string> 107 108 using namespace llvm; 109 using namespace wholeprogramdevirt; 110 111 #define DEBUG_TYPE "wholeprogramdevirt" 112 113 STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets"); 114 STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations"); 115 STATISTIC(NumBranchFunnel, "Number of branch funnels"); 116 STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations"); 117 STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations"); 118 STATISTIC(NumVirtConstProp1Bit, 119 "Number of 1 bit virtual constant propagations"); 120 STATISTIC(NumVirtConstProp, "Number of virtual constant propagations"); 121 122 static cl::opt<PassSummaryAction> ClSummaryAction( 123 "wholeprogramdevirt-summary-action", 124 cl::desc("What to do with the summary when running this pass"), 125 cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 126 clEnumValN(PassSummaryAction::Import, "import", 127 "Import typeid resolutions from summary and globals"), 128 clEnumValN(PassSummaryAction::Export, "export", 129 "Export typeid resolutions to summary and globals")), 130 cl::Hidden); 131 132 static cl::opt<std::string> ClReadSummary( 133 "wholeprogramdevirt-read-summary", 134 cl::desc( 135 "Read summary from given bitcode or YAML file before running pass"), 136 cl::Hidden); 137 138 static cl::opt<std::string> ClWriteSummary( 139 "wholeprogramdevirt-write-summary", 140 cl::desc("Write summary to given bitcode or YAML file after running pass. " 141 "Output file format is deduced from extension: *.bc means writing " 142 "bitcode, otherwise YAML"), 143 cl::Hidden); 144 145 static cl::opt<unsigned> 146 ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, 147 cl::init(10), cl::ZeroOrMore, 148 cl::desc("Maximum number of call targets per " 149 "call site to enable branch funnels")); 150 151 static cl::opt<bool> 152 PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden, 153 cl::init(false), cl::ZeroOrMore, 154 cl::desc("Print index-based devirtualization messages")); 155 156 /// Provide a way to force enable whole program visibility in tests. 157 /// This is needed to support legacy tests that don't contain 158 /// !vcall_visibility metadata (the mere presense of type tests 159 /// previously implied hidden visibility). 160 static cl::opt<bool> 161 WholeProgramVisibility("whole-program-visibility", cl::init(false), 162 cl::Hidden, cl::ZeroOrMore, 163 cl::desc("Enable whole program visibility")); 164 165 /// Provide a way to force disable whole program for debugging or workarounds, 166 /// when enabled via the linker. 167 static cl::opt<bool> DisableWholeProgramVisibility( 168 "disable-whole-program-visibility", cl::init(false), cl::Hidden, 169 cl::ZeroOrMore, 170 cl::desc("Disable whole program visibility (overrides enabling options)")); 171 172 /// Provide way to prevent certain function from being devirtualized 173 static cl::list<std::string> 174 SkipFunctionNames("wholeprogramdevirt-skip", 175 cl::desc("Prevent function(s) from being devirtualized"), 176 cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated); 177 178 /// Mechanism to add runtime checking of devirtualization decisions, optionally 179 /// trapping or falling back to indirect call on any that are not correct. 180 /// Trapping mode is useful for debugging undefined behavior leading to failures 181 /// with WPD. Fallback mode is useful for ensuring safety when whole program 182 /// visibility may be compromised. 183 enum WPDCheckMode { None, Trap, Fallback }; 184 static cl::opt<WPDCheckMode> DevirtCheckMode( 185 "wholeprogramdevirt-check", cl::Hidden, cl::ZeroOrMore, 186 cl::desc("Type of checking for incorrect devirtualizations"), 187 cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"), 188 clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"), 189 clEnumValN(WPDCheckMode::Fallback, "fallback", 190 "Fallback to indirect when incorrect"))); 191 192 namespace { 193 struct PatternList { 194 std::vector<GlobPattern> Patterns; 195 template <class T> void init(const T &StringList) { 196 for (const auto &S : StringList) 197 if (Expected<GlobPattern> Pat = GlobPattern::create(S)) 198 Patterns.push_back(std::move(*Pat)); 199 } 200 bool match(StringRef S) { 201 for (const GlobPattern &P : Patterns) 202 if (P.match(S)) 203 return true; 204 return false; 205 } 206 }; 207 } // namespace 208 209 // Find the minimum offset that we may store a value of size Size bits at. If 210 // IsAfter is set, look for an offset before the object, otherwise look for an 211 // offset after the object. 212 uint64_t 213 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 214 bool IsAfter, uint64_t Size) { 215 // Find a minimum offset taking into account only vtable sizes. 216 uint64_t MinByte = 0; 217 for (const VirtualCallTarget &Target : Targets) { 218 if (IsAfter) 219 MinByte = std::max(MinByte, Target.minAfterBytes()); 220 else 221 MinByte = std::max(MinByte, Target.minBeforeBytes()); 222 } 223 224 // Build a vector of arrays of bytes covering, for each target, a slice of the 225 // used region (see AccumBitVector::BytesUsed in 226 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 227 // this aligns the used regions to start at MinByte. 228 // 229 // In this example, A, B and C are vtables, # is a byte already allocated for 230 // a virtual function pointer, AAAA... (etc.) are the used regions for the 231 // vtables and Offset(X) is the value computed for the Offset variable below 232 // for X. 233 // 234 // Offset(A) 235 // | | 236 // |MinByte 237 // A: ################AAAAAAAA|AAAAAAAA 238 // B: ########BBBBBBBBBBBBBBBB|BBBB 239 // C: ########################|CCCCCCCCCCCCCCCC 240 // | Offset(B) | 241 // 242 // This code produces the slices of A, B and C that appear after the divider 243 // at MinByte. 244 std::vector<ArrayRef<uint8_t>> Used; 245 for (const VirtualCallTarget &Target : Targets) { 246 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 247 : Target.TM->Bits->Before.BytesUsed; 248 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 249 : MinByte - Target.minBeforeBytes(); 250 251 // Disregard used regions that are smaller than Offset. These are 252 // effectively all-free regions that do not need to be checked. 253 if (VTUsed.size() > Offset) 254 Used.push_back(VTUsed.slice(Offset)); 255 } 256 257 if (Size == 1) { 258 // Find a free bit in each member of Used. 259 for (unsigned I = 0;; ++I) { 260 uint8_t BitsUsed = 0; 261 for (auto &&B : Used) 262 if (I < B.size()) 263 BitsUsed |= B[I]; 264 if (BitsUsed != 0xff) 265 return (MinByte + I) * 8 + 266 countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 267 } 268 } else { 269 // Find a free (Size/8) byte region in each member of Used. 270 // FIXME: see if alignment helps. 271 for (unsigned I = 0;; ++I) { 272 for (auto &&B : Used) { 273 unsigned Byte = 0; 274 while ((I + Byte) < B.size() && Byte < (Size / 8)) { 275 if (B[I + Byte]) 276 goto NextI; 277 ++Byte; 278 } 279 } 280 return (MinByte + I) * 8; 281 NextI:; 282 } 283 } 284 } 285 286 void wholeprogramdevirt::setBeforeReturnValues( 287 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 288 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 289 if (BitWidth == 1) 290 OffsetByte = -(AllocBefore / 8 + 1); 291 else 292 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 293 OffsetBit = AllocBefore % 8; 294 295 for (VirtualCallTarget &Target : Targets) { 296 if (BitWidth == 1) 297 Target.setBeforeBit(AllocBefore); 298 else 299 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 300 } 301 } 302 303 void wholeprogramdevirt::setAfterReturnValues( 304 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 305 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 306 if (BitWidth == 1) 307 OffsetByte = AllocAfter / 8; 308 else 309 OffsetByte = (AllocAfter + 7) / 8; 310 OffsetBit = AllocAfter % 8; 311 312 for (VirtualCallTarget &Target : Targets) { 313 if (BitWidth == 1) 314 Target.setAfterBit(AllocAfter); 315 else 316 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 317 } 318 } 319 320 VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) 321 : Fn(Fn), TM(TM), 322 IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} 323 324 namespace { 325 326 // A slot in a set of virtual tables. The TypeID identifies the set of virtual 327 // tables, and the ByteOffset is the offset in bytes from the address point to 328 // the virtual function pointer. 329 struct VTableSlot { 330 Metadata *TypeID; 331 uint64_t ByteOffset; 332 }; 333 334 } // end anonymous namespace 335 336 namespace llvm { 337 338 template <> struct DenseMapInfo<VTableSlot> { 339 static VTableSlot getEmptyKey() { 340 return {DenseMapInfo<Metadata *>::getEmptyKey(), 341 DenseMapInfo<uint64_t>::getEmptyKey()}; 342 } 343 static VTableSlot getTombstoneKey() { 344 return {DenseMapInfo<Metadata *>::getTombstoneKey(), 345 DenseMapInfo<uint64_t>::getTombstoneKey()}; 346 } 347 static unsigned getHashValue(const VTableSlot &I) { 348 return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 349 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 350 } 351 static bool isEqual(const VTableSlot &LHS, 352 const VTableSlot &RHS) { 353 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 354 } 355 }; 356 357 template <> struct DenseMapInfo<VTableSlotSummary> { 358 static VTableSlotSummary getEmptyKey() { 359 return {DenseMapInfo<StringRef>::getEmptyKey(), 360 DenseMapInfo<uint64_t>::getEmptyKey()}; 361 } 362 static VTableSlotSummary getTombstoneKey() { 363 return {DenseMapInfo<StringRef>::getTombstoneKey(), 364 DenseMapInfo<uint64_t>::getTombstoneKey()}; 365 } 366 static unsigned getHashValue(const VTableSlotSummary &I) { 367 return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^ 368 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 369 } 370 static bool isEqual(const VTableSlotSummary &LHS, 371 const VTableSlotSummary &RHS) { 372 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 373 } 374 }; 375 376 } // end namespace llvm 377 378 namespace { 379 380 // Returns true if the function must be unreachable based on ValueInfo. 381 // 382 // In particular, identifies a function as unreachable in the following 383 // conditions 384 // 1) All summaries are live. 385 // 2) All function summaries indicate it's unreachable 386 bool mustBeUnreachableFunction(ValueInfo TheFnVI) { 387 if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { 388 // Returns false if ValueInfo is absent, or the summary list is empty 389 // (e.g., function declarations). 390 return false; 391 } 392 393 for (auto &Summary : TheFnVI.getSummaryList()) { 394 // Conservatively returns false if any non-live functions are seen. 395 // In general either all summaries should be live or all should be dead. 396 if (!Summary->isLive()) 397 return false; 398 if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) { 399 if (!FS->fflags().MustBeUnreachable) 400 return false; 401 } 402 // Do nothing if a non-function has the same GUID (which is rare). 403 // This is correct since non-function summaries are not relevant. 404 } 405 // All function summaries are live and all of them agree that the function is 406 // unreachble. 407 return true; 408 } 409 410 // A virtual call site. VTable is the loaded virtual table pointer, and CS is 411 // the indirect virtual call. 412 struct VirtualCallSite { 413 Value *VTable = nullptr; 414 CallBase &CB; 415 416 // If non-null, this field points to the associated unsafe use count stored in 417 // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 418 // of that field for details. 419 unsigned *NumUnsafeUses = nullptr; 420 421 void 422 emitRemark(const StringRef OptName, const StringRef TargetName, 423 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { 424 Function *F = CB.getCaller(); 425 DebugLoc DLoc = CB.getDebugLoc(); 426 BasicBlock *Block = CB.getParent(); 427 428 using namespace ore; 429 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) 430 << NV("Optimization", OptName) 431 << ": devirtualized a call to " 432 << NV("FunctionName", TargetName)); 433 } 434 435 void replaceAndErase( 436 const StringRef OptName, const StringRef TargetName, bool RemarksEnabled, 437 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 438 Value *New) { 439 if (RemarksEnabled) 440 emitRemark(OptName, TargetName, OREGetter); 441 CB.replaceAllUsesWith(New); 442 if (auto *II = dyn_cast<InvokeInst>(&CB)) { 443 BranchInst::Create(II->getNormalDest(), &CB); 444 II->getUnwindDest()->removePredecessor(II->getParent()); 445 } 446 CB.eraseFromParent(); 447 // This use is no longer unsafe. 448 if (NumUnsafeUses) 449 --*NumUnsafeUses; 450 } 451 }; 452 453 // Call site information collected for a specific VTableSlot and possibly a list 454 // of constant integer arguments. The grouping by arguments is handled by the 455 // VTableSlotInfo class. 456 struct CallSiteInfo { 457 /// The set of call sites for this slot. Used during regular LTO and the 458 /// import phase of ThinLTO (as well as the export phase of ThinLTO for any 459 /// call sites that appear in the merged module itself); in each of these 460 /// cases we are directly operating on the call sites at the IR level. 461 std::vector<VirtualCallSite> CallSites; 462 463 /// Whether all call sites represented by this CallSiteInfo, including those 464 /// in summaries, have been devirtualized. This starts off as true because a 465 /// default constructed CallSiteInfo represents no call sites. 466 bool AllCallSitesDevirted = true; 467 468 // These fields are used during the export phase of ThinLTO and reflect 469 // information collected from function summaries. 470 471 /// Whether any function summary contains an llvm.assume(llvm.type.test) for 472 /// this slot. 473 bool SummaryHasTypeTestAssumeUsers = false; 474 475 /// CFI-specific: a vector containing the list of function summaries that use 476 /// the llvm.type.checked.load intrinsic and therefore will require 477 /// resolutions for llvm.type.test in order to implement CFI checks if 478 /// devirtualization was unsuccessful. If devirtualization was successful, the 479 /// pass will clear this vector by calling markDevirt(). If at the end of the 480 /// pass the vector is non-empty, we will need to add a use of llvm.type.test 481 /// to each of the function summaries in the vector. 482 std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; 483 std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers; 484 485 bool isExported() const { 486 return SummaryHasTypeTestAssumeUsers || 487 !SummaryTypeCheckedLoadUsers.empty(); 488 } 489 490 void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { 491 SummaryTypeCheckedLoadUsers.push_back(FS); 492 AllCallSitesDevirted = false; 493 } 494 495 void addSummaryTypeTestAssumeUser(FunctionSummary *FS) { 496 SummaryTypeTestAssumeUsers.push_back(FS); 497 SummaryHasTypeTestAssumeUsers = true; 498 AllCallSitesDevirted = false; 499 } 500 501 void markDevirt() { 502 AllCallSitesDevirted = true; 503 504 // As explained in the comment for SummaryTypeCheckedLoadUsers. 505 SummaryTypeCheckedLoadUsers.clear(); 506 } 507 }; 508 509 // Call site information collected for a specific VTableSlot. 510 struct VTableSlotInfo { 511 // The set of call sites which do not have all constant integer arguments 512 // (excluding "this"). 513 CallSiteInfo CSInfo; 514 515 // The set of call sites with all constant integer arguments (excluding 516 // "this"), grouped by argument list. 517 std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 518 519 void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses); 520 521 private: 522 CallSiteInfo &findCallSiteInfo(CallBase &CB); 523 }; 524 525 CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { 526 std::vector<uint64_t> Args; 527 auto *CBType = dyn_cast<IntegerType>(CB.getType()); 528 if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) 529 return CSInfo; 530 for (auto &&Arg : drop_begin(CB.args())) { 531 auto *CI = dyn_cast<ConstantInt>(Arg); 532 if (!CI || CI->getBitWidth() > 64) 533 return CSInfo; 534 Args.push_back(CI->getZExtValue()); 535 } 536 return ConstCSInfo[Args]; 537 } 538 539 void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, 540 unsigned *NumUnsafeUses) { 541 auto &CSI = findCallSiteInfo(CB); 542 CSI.AllCallSitesDevirted = false; 543 CSI.CallSites.push_back({VTable, CB, NumUnsafeUses}); 544 } 545 546 struct DevirtModule { 547 Module &M; 548 function_ref<AAResults &(Function &)> AARGetter; 549 function_ref<DominatorTree &(Function &)> LookupDomTree; 550 551 ModuleSummaryIndex *ExportSummary; 552 const ModuleSummaryIndex *ImportSummary; 553 554 IntegerType *Int8Ty; 555 PointerType *Int8PtrTy; 556 IntegerType *Int32Ty; 557 IntegerType *Int64Ty; 558 IntegerType *IntPtrTy; 559 /// Sizeless array type, used for imported vtables. This provides a signal 560 /// to analyzers that these imports may alias, as they do for example 561 /// when multiple unique return values occur in the same vtable. 562 ArrayType *Int8Arr0Ty; 563 564 bool RemarksEnabled; 565 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; 566 567 MapVector<VTableSlot, VTableSlotInfo> CallSlots; 568 569 // Calls that have already been optimized. We may add a call to multiple 570 // VTableSlotInfos if vtable loads are coalesced and need to make sure not to 571 // optimize a call more than once. 572 SmallPtrSet<CallBase *, 8> OptimizedCalls; 573 574 // This map keeps track of the number of "unsafe" uses of a loaded function 575 // pointer. The key is the associated llvm.type.test intrinsic call generated 576 // by this pass. An unsafe use is one that calls the loaded function pointer 577 // directly. Every time we eliminate an unsafe use (for example, by 578 // devirtualizing it or by applying virtual constant propagation), we 579 // decrement the value stored in this map. If a value reaches zero, we can 580 // eliminate the type check by RAUWing the associated llvm.type.test call with 581 // true. 582 std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 583 PatternList FunctionsToSkip; 584 585 DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 586 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 587 function_ref<DominatorTree &(Function &)> LookupDomTree, 588 ModuleSummaryIndex *ExportSummary, 589 const ModuleSummaryIndex *ImportSummary) 590 : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), 591 ExportSummary(ExportSummary), ImportSummary(ImportSummary), 592 Int8Ty(Type::getInt8Ty(M.getContext())), 593 Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 594 Int32Ty(Type::getInt32Ty(M.getContext())), 595 Int64Ty(Type::getInt64Ty(M.getContext())), 596 IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), 597 Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)), 598 RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { 599 assert(!(ExportSummary && ImportSummary)); 600 FunctionsToSkip.init(SkipFunctionNames); 601 } 602 603 bool areRemarksEnabled(); 604 605 void 606 scanTypeTestUsers(Function *TypeTestFunc, 607 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 608 void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 609 610 void buildTypeIdentifierMap( 611 std::vector<VTableBits> &Bits, 612 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 613 614 bool 615 tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 616 const std::set<TypeMemberInfo> &TypeMemberInfos, 617 uint64_t ByteOffset, 618 ModuleSummaryIndex *ExportSummary); 619 620 void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, 621 bool &IsExported); 622 bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary, 623 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 624 VTableSlotInfo &SlotInfo, 625 WholeProgramDevirtResolution *Res); 626 627 void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, 628 bool &IsExported); 629 void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 630 VTableSlotInfo &SlotInfo, 631 WholeProgramDevirtResolution *Res, VTableSlot Slot); 632 633 bool tryEvaluateFunctionsWithArgs( 634 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 635 ArrayRef<uint64_t> Args); 636 637 void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 638 uint64_t TheRetVal); 639 bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 640 CallSiteInfo &CSInfo, 641 WholeProgramDevirtResolution::ByArg *Res); 642 643 // Returns the global symbol name that is used to export information about the 644 // given vtable slot and list of arguments. 645 std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, 646 StringRef Name); 647 648 bool shouldExportConstantsAsAbsoluteSymbols(); 649 650 // This function is called during the export phase to create a symbol 651 // definition containing information about the given vtable slot and list of 652 // arguments. 653 void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 654 Constant *C); 655 void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 656 uint32_t Const, uint32_t &Storage); 657 658 // This function is called during the import phase to create a reference to 659 // the symbol definition created during the export phase. 660 Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 661 StringRef Name); 662 Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 663 StringRef Name, IntegerType *IntTy, 664 uint32_t Storage); 665 666 Constant *getMemberAddr(const TypeMemberInfo *M); 667 668 void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 669 Constant *UniqueMemberAddr); 670 bool tryUniqueRetValOpt(unsigned BitWidth, 671 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 672 CallSiteInfo &CSInfo, 673 WholeProgramDevirtResolution::ByArg *Res, 674 VTableSlot Slot, ArrayRef<uint64_t> Args); 675 676 void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 677 Constant *Byte, Constant *Bit); 678 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 679 VTableSlotInfo &SlotInfo, 680 WholeProgramDevirtResolution *Res, VTableSlot Slot); 681 682 void rebuildGlobal(VTableBits &B); 683 684 // Apply the summary resolution for Slot to all virtual calls in SlotInfo. 685 void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); 686 687 // If we were able to eliminate all unsafe uses for a type checked load, 688 // eliminate the associated type tests by replacing them with true. 689 void removeRedundantTypeTests(); 690 691 bool run(); 692 693 // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`. 694 // 695 // Caller guarantees that `ExportSummary` is not nullptr. 696 static ValueInfo lookUpFunctionValueInfo(Function *TheFn, 697 ModuleSummaryIndex *ExportSummary); 698 699 // Returns true if the function definition must be unreachable. 700 // 701 // Note if this helper function returns true, `F` is guaranteed 702 // to be unreachable; if it returns false, `F` might still 703 // be unreachable but not covered by this helper function. 704 // 705 // Implementation-wise, if function definition is present, IR is analyzed; if 706 // not, look up function flags from ExportSummary as a fallback. 707 static bool mustBeUnreachableFunction(Function *const F, 708 ModuleSummaryIndex *ExportSummary); 709 710 // Lower the module using the action and summary passed as command line 711 // arguments. For testing purposes only. 712 static bool 713 runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter, 714 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 715 function_ref<DominatorTree &(Function &)> LookupDomTree); 716 }; 717 718 struct DevirtIndex { 719 ModuleSummaryIndex &ExportSummary; 720 // The set in which to record GUIDs exported from their module by 721 // devirtualization, used by client to ensure they are not internalized. 722 std::set<GlobalValue::GUID> &ExportedGUIDs; 723 // A map in which to record the information necessary to locate the WPD 724 // resolution for local targets in case they are exported by cross module 725 // importing. 726 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap; 727 728 MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots; 729 730 PatternList FunctionsToSkip; 731 732 DevirtIndex( 733 ModuleSummaryIndex &ExportSummary, 734 std::set<GlobalValue::GUID> &ExportedGUIDs, 735 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) 736 : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs), 737 LocalWPDTargetsMap(LocalWPDTargetsMap) { 738 FunctionsToSkip.init(SkipFunctionNames); 739 } 740 741 bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot, 742 const TypeIdCompatibleVtableInfo TIdInfo, 743 uint64_t ByteOffset); 744 745 bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 746 VTableSlotSummary &SlotSummary, 747 VTableSlotInfo &SlotInfo, 748 WholeProgramDevirtResolution *Res, 749 std::set<ValueInfo> &DevirtTargets); 750 751 void run(); 752 }; 753 754 struct WholeProgramDevirt : public ModulePass { 755 static char ID; 756 757 bool UseCommandLine = false; 758 759 ModuleSummaryIndex *ExportSummary = nullptr; 760 const ModuleSummaryIndex *ImportSummary = nullptr; 761 762 WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { 763 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 764 } 765 766 WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, 767 const ModuleSummaryIndex *ImportSummary) 768 : ModulePass(ID), ExportSummary(ExportSummary), 769 ImportSummary(ImportSummary) { 770 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 771 } 772 773 bool runOnModule(Module &M) override { 774 if (skipModule(M)) 775 return false; 776 777 // In the new pass manager, we can request the optimization 778 // remark emitter pass on a per-function-basis, which the 779 // OREGetter will do for us. 780 // In the old pass manager, this is harder, so we just build 781 // an optimization remark emitter on the fly, when we need it. 782 std::unique_ptr<OptimizationRemarkEmitter> ORE; 783 auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 784 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 785 return *ORE; 786 }; 787 788 auto LookupDomTree = [this](Function &F) -> DominatorTree & { 789 return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); 790 }; 791 792 if (UseCommandLine) 793 return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter, 794 LookupDomTree); 795 796 return DevirtModule(M, LegacyAARGetter(*this), OREGetter, LookupDomTree, 797 ExportSummary, ImportSummary) 798 .run(); 799 } 800 801 void getAnalysisUsage(AnalysisUsage &AU) const override { 802 AU.addRequired<AssumptionCacheTracker>(); 803 AU.addRequired<TargetLibraryInfoWrapperPass>(); 804 AU.addRequired<DominatorTreeWrapperPass>(); 805 } 806 }; 807 808 } // end anonymous namespace 809 810 INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", 811 "Whole program devirtualization", false, false) 812 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 813 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 814 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 815 INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", 816 "Whole program devirtualization", false, false) 817 char WholeProgramDevirt::ID = 0; 818 819 ModulePass * 820 llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, 821 const ModuleSummaryIndex *ImportSummary) { 822 return new WholeProgramDevirt(ExportSummary, ImportSummary); 823 } 824 825 PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 826 ModuleAnalysisManager &AM) { 827 auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 828 auto AARGetter = [&](Function &F) -> AAResults & { 829 return FAM.getResult<AAManager>(F); 830 }; 831 auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 832 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 833 }; 834 auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { 835 return FAM.getResult<DominatorTreeAnalysis>(F); 836 }; 837 if (UseCommandLine) { 838 if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) 839 return PreservedAnalyses::all(); 840 return PreservedAnalyses::none(); 841 } 842 if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, 843 ImportSummary) 844 .run()) 845 return PreservedAnalyses::all(); 846 return PreservedAnalyses::none(); 847 } 848 849 // Enable whole program visibility if enabled by client (e.g. linker) or 850 // internal option, and not force disabled. 851 static bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { 852 return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && 853 !DisableWholeProgramVisibility; 854 } 855 856 namespace llvm { 857 858 /// If whole program visibility asserted, then upgrade all public vcall 859 /// visibility metadata on vtable definitions to linkage unit visibility in 860 /// Module IR (for regular or hybrid LTO). 861 void updateVCallVisibilityInModule( 862 Module &M, bool WholeProgramVisibilityEnabledInLTO, 863 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { 864 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 865 return; 866 for (GlobalVariable &GV : M.globals()) 867 // Add linkage unit visibility to any variable with type metadata, which are 868 // the vtable definitions. We won't have an existing vcall_visibility 869 // metadata on vtable definitions with public visibility. 870 if (GV.hasMetadata(LLVMContext::MD_type) && 871 GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && 872 // Don't upgrade the visibility for symbols exported to the dynamic 873 // linker, as we have no information on their eventual use. 874 !DynamicExportSymbols.count(GV.getGUID())) 875 GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); 876 } 877 878 /// If whole program visibility asserted, then upgrade all public vcall 879 /// visibility metadata on vtable definition summaries to linkage unit 880 /// visibility in Module summary index (for ThinLTO). 881 void updateVCallVisibilityInIndex( 882 ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, 883 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { 884 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 885 return; 886 for (auto &P : Index) { 887 // Don't upgrade the visibility for symbols exported to the dynamic 888 // linker, as we have no information on their eventual use. 889 if (DynamicExportSymbols.count(P.first)) 890 continue; 891 for (auto &S : P.second.SummaryList) { 892 auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); 893 if (!GVar || 894 GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) 895 continue; 896 GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); 897 } 898 } 899 } 900 901 void runWholeProgramDevirtOnIndex( 902 ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, 903 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 904 DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); 905 } 906 907 void updateIndexWPDForExports( 908 ModuleSummaryIndex &Summary, 909 function_ref<bool(StringRef, ValueInfo)> isExported, 910 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 911 for (auto &T : LocalWPDTargetsMap) { 912 auto &VI = T.first; 913 // This was enforced earlier during trySingleImplDevirt. 914 assert(VI.getSummaryList().size() == 1 && 915 "Devirt of local target has more than one copy"); 916 auto &S = VI.getSummaryList()[0]; 917 if (!isExported(S->modulePath(), VI)) 918 continue; 919 920 // It's been exported by a cross module import. 921 for (auto &SlotSummary : T.second) { 922 auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID); 923 assert(TIdSum); 924 auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset); 925 assert(WPDRes != TIdSum->WPDRes.end()); 926 WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 927 WPDRes->second.SingleImplName, 928 Summary.getModuleHash(S->modulePath())); 929 } 930 } 931 } 932 933 } // end namespace llvm 934 935 static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { 936 // Check that summary index contains regular LTO module when performing 937 // export to prevent occasional use of index from pure ThinLTO compilation 938 // (-fno-split-lto-module). This kind of summary index is passed to 939 // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. 940 const auto &ModPaths = Summary->modulePaths(); 941 if (ClSummaryAction != PassSummaryAction::Import && 942 ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) == 943 ModPaths.end()) 944 return createStringError( 945 errc::invalid_argument, 946 "combined summary should contain Regular LTO module"); 947 return ErrorSuccess(); 948 } 949 950 bool DevirtModule::runForTesting( 951 Module &M, function_ref<AAResults &(Function &)> AARGetter, 952 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 953 function_ref<DominatorTree &(Function &)> LookupDomTree) { 954 std::unique_ptr<ModuleSummaryIndex> Summary = 955 std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false); 956 957 // Handle the command-line summary arguments. This code is for testing 958 // purposes only, so we handle errors directly. 959 if (!ClReadSummary.empty()) { 960 ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 961 ": "); 962 auto ReadSummaryFile = 963 ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 964 if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr = 965 getModuleSummaryIndex(*ReadSummaryFile)) { 966 Summary = std::move(*SummaryOrErr); 967 ExitOnErr(checkCombinedSummaryForTesting(Summary.get())); 968 } else { 969 // Try YAML if we've failed with bitcode. 970 consumeError(SummaryOrErr.takeError()); 971 yaml::Input In(ReadSummaryFile->getBuffer()); 972 In >> *Summary; 973 ExitOnErr(errorCodeToError(In.error())); 974 } 975 } 976 977 bool Changed = 978 DevirtModule(M, AARGetter, OREGetter, LookupDomTree, 979 ClSummaryAction == PassSummaryAction::Export ? Summary.get() 980 : nullptr, 981 ClSummaryAction == PassSummaryAction::Import ? Summary.get() 982 : nullptr) 983 .run(); 984 985 if (!ClWriteSummary.empty()) { 986 ExitOnError ExitOnErr( 987 "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 988 std::error_code EC; 989 if (StringRef(ClWriteSummary).endswith(".bc")) { 990 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); 991 ExitOnErr(errorCodeToError(EC)); 992 writeIndexToFile(*Summary, OS); 993 } else { 994 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF); 995 ExitOnErr(errorCodeToError(EC)); 996 yaml::Output Out(OS); 997 Out << *Summary; 998 } 999 } 1000 1001 return Changed; 1002 } 1003 1004 void DevirtModule::buildTypeIdentifierMap( 1005 std::vector<VTableBits> &Bits, 1006 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 1007 DenseMap<GlobalVariable *, VTableBits *> GVToBits; 1008 Bits.reserve(M.getGlobalList().size()); 1009 SmallVector<MDNode *, 2> Types; 1010 for (GlobalVariable &GV : M.globals()) { 1011 Types.clear(); 1012 GV.getMetadata(LLVMContext::MD_type, Types); 1013 if (GV.isDeclaration() || Types.empty()) 1014 continue; 1015 1016 VTableBits *&BitsPtr = GVToBits[&GV]; 1017 if (!BitsPtr) { 1018 Bits.emplace_back(); 1019 Bits.back().GV = &GV; 1020 Bits.back().ObjectSize = 1021 M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 1022 BitsPtr = &Bits.back(); 1023 } 1024 1025 for (MDNode *Type : Types) { 1026 auto TypeID = Type->getOperand(1).get(); 1027 1028 uint64_t Offset = 1029 cast<ConstantInt>( 1030 cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 1031 ->getZExtValue(); 1032 1033 TypeIdMap[TypeID].insert({BitsPtr, Offset}); 1034 } 1035 } 1036 } 1037 1038 bool DevirtModule::tryFindVirtualCallTargets( 1039 std::vector<VirtualCallTarget> &TargetsForSlot, 1040 const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset, 1041 ModuleSummaryIndex *ExportSummary) { 1042 for (const TypeMemberInfo &TM : TypeMemberInfos) { 1043 if (!TM.Bits->GV->isConstant()) 1044 return false; 1045 1046 // We cannot perform whole program devirtualization analysis on a vtable 1047 // with public LTO visibility. 1048 if (TM.Bits->GV->getVCallVisibility() == 1049 GlobalObject::VCallVisibilityPublic) 1050 return false; 1051 1052 Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), 1053 TM.Offset + ByteOffset, M); 1054 if (!Ptr) 1055 return false; 1056 1057 auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); 1058 if (!Fn) 1059 return false; 1060 1061 if (FunctionsToSkip.match(Fn->getName())) 1062 return false; 1063 1064 // We can disregard __cxa_pure_virtual as a possible call target, as 1065 // calls to pure virtuals are UB. 1066 if (Fn->getName() == "__cxa_pure_virtual") 1067 continue; 1068 1069 // We can disregard unreachable functions as possible call targets, as 1070 // unreachable functions shouldn't be called. 1071 if (mustBeUnreachableFunction(Fn, ExportSummary)) 1072 continue; 1073 1074 TargetsForSlot.push_back({Fn, &TM}); 1075 } 1076 1077 // Give up if we couldn't find any targets. 1078 return !TargetsForSlot.empty(); 1079 } 1080 1081 bool DevirtIndex::tryFindVirtualCallTargets( 1082 std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, 1083 uint64_t ByteOffset) { 1084 for (const TypeIdOffsetVtableInfo &P : TIdInfo) { 1085 // Find a representative copy of the vtable initializer. 1086 // We can have multiple available_externally, linkonce_odr and weak_odr 1087 // vtable initializers. We can also have multiple external vtable 1088 // initializers in the case of comdats, which we cannot check here. 1089 // The linker should give an error in this case. 1090 // 1091 // Also, handle the case of same-named local Vtables with the same path 1092 // and therefore the same GUID. This can happen if there isn't enough 1093 // distinguishing path when compiling the source file. In that case we 1094 // conservatively return false early. 1095 const GlobalVarSummary *VS = nullptr; 1096 bool LocalFound = false; 1097 for (auto &S : P.VTableVI.getSummaryList()) { 1098 if (GlobalValue::isLocalLinkage(S->linkage())) { 1099 if (LocalFound) 1100 return false; 1101 LocalFound = true; 1102 } 1103 auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); 1104 if (!CurVS->vTableFuncs().empty() || 1105 // Previously clang did not attach the necessary type metadata to 1106 // available_externally vtables, in which case there would not 1107 // be any vtable functions listed in the summary and we need 1108 // to treat this case conservatively (in case the bitcode is old). 1109 // However, we will also not have any vtable functions in the 1110 // case of a pure virtual base class. In that case we do want 1111 // to set VS to avoid treating it conservatively. 1112 !GlobalValue::isAvailableExternallyLinkage(S->linkage())) { 1113 VS = CurVS; 1114 // We cannot perform whole program devirtualization analysis on a vtable 1115 // with public LTO visibility. 1116 if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) 1117 return false; 1118 } 1119 } 1120 // There will be no VS if all copies are available_externally having no 1121 // type metadata. In that case we can't safely perform WPD. 1122 if (!VS) 1123 return false; 1124 if (!VS->isLive()) 1125 continue; 1126 for (auto VTP : VS->vTableFuncs()) { 1127 if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset) 1128 continue; 1129 1130 if (mustBeUnreachableFunction(VTP.FuncVI)) 1131 continue; 1132 1133 TargetsForSlot.push_back(VTP.FuncVI); 1134 } 1135 } 1136 1137 // Give up if we couldn't find any targets. 1138 return !TargetsForSlot.empty(); 1139 } 1140 1141 void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 1142 Constant *TheFn, bool &IsExported) { 1143 // Don't devirtualize function if we're told to skip it 1144 // in -wholeprogramdevirt-skip. 1145 if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName())) 1146 return; 1147 auto Apply = [&](CallSiteInfo &CSInfo) { 1148 for (auto &&VCallSite : CSInfo.CallSites) { 1149 if (!OptimizedCalls.insert(&VCallSite.CB).second) 1150 continue; 1151 1152 if (RemarksEnabled) 1153 VCallSite.emitRemark("single-impl", 1154 TheFn->stripPointerCasts()->getName(), OREGetter); 1155 NumSingleImpl++; 1156 auto &CB = VCallSite.CB; 1157 assert(!CB.getCalledFunction() && "devirtualizing direct call?"); 1158 IRBuilder<> Builder(&CB); 1159 Value *Callee = 1160 Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType()); 1161 1162 // If trap checking is enabled, add support to compare the virtual 1163 // function pointer to the devirtualized target. In case of a mismatch, 1164 // perform a debug trap. 1165 if (DevirtCheckMode == WPDCheckMode::Trap) { 1166 auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee); 1167 Instruction *ThenTerm = 1168 SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false); 1169 Builder.SetInsertPoint(ThenTerm); 1170 Function *TrapFn = Intrinsic::getDeclaration(&M, Intrinsic::debugtrap); 1171 auto *CallTrap = Builder.CreateCall(TrapFn); 1172 CallTrap->setDebugLoc(CB.getDebugLoc()); 1173 } 1174 1175 // If fallback checking is enabled, add support to compare the virtual 1176 // function pointer to the devirtualized target. In case of a mismatch, 1177 // fall back to indirect call. 1178 if (DevirtCheckMode == WPDCheckMode::Fallback) { 1179 MDNode *Weights = 1180 MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); 1181 // Version the indirect call site. If the called value is equal to the 1182 // given callee, 'NewInst' will be executed, otherwise the original call 1183 // site will be executed. 1184 CallBase &NewInst = versionCallSite(CB, Callee, Weights); 1185 NewInst.setCalledOperand(Callee); 1186 // Since the new call site is direct, we must clear metadata that 1187 // is only appropriate for indirect calls. This includes !prof and 1188 // !callees metadata. 1189 NewInst.setMetadata(LLVMContext::MD_prof, nullptr); 1190 NewInst.setMetadata(LLVMContext::MD_callees, nullptr); 1191 // Additionally, we should remove them from the fallback indirect call, 1192 // so that we don't attempt to perform indirect call promotion later. 1193 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1194 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1195 } 1196 1197 // In either trapping or non-checking mode, devirtualize original call. 1198 else { 1199 // Devirtualize unconditionally. 1200 CB.setCalledOperand(Callee); 1201 // Since the call site is now direct, we must clear metadata that 1202 // is only appropriate for indirect calls. This includes !prof and 1203 // !callees metadata. 1204 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1205 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1206 } 1207 1208 // This use is no longer unsafe. 1209 if (VCallSite.NumUnsafeUses) 1210 --*VCallSite.NumUnsafeUses; 1211 } 1212 if (CSInfo.isExported()) 1213 IsExported = true; 1214 CSInfo.markDevirt(); 1215 }; 1216 Apply(SlotInfo.CSInfo); 1217 for (auto &P : SlotInfo.ConstCSInfo) 1218 Apply(P.second); 1219 } 1220 1221 static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { 1222 // We can't add calls if we haven't seen a definition 1223 if (Callee.getSummaryList().empty()) 1224 return false; 1225 1226 // Insert calls into the summary index so that the devirtualized targets 1227 // are eligible for import. 1228 // FIXME: Annotate type tests with hotness. For now, mark these as hot 1229 // to better ensure we have the opportunity to inline them. 1230 bool IsExported = false; 1231 auto &S = Callee.getSummaryList()[0]; 1232 CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); 1233 auto AddCalls = [&](CallSiteInfo &CSInfo) { 1234 for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { 1235 FS->addCall({Callee, CI}); 1236 IsExported |= S->modulePath() != FS->modulePath(); 1237 } 1238 for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) { 1239 FS->addCall({Callee, CI}); 1240 IsExported |= S->modulePath() != FS->modulePath(); 1241 } 1242 }; 1243 AddCalls(SlotInfo.CSInfo); 1244 for (auto &P : SlotInfo.ConstCSInfo) 1245 AddCalls(P.second); 1246 return IsExported; 1247 } 1248 1249 bool DevirtModule::trySingleImplDevirt( 1250 ModuleSummaryIndex *ExportSummary, 1251 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1252 WholeProgramDevirtResolution *Res) { 1253 // See if the program contains a single implementation of this virtual 1254 // function. 1255 Function *TheFn = TargetsForSlot[0].Fn; 1256 for (auto &&Target : TargetsForSlot) 1257 if (TheFn != Target.Fn) 1258 return false; 1259 1260 // If so, update each call site to call that implementation directly. 1261 if (RemarksEnabled || AreStatisticsEnabled()) 1262 TargetsForSlot[0].WasDevirt = true; 1263 1264 bool IsExported = false; 1265 applySingleImplDevirt(SlotInfo, TheFn, IsExported); 1266 if (!IsExported) 1267 return false; 1268 1269 // If the only implementation has local linkage, we must promote to external 1270 // to make it visible to thin LTO objects. We can only get here during the 1271 // ThinLTO export phase. 1272 if (TheFn->hasLocalLinkage()) { 1273 std::string NewName = (TheFn->getName() + ".llvm.merged").str(); 1274 1275 // Since we are renaming the function, any comdats with the same name must 1276 // also be renamed. This is required when targeting COFF, as the comdat name 1277 // must match one of the names of the symbols in the comdat. 1278 if (Comdat *C = TheFn->getComdat()) { 1279 if (C->getName() == TheFn->getName()) { 1280 Comdat *NewC = M.getOrInsertComdat(NewName); 1281 NewC->setSelectionKind(C->getSelectionKind()); 1282 for (GlobalObject &GO : M.global_objects()) 1283 if (GO.getComdat() == C) 1284 GO.setComdat(NewC); 1285 } 1286 } 1287 1288 TheFn->setLinkage(GlobalValue::ExternalLinkage); 1289 TheFn->setVisibility(GlobalValue::HiddenVisibility); 1290 TheFn->setName(NewName); 1291 } 1292 if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID())) 1293 // Any needed promotion of 'TheFn' has already been done during 1294 // LTO unit split, so we can ignore return value of AddCalls. 1295 AddCalls(SlotInfo, TheFnVI); 1296 1297 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1298 Res->SingleImplName = std::string(TheFn->getName()); 1299 1300 return true; 1301 } 1302 1303 bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 1304 VTableSlotSummary &SlotSummary, 1305 VTableSlotInfo &SlotInfo, 1306 WholeProgramDevirtResolution *Res, 1307 std::set<ValueInfo> &DevirtTargets) { 1308 // See if the program contains a single implementation of this virtual 1309 // function. 1310 auto TheFn = TargetsForSlot[0]; 1311 for (auto &&Target : TargetsForSlot) 1312 if (TheFn != Target) 1313 return false; 1314 1315 // Don't devirtualize if we don't have target definition. 1316 auto Size = TheFn.getSummaryList().size(); 1317 if (!Size) 1318 return false; 1319 1320 // Don't devirtualize function if we're told to skip it 1321 // in -wholeprogramdevirt-skip. 1322 if (FunctionsToSkip.match(TheFn.name())) 1323 return false; 1324 1325 // If the summary list contains multiple summaries where at least one is 1326 // a local, give up, as we won't know which (possibly promoted) name to use. 1327 for (auto &S : TheFn.getSummaryList()) 1328 if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) 1329 return false; 1330 1331 // Collect functions devirtualized at least for one call site for stats. 1332 if (PrintSummaryDevirt || AreStatisticsEnabled()) 1333 DevirtTargets.insert(TheFn); 1334 1335 auto &S = TheFn.getSummaryList()[0]; 1336 bool IsExported = AddCalls(SlotInfo, TheFn); 1337 if (IsExported) 1338 ExportedGUIDs.insert(TheFn.getGUID()); 1339 1340 // Record in summary for use in devirtualization during the ThinLTO import 1341 // step. 1342 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1343 if (GlobalValue::isLocalLinkage(S->linkage())) { 1344 if (IsExported) 1345 // If target is a local function and we are exporting it by 1346 // devirtualizing a call in another module, we need to record the 1347 // promoted name. 1348 Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 1349 TheFn.name(), ExportSummary.getModuleHash(S->modulePath())); 1350 else { 1351 LocalWPDTargetsMap[TheFn].push_back(SlotSummary); 1352 Res->SingleImplName = std::string(TheFn.name()); 1353 } 1354 } else 1355 Res->SingleImplName = std::string(TheFn.name()); 1356 1357 // Name will be empty if this thin link driven off of serialized combined 1358 // index (e.g. llvm-lto). However, WPD is not supported/invoked for the 1359 // legacy LTO API anyway. 1360 assert(!Res->SingleImplName.empty()); 1361 1362 return true; 1363 } 1364 1365 void DevirtModule::tryICallBranchFunnel( 1366 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1367 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1368 Triple T(M.getTargetTriple()); 1369 if (T.getArch() != Triple::x86_64) 1370 return; 1371 1372 if (TargetsForSlot.size() > ClThreshold) 1373 return; 1374 1375 bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; 1376 if (!HasNonDevirt) 1377 for (auto &P : SlotInfo.ConstCSInfo) 1378 if (!P.second.AllCallSitesDevirted) { 1379 HasNonDevirt = true; 1380 break; 1381 } 1382 1383 if (!HasNonDevirt) 1384 return; 1385 1386 FunctionType *FT = 1387 FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true); 1388 Function *JT; 1389 if (isa<MDString>(Slot.TypeID)) { 1390 JT = Function::Create(FT, Function::ExternalLinkage, 1391 M.getDataLayout().getProgramAddressSpace(), 1392 getGlobalName(Slot, {}, "branch_funnel"), &M); 1393 JT->setVisibility(GlobalValue::HiddenVisibility); 1394 } else { 1395 JT = Function::Create(FT, Function::InternalLinkage, 1396 M.getDataLayout().getProgramAddressSpace(), 1397 "branch_funnel", &M); 1398 } 1399 JT->addParamAttr(0, Attribute::Nest); 1400 1401 std::vector<Value *> JTArgs; 1402 JTArgs.push_back(JT->arg_begin()); 1403 for (auto &T : TargetsForSlot) { 1404 JTArgs.push_back(getMemberAddr(T.TM)); 1405 JTArgs.push_back(T.Fn); 1406 } 1407 1408 BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); 1409 Function *Intr = 1410 Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {}); 1411 1412 auto *CI = CallInst::Create(Intr, JTArgs, "", BB); 1413 CI->setTailCallKind(CallInst::TCK_MustTail); 1414 ReturnInst::Create(M.getContext(), nullptr, BB); 1415 1416 bool IsExported = false; 1417 applyICallBranchFunnel(SlotInfo, JT, IsExported); 1418 if (IsExported) 1419 Res->TheKind = WholeProgramDevirtResolution::BranchFunnel; 1420 } 1421 1422 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, 1423 Constant *JT, bool &IsExported) { 1424 auto Apply = [&](CallSiteInfo &CSInfo) { 1425 if (CSInfo.isExported()) 1426 IsExported = true; 1427 if (CSInfo.AllCallSitesDevirted) 1428 return; 1429 for (auto &&VCallSite : CSInfo.CallSites) { 1430 CallBase &CB = VCallSite.CB; 1431 1432 // Jump tables are only profitable if the retpoline mitigation is enabled. 1433 Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); 1434 if (!FSAttr.isValid() || 1435 !FSAttr.getValueAsString().contains("+retpoline")) 1436 continue; 1437 1438 NumBranchFunnel++; 1439 if (RemarksEnabled) 1440 VCallSite.emitRemark("branch-funnel", 1441 JT->stripPointerCasts()->getName(), OREGetter); 1442 1443 // Pass the address of the vtable in the nest register, which is r10 on 1444 // x86_64. 1445 std::vector<Type *> NewArgs; 1446 NewArgs.push_back(Int8PtrTy); 1447 append_range(NewArgs, CB.getFunctionType()->params()); 1448 FunctionType *NewFT = 1449 FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, 1450 CB.getFunctionType()->isVarArg()); 1451 PointerType *NewFTPtr = PointerType::getUnqual(NewFT); 1452 1453 IRBuilder<> IRB(&CB); 1454 std::vector<Value *> Args; 1455 Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); 1456 llvm::append_range(Args, CB.args()); 1457 1458 CallBase *NewCS = nullptr; 1459 if (isa<CallInst>(CB)) 1460 NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args); 1461 else 1462 NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr), 1463 cast<InvokeInst>(CB).getNormalDest(), 1464 cast<InvokeInst>(CB).getUnwindDest(), Args); 1465 NewCS->setCallingConv(CB.getCallingConv()); 1466 1467 AttributeList Attrs = CB.getAttributes(); 1468 std::vector<AttributeSet> NewArgAttrs; 1469 NewArgAttrs.push_back(AttributeSet::get( 1470 M.getContext(), ArrayRef<Attribute>{Attribute::get( 1471 M.getContext(), Attribute::Nest)})); 1472 for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) 1473 NewArgAttrs.push_back(Attrs.getParamAttrs(I)); 1474 NewCS->setAttributes( 1475 AttributeList::get(M.getContext(), Attrs.getFnAttrs(), 1476 Attrs.getRetAttrs(), NewArgAttrs)); 1477 1478 CB.replaceAllUsesWith(NewCS); 1479 CB.eraseFromParent(); 1480 1481 // This use is no longer unsafe. 1482 if (VCallSite.NumUnsafeUses) 1483 --*VCallSite.NumUnsafeUses; 1484 } 1485 // Don't mark as devirtualized because there may be callers compiled without 1486 // retpoline mitigation, which would mean that they are lowered to 1487 // llvm.type.test and therefore require an llvm.type.test resolution for the 1488 // type identifier. 1489 }; 1490 Apply(SlotInfo.CSInfo); 1491 for (auto &P : SlotInfo.ConstCSInfo) 1492 Apply(P.second); 1493 } 1494 1495 bool DevirtModule::tryEvaluateFunctionsWithArgs( 1496 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1497 ArrayRef<uint64_t> Args) { 1498 // Evaluate each function and store the result in each target's RetVal 1499 // field. 1500 for (VirtualCallTarget &Target : TargetsForSlot) { 1501 if (Target.Fn->arg_size() != Args.size() + 1) 1502 return false; 1503 1504 Evaluator Eval(M.getDataLayout(), nullptr); 1505 SmallVector<Constant *, 2> EvalArgs; 1506 EvalArgs.push_back( 1507 Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 1508 for (unsigned I = 0; I != Args.size(); ++I) { 1509 auto *ArgTy = dyn_cast<IntegerType>( 1510 Target.Fn->getFunctionType()->getParamType(I + 1)); 1511 if (!ArgTy) 1512 return false; 1513 EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 1514 } 1515 1516 Constant *RetVal; 1517 if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 1518 !isa<ConstantInt>(RetVal)) 1519 return false; 1520 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 1521 } 1522 return true; 1523 } 1524 1525 void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1526 uint64_t TheRetVal) { 1527 for (auto Call : CSInfo.CallSites) { 1528 if (!OptimizedCalls.insert(&Call.CB).second) 1529 continue; 1530 NumUniformRetVal++; 1531 Call.replaceAndErase( 1532 "uniform-ret-val", FnName, RemarksEnabled, OREGetter, 1533 ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal)); 1534 } 1535 CSInfo.markDevirt(); 1536 } 1537 1538 bool DevirtModule::tryUniformRetValOpt( 1539 MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, 1540 WholeProgramDevirtResolution::ByArg *Res) { 1541 // Uniform return value optimization. If all functions return the same 1542 // constant, replace all calls with that constant. 1543 uint64_t TheRetVal = TargetsForSlot[0].RetVal; 1544 for (const VirtualCallTarget &Target : TargetsForSlot) 1545 if (Target.RetVal != TheRetVal) 1546 return false; 1547 1548 if (CSInfo.isExported()) { 1549 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; 1550 Res->Info = TheRetVal; 1551 } 1552 1553 applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 1554 if (RemarksEnabled || AreStatisticsEnabled()) 1555 for (auto &&Target : TargetsForSlot) 1556 Target.WasDevirt = true; 1557 return true; 1558 } 1559 1560 std::string DevirtModule::getGlobalName(VTableSlot Slot, 1561 ArrayRef<uint64_t> Args, 1562 StringRef Name) { 1563 std::string FullName = "__typeid_"; 1564 raw_string_ostream OS(FullName); 1565 OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; 1566 for (uint64_t Arg : Args) 1567 OS << '_' << Arg; 1568 OS << '_' << Name; 1569 return OS.str(); 1570 } 1571 1572 bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { 1573 Triple T(M.getTargetTriple()); 1574 return T.isX86() && T.getObjectFormat() == Triple::ELF; 1575 } 1576 1577 void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1578 StringRef Name, Constant *C) { 1579 GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, 1580 getGlobalName(Slot, Args, Name), C, &M); 1581 GA->setVisibility(GlobalValue::HiddenVisibility); 1582 } 1583 1584 void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1585 StringRef Name, uint32_t Const, 1586 uint32_t &Storage) { 1587 if (shouldExportConstantsAsAbsoluteSymbols()) { 1588 exportGlobal( 1589 Slot, Args, Name, 1590 ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy)); 1591 return; 1592 } 1593 1594 Storage = Const; 1595 } 1596 1597 Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1598 StringRef Name) { 1599 Constant *C = 1600 M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty); 1601 auto *GV = dyn_cast<GlobalVariable>(C); 1602 if (GV) 1603 GV->setVisibility(GlobalValue::HiddenVisibility); 1604 return C; 1605 } 1606 1607 Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1608 StringRef Name, IntegerType *IntTy, 1609 uint32_t Storage) { 1610 if (!shouldExportConstantsAsAbsoluteSymbols()) 1611 return ConstantInt::get(IntTy, Storage); 1612 1613 Constant *C = importGlobal(Slot, Args, Name); 1614 auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); 1615 C = ConstantExpr::getPtrToInt(C, IntTy); 1616 1617 // We only need to set metadata if the global is newly created, in which 1618 // case it would not have hidden visibility. 1619 if (GV->hasMetadata(LLVMContext::MD_absolute_symbol)) 1620 return C; 1621 1622 auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { 1623 auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); 1624 auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); 1625 GV->setMetadata(LLVMContext::MD_absolute_symbol, 1626 MDNode::get(M.getContext(), {MinC, MaxC})); 1627 }; 1628 unsigned AbsWidth = IntTy->getBitWidth(); 1629 if (AbsWidth == IntPtrTy->getBitWidth()) 1630 SetAbsRange(~0ull, ~0ull); // Full set. 1631 else 1632 SetAbsRange(0, 1ull << AbsWidth); 1633 return C; 1634 } 1635 1636 void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1637 bool IsOne, 1638 Constant *UniqueMemberAddr) { 1639 for (auto &&Call : CSInfo.CallSites) { 1640 if (!OptimizedCalls.insert(&Call.CB).second) 1641 continue; 1642 IRBuilder<> B(&Call.CB); 1643 Value *Cmp = 1644 B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, 1645 B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); 1646 Cmp = B.CreateZExt(Cmp, Call.CB.getType()); 1647 NumUniqueRetVal++; 1648 Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, 1649 Cmp); 1650 } 1651 CSInfo.markDevirt(); 1652 } 1653 1654 Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { 1655 Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); 1656 return ConstantExpr::getGetElementPtr(Int8Ty, C, 1657 ConstantInt::get(Int64Ty, M->Offset)); 1658 } 1659 1660 bool DevirtModule::tryUniqueRetValOpt( 1661 unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1662 CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, 1663 VTableSlot Slot, ArrayRef<uint64_t> Args) { 1664 // IsOne controls whether we look for a 0 or a 1. 1665 auto tryUniqueRetValOptFor = [&](bool IsOne) { 1666 const TypeMemberInfo *UniqueMember = nullptr; 1667 for (const VirtualCallTarget &Target : TargetsForSlot) { 1668 if (Target.RetVal == (IsOne ? 1 : 0)) { 1669 if (UniqueMember) 1670 return false; 1671 UniqueMember = Target.TM; 1672 } 1673 } 1674 1675 // We should have found a unique member or bailed out by now. We already 1676 // checked for a uniform return value in tryUniformRetValOpt. 1677 assert(UniqueMember); 1678 1679 Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); 1680 if (CSInfo.isExported()) { 1681 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; 1682 Res->Info = IsOne; 1683 1684 exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); 1685 } 1686 1687 // Replace each call with the comparison. 1688 applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 1689 UniqueMemberAddr); 1690 1691 // Update devirtualization statistics for targets. 1692 if (RemarksEnabled || AreStatisticsEnabled()) 1693 for (auto &&Target : TargetsForSlot) 1694 Target.WasDevirt = true; 1695 1696 return true; 1697 }; 1698 1699 if (BitWidth == 1) { 1700 if (tryUniqueRetValOptFor(true)) 1701 return true; 1702 if (tryUniqueRetValOptFor(false)) 1703 return true; 1704 } 1705 return false; 1706 } 1707 1708 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 1709 Constant *Byte, Constant *Bit) { 1710 for (auto Call : CSInfo.CallSites) { 1711 if (!OptimizedCalls.insert(&Call.CB).second) 1712 continue; 1713 auto *RetType = cast<IntegerType>(Call.CB.getType()); 1714 IRBuilder<> B(&Call.CB); 1715 Value *Addr = 1716 B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); 1717 if (RetType->getBitWidth() == 1) { 1718 Value *Bits = B.CreateLoad(Int8Ty, Addr); 1719 Value *BitsAndBit = B.CreateAnd(Bits, Bit); 1720 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 1721 NumVirtConstProp1Bit++; 1722 Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 1723 OREGetter, IsBitSet); 1724 } else { 1725 Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 1726 Value *Val = B.CreateLoad(RetType, ValAddr); 1727 NumVirtConstProp++; 1728 Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, 1729 OREGetter, Val); 1730 } 1731 } 1732 CSInfo.markDevirt(); 1733 } 1734 1735 bool DevirtModule::tryVirtualConstProp( 1736 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1737 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1738 // This only works if the function returns an integer. 1739 auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 1740 if (!RetType) 1741 return false; 1742 unsigned BitWidth = RetType->getBitWidth(); 1743 if (BitWidth > 64) 1744 return false; 1745 1746 // Make sure that each function is defined, does not access memory, takes at 1747 // least one argument, does not use its first argument (which we assume is 1748 // 'this'), and has the same return type. 1749 // 1750 // Note that we test whether this copy of the function is readnone, rather 1751 // than testing function attributes, which must hold for any copy of the 1752 // function, even a less optimized version substituted at link time. This is 1753 // sound because the virtual constant propagation optimizations effectively 1754 // inline all implementations of the virtual function into each call site, 1755 // rather than using function attributes to perform local optimization. 1756 for (VirtualCallTarget &Target : TargetsForSlot) { 1757 if (Target.Fn->isDeclaration() || 1758 computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != 1759 FMRB_DoesNotAccessMemory || 1760 Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || 1761 Target.Fn->getReturnType() != RetType) 1762 return false; 1763 } 1764 1765 for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 1766 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 1767 continue; 1768 1769 WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; 1770 if (Res) 1771 ResByArg = &Res->ResByArg[CSByConstantArg.first]; 1772 1773 if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) 1774 continue; 1775 1776 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, 1777 ResByArg, Slot, CSByConstantArg.first)) 1778 continue; 1779 1780 // Find an allocation offset in bits in all vtables associated with the 1781 // type. 1782 uint64_t AllocBefore = 1783 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 1784 uint64_t AllocAfter = 1785 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 1786 1787 // Calculate the total amount of padding needed to store a value at both 1788 // ends of the object. 1789 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 1790 for (auto &&Target : TargetsForSlot) { 1791 TotalPaddingBefore += std::max<int64_t>( 1792 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 1793 TotalPaddingAfter += std::max<int64_t>( 1794 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 1795 } 1796 1797 // If the amount of padding is too large, give up. 1798 // FIXME: do something smarter here. 1799 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 1800 continue; 1801 1802 // Calculate the offset to the value as a (possibly negative) byte offset 1803 // and (if applicable) a bit offset, and store the values in the targets. 1804 int64_t OffsetByte; 1805 uint64_t OffsetBit; 1806 if (TotalPaddingBefore <= TotalPaddingAfter) 1807 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 1808 OffsetBit); 1809 else 1810 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 1811 OffsetBit); 1812 1813 if (RemarksEnabled || AreStatisticsEnabled()) 1814 for (auto &&Target : TargetsForSlot) 1815 Target.WasDevirt = true; 1816 1817 1818 if (CSByConstantArg.second.isExported()) { 1819 ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; 1820 exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte, 1821 ResByArg->Byte); 1822 exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit, 1823 ResByArg->Bit); 1824 } 1825 1826 // Rewrite each call to a load from OffsetByte/OffsetBit. 1827 Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 1828 Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 1829 applyVirtualConstProp(CSByConstantArg.second, 1830 TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 1831 } 1832 return true; 1833 } 1834 1835 void DevirtModule::rebuildGlobal(VTableBits &B) { 1836 if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 1837 return; 1838 1839 // Align the before byte array to the global's minimum alignment so that we 1840 // don't break any alignment requirements on the global. 1841 Align Alignment = M.getDataLayout().getValueOrABITypeAlignment( 1842 B.GV->getAlign(), B.GV->getValueType()); 1843 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment)); 1844 1845 // Before was stored in reverse order; flip it now. 1846 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 1847 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 1848 1849 // Build an anonymous global containing the before bytes, followed by the 1850 // original initializer, followed by the after bytes. 1851 auto NewInit = ConstantStruct::getAnon( 1852 {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 1853 B.GV->getInitializer(), 1854 ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 1855 auto NewGV = 1856 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 1857 GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 1858 NewGV->setSection(B.GV->getSection()); 1859 NewGV->setComdat(B.GV->getComdat()); 1860 NewGV->setAlignment(B.GV->getAlign()); 1861 1862 // Copy the original vtable's metadata to the anonymous global, adjusting 1863 // offsets as required. 1864 NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 1865 1866 // Build an alias named after the original global, pointing at the second 1867 // element (the original initializer). 1868 auto Alias = GlobalAlias::create( 1869 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 1870 ConstantExpr::getGetElementPtr( 1871 NewInit->getType(), NewGV, 1872 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 1873 ConstantInt::get(Int32Ty, 1)}), 1874 &M); 1875 Alias->setVisibility(B.GV->getVisibility()); 1876 Alias->takeName(B.GV); 1877 1878 B.GV->replaceAllUsesWith(Alias); 1879 B.GV->eraseFromParent(); 1880 } 1881 1882 bool DevirtModule::areRemarksEnabled() { 1883 const auto &FL = M.getFunctionList(); 1884 for (const Function &Fn : FL) { 1885 const auto &BBL = Fn.getBasicBlockList(); 1886 if (BBL.empty()) 1887 continue; 1888 auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); 1889 return DI.isEnabled(); 1890 } 1891 return false; 1892 } 1893 1894 void DevirtModule::scanTypeTestUsers( 1895 Function *TypeTestFunc, 1896 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 1897 // Find all virtual calls via a virtual table pointer %p under an assumption 1898 // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 1899 // points to a member of the type identifier %md. Group calls by (type ID, 1900 // offset) pair (effectively the identity of the virtual function) and store 1901 // to CallSlots. 1902 for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { 1903 auto *CI = dyn_cast<CallInst>(U.getUser()); 1904 if (!CI) 1905 continue; 1906 1907 // Search for virtual calls based on %p and add them to DevirtCalls. 1908 SmallVector<DevirtCallSite, 1> DevirtCalls; 1909 SmallVector<CallInst *, 1> Assumes; 1910 auto &DT = LookupDomTree(*CI->getFunction()); 1911 findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); 1912 1913 Metadata *TypeId = 1914 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 1915 // If we found any, add them to CallSlots. 1916 if (!Assumes.empty()) { 1917 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 1918 for (DevirtCallSite Call : DevirtCalls) 1919 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr); 1920 } 1921 1922 auto RemoveTypeTestAssumes = [&]() { 1923 // We no longer need the assumes or the type test. 1924 for (auto Assume : Assumes) 1925 Assume->eraseFromParent(); 1926 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 1927 // may use the vtable argument later. 1928 if (CI->use_empty()) 1929 CI->eraseFromParent(); 1930 }; 1931 1932 // At this point we could remove all type test assume sequences, as they 1933 // were originally inserted for WPD. However, we can keep these in the 1934 // code stream for later analysis (e.g. to help drive more efficient ICP 1935 // sequences). They will eventually be removed by a second LowerTypeTests 1936 // invocation that cleans them up. In order to do this correctly, the first 1937 // LowerTypeTests invocation needs to know that they have "Unknown" type 1938 // test resolution, so that they aren't treated as Unsat and lowered to 1939 // False, which will break any uses on assumes. Below we remove any type 1940 // test assumes that will not be treated as Unknown by LTT. 1941 1942 // The type test assumes will be treated by LTT as Unsat if the type id is 1943 // not used on a global (in which case it has no entry in the TypeIdMap). 1944 if (!TypeIdMap.count(TypeId)) 1945 RemoveTypeTestAssumes(); 1946 1947 // For ThinLTO importing, we need to remove the type test assumes if this is 1948 // an MDString type id without a corresponding TypeIdSummary. Any 1949 // non-MDString type ids are ignored and treated as Unknown by LTT, so their 1950 // type test assumes can be kept. If the MDString type id is missing a 1951 // TypeIdSummary (e.g. because there was no use on a vcall, preventing the 1952 // exporting phase of WPD from analyzing it), then it would be treated as 1953 // Unsat by LTT and we need to remove its type test assumes here. If not 1954 // used on a vcall we don't need them for later optimization use in any 1955 // case. 1956 else if (ImportSummary && isa<MDString>(TypeId)) { 1957 const TypeIdSummary *TidSummary = 1958 ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString()); 1959 if (!TidSummary) 1960 RemoveTypeTestAssumes(); 1961 else 1962 // If one was created it should not be Unsat, because if we reached here 1963 // the type id was used on a global. 1964 assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat); 1965 } 1966 } 1967 } 1968 1969 void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 1970 Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 1971 1972 for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) { 1973 auto *CI = dyn_cast<CallInst>(U.getUser()); 1974 if (!CI) 1975 continue; 1976 1977 Value *Ptr = CI->getArgOperand(0); 1978 Value *Offset = CI->getArgOperand(1); 1979 Value *TypeIdValue = CI->getArgOperand(2); 1980 Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 1981 1982 SmallVector<DevirtCallSite, 1> DevirtCalls; 1983 SmallVector<Instruction *, 1> LoadedPtrs; 1984 SmallVector<Instruction *, 1> Preds; 1985 bool HasNonCallUses = false; 1986 auto &DT = LookupDomTree(*CI->getFunction()); 1987 findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 1988 HasNonCallUses, CI, DT); 1989 1990 // Start by generating "pessimistic" code that explicitly loads the function 1991 // pointer from the vtable and performs the type check. If possible, we will 1992 // eliminate the load and the type check later. 1993 1994 // If possible, only generate the load at the point where it is used. 1995 // This helps avoid unnecessary spills. 1996 IRBuilder<> LoadB( 1997 (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 1998 Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 1999 Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); 2000 Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); 2001 2002 for (Instruction *LoadedPtr : LoadedPtrs) { 2003 LoadedPtr->replaceAllUsesWith(LoadedValue); 2004 LoadedPtr->eraseFromParent(); 2005 } 2006 2007 // Likewise for the type test. 2008 IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 2009 CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 2010 2011 for (Instruction *Pred : Preds) { 2012 Pred->replaceAllUsesWith(TypeTestCall); 2013 Pred->eraseFromParent(); 2014 } 2015 2016 // We have already erased any extractvalue instructions that refer to the 2017 // intrinsic call, but the intrinsic may have other non-extractvalue uses 2018 // (although this is unlikely). In that case, explicitly build a pair and 2019 // RAUW it. 2020 if (!CI->use_empty()) { 2021 Value *Pair = UndefValue::get(CI->getType()); 2022 IRBuilder<> B(CI); 2023 Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 2024 Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 2025 CI->replaceAllUsesWith(Pair); 2026 } 2027 2028 // The number of unsafe uses is initially the number of uses. 2029 auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 2030 NumUnsafeUses = DevirtCalls.size(); 2031 2032 // If the function pointer has a non-call user, we cannot eliminate the type 2033 // check, as one of those users may eventually call the pointer. Increment 2034 // the unsafe use count to make sure it cannot reach zero. 2035 if (HasNonCallUses) 2036 ++NumUnsafeUses; 2037 for (DevirtCallSite Call : DevirtCalls) { 2038 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, 2039 &NumUnsafeUses); 2040 } 2041 2042 CI->eraseFromParent(); 2043 } 2044 } 2045 2046 void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { 2047 auto *TypeId = dyn_cast<MDString>(Slot.TypeID); 2048 if (!TypeId) 2049 return; 2050 const TypeIdSummary *TidSummary = 2051 ImportSummary->getTypeIdSummary(TypeId->getString()); 2052 if (!TidSummary) 2053 return; 2054 auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); 2055 if (ResI == TidSummary->WPDRes.end()) 2056 return; 2057 const WholeProgramDevirtResolution &Res = ResI->second; 2058 2059 if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { 2060 assert(!Res.SingleImplName.empty()); 2061 // The type of the function in the declaration is irrelevant because every 2062 // call site will cast it to the correct type. 2063 Constant *SingleImpl = 2064 cast<Constant>(M.getOrInsertFunction(Res.SingleImplName, 2065 Type::getVoidTy(M.getContext())) 2066 .getCallee()); 2067 2068 // This is the import phase so we should not be exporting anything. 2069 bool IsExported = false; 2070 applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); 2071 assert(!IsExported); 2072 } 2073 2074 for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { 2075 auto I = Res.ResByArg.find(CSByConstantArg.first); 2076 if (I == Res.ResByArg.end()) 2077 continue; 2078 auto &ResByArg = I->second; 2079 // FIXME: We should figure out what to do about the "function name" argument 2080 // to the apply* functions, as the function names are unavailable during the 2081 // importing phase. For now we just pass the empty string. This does not 2082 // impact correctness because the function names are just used for remarks. 2083 switch (ResByArg.TheKind) { 2084 case WholeProgramDevirtResolution::ByArg::UniformRetVal: 2085 applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); 2086 break; 2087 case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { 2088 Constant *UniqueMemberAddr = 2089 importGlobal(Slot, CSByConstantArg.first, "unique_member"); 2090 applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, 2091 UniqueMemberAddr); 2092 break; 2093 } 2094 case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { 2095 Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte", 2096 Int32Ty, ResByArg.Byte); 2097 Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty, 2098 ResByArg.Bit); 2099 applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); 2100 break; 2101 } 2102 default: 2103 break; 2104 } 2105 } 2106 2107 if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { 2108 // The type of the function is irrelevant, because it's bitcast at calls 2109 // anyhow. 2110 Constant *JT = cast<Constant>( 2111 M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), 2112 Type::getVoidTy(M.getContext())) 2113 .getCallee()); 2114 bool IsExported = false; 2115 applyICallBranchFunnel(SlotInfo, JT, IsExported); 2116 assert(!IsExported); 2117 } 2118 } 2119 2120 void DevirtModule::removeRedundantTypeTests() { 2121 auto True = ConstantInt::getTrue(M.getContext()); 2122 for (auto &&U : NumUnsafeUsesForTypeTest) { 2123 if (U.second == 0) { 2124 U.first->replaceAllUsesWith(True); 2125 U.first->eraseFromParent(); 2126 } 2127 } 2128 } 2129 2130 ValueInfo 2131 DevirtModule::lookUpFunctionValueInfo(Function *TheFn, 2132 ModuleSummaryIndex *ExportSummary) { 2133 assert((ExportSummary != nullptr) && 2134 "Caller guarantees ExportSummary is not nullptr"); 2135 2136 const auto TheFnGUID = TheFn->getGUID(); 2137 const auto TheFnGUIDWithExportedName = GlobalValue::getGUID(TheFn->getName()); 2138 // Look up ValueInfo with the GUID in the current linkage. 2139 ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFnGUID); 2140 // If no entry is found and GUID is different from GUID computed using 2141 // exported name, look up ValueInfo with the exported name unconditionally. 2142 // This is a fallback. 2143 // 2144 // The reason to have a fallback: 2145 // 1. LTO could enable global value internalization via 2146 // `enable-lto-internalization`. 2147 // 2. The GUID in ExportedSummary is computed using exported name. 2148 if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) { 2149 TheFnVI = ExportSummary->getValueInfo(TheFnGUIDWithExportedName); 2150 } 2151 return TheFnVI; 2152 } 2153 2154 bool DevirtModule::mustBeUnreachableFunction( 2155 Function *const F, ModuleSummaryIndex *ExportSummary) { 2156 // First, learn unreachability by analyzing function IR. 2157 if (!F->isDeclaration()) { 2158 // A function must be unreachable if its entry block ends with an 2159 // 'unreachable'. 2160 return isa<UnreachableInst>(F->getEntryBlock().getTerminator()); 2161 } 2162 // Learn unreachability from ExportSummary if ExportSummary is present. 2163 return ExportSummary && 2164 ::mustBeUnreachableFunction( 2165 DevirtModule::lookUpFunctionValueInfo(F, ExportSummary)); 2166 } 2167 2168 bool DevirtModule::run() { 2169 // If only some of the modules were split, we cannot correctly perform 2170 // this transformation. We already checked for the presense of type tests 2171 // with partially split modules during the thin link, and would have emitted 2172 // an error if any were found, so here we can simply return. 2173 if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || 2174 (ImportSummary && ImportSummary->partiallySplitLTOUnits())) 2175 return false; 2176 2177 Function *TypeTestFunc = 2178 M.getFunction(Intrinsic::getName(Intrinsic::type_test)); 2179 Function *TypeCheckedLoadFunc = 2180 M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); 2181 Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 2182 2183 // Normally if there are no users of the devirtualization intrinsics in the 2184 // module, this pass has nothing to do. But if we are exporting, we also need 2185 // to handle any users that appear only in the function summaries. 2186 if (!ExportSummary && 2187 (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 2188 AssumeFunc->use_empty()) && 2189 (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) 2190 return false; 2191 2192 // Rebuild type metadata into a map for easy lookup. 2193 std::vector<VTableBits> Bits; 2194 DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 2195 buildTypeIdentifierMap(Bits, TypeIdMap); 2196 2197 if (TypeTestFunc && AssumeFunc) 2198 scanTypeTestUsers(TypeTestFunc, TypeIdMap); 2199 2200 if (TypeCheckedLoadFunc) 2201 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 2202 2203 if (ImportSummary) { 2204 for (auto &S : CallSlots) 2205 importResolution(S.first, S.second); 2206 2207 removeRedundantTypeTests(); 2208 2209 // We have lowered or deleted the type intrinsics, so we will no longer have 2210 // enough information to reason about the liveness of virtual function 2211 // pointers in GlobalDCE. 2212 for (GlobalVariable &GV : M.globals()) 2213 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2214 2215 // The rest of the code is only necessary when exporting or during regular 2216 // LTO, so we are done. 2217 return true; 2218 } 2219 2220 if (TypeIdMap.empty()) 2221 return true; 2222 2223 // Collect information from summary about which calls to try to devirtualize. 2224 if (ExportSummary) { 2225 DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; 2226 for (auto &P : TypeIdMap) { 2227 if (auto *TypeId = dyn_cast<MDString>(P.first)) 2228 MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( 2229 TypeId); 2230 } 2231 2232 for (auto &P : *ExportSummary) { 2233 for (auto &S : P.second.SummaryList) { 2234 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2235 if (!FS) 2236 continue; 2237 // FIXME: Only add live functions. 2238 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2239 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2240 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2241 } 2242 } 2243 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2244 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2245 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2246 } 2247 } 2248 for (const FunctionSummary::ConstVCall &VC : 2249 FS->type_test_assume_const_vcalls()) { 2250 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2251 CallSlots[{MD, VC.VFunc.Offset}] 2252 .ConstCSInfo[VC.Args] 2253 .addSummaryTypeTestAssumeUser(FS); 2254 } 2255 } 2256 for (const FunctionSummary::ConstVCall &VC : 2257 FS->type_checked_load_const_vcalls()) { 2258 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2259 CallSlots[{MD, VC.VFunc.Offset}] 2260 .ConstCSInfo[VC.Args] 2261 .addSummaryTypeCheckedLoadUser(FS); 2262 } 2263 } 2264 } 2265 } 2266 } 2267 2268 // For each (type, offset) pair: 2269 bool DidVirtualConstProp = false; 2270 std::map<std::string, Function*> DevirtTargets; 2271 for (auto &S : CallSlots) { 2272 // Search each of the members of the type identifier for the virtual 2273 // function implementation at offset S.first.ByteOffset, and add to 2274 // TargetsForSlot. 2275 std::vector<VirtualCallTarget> TargetsForSlot; 2276 WholeProgramDevirtResolution *Res = nullptr; 2277 const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID]; 2278 if (ExportSummary && isa<MDString>(S.first.TypeID) && 2279 TypeMemberInfos.size()) 2280 // For any type id used on a global's type metadata, create the type id 2281 // summary resolution regardless of whether we can devirtualize, so that 2282 // lower type tests knows the type id is not Unsat. If it was not used on 2283 // a global's type metadata, the TypeIdMap entry set will be empty, and 2284 // we don't want to create an entry (with the default Unknown type 2285 // resolution), which can prevent detection of the Unsat. 2286 Res = &ExportSummary 2287 ->getOrInsertTypeIdSummary( 2288 cast<MDString>(S.first.TypeID)->getString()) 2289 .WPDRes[S.first.ByteOffset]; 2290 if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, 2291 S.first.ByteOffset, ExportSummary)) { 2292 2293 if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { 2294 DidVirtualConstProp |= 2295 tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); 2296 2297 tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first); 2298 } 2299 2300 // Collect functions devirtualized at least for one call site for stats. 2301 if (RemarksEnabled || AreStatisticsEnabled()) 2302 for (const auto &T : TargetsForSlot) 2303 if (T.WasDevirt) 2304 DevirtTargets[std::string(T.Fn->getName())] = T.Fn; 2305 } 2306 2307 // CFI-specific: if we are exporting and any llvm.type.checked.load 2308 // intrinsics were *not* devirtualized, we need to add the resulting 2309 // llvm.type.test intrinsics to the function summaries so that the 2310 // LowerTypeTests pass will export them. 2311 if (ExportSummary && isa<MDString>(S.first.TypeID)) { 2312 auto GUID = 2313 GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); 2314 for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) 2315 FS->addTypeTest(GUID); 2316 for (auto &CCS : S.second.ConstCSInfo) 2317 for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) 2318 FS->addTypeTest(GUID); 2319 } 2320 } 2321 2322 if (RemarksEnabled) { 2323 // Generate remarks for each devirtualized function. 2324 for (const auto &DT : DevirtTargets) { 2325 Function *F = DT.second; 2326 2327 using namespace ore; 2328 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) 2329 << "devirtualized " 2330 << NV("FunctionName", DT.first)); 2331 } 2332 } 2333 2334 NumDevirtTargets += DevirtTargets.size(); 2335 2336 removeRedundantTypeTests(); 2337 2338 // Rebuild each global we touched as part of virtual constant propagation to 2339 // include the before and after bytes. 2340 if (DidVirtualConstProp) 2341 for (VTableBits &B : Bits) 2342 rebuildGlobal(B); 2343 2344 // We have lowered or deleted the type intrinsics, so we will no longer have 2345 // enough information to reason about the liveness of virtual function 2346 // pointers in GlobalDCE. 2347 for (GlobalVariable &GV : M.globals()) 2348 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2349 2350 return true; 2351 } 2352 2353 void DevirtIndex::run() { 2354 if (ExportSummary.typeIdCompatibleVtableMap().empty()) 2355 return; 2356 2357 DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; 2358 for (auto &P : ExportSummary.typeIdCompatibleVtableMap()) { 2359 NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first); 2360 } 2361 2362 // Collect information from summary about which calls to try to devirtualize. 2363 for (auto &P : ExportSummary) { 2364 for (auto &S : P.second.SummaryList) { 2365 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2366 if (!FS) 2367 continue; 2368 // FIXME: Only add live functions. 2369 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2370 for (StringRef Name : NameByGUID[VF.GUID]) { 2371 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2372 } 2373 } 2374 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2375 for (StringRef Name : NameByGUID[VF.GUID]) { 2376 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2377 } 2378 } 2379 for (const FunctionSummary::ConstVCall &VC : 2380 FS->type_test_assume_const_vcalls()) { 2381 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2382 CallSlots[{Name, VC.VFunc.Offset}] 2383 .ConstCSInfo[VC.Args] 2384 .addSummaryTypeTestAssumeUser(FS); 2385 } 2386 } 2387 for (const FunctionSummary::ConstVCall &VC : 2388 FS->type_checked_load_const_vcalls()) { 2389 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2390 CallSlots[{Name, VC.VFunc.Offset}] 2391 .ConstCSInfo[VC.Args] 2392 .addSummaryTypeCheckedLoadUser(FS); 2393 } 2394 } 2395 } 2396 } 2397 2398 std::set<ValueInfo> DevirtTargets; 2399 // For each (type, offset) pair: 2400 for (auto &S : CallSlots) { 2401 // Search each of the members of the type identifier for the virtual 2402 // function implementation at offset S.first.ByteOffset, and add to 2403 // TargetsForSlot. 2404 std::vector<ValueInfo> TargetsForSlot; 2405 auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); 2406 assert(TidSummary); 2407 // Create the type id summary resolution regardlness of whether we can 2408 // devirtualize, so that lower type tests knows the type id is used on 2409 // a global and not Unsat. 2410 WholeProgramDevirtResolution *Res = 2411 &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID) 2412 .WPDRes[S.first.ByteOffset]; 2413 if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, 2414 S.first.ByteOffset)) { 2415 2416 if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res, 2417 DevirtTargets)) 2418 continue; 2419 } 2420 } 2421 2422 // Optionally have the thin link print message for each devirtualized 2423 // function. 2424 if (PrintSummaryDevirt) 2425 for (const auto &DT : DevirtTargets) 2426 errs() << "Devirtualized call to " << DT << "\n"; 2427 2428 NumDevirtTargets += DevirtTargets.size(); 2429 } 2430