1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass implements whole program optimization of virtual calls in cases 11 // where we know (via bitset information) that the list of callee is fixed. This 12 // includes the following: 13 // - Single implementation devirtualization: if a virtual call has a single 14 // possible callee, replace all calls with a direct call to that callee. 15 // - Virtual constant propagation: if the virtual function's return type is an 16 // integer <=64 bits and all possible callees are readnone, for each class and 17 // each list of constant arguments: evaluate the function, store the return 18 // value alongside the virtual table, and rewrite each virtual call as a load 19 // from the virtual table. 20 // - Uniform return value optimization: if the conditions for virtual constant 21 // propagation hold and each function returns the same constant value, replace 22 // each virtual call with that constant. 23 // - Unique return value optimization for i1 return values: if the conditions 24 // for virtual constant propagation hold and a single vtable's function 25 // returns 0, or a single vtable's function returns 1, replace each virtual 26 // call with a comparison of the vptr against that vtable's address. 27 // 28 //===----------------------------------------------------------------------===// 29 30 #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 31 #include "llvm/Transforms/IPO.h" 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/IR/CallSite.h" 35 #include "llvm/IR/Constants.h" 36 #include "llvm/IR/DataLayout.h" 37 #include "llvm/IR/IRBuilder.h" 38 #include "llvm/IR/Instructions.h" 39 #include "llvm/IR/Intrinsics.h" 40 #include "llvm/IR/Module.h" 41 #include "llvm/Pass.h" 42 #include "llvm/Support/raw_ostream.h" 43 #include "llvm/Transforms/Utils/Evaluator.h" 44 #include "llvm/Transforms/Utils/Local.h" 45 46 #include <set> 47 48 using namespace llvm; 49 using namespace wholeprogramdevirt; 50 51 #define DEBUG_TYPE "wholeprogramdevirt" 52 53 // Find the minimum offset that we may store a value of size Size bits at. If 54 // IsAfter is set, look for an offset before the object, otherwise look for an 55 // offset after the object. 56 uint64_t 57 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 58 bool IsAfter, uint64_t Size) { 59 // Find a minimum offset taking into account only vtable sizes. 60 uint64_t MinByte = 0; 61 for (const VirtualCallTarget &Target : Targets) { 62 if (IsAfter) 63 MinByte = std::max(MinByte, Target.minAfterBytes()); 64 else 65 MinByte = std::max(MinByte, Target.minBeforeBytes()); 66 } 67 68 // Build a vector of arrays of bytes covering, for each target, a slice of the 69 // used region (see AccumBitVector::BytesUsed in 70 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 71 // this aligns the used regions to start at MinByte. 72 // 73 // In this example, A, B and C are vtables, # is a byte already allocated for 74 // a virtual function pointer, AAAA... (etc.) are the used regions for the 75 // vtables and Offset(X) is the value computed for the Offset variable below 76 // for X. 77 // 78 // Offset(A) 79 // | | 80 // |MinByte 81 // A: ################AAAAAAAA|AAAAAAAA 82 // B: ########BBBBBBBBBBBBBBBB|BBBB 83 // C: ########################|CCCCCCCCCCCCCCCC 84 // | Offset(B) | 85 // 86 // This code produces the slices of A, B and C that appear after the divider 87 // at MinByte. 88 std::vector<ArrayRef<uint8_t>> Used; 89 for (const VirtualCallTarget &Target : Targets) { 90 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.BS->Bits->After.BytesUsed 91 : Target.BS->Bits->Before.BytesUsed; 92 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 93 : MinByte - Target.minBeforeBytes(); 94 95 // Disregard used regions that are smaller than Offset. These are 96 // effectively all-free regions that do not need to be checked. 97 if (VTUsed.size() > Offset) 98 Used.push_back(VTUsed.slice(Offset)); 99 } 100 101 if (Size == 1) { 102 // Find a free bit in each member of Used. 103 for (unsigned I = 0;; ++I) { 104 uint8_t BitsUsed = 0; 105 for (auto &&B : Used) 106 if (I < B.size()) 107 BitsUsed |= B[I]; 108 if (BitsUsed != 0xff) 109 return (MinByte + I) * 8 + 110 countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); 111 } 112 } else { 113 // Find a free (Size/8) byte region in each member of Used. 114 // FIXME: see if alignment helps. 115 for (unsigned I = 0;; ++I) { 116 for (auto &&B : Used) { 117 unsigned Byte = 0; 118 while ((I + Byte) < B.size() && Byte < (Size / 8)) { 119 if (B[I + Byte]) 120 goto NextI; 121 ++Byte; 122 } 123 } 124 return (MinByte + I) * 8; 125 NextI:; 126 } 127 } 128 } 129 130 void wholeprogramdevirt::setBeforeReturnValues( 131 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 132 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 133 if (BitWidth == 1) 134 OffsetByte = -(AllocBefore / 8 + 1); 135 else 136 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 137 OffsetBit = AllocBefore % 8; 138 139 for (VirtualCallTarget &Target : Targets) { 140 if (BitWidth == 1) 141 Target.setBeforeBit(AllocBefore); 142 else 143 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 144 } 145 } 146 147 void wholeprogramdevirt::setAfterReturnValues( 148 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 149 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 150 if (BitWidth == 1) 151 OffsetByte = AllocAfter / 8; 152 else 153 OffsetByte = (AllocAfter + 7) / 8; 154 OffsetBit = AllocAfter % 8; 155 156 for (VirtualCallTarget &Target : Targets) { 157 if (BitWidth == 1) 158 Target.setAfterBit(AllocAfter); 159 else 160 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 161 } 162 } 163 164 VirtualCallTarget::VirtualCallTarget(Function *Fn, const BitSetInfo *BS) 165 : Fn(Fn), BS(BS), 166 IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {} 167 168 namespace { 169 170 // A slot in a set of virtual tables. The BitSetID identifies the set of virtual 171 // tables, and the ByteOffset is the offset in bytes from the address point to 172 // the virtual function pointer. 173 struct VTableSlot { 174 Metadata *BitSetID; 175 uint64_t ByteOffset; 176 }; 177 178 } 179 180 namespace llvm { 181 182 template <> struct DenseMapInfo<VTableSlot> { 183 static VTableSlot getEmptyKey() { 184 return {DenseMapInfo<Metadata *>::getEmptyKey(), 185 DenseMapInfo<uint64_t>::getEmptyKey()}; 186 } 187 static VTableSlot getTombstoneKey() { 188 return {DenseMapInfo<Metadata *>::getTombstoneKey(), 189 DenseMapInfo<uint64_t>::getTombstoneKey()}; 190 } 191 static unsigned getHashValue(const VTableSlot &I) { 192 return DenseMapInfo<Metadata *>::getHashValue(I.BitSetID) ^ 193 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 194 } 195 static bool isEqual(const VTableSlot &LHS, 196 const VTableSlot &RHS) { 197 return LHS.BitSetID == RHS.BitSetID && LHS.ByteOffset == RHS.ByteOffset; 198 } 199 }; 200 201 } 202 203 namespace { 204 205 // A virtual call site. VTable is the loaded virtual table pointer, and CS is 206 // the indirect virtual call. 207 struct VirtualCallSite { 208 Value *VTable; 209 CallSite CS; 210 211 void replaceAndErase(Value *New) { 212 CS->replaceAllUsesWith(New); 213 if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { 214 BranchInst::Create(II->getNormalDest(), CS.getInstruction()); 215 II->getUnwindDest()->removePredecessor(II->getParent()); 216 } 217 CS->eraseFromParent(); 218 } 219 }; 220 221 struct DevirtModule { 222 Module &M; 223 IntegerType *Int8Ty; 224 PointerType *Int8PtrTy; 225 IntegerType *Int32Ty; 226 227 MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; 228 229 DevirtModule(Module &M) 230 : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), 231 Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 232 Int32Ty(Type::getInt32Ty(M.getContext())) {} 233 void findLoadCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, 234 uint64_t Offset, Value *VTable); 235 void findCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, uint64_t Offset, 236 Value *VTable); 237 238 void buildBitSets(std::vector<VTableBits> &Bits, 239 DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets); 240 bool tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 241 const std::set<BitSetInfo> &BitSetInfos, 242 uint64_t ByteOffset); 243 bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot, 244 MutableArrayRef<VirtualCallSite> CallSites); 245 bool tryEvaluateFunctionsWithArgs( 246 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 247 ArrayRef<ConstantInt *> Args); 248 bool tryUniformRetValOpt(IntegerType *RetType, 249 ArrayRef<VirtualCallTarget> TargetsForSlot, 250 MutableArrayRef<VirtualCallSite> CallSites); 251 bool tryUniqueRetValOpt(unsigned BitWidth, 252 ArrayRef<VirtualCallTarget> TargetsForSlot, 253 MutableArrayRef<VirtualCallSite> CallSites); 254 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 255 ArrayRef<VirtualCallSite> CallSites); 256 257 void rebuildGlobal(VTableBits &B); 258 259 bool run(); 260 }; 261 262 struct WholeProgramDevirt : public ModulePass { 263 static char ID; 264 WholeProgramDevirt() : ModulePass(ID) { 265 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); 266 } 267 bool runOnModule(Module &M) { return DevirtModule(M).run(); } 268 }; 269 270 } // anonymous namespace 271 272 INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", 273 "Whole program devirtualization", false, false) 274 char WholeProgramDevirt::ID = 0; 275 276 ModulePass *llvm::createWholeProgramDevirtPass() { 277 return new WholeProgramDevirt; 278 } 279 280 // Search for virtual calls that call FPtr and add them to CallSlots. 281 void DevirtModule::findCallsAtConstantOffset(Metadata *BitSet, Value *FPtr, 282 uint64_t Offset, Value *VTable) { 283 for (const Use &U : FPtr->uses()) { 284 Value *User = U.getUser(); 285 if (isa<BitCastInst>(User)) { 286 findCallsAtConstantOffset(BitSet, User, Offset, VTable); 287 } else if (auto CI = dyn_cast<CallInst>(User)) { 288 CallSlots[{BitSet, Offset}].push_back({VTable, CI}); 289 } else if (auto II = dyn_cast<InvokeInst>(User)) { 290 CallSlots[{BitSet, Offset}].push_back({VTable, II}); 291 } 292 } 293 } 294 295 // Search for virtual calls that load from VPtr and add them to CallSlots. 296 void DevirtModule::findLoadCallsAtConstantOffset(Metadata *BitSet, Value *VPtr, 297 uint64_t Offset, 298 Value *VTable) { 299 for (const Use &U : VPtr->uses()) { 300 Value *User = U.getUser(); 301 if (isa<BitCastInst>(User)) { 302 findLoadCallsAtConstantOffset(BitSet, User, Offset, VTable); 303 } else if (isa<LoadInst>(User)) { 304 findCallsAtConstantOffset(BitSet, User, Offset, VTable); 305 } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) { 306 // Take into account the GEP offset. 307 if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { 308 SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end()); 309 uint64_t GEPOffset = M.getDataLayout().getIndexedOffsetInType( 310 GEP->getSourceElementType(), Indices); 311 findLoadCallsAtConstantOffset(BitSet, User, Offset + GEPOffset, VTable); 312 } 313 } 314 } 315 } 316 317 void DevirtModule::buildBitSets( 318 std::vector<VTableBits> &Bits, 319 DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets) { 320 NamedMDNode *BitSetNM = M.getNamedMetadata("llvm.bitsets"); 321 if (!BitSetNM) 322 return; 323 324 DenseMap<GlobalVariable *, VTableBits *> GVToBits; 325 Bits.reserve(BitSetNM->getNumOperands()); 326 for (auto Op : BitSetNM->operands()) { 327 auto OpConstMD = dyn_cast_or_null<ConstantAsMetadata>(Op->getOperand(1)); 328 if (!OpConstMD) 329 continue; 330 auto BitSetID = Op->getOperand(0).get(); 331 332 Constant *OpConst = OpConstMD->getValue(); 333 if (auto GA = dyn_cast<GlobalAlias>(OpConst)) 334 OpConst = GA->getAliasee(); 335 auto OpGlobal = dyn_cast<GlobalVariable>(OpConst); 336 if (!OpGlobal) 337 continue; 338 339 uint64_t Offset = 340 cast<ConstantInt>( 341 cast<ConstantAsMetadata>(Op->getOperand(2))->getValue()) 342 ->getZExtValue(); 343 344 VTableBits *&BitsPtr = GVToBits[OpGlobal]; 345 if (!BitsPtr) { 346 Bits.emplace_back(); 347 Bits.back().GV = OpGlobal; 348 Bits.back().ObjectSize = M.getDataLayout().getTypeAllocSize( 349 OpGlobal->getInitializer()->getType()); 350 BitsPtr = &Bits.back(); 351 } 352 BitSets[BitSetID].insert({BitsPtr, Offset}); 353 } 354 } 355 356 bool DevirtModule::tryFindVirtualCallTargets( 357 std::vector<VirtualCallTarget> &TargetsForSlot, 358 const std::set<BitSetInfo> &BitSetInfos, uint64_t ByteOffset) { 359 for (const BitSetInfo &BS : BitSetInfos) { 360 if (!BS.Bits->GV->isConstant()) 361 return false; 362 363 auto Init = dyn_cast<ConstantArray>(BS.Bits->GV->getInitializer()); 364 if (!Init) 365 return false; 366 ArrayType *VTableTy = Init->getType(); 367 368 uint64_t ElemSize = 369 M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); 370 uint64_t GlobalSlotOffset = BS.Offset + ByteOffset; 371 if (GlobalSlotOffset % ElemSize != 0) 372 return false; 373 374 unsigned Op = GlobalSlotOffset / ElemSize; 375 if (Op >= Init->getNumOperands()) 376 return false; 377 378 auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts()); 379 if (!Fn) 380 return false; 381 382 // We can disregard __cxa_pure_virtual as a possible call target, as 383 // calls to pure virtuals are UB. 384 if (Fn->getName() == "__cxa_pure_virtual") 385 continue; 386 387 TargetsForSlot.push_back({Fn, &BS}); 388 } 389 390 // Give up if we couldn't find any targets. 391 return !TargetsForSlot.empty(); 392 } 393 394 bool DevirtModule::trySingleImplDevirt( 395 ArrayRef<VirtualCallTarget> TargetsForSlot, 396 MutableArrayRef<VirtualCallSite> CallSites) { 397 // See if the program contains a single implementation of this virtual 398 // function. 399 Function *TheFn = TargetsForSlot[0].Fn; 400 for (auto &&Target : TargetsForSlot) 401 if (TheFn != Target.Fn) 402 return false; 403 404 // If so, update each call site to call that implementation directly. 405 for (auto &&VCallSite : CallSites) { 406 VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( 407 TheFn, VCallSite.CS.getCalledValue()->getType())); 408 } 409 return true; 410 } 411 412 bool DevirtModule::tryEvaluateFunctionsWithArgs( 413 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 414 ArrayRef<ConstantInt *> Args) { 415 // Evaluate each function and store the result in each target's RetVal 416 // field. 417 for (VirtualCallTarget &Target : TargetsForSlot) { 418 if (Target.Fn->arg_size() != Args.size() + 1) 419 return false; 420 for (unsigned I = 0; I != Args.size(); ++I) 421 if (Target.Fn->getFunctionType()->getParamType(I + 1) != 422 Args[I]->getType()) 423 return false; 424 425 Evaluator Eval(M.getDataLayout(), nullptr); 426 SmallVector<Constant *, 2> EvalArgs; 427 EvalArgs.push_back( 428 Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); 429 EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); 430 Constant *RetVal; 431 if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || 432 !isa<ConstantInt>(RetVal)) 433 return false; 434 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 435 } 436 return true; 437 } 438 439 bool DevirtModule::tryUniformRetValOpt( 440 IntegerType *RetType, ArrayRef<VirtualCallTarget> TargetsForSlot, 441 MutableArrayRef<VirtualCallSite> CallSites) { 442 // Uniform return value optimization. If all functions return the same 443 // constant, replace all calls with that constant. 444 uint64_t TheRetVal = TargetsForSlot[0].RetVal; 445 for (const VirtualCallTarget &Target : TargetsForSlot) 446 if (Target.RetVal != TheRetVal) 447 return false; 448 449 auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); 450 for (auto Call : CallSites) 451 Call.replaceAndErase(TheRetValConst); 452 return true; 453 } 454 455 bool DevirtModule::tryUniqueRetValOpt( 456 unsigned BitWidth, ArrayRef<VirtualCallTarget> TargetsForSlot, 457 MutableArrayRef<VirtualCallSite> CallSites) { 458 // IsOne controls whether we look for a 0 or a 1. 459 auto tryUniqueRetValOptFor = [&](bool IsOne) { 460 const BitSetInfo *UniqueBitSet = 0; 461 for (const VirtualCallTarget &Target : TargetsForSlot) { 462 if (Target.RetVal == (IsOne ? 1 : 0)) { 463 if (UniqueBitSet) 464 return false; 465 UniqueBitSet = Target.BS; 466 } 467 } 468 469 // We should have found a unique bit set or bailed out by now. We already 470 // checked for a uniform return value in tryUniformRetValOpt. 471 assert(UniqueBitSet); 472 473 // Replace each call with the comparison. 474 for (auto &&Call : CallSites) { 475 IRBuilder<> B(Call.CS.getInstruction()); 476 Value *OneAddr = B.CreateBitCast(UniqueBitSet->Bits->GV, Int8PtrTy); 477 OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueBitSet->Offset); 478 Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, 479 Call.VTable, OneAddr); 480 Call.replaceAndErase(Cmp); 481 } 482 return true; 483 }; 484 485 if (BitWidth == 1) { 486 if (tryUniqueRetValOptFor(true)) 487 return true; 488 if (tryUniqueRetValOptFor(false)) 489 return true; 490 } 491 return false; 492 } 493 494 bool DevirtModule::tryVirtualConstProp( 495 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 496 ArrayRef<VirtualCallSite> CallSites) { 497 // This only works if the function returns an integer. 498 auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); 499 if (!RetType) 500 return false; 501 unsigned BitWidth = RetType->getBitWidth(); 502 if (BitWidth > 64) 503 return false; 504 505 // Make sure that each function does not access memory, takes at least one 506 // argument, does not use its first argument (which we assume is 'this'), 507 // and has the same return type. 508 for (VirtualCallTarget &Target : TargetsForSlot) { 509 if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() || 510 !Target.Fn->arg_begin()->use_empty() || 511 Target.Fn->getReturnType() != RetType) 512 return false; 513 } 514 515 // Group call sites by the list of constant arguments they pass. 516 // The comparator ensures deterministic ordering. 517 struct ByAPIntValue { 518 bool operator()(const std::vector<ConstantInt *> &A, 519 const std::vector<ConstantInt *> &B) const { 520 return std::lexicographical_compare( 521 A.begin(), A.end(), B.begin(), B.end(), 522 [](ConstantInt *AI, ConstantInt *BI) { 523 return AI->getValue().ult(BI->getValue()); 524 }); 525 } 526 }; 527 std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>, 528 ByAPIntValue> 529 VCallSitesByConstantArg; 530 for (auto &&VCallSite : CallSites) { 531 std::vector<ConstantInt *> Args; 532 if (VCallSite.CS.getType() != RetType) 533 continue; 534 for (auto &&Arg : 535 make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { 536 if (!isa<ConstantInt>(Arg)) 537 break; 538 Args.push_back(cast<ConstantInt>(&Arg)); 539 } 540 if (Args.size() + 1 != VCallSite.CS.arg_size()) 541 continue; 542 543 VCallSitesByConstantArg[Args].push_back(VCallSite); 544 } 545 546 for (auto &&CSByConstantArg : VCallSitesByConstantArg) { 547 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 548 continue; 549 550 if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) 551 continue; 552 553 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) 554 continue; 555 556 // Find an allocation offset in bits in all vtables in the bitset. 557 uint64_t AllocBefore = 558 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 559 uint64_t AllocAfter = 560 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 561 562 // Calculate the total amount of padding needed to store a value at both 563 // ends of the object. 564 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 565 for (auto &&Target : TargetsForSlot) { 566 TotalPaddingBefore += std::max<int64_t>( 567 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 568 TotalPaddingAfter += std::max<int64_t>( 569 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 570 } 571 572 // If the amount of padding is too large, give up. 573 // FIXME: do something smarter here. 574 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 575 continue; 576 577 // Calculate the offset to the value as a (possibly negative) byte offset 578 // and (if applicable) a bit offset, and store the values in the targets. 579 int64_t OffsetByte; 580 uint64_t OffsetBit; 581 if (TotalPaddingBefore <= TotalPaddingAfter) 582 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 583 OffsetBit); 584 else 585 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 586 OffsetBit); 587 588 // Rewrite each call to a load from OffsetByte/OffsetBit. 589 for (auto Call : CSByConstantArg.second) { 590 IRBuilder<> B(Call.CS.getInstruction()); 591 Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); 592 if (BitWidth == 1) { 593 Value *Bits = B.CreateLoad(Addr); 594 Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 595 Value *BitsAndBit = B.CreateAnd(Bits, Bit); 596 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 597 Call.replaceAndErase(IsBitSet); 598 } else { 599 Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); 600 Value *Val = B.CreateLoad(RetType, ValAddr); 601 Call.replaceAndErase(Val); 602 } 603 } 604 } 605 return true; 606 } 607 608 void DevirtModule::rebuildGlobal(VTableBits &B) { 609 if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 610 return; 611 612 // Align each byte array to pointer width. 613 unsigned PointerSize = M.getDataLayout().getPointerSize(); 614 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); 615 B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); 616 617 // Before was stored in reverse order; flip it now. 618 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 619 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 620 621 // Build an anonymous global containing the before bytes, followed by the 622 // original initializer, followed by the after bytes. 623 auto NewInit = ConstantStruct::getAnon( 624 {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 625 B.GV->getInitializer(), 626 ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 627 auto NewGV = 628 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 629 GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 630 NewGV->setSection(B.GV->getSection()); 631 NewGV->setComdat(B.GV->getComdat()); 632 633 // Build an alias named after the original global, pointing at the second 634 // element (the original initializer). 635 auto Alias = GlobalAlias::create( 636 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 637 ConstantExpr::getGetElementPtr( 638 NewInit->getType(), NewGV, 639 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 640 ConstantInt::get(Int32Ty, 1)}), 641 &M); 642 Alias->setVisibility(B.GV->getVisibility()); 643 Alias->takeName(B.GV); 644 645 B.GV->replaceAllUsesWith(Alias); 646 B.GV->eraseFromParent(); 647 } 648 649 bool DevirtModule::run() { 650 Function *BitSetTestFunc = 651 M.getFunction(Intrinsic::getName(Intrinsic::bitset_test)); 652 if (!BitSetTestFunc || BitSetTestFunc->use_empty()) 653 return false; 654 655 Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 656 if (!AssumeFunc || AssumeFunc->use_empty()) 657 return false; 658 659 // Find all virtual calls via a virtual table pointer %p under an assumption 660 // of the form llvm.assume(llvm.bitset.test(%p, %md)). This indicates that %p 661 // points to a vtable in the bitset %md. Group calls by (bitset, offset) pair 662 // (effectively the identity of the virtual function) and store to CallSlots. 663 DenseSet<Value *> SeenPtrs; 664 for (auto I = BitSetTestFunc->use_begin(), E = BitSetTestFunc->use_end(); 665 I != E;) { 666 auto CI = dyn_cast<CallInst>(I->getUser()); 667 ++I; 668 if (!CI) 669 continue; 670 671 // Find llvm.assume intrinsics for this llvm.bitset.test call. 672 SmallVector<CallInst *, 1> Assumes; 673 for (const Use &CIU : CI->uses()) { 674 auto AssumeCI = dyn_cast<CallInst>(CIU.getUser()); 675 if (AssumeCI && AssumeCI->getCalledValue() == AssumeFunc) 676 Assumes.push_back(AssumeCI); 677 } 678 679 // If we found any, search for virtual calls based on %p and add them to 680 // CallSlots. 681 if (!Assumes.empty()) { 682 Metadata *BitSet = 683 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 684 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 685 if (SeenPtrs.insert(Ptr).second) 686 findLoadCallsAtConstantOffset(BitSet, Ptr, 0, CI->getArgOperand(0)); 687 } 688 689 // We no longer need the assumes or the bitset test. 690 for (auto Assume : Assumes) 691 Assume->eraseFromParent(); 692 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 693 // may use the vtable argument later. 694 if (CI->use_empty()) 695 CI->eraseFromParent(); 696 } 697 698 // Rebuild llvm.bitsets metadata into a map for easy lookup. 699 std::vector<VTableBits> Bits; 700 DenseMap<Metadata *, std::set<BitSetInfo>> BitSets; 701 buildBitSets(Bits, BitSets); 702 if (BitSets.empty()) 703 return true; 704 705 // For each (bitset, offset) pair: 706 bool DidVirtualConstProp = false; 707 for (auto &S : CallSlots) { 708 // Search each of the vtables in the bitset for the virtual function 709 // implementation at offset S.first.ByteOffset, and add to TargetsForSlot. 710 std::vector<VirtualCallTarget> TargetsForSlot; 711 if (!tryFindVirtualCallTargets(TargetsForSlot, BitSets[S.first.BitSetID], 712 S.first.ByteOffset)) 713 continue; 714 715 if (trySingleImplDevirt(TargetsForSlot, S.second)) 716 continue; 717 718 DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); 719 } 720 721 // Rebuild each global we touched as part of virtual constant propagation to 722 // include the before and after bytes. 723 if (DidVirtualConstProp) 724 for (VTableBits &B : Bits) 725 rebuildGlobal(B); 726 727 return true; 728 } 729