17efd7506SPeter Collingbourne //===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===//
27efd7506SPeter Collingbourne //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67efd7506SPeter Collingbourne //
77efd7506SPeter Collingbourne //===----------------------------------------------------------------------===//
87efd7506SPeter Collingbourne //
97efd7506SPeter Collingbourne // This file contains functions that make it easier to manipulate type metadata
107efd7506SPeter Collingbourne // for devirtualization.
117efd7506SPeter Collingbourne //
127efd7506SPeter Collingbourne //===----------------------------------------------------------------------===//
137efd7506SPeter Collingbourne
147efd7506SPeter Collingbourne #include "llvm/Analysis/TypeMetadataUtils.h"
150312f614SPeter Collingbourne #include "llvm/IR/Constants.h"
16f24136f1STeresa Johnson #include "llvm/IR/Dominators.h"
17ea0880ddSSimon Pilgrim #include "llvm/IR/Instructions.h"
18908215b3SPhilip Reames #include "llvm/IR/IntrinsicInst.h"
197efd7506SPeter Collingbourne #include "llvm/IR/Module.h"
207efd7506SPeter Collingbourne
217efd7506SPeter Collingbourne using namespace llvm;
227efd7506SPeter Collingbourne
237efd7506SPeter Collingbourne // Search for virtual calls that call FPtr and add them to DevirtCalls.
247efd7506SPeter Collingbourne static void
findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> & DevirtCalls,bool * HasNonCallUses,Value * FPtr,uint64_t Offset,const CallInst * CI,DominatorTree & DT)257efd7506SPeter Collingbourne findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
26f24136f1STeresa Johnson bool *HasNonCallUses, Value *FPtr, uint64_t Offset,
27f24136f1STeresa Johnson const CallInst *CI, DominatorTree &DT) {
287efd7506SPeter Collingbourne for (const Use &U : FPtr->uses()) {
29f24136f1STeresa Johnson Instruction *User = cast<Instruction>(U.getUser());
30f24136f1STeresa Johnson // Ignore this instruction if it is not dominated by the type intrinsic
31f24136f1STeresa Johnson // being analyzed. Otherwise we may transform a call sharing the same
32f24136f1STeresa Johnson // vtable pointer incorrectly. Specifically, this situation can arise
33f24136f1STeresa Johnson // after indirect call promotion and inlining, where we may have uses
34f24136f1STeresa Johnson // of the vtable pointer guarded by a function pointer check, and a fallback
35f24136f1STeresa Johnson // indirect call.
36f24136f1STeresa Johnson if (!DT.dominates(CI, User))
37f24136f1STeresa Johnson continue;
387efd7506SPeter Collingbourne if (isa<BitCastInst>(User)) {
39f24136f1STeresa Johnson findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
40f24136f1STeresa Johnson DT);
41cea6f4d5SMircea Trofin } else if (auto *CI = dyn_cast<CallInst>(User)) {
42cea6f4d5SMircea Trofin DevirtCalls.push_back({Offset, *CI});
43cea6f4d5SMircea Trofin } else if (auto *II = dyn_cast<InvokeInst>(User)) {
44cea6f4d5SMircea Trofin DevirtCalls.push_back({Offset, *II});
450312f614SPeter Collingbourne } else if (HasNonCallUses) {
460312f614SPeter Collingbourne *HasNonCallUses = true;
477efd7506SPeter Collingbourne }
487efd7506SPeter Collingbourne }
497efd7506SPeter Collingbourne }
507efd7506SPeter Collingbourne
517efd7506SPeter Collingbourne // Search for virtual calls that load from VPtr and add them to DevirtCalls.
findLoadCallsAtConstantOffset(const Module * M,SmallVectorImpl<DevirtCallSite> & DevirtCalls,Value * VPtr,int64_t Offset,const CallInst * CI,DominatorTree & DT)52f24136f1STeresa Johnson static void findLoadCallsAtConstantOffset(
53f24136f1STeresa Johnson const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr,
54f24136f1STeresa Johnson int64_t Offset, const CallInst *CI, DominatorTree &DT) {
557efd7506SPeter Collingbourne for (const Use &U : VPtr->uses()) {
567efd7506SPeter Collingbourne Value *User = U.getUser();
577efd7506SPeter Collingbourne if (isa<BitCastInst>(User)) {
58f24136f1STeresa Johnson findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT);
597efd7506SPeter Collingbourne } else if (isa<LoadInst>(User)) {
60f24136f1STeresa Johnson findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT);
617efd7506SPeter Collingbourne } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
627efd7506SPeter Collingbourne // Take into account the GEP offset.
637efd7506SPeter Collingbourne if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
647efd7506SPeter Collingbourne SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
6517bdf445SDavid Majnemer int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType(
667efd7506SPeter Collingbourne GEP->getSourceElementType(), Indices);
67f24136f1STeresa Johnson findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset,
68f24136f1STeresa Johnson CI, DT);
697efd7506SPeter Collingbourne }
707efd7506SPeter Collingbourne }
717efd7506SPeter Collingbourne }
727efd7506SPeter Collingbourne }
737efd7506SPeter Collingbourne
findDevirtualizableCallsForTypeTest(SmallVectorImpl<DevirtCallSite> & DevirtCalls,SmallVectorImpl<CallInst * > & Assumes,const CallInst * CI,DominatorTree & DT)740312f614SPeter Collingbourne void llvm::findDevirtualizableCallsForTypeTest(
757efd7506SPeter Collingbourne SmallVectorImpl<DevirtCallSite> &DevirtCalls,
76f24136f1STeresa Johnson SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI,
77f24136f1STeresa Johnson DominatorTree &DT) {
78*2eade1dbSArthur Eubanks assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test ||
79*2eade1dbSArthur Eubanks CI->getCalledFunction()->getIntrinsicID() ==
80*2eade1dbSArthur Eubanks Intrinsic::public_type_test);
817efd7506SPeter Collingbourne
825ad775f2SPeter Collingbourne const Module *M = CI->getParent()->getParent()->getParent();
837efd7506SPeter Collingbourne
847efd7506SPeter Collingbourne // Find llvm.assume intrinsics for this llvm.type.test call.
85908215b3SPhilip Reames for (const Use &CIU : CI->uses())
86908215b3SPhilip Reames if (auto *Assume = dyn_cast<AssumeInst>(CIU.getUser()))
874824d876SPhilip Reames Assumes.push_back(Assume);
887efd7506SPeter Collingbourne
897efd7506SPeter Collingbourne // If we found any, search for virtual calls based on %p and add them to
907efd7506SPeter Collingbourne // DevirtCalls.
917efd7506SPeter Collingbourne if (!Assumes.empty())
92f24136f1STeresa Johnson findLoadCallsAtConstantOffset(
93f24136f1STeresa Johnson M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT);
947efd7506SPeter Collingbourne }
950312f614SPeter Collingbourne
findDevirtualizableCallsForTypeCheckedLoad(SmallVectorImpl<DevirtCallSite> & DevirtCalls,SmallVectorImpl<Instruction * > & LoadedPtrs,SmallVectorImpl<Instruction * > & Preds,bool & HasNonCallUses,const CallInst * CI,DominatorTree & DT)960312f614SPeter Collingbourne void llvm::findDevirtualizableCallsForTypeCheckedLoad(
970312f614SPeter Collingbourne SmallVectorImpl<DevirtCallSite> &DevirtCalls,
980312f614SPeter Collingbourne SmallVectorImpl<Instruction *> &LoadedPtrs,
995ad775f2SPeter Collingbourne SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses,
100f24136f1STeresa Johnson const CallInst *CI, DominatorTree &DT) {
1010312f614SPeter Collingbourne assert(CI->getCalledFunction()->getIntrinsicID() ==
1020312f614SPeter Collingbourne Intrinsic::type_checked_load);
1030312f614SPeter Collingbourne
1040312f614SPeter Collingbourne auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
1050312f614SPeter Collingbourne if (!Offset) {
1060312f614SPeter Collingbourne HasNonCallUses = true;
1070312f614SPeter Collingbourne return;
1080312f614SPeter Collingbourne }
1090312f614SPeter Collingbourne
1105ad775f2SPeter Collingbourne for (const Use &U : CI->uses()) {
1110312f614SPeter Collingbourne auto CIU = U.getUser();
1120312f614SPeter Collingbourne if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) {
1130312f614SPeter Collingbourne if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) {
1140312f614SPeter Collingbourne LoadedPtrs.push_back(EVI);
1150312f614SPeter Collingbourne continue;
1160312f614SPeter Collingbourne }
1170312f614SPeter Collingbourne if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) {
1180312f614SPeter Collingbourne Preds.push_back(EVI);
1190312f614SPeter Collingbourne continue;
1200312f614SPeter Collingbourne }
1210312f614SPeter Collingbourne }
1220312f614SPeter Collingbourne HasNonCallUses = true;
1230312f614SPeter Collingbourne }
1240312f614SPeter Collingbourne
1250312f614SPeter Collingbourne for (Value *LoadedPtr : LoadedPtrs)
1260312f614SPeter Collingbourne findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr,
127f24136f1STeresa Johnson Offset->getZExtValue(), CI, DT);
1280312f614SPeter Collingbourne }
1293b598b9cSOliver Stannard
getPointerAtOffset(Constant * I,uint64_t Offset,Module & M,Constant * TopLevelGlobal)1304c066bd0SKuba Mracek Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
1314c066bd0SKuba Mracek Constant *TopLevelGlobal) {
1323b598b9cSOliver Stannard if (I->getType()->isPointerTy()) {
1333b598b9cSOliver Stannard if (Offset == 0)
1343b598b9cSOliver Stannard return I;
1353b598b9cSOliver Stannard return nullptr;
1363b598b9cSOliver Stannard }
1373b598b9cSOliver Stannard
1383b598b9cSOliver Stannard const DataLayout &DL = M.getDataLayout();
1393b598b9cSOliver Stannard
1403b598b9cSOliver Stannard if (auto *C = dyn_cast<ConstantStruct>(I)) {
1413b598b9cSOliver Stannard const StructLayout *SL = DL.getStructLayout(C->getType());
1423b598b9cSOliver Stannard if (Offset >= SL->getSizeInBytes())
1433b598b9cSOliver Stannard return nullptr;
1443b598b9cSOliver Stannard
1453b598b9cSOliver Stannard unsigned Op = SL->getElementContainingOffset(Offset);
1463b598b9cSOliver Stannard return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
1474c066bd0SKuba Mracek Offset - SL->getElementOffset(Op), M,
1484c066bd0SKuba Mracek TopLevelGlobal);
1493b598b9cSOliver Stannard }
1503b598b9cSOliver Stannard if (auto *C = dyn_cast<ConstantArray>(I)) {
1513b598b9cSOliver Stannard ArrayType *VTableTy = C->getType();
1523b598b9cSOliver Stannard uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
1533b598b9cSOliver Stannard
1543b598b9cSOliver Stannard unsigned Op = Offset / ElemSize;
1553b598b9cSOliver Stannard if (Op >= C->getNumOperands())
1563b598b9cSOliver Stannard return nullptr;
1573b598b9cSOliver Stannard
1583b598b9cSOliver Stannard return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
1594c066bd0SKuba Mracek Offset % ElemSize, M, TopLevelGlobal);
1604c066bd0SKuba Mracek }
1614c066bd0SKuba Mracek
1624c066bd0SKuba Mracek // (Swift-specific) relative-pointer support starts here.
1634c066bd0SKuba Mracek if (auto *CI = dyn_cast<ConstantInt>(I)) {
1644c066bd0SKuba Mracek if (Offset == 0 && CI->getZExtValue() == 0) {
1654c066bd0SKuba Mracek return I;
1664c066bd0SKuba Mracek }
1674c066bd0SKuba Mracek }
1684c066bd0SKuba Mracek if (auto *C = dyn_cast<ConstantExpr>(I)) {
1694c066bd0SKuba Mracek switch (C->getOpcode()) {
1704c066bd0SKuba Mracek case Instruction::Trunc:
1714c066bd0SKuba Mracek case Instruction::PtrToInt:
1724c066bd0SKuba Mracek return getPointerAtOffset(cast<Constant>(C->getOperand(0)), Offset, M,
1734c066bd0SKuba Mracek TopLevelGlobal);
1744c066bd0SKuba Mracek case Instruction::Sub: {
1754c066bd0SKuba Mracek auto *Operand0 = cast<Constant>(C->getOperand(0));
1764c066bd0SKuba Mracek auto *Operand1 = cast<Constant>(C->getOperand(1));
177e80ee4cbSKuba Mracek
178e80ee4cbSKuba Mracek auto StripGEP = [](Constant *C) {
179e80ee4cbSKuba Mracek auto *CE = dyn_cast<ConstantExpr>(C);
180e80ee4cbSKuba Mracek if (!CE)
181e80ee4cbSKuba Mracek return C;
182e80ee4cbSKuba Mracek if (CE->getOpcode() != Instruction::GetElementPtr)
183e80ee4cbSKuba Mracek return C;
184e80ee4cbSKuba Mracek return CE->getOperand(0);
185e80ee4cbSKuba Mracek };
186e80ee4cbSKuba Mracek auto *Operand1TargetGlobal = StripGEP(getPointerAtOffset(Operand1, 0, M));
1874c066bd0SKuba Mracek
1884c066bd0SKuba Mracek // Check that in the "sub (@a, @b)" expression, @b points back to the top
189e80ee4cbSKuba Mracek // level global (or a GEP thereof) that we're processing. Otherwise bail.
1904c066bd0SKuba Mracek if (Operand1TargetGlobal != TopLevelGlobal)
1914c066bd0SKuba Mracek return nullptr;
1924c066bd0SKuba Mracek
1934c066bd0SKuba Mracek return getPointerAtOffset(Operand0, Offset, M, TopLevelGlobal);
1944c066bd0SKuba Mracek }
1954c066bd0SKuba Mracek default:
1964c066bd0SKuba Mracek return nullptr;
1974c066bd0SKuba Mracek }
1983b598b9cSOliver Stannard }
1993b598b9cSOliver Stannard return nullptr;
2003b598b9cSOliver Stannard }
2017329abf2SKuba Mracek
replaceRelativePointerUsersWithZero(Function * F)2027329abf2SKuba Mracek void llvm::replaceRelativePointerUsersWithZero(Function *F) {
2037329abf2SKuba Mracek for (auto *U : F->users()) {
2047329abf2SKuba Mracek auto *PtrExpr = dyn_cast<ConstantExpr>(U);
2057329abf2SKuba Mracek if (!PtrExpr || PtrExpr->getOpcode() != Instruction::PtrToInt)
2067329abf2SKuba Mracek continue;
2077329abf2SKuba Mracek
2087329abf2SKuba Mracek for (auto *PtrToIntUser : PtrExpr->users()) {
2097329abf2SKuba Mracek auto *SubExpr = dyn_cast<ConstantExpr>(PtrToIntUser);
2107329abf2SKuba Mracek if (!SubExpr || SubExpr->getOpcode() != Instruction::Sub)
2117329abf2SKuba Mracek continue;
2127329abf2SKuba Mracek
2137329abf2SKuba Mracek SubExpr->replaceNonMetadataUsesWith(
2147329abf2SKuba Mracek ConstantInt::get(SubExpr->getType(), 0));
2157329abf2SKuba Mracek }
2167329abf2SKuba Mracek }
2177329abf2SKuba Mracek }
218