1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass implements whole program optimization of virtual calls in cases 11 // where we know (via bitset information) that the list of callee is fixed. This 12 // includes the following: 13 // - Single implementation devirtualization: if a virtual call has a single 14 // possible callee, replace all calls with a direct call to that callee. 15 // - Virtual constant propagation: if the virtual function's return type is an 16 // integer <=64 bits and all possible callees are readnone, for each class and 17 // each list of constant arguments: evaluate the function, store the return 18 // value alongside the virtual table, and rewrite each virtual call as a load 19 // from the virtual table. 20 // - Uniform return value optimization: if the conditions for virtual constant 21 // propagation hold and each function returns the same constant value, replace 22 // each virtual call with that constant. 23 // - Unique return value optimization for i1 return values: if the conditions 24 // for virtual constant propagation hold and a single vtable's function 25 // returns 0, or a single vtable's function returns 1, replace each virtual 26 // call with a comparison of the vptr against that vtable's address. 27 // 28 //===----------------------------------------------------------------------===// 29 30 #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 31 #include "llvm/ADT/ArrayRef.h" 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/Analysis/BitSetUtils.h" 35 #include "llvm/IR/CallSite.h" 36 #include "llvm/IR/Constants.h" 37 #include "llvm/IR/DataLayout.h" 38 #include "llvm/IR/IRBuilder.h" 39 #include "llvm/IR/Instructions.h" 40 #include "llvm/IR/Intrinsics.h" 41 #include "llvm/IR/Module.h" 42 #include "llvm/Pass.h" 43 #include "llvm/Support/raw_ostream.h" 44 #include "llvm/Transforms/IPO.h" 45 #include "llvm/Transforms/Utils/Evaluator.h" 46 #include "llvm/Transforms/Utils/Local.h" 47 48 #include <set> 49 50 using namespace llvm; 51 using namespace wholeprogramdevirt; 52 53 #define DEBUG_TYPE "wholeprogramdevirt" 54 55 // Find the minimum offset that we may store a value of size Size bits at. If 56 // IsAfter is set, look for an offset before the object, otherwise look for an 57 // offset after the object. 58 uint64_t 59 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 60 bool IsAfter, uint64_t Size) { 61 // Find a minimum offset taking into account only vtable sizes. 62 uint64_t MinByte = 0; 63 for (const VirtualCallTarget &Target : Targets) { 64 if (IsAfter) 65 MinByte = std::max(MinByte, Target.minAfterBytes()); 66 else 67 MinByte = std::max(MinByte, Target.minBeforeBytes()); 68 } 69 70 // Build a vector of arrays of bytes covering, for each target, a slice of the 71 // used region (see AccumBitVector::BytesUsed in 72 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 73 // this aligns the used regions to start at MinByte. 74 // 75 // In this example, A, B and C are vtables, # is a byte already allocated for 76 // a virtual function pointer, AAAA... (etc.) are the used regions for the 77 // vtables and Offset(X) is the value computed for the Offset variable below 78 // for X. 79 // 80 // Offset(A) 81 // | | 82 // |MinByte 83 // A: ################AAAAAAAA|AAAAAAAA 84 // B: ########BBBBBBBBBBBBBBBB|BBBB 85 // C: ########################|CCCCCCCCCCCCCCCC 86 // | Offset(B) | 87 // 88 // This code produces the slices of A, B and C that appear after the divider 89 // at MinByte. 90 std::vector<ArrayRef<uint8_t>> Used; 91 for (const VirtualCallTarget &Target : Targets) { 92 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.BS->Bits->After.BytesUsed 93 : Target.BS->Bits->Before.BytesUsed; 94 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 95 : MinByte - Target.minBeforeBytes(); 96 97 // Disregard used regions that are smaller than Offset. These are 98 // effectively all-free regions that do not need to be checked. 99 if (VTUsed.size() > Offset) 100 Used.push_back(VTUsed.slice(Offset)); 101 } 102 103 if (Size == 1) { 104 // Find a free bit in each member of Used. 105 for (unsigned I = 0;; ++I) { 106 uint8_t BitsUsed = 0; 107 for (auto &&B : Used) 108 if (I < B.size()) 109 BitsUsed |= B[I]; 110 if (BitsUsed != 0xff) 111 return (MinByte + I) * 8 + 112 countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 113 } 114 } else { 115 // Find a free (Size/8) byte region in each member of Used. 116 // FIXME: see if alignment helps. 117 for (unsigned I = 0;; ++I) { 118 for (auto &&B : Used) { 119 unsigned Byte = 0; 120 while ((I + Byte) < B.size() && Byte < (Size / 8)) { 121 if (B[I + Byte]) 122 goto NextI; 123 ++Byte; 124 } 125 } 126 return (MinByte + I) * 8; 127 NextI:; 128 } 129 } 130 } 131 132 void wholeprogramdevirt::setBeforeReturnValues( 133 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 134 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 135 if (BitWidth == 1) 136 OffsetByte = -(AllocBefore / 8 + 1); 137 else 138 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 139 OffsetBit = AllocBefore % 8; 140 141 for (VirtualCallTarget &Target : Targets) { 142 if (BitWidth == 1) 143 Target.setBeforeBit(AllocBefore); 144 else 145 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 146 } 147 } 148 149 void wholeprogramdevirt::setAfterReturnValues( 150 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 151 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 152 if (BitWidth == 1) 153 OffsetByte = AllocAfter / 8; 154 else 155 OffsetByte = (AllocAfter + 7) / 8; 156 OffsetBit = AllocAfter % 8; 157 158 for (VirtualCallTarget &Target : Targets) { 159 if (BitWidth == 1) 160 Target.setAfterBit(AllocAfter); 161 else 162 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 163 } 164 } 165 166 VirtualCallTarget::VirtualCallTarget(Function *Fn, const BitSetInfo *BS) 167 : Fn(Fn), BS(BS), 168 IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {} 169 170 namespace { 171 172 // A slot in a set of virtual tables. The BitSetID identifies the set of virtual 173 // tables, and the ByteOffset is the offset in bytes from the address point to 174 // the virtual function pointer. 175 struct VTableSlot { 176 Metadata *BitSetID; 177 uint64_t ByteOffset; 178 }; 179 180 } 181 182 namespace llvm { 183 184 template <> struct DenseMapInfo<VTableSlot> { 185 static VTableSlot getEmptyKey() { 186 return {DenseMapInfo<Metadata *>::getEmptyKey(), 187 DenseMapInfo<uint64_t>::getEmptyKey()}; 188 } 189 static VTableSlot getTombstoneKey() { 190 return {DenseMapInfo<Metadata *>::getTombstoneKey(), 191 DenseMapInfo<uint64_t>::getTombstoneKey()}; 192 } 193 static unsigned getHashValue(const VTableSlot &I) { 194 return DenseMapInfo<Metadata *>::getHashValue(I.BitSetID) ^ 195 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 196 } 197 static bool isEqual(const VTableSlot &LHS, 198 const VTableSlot &RHS) { 199 return LHS.BitSetID == RHS.BitSetID && LHS.ByteOffset == RHS.ByteOffset; 200 } 201 }; 202 203 } 204 205 namespace { 206 207 // A virtual call site. VTable is the loaded virtual table pointer, and CS is 208 // the indirect virtual call. 209 struct VirtualCallSite { 210 Value *VTable; 211 CallSite CS; 212 213 void replaceAndErase(Value *New) { 214 CS->replaceAllUsesWith(New); 215 if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { 216 BranchInst::Create(II->getNormalDest(), CS.getInstruction()); 217 II->getUnwindDest()->removePredecessor(II->getParent()); 218 } 219 CS->eraseFromParent(); 220 } 221 }; 222 223 struct DevirtModule { 224 Module &M; 225 IntegerType *Int8Ty; 226 PointerType *Int8PtrTy; 227 IntegerType *Int32Ty; 228 229 MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; 230 231 DevirtModule(Module &M) 232 : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), 233 Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 234 Int32Ty(Type::getInt32Ty(M.getContext())) {} 235 236 void buildBitSets(std::vector<VTableBits> &Bits, 237 DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets); 238 bool tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 239 const std::set<BitSetInfo> &BitSetInfos, 240 uint64_t ByteOffset); 241 bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot, 242 MutableArrayRef<VirtualCallSite> CallSites); 243 bool tryEvaluateFunctionsWithArgs( 244 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 245 ArrayRef<ConstantInt *> Args); 246 bool tryUniformRetValOpt(IntegerType *RetType, 247 ArrayRef<VirtualCallTarget> TargetsForSlot, 248 MutableArrayRef<VirtualCallSite> CallSites); 249 bool tryUniqueRetValOpt(unsigned BitWidth, 250 ArrayRef<VirtualCallTarget> TargetsForSlot, 251 MutableArrayRef<VirtualCallSite> CallSites); 252 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 253 ArrayRef<VirtualCallSite> CallSites); 254 255 void rebuildGlobal(VTableBits &B); 256 257 bool run(); 258 }; 259 260 struct WholeProgramDevirt : public ModulePass { 261 static char ID; 262 WholeProgramDevirt() : ModulePass(ID) { 263 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 264 } 265 bool runOnModule(Module &M) { 266 if (skipModule(M)) 267 return false; 268 269 return DevirtModule(M).run(); 270 } 271 }; 272 273 } // anonymous namespace 274 275 INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", 276 "Whole program devirtualization", false, false) 277 char WholeProgramDevirt::ID = 0; 278 279 ModulePass *llvm::createWholeProgramDevirtPass() { 280 return new WholeProgramDevirt; 281 } 282 283 void DevirtModule::buildBitSets( 284 std::vector<VTableBits> &Bits, 285 DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets) { 286 NamedMDNode *BitSetNM = M.getNamedMetadata("llvm.bitsets"); 287 if (!BitSetNM) 288 return; 289 290 DenseMap<GlobalVariable *, VTableBits *> GVToBits; 291 Bits.reserve(BitSetNM->getNumOperands()); 292 for (auto Op : BitSetNM->operands()) { 293 auto OpConstMD = dyn_cast_or_null<ConstantAsMetadata>(Op->getOperand(1)); 294 if (!OpConstMD) 295 continue; 296 auto BitSetID = Op->getOperand(0).get(); 297 298 Constant *OpConst = OpConstMD->getValue(); 299 if (auto GA = dyn_cast<GlobalAlias>(OpConst)) 300 OpConst = GA->getAliasee(); 301 auto OpGlobal = dyn_cast<GlobalVariable>(OpConst); 302 if (!OpGlobal) 303 continue; 304 305 uint64_t Offset = 306 cast<ConstantInt>( 307 cast<ConstantAsMetadata>(Op->getOperand(2))->getValue()) 308 ->getZExtValue(); 309 310 VTableBits *&BitsPtr = GVToBits[OpGlobal]; 311 if (!BitsPtr) { 312 Bits.emplace_back(); 313 Bits.back().GV = OpGlobal; 314 Bits.back().ObjectSize = M.getDataLayout().getTypeAllocSize( 315 OpGlobal->getInitializer()->getType()); 316 BitsPtr = &Bits.back(); 317 } 318 BitSets[BitSetID].insert({BitsPtr, Offset}); 319 } 320 } 321 322 bool DevirtModule::tryFindVirtualCallTargets( 323 std::vector<VirtualCallTarget> &TargetsForSlot, 324 const std::set<BitSetInfo> &BitSetInfos, uint64_t ByteOffset) { 325 for (const BitSetInfo &BS : BitSetInfos) { 326 if (!BS.Bits->GV->isConstant()) 327 return false; 328 329 auto Init = dyn_cast<ConstantArray>(BS.Bits->GV->getInitializer()); 330 if (!Init) 331 return false; 332 ArrayType *VTableTy = Init->getType(); 333 334 uint64_t ElemSize = 335 M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); 336 uint64_t GlobalSlotOffset = BS.Offset + ByteOffset; 337 if (GlobalSlotOffset % ElemSize != 0) 338 return false; 339 340 unsigned Op = GlobalSlotOffset / ElemSize; 341 if (Op >= Init->getNumOperands()) 342 return false; 343 344 auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts()); 345 if (!Fn) 346 return false; 347 348 // We can disregard __cxa_pure_virtual as a possible call target, as 349 // calls to pure virtuals are UB. 350 if (Fn->getName() == "__cxa_pure_virtual") 351 continue; 352 353 TargetsForSlot.push_back({Fn, &BS}); 354 } 355 356 // Give up if we couldn't find any targets. 357 return !TargetsForSlot.empty(); 358 } 359 360 bool DevirtModule::trySingleImplDevirt( 361 ArrayRef<VirtualCallTarget> TargetsForSlot, 362 MutableArrayRef<VirtualCallSite> CallSites) { 363 // See if the program contains a single implementation of this virtual 364 // function. 365 Function *TheFn = TargetsForSlot[0].Fn; 366 for (auto &&Target : TargetsForSlot) 367 if (TheFn != Target.Fn) 368 return false; 369 370 // If so, update each call site to call that implementation directly. 371 for (auto &&VCallSite : CallSites) { 372 VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( 373 TheFn, VCallSite.CS.getCalledValue()->getType())); 374 } 375 return true; 376 } 377 378 bool DevirtModule::tryEvaluateFunctionsWithArgs( 379 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 380 ArrayRef<ConstantInt *> Args) { 381 // Evaluate each function and store the result in each target's RetVal 382 // field. 383 for (VirtualCallTarget &Target : TargetsForSlot) { 384 if (Target.Fn->arg_size() != Args.size() + 1) 385 return false; 386 for (unsigned I = 0; I != Args.size(); ++I) 387 if (Target.Fn->getFunctionType()->getParamType(I + 1) != 388 Args[I]->getType()) 389 return false; 390 391 Evaluator Eval(M.getDataLayout(), nullptr); 392 SmallVector<Constant *, 2> EvalArgs; 393 EvalArgs.push_back( 394 Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 395 EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); 396 Constant *RetVal; 397 if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 398 !isa<ConstantInt>(RetVal)) 399 return false; 400 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 401 } 402 return true; 403 } 404 405 bool DevirtModule::tryUniformRetValOpt( 406 IntegerType *RetType, ArrayRef<VirtualCallTarget> TargetsForSlot, 407 MutableArrayRef<VirtualCallSite> CallSites) { 408 // Uniform return value optimization. If all functions return the same 409 // constant, replace all calls with that constant. 410 uint64_t TheRetVal = TargetsForSlot[0].RetVal; 411 for (const VirtualCallTarget &Target : TargetsForSlot) 412 if (Target.RetVal != TheRetVal) 413 return false; 414 415 auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); 416 for (auto Call : CallSites) 417 Call.replaceAndErase(TheRetValConst); 418 return true; 419 } 420 421 bool DevirtModule::tryUniqueRetValOpt( 422 unsigned BitWidth, ArrayRef<VirtualCallTarget> TargetsForSlot, 423 MutableArrayRef<VirtualCallSite> CallSites) { 424 // IsOne controls whether we look for a 0 or a 1. 425 auto tryUniqueRetValOptFor = [&](bool IsOne) { 426 const BitSetInfo *UniqueBitSet = 0; 427 for (const VirtualCallTarget &Target : TargetsForSlot) { 428 if (Target.RetVal == (IsOne ? 1 : 0)) { 429 if (UniqueBitSet) 430 return false; 431 UniqueBitSet = Target.BS; 432 } 433 } 434 435 // We should have found a unique bit set or bailed out by now. We already 436 // checked for a uniform return value in tryUniformRetValOpt. 437 assert(UniqueBitSet); 438 439 // Replace each call with the comparison. 440 for (auto &&Call : CallSites) { 441 IRBuilder<> B(Call.CS.getInstruction()); 442 Value *OneAddr = B.CreateBitCast(UniqueBitSet->Bits->GV, Int8PtrTy); 443 OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueBitSet->Offset); 444 Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, 445 Call.VTable, OneAddr); 446 Call.replaceAndErase(Cmp); 447 } 448 return true; 449 }; 450 451 if (BitWidth == 1) { 452 if (tryUniqueRetValOptFor(true)) 453 return true; 454 if (tryUniqueRetValOptFor(false)) 455 return true; 456 } 457 return false; 458 } 459 460 bool DevirtModule::tryVirtualConstProp( 461 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 462 ArrayRef<VirtualCallSite> CallSites) { 463 // This only works if the function returns an integer. 464 auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 465 if (!RetType) 466 return false; 467 unsigned BitWidth = RetType->getBitWidth(); 468 if (BitWidth > 64) 469 return false; 470 471 // Make sure that each function does not access memory, takes at least one 472 // argument, does not use its first argument (which we assume is 'this'), 473 // and has the same return type. 474 for (VirtualCallTarget &Target : TargetsForSlot) { 475 if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() || 476 !Target.Fn->arg_begin()->use_empty() || 477 Target.Fn->getReturnType() != RetType) 478 return false; 479 } 480 481 // Group call sites by the list of constant arguments they pass. 482 // The comparator ensures deterministic ordering. 483 struct ByAPIntValue { 484 bool operator()(const std::vector<ConstantInt *> &A, 485 const std::vector<ConstantInt *> &B) const { 486 return std::lexicographical_compare( 487 A.begin(), A.end(), B.begin(), B.end(), 488 [](ConstantInt *AI, ConstantInt *BI) { 489 return AI->getValue().ult(BI->getValue()); 490 }); 491 } 492 }; 493 std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>, 494 ByAPIntValue> 495 VCallSitesByConstantArg; 496 for (auto &&VCallSite : CallSites) { 497 std::vector<ConstantInt *> Args; 498 if (VCallSite.CS.getType() != RetType) 499 continue; 500 for (auto &&Arg : 501 make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { 502 if (!isa<ConstantInt>(Arg)) 503 break; 504 Args.push_back(cast<ConstantInt>(&Arg)); 505 } 506 if (Args.size() + 1 != VCallSite.CS.arg_size()) 507 continue; 508 509 VCallSitesByConstantArg[Args].push_back(VCallSite); 510 } 511 512 for (auto &&CSByConstantArg : VCallSitesByConstantArg) { 513 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 514 continue; 515 516 if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) 517 continue; 518 519 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) 520 continue; 521 522 // Find an allocation offset in bits in all vtables in the bitset. 523 uint64_t AllocBefore = 524 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 525 uint64_t AllocAfter = 526 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 527 528 // Calculate the total amount of padding needed to store a value at both 529 // ends of the object. 530 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 531 for (auto &&Target : TargetsForSlot) { 532 TotalPaddingBefore += std::max<int64_t>( 533 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 534 TotalPaddingAfter += std::max<int64_t>( 535 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 536 } 537 538 // If the amount of padding is too large, give up. 539 // FIXME: do something smarter here. 540 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 541 continue; 542 543 // Calculate the offset to the value as a (possibly negative) byte offset 544 // and (if applicable) a bit offset, and store the values in the targets. 545 int64_t OffsetByte; 546 uint64_t OffsetBit; 547 if (TotalPaddingBefore <= TotalPaddingAfter) 548 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 549 OffsetBit); 550 else 551 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 552 OffsetBit); 553 554 // Rewrite each call to a load from OffsetByte/OffsetBit. 555 for (auto Call : CSByConstantArg.second) { 556 IRBuilder<> B(Call.CS.getInstruction()); 557 Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); 558 if (BitWidth == 1) { 559 Value *Bits = B.CreateLoad(Addr); 560 Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 561 Value *BitsAndBit = B.CreateAnd(Bits, Bit); 562 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 563 Call.replaceAndErase(IsBitSet); 564 } else { 565 Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 566 Value *Val = B.CreateLoad(RetType, ValAddr); 567 Call.replaceAndErase(Val); 568 } 569 } 570 } 571 return true; 572 } 573 574 void DevirtModule::rebuildGlobal(VTableBits &B) { 575 if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 576 return; 577 578 // Align each byte array to pointer width. 579 unsigned PointerSize = M.getDataLayout().getPointerSize(); 580 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); 581 B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); 582 583 // Before was stored in reverse order; flip it now. 584 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 585 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 586 587 // Build an anonymous global containing the before bytes, followed by the 588 // original initializer, followed by the after bytes. 589 auto NewInit = ConstantStruct::getAnon( 590 {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 591 B.GV->getInitializer(), 592 ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 593 auto NewGV = 594 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 595 GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 596 NewGV->setSection(B.GV->getSection()); 597 NewGV->setComdat(B.GV->getComdat()); 598 599 // Build an alias named after the original global, pointing at the second 600 // element (the original initializer). 601 auto Alias = GlobalAlias::create( 602 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 603 ConstantExpr::getGetElementPtr( 604 NewInit->getType(), NewGV, 605 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 606 ConstantInt::get(Int32Ty, 1)}), 607 &M); 608 Alias->setVisibility(B.GV->getVisibility()); 609 Alias->takeName(B.GV); 610 611 B.GV->replaceAllUsesWith(Alias); 612 B.GV->eraseFromParent(); 613 } 614 615 bool DevirtModule::run() { 616 Function *BitSetTestFunc = 617 M.getFunction(Intrinsic::getName(Intrinsic::bitset_test)); 618 if (!BitSetTestFunc || BitSetTestFunc->use_empty()) 619 return false; 620 621 Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 622 if (!AssumeFunc || AssumeFunc->use_empty()) 623 return false; 624 625 // Find all virtual calls via a virtual table pointer %p under an assumption 626 // of the form llvm.assume(llvm.bitset.test(%p, %md)). This indicates that %p 627 // points to a vtable in the bitset %md. Group calls by (bitset, offset) pair 628 // (effectively the identity of the virtual function) and store to CallSlots. 629 DenseSet<Value *> SeenPtrs; 630 for (auto I = BitSetTestFunc->use_begin(), E = BitSetTestFunc->use_end(); 631 I != E;) { 632 auto CI = dyn_cast<CallInst>(I->getUser()); 633 ++I; 634 if (!CI) 635 continue; 636 637 // Search for virtual calls based on %p and add them to DevirtCalls. 638 SmallVector<DevirtCallSite, 1> DevirtCalls; 639 SmallVector<CallInst *, 1> Assumes; 640 findDevirtualizableCalls(DevirtCalls, Assumes, CI); 641 642 // If we found any, add them to CallSlots. Only do this if we haven't seen 643 // the vtable pointer before, as it may have been CSE'd with pointers from 644 // other call sites, and we don't want to process call sites multiple times. 645 if (!Assumes.empty()) { 646 Metadata *BitSet = 647 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 648 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 649 if (SeenPtrs.insert(Ptr).second) { 650 for (DevirtCallSite Call : DevirtCalls) { 651 CallSlots[{BitSet, Call.Offset}].push_back( 652 {CI->getArgOperand(0), Call.CS}); 653 } 654 } 655 } 656 657 // We no longer need the assumes or the bitset test. 658 for (auto Assume : Assumes) 659 Assume->eraseFromParent(); 660 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 661 // may use the vtable argument later. 662 if (CI->use_empty()) 663 CI->eraseFromParent(); 664 } 665 666 // Rebuild llvm.bitsets metadata into a map for easy lookup. 667 std::vector<VTableBits> Bits; 668 DenseMap<Metadata *, std::set<BitSetInfo>> BitSets; 669 buildBitSets(Bits, BitSets); 670 if (BitSets.empty()) 671 return true; 672 673 // For each (bitset, offset) pair: 674 bool DidVirtualConstProp = false; 675 for (auto &S : CallSlots) { 676 // Search each of the vtables in the bitset for the virtual function 677 // implementation at offset S.first.ByteOffset, and add to TargetsForSlot. 678 std::vector<VirtualCallTarget> TargetsForSlot; 679 if (!tryFindVirtualCallTargets(TargetsForSlot, BitSets[S.first.BitSetID], 680 S.first.ByteOffset)) 681 continue; 682 683 if (trySingleImplDevirt(TargetsForSlot, S.second)) 684 continue; 685 686 DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); 687 } 688 689 // Rebuild each global we touched as part of virtual constant propagation to 690 // include the before and after bytes. 691 if (DidVirtualConstProp) 692 for (VTableBits &B : Bits) 693 rebuildGlobal(B); 694 695 return true; 696 } 697