1 //===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains functions that make it easier to manipulate type metadata
11 // for devirtualization.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Analysis/TypeMetadataUtils.h"
16 #include "llvm/IR/Intrinsics.h"
17 #include "llvm/IR/Module.h"
18 
19 using namespace llvm;
20 
21 // Search for virtual calls that call FPtr and add them to DevirtCalls.
22 static void
23 findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
24                           Value *FPtr, uint64_t Offset) {
25   for (const Use &U : FPtr->uses()) {
26     Value *User = U.getUser();
27     if (isa<BitCastInst>(User)) {
28       findCallsAtConstantOffset(DevirtCalls, User, Offset);
29     } else if (auto CI = dyn_cast<CallInst>(User)) {
30       DevirtCalls.push_back({Offset, CI});
31     } else if (auto II = dyn_cast<InvokeInst>(User)) {
32       DevirtCalls.push_back({Offset, II});
33     }
34   }
35 }
36 
37 // Search for virtual calls that load from VPtr and add them to DevirtCalls.
38 static void
39 findLoadCallsAtConstantOffset(Module *M,
40                               SmallVectorImpl<DevirtCallSite> &DevirtCalls,
41                               Value *VPtr, uint64_t Offset) {
42   for (const Use &U : VPtr->uses()) {
43     Value *User = U.getUser();
44     if (isa<BitCastInst>(User)) {
45       findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset);
46     } else if (isa<LoadInst>(User)) {
47       findCallsAtConstantOffset(DevirtCalls, User, Offset);
48     } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
49       // Take into account the GEP offset.
50       if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
51         SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
52         uint64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType(
53             GEP->getSourceElementType(), Indices);
54         findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset);
55       }
56     }
57   }
58 }
59 
60 void llvm::findDevirtualizableCalls(
61     SmallVectorImpl<DevirtCallSite> &DevirtCalls,
62     SmallVectorImpl<CallInst *> &Assumes, CallInst *CI) {
63   assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test);
64 
65   Module *M = CI->getParent()->getParent()->getParent();
66 
67   // Find llvm.assume intrinsics for this llvm.type.test call.
68   for (const Use &CIU : CI->uses()) {
69     auto AssumeCI = dyn_cast<CallInst>(CIU.getUser());
70     if (AssumeCI) {
71       Function *F = AssumeCI->getCalledFunction();
72       if (F && F->getIntrinsicID() == Intrinsic::assume)
73         Assumes.push_back(AssumeCI);
74     }
75   }
76 
77   // If we found any, search for virtual calls based on %p and add them to
78   // DevirtCalls.
79   if (!Assumes.empty())
80     findLoadCallsAtConstantOffset(M, DevirtCalls,
81                                   CI->getArgOperand(0)->stripPointerCasts(), 0);
82 }
83