1 //===--- CGClass.cpp - Emit LLVM Code for C++ classes ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This contains code dealing with C++ code generation of classes
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenFunction.h"
15 #include "clang/AST/CXXInheritance.h"
16 #include "clang/AST/RecordLayout.h"
17 
18 using namespace clang;
19 using namespace CodeGen;
20 
21 static uint64_t
22 ComputeNonVirtualBaseClassOffset(ASTContext &Context, CXXBasePaths &Paths,
23                                  unsigned Start) {
24   uint64_t Offset = 0;
25 
26   const CXXBasePath &Path = Paths.front();
27   for (unsigned i = Start, e = Path.size(); i != e; ++i) {
28     const CXXBasePathElement& Element = Path[i];
29 
30     // Get the layout.
31     const ASTRecordLayout &Layout = Context.getASTRecordLayout(Element.Class);
32 
33     const CXXBaseSpecifier *BS = Element.Base;
34     assert(!BS->isVirtual() && "Should not see virtual bases here!");
35 
36     const CXXRecordDecl *Base =
37       cast<CXXRecordDecl>(BS->getType()->getAs<RecordType>()->getDecl());
38 
39     // Add the offset.
40     Offset += Layout.getBaseClassOffset(Base) / 8;
41   }
42 
43   return Offset;
44 }
45 
46 llvm::Constant *
47 CodeGenModule::GetCXXBaseClassOffset(const CXXRecordDecl *ClassDecl,
48                                      const CXXRecordDecl *BaseClassDecl) {
49   if (ClassDecl == BaseClassDecl)
50     return 0;
51 
52   CXXBasePaths Paths(/*FindAmbiguities=*/false,
53                      /*RecordPaths=*/true, /*DetectVirtual=*/false);
54   if (!const_cast<CXXRecordDecl *>(ClassDecl)->
55         isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
56     assert(false && "Class must be derived from the passed in base class!");
57     return 0;
58   }
59 
60   uint64_t Offset = ComputeNonVirtualBaseClassOffset(getContext(), Paths, 0);
61   if (!Offset)
62     return 0;
63 
64   const llvm::Type *PtrDiffTy =
65     Types.ConvertType(getContext().getPointerDiffType());
66 
67   return llvm::ConstantInt::get(PtrDiffTy, Offset);
68 }
69 
70 static llvm::Value *GetCXXBaseClassOffset(CodeGenFunction &CGF,
71                                           llvm::Value *BaseValue,
72                                           const CXXRecordDecl *ClassDecl,
73                                           const CXXRecordDecl *BaseClassDecl) {
74   CXXBasePaths Paths(/*FindAmbiguities=*/false,
75                      /*RecordPaths=*/true, /*DetectVirtual=*/false);
76   if (!const_cast<CXXRecordDecl *>(ClassDecl)->
77         isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
78     assert(false && "Class must be derived from the passed in base class!");
79     return 0;
80   }
81 
82   unsigned Start = 0;
83   llvm::Value *VirtualOffset = 0;
84 
85   const CXXBasePath &Path = Paths.front();
86   const CXXRecordDecl *VBase = 0;
87   for (unsigned i = 0, e = Path.size(); i != e; ++i) {
88     const CXXBasePathElement& Element = Path[i];
89     if (Element.Base->isVirtual()) {
90       Start = i+1;
91       QualType VBaseType = Element.Base->getType();
92       VBase = cast<CXXRecordDecl>(VBaseType->getAs<RecordType>()->getDecl());
93     }
94   }
95   if (VBase)
96     VirtualOffset =
97       CGF.GetVirtualCXXBaseClassOffset(BaseValue, ClassDecl, VBase);
98 
99   uint64_t Offset =
100     ComputeNonVirtualBaseClassOffset(CGF.getContext(), Paths, Start);
101 
102   if (!Offset)
103     return VirtualOffset;
104 
105   const llvm::Type *PtrDiffTy =
106     CGF.ConvertType(CGF.getContext().getPointerDiffType());
107   llvm::Value *NonVirtualOffset = llvm::ConstantInt::get(PtrDiffTy, Offset);
108 
109   if (VirtualOffset)
110     return CGF.Builder.CreateAdd(VirtualOffset, NonVirtualOffset);
111 
112   return NonVirtualOffset;
113 }
114 
115 // FIXME: This probably belongs in CGVtable, but it relies on
116 // the static function ComputeNonVirtualBaseClassOffset, so we should make that
117 // a CodeGenModule member function as well.
118 ThunkAdjustment
119 CodeGenModule::ComputeThunkAdjustment(const CXXRecordDecl *ClassDecl,
120                                       const CXXRecordDecl *BaseClassDecl) {
121   CXXBasePaths Paths(/*FindAmbiguities=*/false,
122                      /*RecordPaths=*/true, /*DetectVirtual=*/false);
123   if (!const_cast<CXXRecordDecl *>(ClassDecl)->
124         isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
125     assert(false && "Class must be derived from the passed in base class!");
126     return ThunkAdjustment();
127   }
128 
129   unsigned Start = 0;
130   uint64_t VirtualOffset = 0;
131 
132   const CXXBasePath &Path = Paths.front();
133   const CXXRecordDecl *VBase = 0;
134   for (unsigned i = 0, e = Path.size(); i != e; ++i) {
135     const CXXBasePathElement& Element = Path[i];
136     if (Element.Base->isVirtual()) {
137       Start = i+1;
138       QualType VBaseType = Element.Base->getType();
139       VBase = cast<CXXRecordDecl>(VBaseType->getAs<RecordType>()->getDecl());
140     }
141   }
142   if (VBase)
143     VirtualOffset =
144       getVtableInfo().getVirtualBaseOffsetIndex(ClassDecl, BaseClassDecl);
145 
146   uint64_t Offset =
147     ComputeNonVirtualBaseClassOffset(getContext(), Paths, Start);
148   return ThunkAdjustment(Offset, VirtualOffset);
149 }
150 
151 llvm::Value *
152 CodeGenFunction::GetAddressOfBaseClass(llvm::Value *Value,
153                                        const CXXRecordDecl *ClassDecl,
154                                        const CXXRecordDecl *BaseClassDecl,
155                                        bool NullCheckValue) {
156   QualType BTy =
157     getContext().getCanonicalType(
158       getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(BaseClassDecl)));
159   const llvm::Type *BasePtrTy = llvm::PointerType::getUnqual(ConvertType(BTy));
160 
161   if (ClassDecl == BaseClassDecl) {
162     // Just cast back.
163     return Builder.CreateBitCast(Value, BasePtrTy);
164   }
165 
166   llvm::BasicBlock *CastNull = 0;
167   llvm::BasicBlock *CastNotNull = 0;
168   llvm::BasicBlock *CastEnd = 0;
169 
170   if (NullCheckValue) {
171     CastNull = createBasicBlock("cast.null");
172     CastNotNull = createBasicBlock("cast.notnull");
173     CastEnd = createBasicBlock("cast.end");
174 
175     llvm::Value *IsNull =
176       Builder.CreateICmpEQ(Value,
177                            llvm::Constant::getNullValue(Value->getType()));
178     Builder.CreateCondBr(IsNull, CastNull, CastNotNull);
179     EmitBlock(CastNotNull);
180   }
181 
182   const llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(VMContext);
183 
184   llvm::Value *Offset =
185     GetCXXBaseClassOffset(*this, Value, ClassDecl, BaseClassDecl);
186 
187   if (Offset) {
188     // Apply the offset.
189     Value = Builder.CreateBitCast(Value, Int8PtrTy);
190     Value = Builder.CreateGEP(Value, Offset, "add.ptr");
191   }
192 
193   // Cast back.
194   Value = Builder.CreateBitCast(Value, BasePtrTy);
195 
196   if (NullCheckValue) {
197     Builder.CreateBr(CastEnd);
198     EmitBlock(CastNull);
199     Builder.CreateBr(CastEnd);
200     EmitBlock(CastEnd);
201 
202     llvm::PHINode *PHI = Builder.CreatePHI(Value->getType());
203     PHI->reserveOperandSpace(2);
204     PHI->addIncoming(Value, CastNotNull);
205     PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()),
206                      CastNull);
207     Value = PHI;
208   }
209 
210   return Value;
211 }
212 
213 llvm::Value *
214 CodeGenFunction::GetAddressOfDerivedClass(llvm::Value *Value,
215                                           const CXXRecordDecl *ClassDecl,
216                                           const CXXRecordDecl *DerivedClassDecl,
217                                           bool NullCheckValue) {
218   QualType DerivedTy =
219     getContext().getCanonicalType(
220     getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(DerivedClassDecl)));
221   const llvm::Type *DerivedPtrTy = ConvertType(DerivedTy)->getPointerTo();
222 
223   if (ClassDecl == DerivedClassDecl) {
224     // Just cast back.
225     return Builder.CreateBitCast(Value, DerivedPtrTy);
226   }
227 
228   llvm::BasicBlock *CastNull = 0;
229   llvm::BasicBlock *CastNotNull = 0;
230   llvm::BasicBlock *CastEnd = 0;
231 
232   if (NullCheckValue) {
233     CastNull = createBasicBlock("cast.null");
234     CastNotNull = createBasicBlock("cast.notnull");
235     CastEnd = createBasicBlock("cast.end");
236 
237     llvm::Value *IsNull =
238     Builder.CreateICmpEQ(Value,
239                          llvm::Constant::getNullValue(Value->getType()));
240     Builder.CreateCondBr(IsNull, CastNull, CastNotNull);
241     EmitBlock(CastNotNull);
242   }
243 
244   llvm::Value *Offset = GetCXXBaseClassOffset(*this, Value, DerivedClassDecl,
245                                               ClassDecl);
246   if (Offset) {
247     // Apply the offset.
248     Value = Builder.CreatePtrToInt(Value, Offset->getType());
249     Value = Builder.CreateSub(Value, Offset);
250     Value = Builder.CreateIntToPtr(Value, DerivedPtrTy);
251   } else {
252     // Just cast.
253     Value = Builder.CreateBitCast(Value, DerivedPtrTy);
254   }
255 
256   if (NullCheckValue) {
257     Builder.CreateBr(CastEnd);
258     EmitBlock(CastNull);
259     Builder.CreateBr(CastEnd);
260     EmitBlock(CastEnd);
261 
262     llvm::PHINode *PHI = Builder.CreatePHI(Value->getType());
263     PHI->reserveOperandSpace(2);
264     PHI->addIncoming(Value, CastNotNull);
265     PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()),
266                      CastNull);
267     Value = PHI;
268   }
269 
270   return Value;
271 }
272