17efd7506SPeter Collingbourne //===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===//
27efd7506SPeter Collingbourne //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67efd7506SPeter Collingbourne //
77efd7506SPeter Collingbourne //===----------------------------------------------------------------------===//
87efd7506SPeter Collingbourne //
97efd7506SPeter Collingbourne // This file contains functions that make it easier to manipulate type metadata
107efd7506SPeter Collingbourne // for devirtualization.
117efd7506SPeter Collingbourne //
127efd7506SPeter Collingbourne //===----------------------------------------------------------------------===//
137efd7506SPeter Collingbourne 
147efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h"
150312f614SPeter Collingbourne #include "llvm/IR/Constants.h"
16f24136f1STeresa Johnson #include "llvm/IR/Dominators.h"
17ea0880ddSSimon Pilgrim #include "llvm/IR/Instructions.h"
18908215b3SPhilip Reames #include "llvm/IR/IntrinsicInst.h"
197efd7506SPeter Collingbourne #include "llvm/IR/Module.h"
207efd7506SPeter Collingbourne 
217efd7506SPeter Collingbourne using namespace llvm;
227efd7506SPeter Collingbourne 
237efd7506SPeter Collingbourne // Search for virtual calls that call FPtr and add them to DevirtCalls.
247efd7506SPeter Collingbourne static void
findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> & DevirtCalls,bool * HasNonCallUses,Value * FPtr,uint64_t Offset,const CallInst * CI,DominatorTree & DT)257efd7506SPeter Collingbourne findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
26f24136f1STeresa Johnson                           bool *HasNonCallUses, Value *FPtr, uint64_t Offset,
27f24136f1STeresa Johnson                           const CallInst *CI, DominatorTree &DT) {
287efd7506SPeter Collingbourne   for (const Use &U : FPtr->uses()) {
29f24136f1STeresa Johnson     Instruction *User = cast<Instruction>(U.getUser());
30f24136f1STeresa Johnson     // Ignore this instruction if it is not dominated by the type intrinsic
31f24136f1STeresa Johnson     // being analyzed. Otherwise we may transform a call sharing the same
32f24136f1STeresa Johnson     // vtable pointer incorrectly. Specifically, this situation can arise
33f24136f1STeresa Johnson     // after indirect call promotion and inlining, where we may have uses
34f24136f1STeresa Johnson     // of the vtable pointer guarded by a function pointer check, and a fallback
35f24136f1STeresa Johnson     // indirect call.
36f24136f1STeresa Johnson     if (!DT.dominates(CI, User))
37f24136f1STeresa Johnson       continue;
387efd7506SPeter Collingbourne     if (isa<BitCastInst>(User)) {
39f24136f1STeresa Johnson       findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
40f24136f1STeresa Johnson                                 DT);
41cea6f4d5SMircea Trofin     } else if (auto *CI = dyn_cast<CallInst>(User)) {
42cea6f4d5SMircea Trofin       DevirtCalls.push_back({Offset, *CI});
43cea6f4d5SMircea Trofin     } else if (auto *II = dyn_cast<InvokeInst>(User)) {
44cea6f4d5SMircea Trofin       DevirtCalls.push_back({Offset, *II});
450312f614SPeter Collingbourne     } else if (HasNonCallUses) {
460312f614SPeter Collingbourne       *HasNonCallUses = true;
477efd7506SPeter Collingbourne     }
487efd7506SPeter Collingbourne   }
497efd7506SPeter Collingbourne }
507efd7506SPeter Collingbourne 
517efd7506SPeter Collingbourne // Search for virtual calls that load from VPtr and add them to DevirtCalls.
findLoadCallsAtConstantOffset(const Module * M,SmallVectorImpl<DevirtCallSite> & DevirtCalls,Value * VPtr,int64_t Offset,const CallInst * CI,DominatorTree & DT)52f24136f1STeresa Johnson static void findLoadCallsAtConstantOffset(
53f24136f1STeresa Johnson     const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr,
54f24136f1STeresa Johnson     int64_t Offset, const CallInst *CI, DominatorTree &DT) {
557efd7506SPeter Collingbourne   for (const Use &U : VPtr->uses()) {
567efd7506SPeter Collingbourne     Value *User = U.getUser();
577efd7506SPeter Collingbourne     if (isa<BitCastInst>(User)) {
58f24136f1STeresa Johnson       findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT);
597efd7506SPeter Collingbourne     } else if (isa<LoadInst>(User)) {
60f24136f1STeresa Johnson       findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT);
617efd7506SPeter Collingbourne     } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
627efd7506SPeter Collingbourne       // Take into account the GEP offset.
637efd7506SPeter Collingbourne       if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
647efd7506SPeter Collingbourne         SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
6517bdf445SDavid Majnemer         int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType(
667efd7506SPeter Collingbourne             GEP->getSourceElementType(), Indices);
67f24136f1STeresa Johnson         findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset,
68f24136f1STeresa Johnson                                       CI, DT);
697efd7506SPeter Collingbourne       }
707efd7506SPeter Collingbourne     }
717efd7506SPeter Collingbourne   }
727efd7506SPeter Collingbourne }
737efd7506SPeter Collingbourne 
findDevirtualizableCallsForTypeTest(SmallVectorImpl<DevirtCallSite> & DevirtCalls,SmallVectorImpl<CallInst * > & Assumes,const CallInst * CI,DominatorTree & DT)740312f614SPeter Collingbourne void llvm::findDevirtualizableCallsForTypeTest(
757efd7506SPeter Collingbourne     SmallVectorImpl<DevirtCallSite> &DevirtCalls,
76f24136f1STeresa Johnson     SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI,
77f24136f1STeresa Johnson     DominatorTree &DT) {
78*2eade1dbSArthur Eubanks   assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test ||
79*2eade1dbSArthur Eubanks          CI->getCalledFunction()->getIntrinsicID() ==
80*2eade1dbSArthur Eubanks              Intrinsic::public_type_test);
817efd7506SPeter Collingbourne 
825ad775f2SPeter Collingbourne   const Module *M = CI->getParent()->getParent()->getParent();
837efd7506SPeter Collingbourne 
847efd7506SPeter Collingbourne   // Find llvm.assume intrinsics for this llvm.type.test call.
85908215b3SPhilip Reames   for (const Use &CIU : CI->uses())
86908215b3SPhilip Reames     if (auto *Assume = dyn_cast<AssumeInst>(CIU.getUser()))
874824d876SPhilip Reames       Assumes.push_back(Assume);
887efd7506SPeter Collingbourne 
897efd7506SPeter Collingbourne   // If we found any, search for virtual calls based on %p and add them to
907efd7506SPeter Collingbourne   // DevirtCalls.
917efd7506SPeter Collingbourne   if (!Assumes.empty())
92f24136f1STeresa Johnson     findLoadCallsAtConstantOffset(
93f24136f1STeresa Johnson         M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT);
947efd7506SPeter Collingbourne }
950312f614SPeter Collingbourne 
findDevirtualizableCallsForTypeCheckedLoad(SmallVectorImpl<DevirtCallSite> & DevirtCalls,SmallVectorImpl<Instruction * > & LoadedPtrs,SmallVectorImpl<Instruction * > & Preds,bool & HasNonCallUses,const CallInst * CI,DominatorTree & DT)960312f614SPeter Collingbourne void llvm::findDevirtualizableCallsForTypeCheckedLoad(
970312f614SPeter Collingbourne     SmallVectorImpl<DevirtCallSite> &DevirtCalls,
980312f614SPeter Collingbourne     SmallVectorImpl<Instruction *> &LoadedPtrs,
995ad775f2SPeter Collingbourne     SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses,
100f24136f1STeresa Johnson     const CallInst *CI, DominatorTree &DT) {
1010312f614SPeter Collingbourne   assert(CI->getCalledFunction()->getIntrinsicID() ==
1020312f614SPeter Collingbourne          Intrinsic::type_checked_load);
1030312f614SPeter Collingbourne 
1040312f614SPeter Collingbourne   auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
1050312f614SPeter Collingbourne   if (!Offset) {
1060312f614SPeter Collingbourne     HasNonCallUses = true;
1070312f614SPeter Collingbourne     return;
1080312f614SPeter Collingbourne   }
1090312f614SPeter Collingbourne 
1105ad775f2SPeter Collingbourne   for (const Use &U : CI->uses()) {
1110312f614SPeter Collingbourne     auto CIU = U.getUser();
1120312f614SPeter Collingbourne     if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) {
1130312f614SPeter Collingbourne       if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) {
1140312f614SPeter Collingbourne         LoadedPtrs.push_back(EVI);
1150312f614SPeter Collingbourne         continue;
1160312f614SPeter Collingbourne       }
1170312f614SPeter Collingbourne       if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) {
1180312f614SPeter Collingbourne         Preds.push_back(EVI);
1190312f614SPeter Collingbourne         continue;
1200312f614SPeter Collingbourne       }
1210312f614SPeter Collingbourne     }
1220312f614SPeter Collingbourne     HasNonCallUses = true;
1230312f614SPeter Collingbourne   }
1240312f614SPeter Collingbourne 
1250312f614SPeter Collingbourne   for (Value *LoadedPtr : LoadedPtrs)
1260312f614SPeter Collingbourne     findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr,
127f24136f1STeresa Johnson                               Offset->getZExtValue(), CI, DT);
1280312f614SPeter Collingbourne }
1293b598b9cSOliver Stannard 
getPointerAtOffset(Constant * I,uint64_t Offset,Module & M,Constant * TopLevelGlobal)1304c066bd0SKuba Mracek Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
1314c066bd0SKuba Mracek                                    Constant *TopLevelGlobal) {
1323b598b9cSOliver Stannard   if (I->getType()->isPointerTy()) {
1333b598b9cSOliver Stannard     if (Offset == 0)
1343b598b9cSOliver Stannard       return I;
1353b598b9cSOliver Stannard     return nullptr;
1363b598b9cSOliver Stannard   }
1373b598b9cSOliver Stannard 
1383b598b9cSOliver Stannard   const DataLayout &DL = M.getDataLayout();
1393b598b9cSOliver Stannard 
1403b598b9cSOliver Stannard   if (auto *C = dyn_cast<ConstantStruct>(I)) {
1413b598b9cSOliver Stannard     const StructLayout *SL = DL.getStructLayout(C->getType());
1423b598b9cSOliver Stannard     if (Offset >= SL->getSizeInBytes())
1433b598b9cSOliver Stannard       return nullptr;
1443b598b9cSOliver Stannard 
1453b598b9cSOliver Stannard     unsigned Op = SL->getElementContainingOffset(Offset);
1463b598b9cSOliver Stannard     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
1474c066bd0SKuba Mracek                               Offset - SL->getElementOffset(Op), M,
1484c066bd0SKuba Mracek                               TopLevelGlobal);
1493b598b9cSOliver Stannard   }
1503b598b9cSOliver Stannard   if (auto *C = dyn_cast<ConstantArray>(I)) {
1513b598b9cSOliver Stannard     ArrayType *VTableTy = C->getType();
1523b598b9cSOliver Stannard     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
1533b598b9cSOliver Stannard 
1543b598b9cSOliver Stannard     unsigned Op = Offset / ElemSize;
1553b598b9cSOliver Stannard     if (Op >= C->getNumOperands())
1563b598b9cSOliver Stannard       return nullptr;
1573b598b9cSOliver Stannard 
1583b598b9cSOliver Stannard     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
1594c066bd0SKuba Mracek                               Offset % ElemSize, M, TopLevelGlobal);
1604c066bd0SKuba Mracek   }
1614c066bd0SKuba Mracek 
1624c066bd0SKuba Mracek   // (Swift-specific) relative-pointer support starts here.
1634c066bd0SKuba Mracek   if (auto *CI = dyn_cast<ConstantInt>(I)) {
1644c066bd0SKuba Mracek     if (Offset == 0 && CI->getZExtValue() == 0) {
1654c066bd0SKuba Mracek       return I;
1664c066bd0SKuba Mracek     }
1674c066bd0SKuba Mracek   }
1684c066bd0SKuba Mracek   if (auto *C = dyn_cast<ConstantExpr>(I)) {
1694c066bd0SKuba Mracek     switch (C->getOpcode()) {
1704c066bd0SKuba Mracek     case Instruction::Trunc:
1714c066bd0SKuba Mracek     case Instruction::PtrToInt:
1724c066bd0SKuba Mracek       return getPointerAtOffset(cast<Constant>(C->getOperand(0)), Offset, M,
1734c066bd0SKuba Mracek                                 TopLevelGlobal);
1744c066bd0SKuba Mracek     case Instruction::Sub: {
1754c066bd0SKuba Mracek       auto *Operand0 = cast<Constant>(C->getOperand(0));
1764c066bd0SKuba Mracek       auto *Operand1 = cast<Constant>(C->getOperand(1));
177e80ee4cbSKuba Mracek 
178e80ee4cbSKuba Mracek       auto StripGEP = [](Constant *C) {
179e80ee4cbSKuba Mracek         auto *CE = dyn_cast<ConstantExpr>(C);
180e80ee4cbSKuba Mracek         if (!CE)
181e80ee4cbSKuba Mracek           return C;
182e80ee4cbSKuba Mracek         if (CE->getOpcode() != Instruction::GetElementPtr)
183e80ee4cbSKuba Mracek           return C;
184e80ee4cbSKuba Mracek         return CE->getOperand(0);
185e80ee4cbSKuba Mracek       };
186e80ee4cbSKuba Mracek       auto *Operand1TargetGlobal = StripGEP(getPointerAtOffset(Operand1, 0, M));
1874c066bd0SKuba Mracek 
1884c066bd0SKuba Mracek       // Check that in the "sub (@a, @b)" expression, @b points back to the top
189e80ee4cbSKuba Mracek       // level global (or a GEP thereof) that we're processing. Otherwise bail.
1904c066bd0SKuba Mracek       if (Operand1TargetGlobal != TopLevelGlobal)
1914c066bd0SKuba Mracek         return nullptr;
1924c066bd0SKuba Mracek 
1934c066bd0SKuba Mracek       return getPointerAtOffset(Operand0, Offset, M, TopLevelGlobal);
1944c066bd0SKuba Mracek     }
1954c066bd0SKuba Mracek     default:
1964c066bd0SKuba Mracek       return nullptr;
1974c066bd0SKuba Mracek     }
1983b598b9cSOliver Stannard   }
1993b598b9cSOliver Stannard   return nullptr;
2003b598b9cSOliver Stannard }
2017329abf2SKuba Mracek 
replaceRelativePointerUsersWithZero(Function * F)2027329abf2SKuba Mracek void llvm::replaceRelativePointerUsersWithZero(Function *F) {
2037329abf2SKuba Mracek   for (auto *U : F->users()) {
2047329abf2SKuba Mracek     auto *PtrExpr = dyn_cast<ConstantExpr>(U);
2057329abf2SKuba Mracek     if (!PtrExpr || PtrExpr->getOpcode() != Instruction::PtrToInt)
2067329abf2SKuba Mracek       continue;
2077329abf2SKuba Mracek 
2087329abf2SKuba Mracek     for (auto *PtrToIntUser : PtrExpr->users()) {
2097329abf2SKuba Mracek       auto *SubExpr = dyn_cast<ConstantExpr>(PtrToIntUser);
2107329abf2SKuba Mracek       if (!SubExpr || SubExpr->getOpcode() != Instruction::Sub)
2117329abf2SKuba Mracek         continue;
2127329abf2SKuba Mracek 
2137329abf2SKuba Mracek       SubExpr->replaceNonMetadataUsesWith(
2147329abf2SKuba Mracek           ConstantInt::get(SubExpr->getType(), 0));
2157329abf2SKuba Mracek     }
2167329abf2SKuba Mracek   }
2177329abf2SKuba Mracek }
218