1 //===--- ByteCodeStmtGen.cpp - Code generator for expressions ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ByteCodeStmtGen.h"
10 #include "ByteCodeEmitter.h"
11 #include "ByteCodeGenError.h"
12 #include "Context.h"
13 #include "Function.h"
14 #include "PrimType.h"
15 
16 using namespace clang;
17 using namespace clang::interp;
18 
19 namespace clang {
20 namespace interp {
21 
22 /// Scope managing label targets.
23 template <class Emitter> class LabelScope {
24 public:
25   virtual ~LabelScope() {  }
26 
27 protected:
28   LabelScope(ByteCodeStmtGen<Emitter> *Ctx) : Ctx(Ctx) {}
29   /// ByteCodeStmtGen instance.
30   ByteCodeStmtGen<Emitter> *Ctx;
31 };
32 
33 /// Sets the context for break/continue statements.
34 template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
35 public:
36   using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
37   using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
38 
39   LoopScope(ByteCodeStmtGen<Emitter> *Ctx, LabelTy BreakLabel,
40             LabelTy ContinueLabel)
41       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
42         OldContinueLabel(Ctx->ContinueLabel) {
43     this->Ctx->BreakLabel = BreakLabel;
44     this->Ctx->ContinueLabel = ContinueLabel;
45   }
46 
47   ~LoopScope() {
48     this->Ctx->BreakLabel = OldBreakLabel;
49     this->Ctx->ContinueLabel = OldContinueLabel;
50   }
51 
52 private:
53   OptLabelTy OldBreakLabel;
54   OptLabelTy OldContinueLabel;
55 };
56 
57 // Sets the context for a switch scope, mapping labels.
58 template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
59 public:
60   using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
61   using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
62   using CaseMap = typename ByteCodeStmtGen<Emitter>::CaseMap;
63 
64   SwitchScope(ByteCodeStmtGen<Emitter> *Ctx, CaseMap &&CaseLabels,
65               LabelTy BreakLabel, OptLabelTy DefaultLabel)
66       : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
67         OldDefaultLabel(this->Ctx->DefaultLabel),
68         OldCaseLabels(std::move(this->Ctx->CaseLabels)) {
69     this->Ctx->BreakLabel = BreakLabel;
70     this->Ctx->DefaultLabel = DefaultLabel;
71     this->Ctx->CaseLabels = std::move(CaseLabels);
72   }
73 
74   ~SwitchScope() {
75     this->Ctx->BreakLabel = OldBreakLabel;
76     this->Ctx->DefaultLabel = OldDefaultLabel;
77     this->Ctx->CaseLabels = std::move(OldCaseLabels);
78   }
79 
80 private:
81   OptLabelTy OldBreakLabel;
82   OptLabelTy OldDefaultLabel;
83   CaseMap OldCaseLabels;
84 };
85 
86 } // namespace interp
87 } // namespace clang
88 
89 template <class Emitter>
90 bool ByteCodeStmtGen<Emitter>::emitLambdaStaticInvokerBody(
91     const CXXMethodDecl *MD) {
92   assert(MD->isLambdaStaticInvoker());
93   assert(MD->hasBody());
94   assert(cast<CompoundStmt>(MD->getBody())->body_empty());
95 
96   const CXXRecordDecl *ClosureClass = MD->getParent();
97   const CXXMethodDecl *LambdaCallOp = ClosureClass->getLambdaCallOperator();
98   assert(ClosureClass->captures_begin() == ClosureClass->captures_end());
99   const Function *Func = this->getFunction(LambdaCallOp);
100   if (!Func)
101     return false;
102   assert(Func->hasThisPointer());
103   assert(Func->getNumParams() == (MD->getNumParams() + 1 + Func->hasRVO()));
104 
105   if (Func->hasRVO()) {
106     if (!this->emitRVOPtr(MD))
107       return false;
108   }
109 
110   // The lambda call operator needs an instance pointer, but we don't have
111   // one here, and we don't need one either because the lambda cannot have
112   // any captures, as verified above. Emit a null pointer. This is then
113   // special-cased when interpreting to not emit any misleading diagnostics.
114   if (!this->emitNullPtr(MD))
115     return false;
116 
117   // Forward all arguments from the static invoker to the lambda call operator.
118   for (const ParmVarDecl *PVD : MD->parameters()) {
119     auto It = this->Params.find(PVD);
120     assert(It != this->Params.end());
121 
122     // We do the lvalue-to-rvalue conversion manually here, so no need
123     // to care about references.
124     PrimType ParamType = this->classify(PVD->getType()).value_or(PT_Ptr);
125     if (!this->emitGetParam(ParamType, It->second.Offset, MD))
126       return false;
127   }
128 
129   if (!this->emitCall(Func, LambdaCallOp))
130     return false;
131 
132   this->emitCleanup();
133   if (ReturnType)
134     return this->emitRet(*ReturnType, MD);
135 
136   // Nothing to do, since we emitted the RVO pointer above.
137   return this->emitRetVoid(MD);
138 }
139 
140 template <class Emitter>
141 bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
142   // Classify the return type.
143   ReturnType = this->classify(F->getReturnType());
144 
145   // Emit custom code if this is a lambda static invoker.
146   if (const auto *MD = dyn_cast<CXXMethodDecl>(F);
147       MD && MD->isLambdaStaticInvoker())
148     return this->emitLambdaStaticInvokerBody(MD);
149 
150   // Constructor. Set up field initializers.
151   if (const auto *Ctor = dyn_cast<CXXConstructorDecl>(F)) {
152     const RecordDecl *RD = Ctor->getParent();
153     const Record *R = this->getRecord(RD);
154     if (!R)
155       return false;
156 
157     for (const auto *Init : Ctor->inits()) {
158       // Scope needed for the initializers.
159       BlockScope<Emitter> Scope(this);
160 
161       const Expr *InitExpr = Init->getInit();
162       if (const FieldDecl *Member = Init->getMember()) {
163         const Record::Field *F = R->getField(Member);
164 
165         if (std::optional<PrimType> T = this->classify(InitExpr)) {
166           if (!this->visit(InitExpr))
167             return false;
168 
169           if (F->isBitField()) {
170             if (!this->emitInitThisBitField(*T, F, InitExpr))
171               return false;
172           } else {
173             if (!this->emitInitThisField(*T, F->Offset, InitExpr))
174               return false;
175           }
176         } else {
177           // Non-primitive case. Get a pointer to the field-to-initialize
178           // on the stack and call visitInitialzer() for it.
179           if (!this->emitGetPtrThisField(F->Offset, InitExpr))
180             return false;
181 
182           if (!this->visitInitializer(InitExpr))
183             return false;
184 
185           if (!this->emitPopPtr(InitExpr))
186             return false;
187         }
188       } else if (const Type *Base = Init->getBaseClass()) {
189         // Base class initializer.
190         // Get This Base and call initializer on it.
191         const auto *BaseDecl = Base->getAsCXXRecordDecl();
192         assert(BaseDecl);
193         const Record::Base *B = R->getBase(BaseDecl);
194         assert(B);
195         if (!this->emitGetPtrThisBase(B->Offset, InitExpr))
196           return false;
197         if (!this->visitInitializer(InitExpr))
198           return false;
199         if (!this->emitInitPtrPop(InitExpr))
200           return false;
201       } else {
202         assert(Init->isDelegatingInitializer());
203         if (!this->emitThis(InitExpr))
204           return false;
205         if (!this->visitInitializer(Init->getInit()))
206           return false;
207         if (!this->emitPopPtr(InitExpr))
208           return false;
209       }
210     }
211   }
212 
213   if (const auto *Body = F->getBody())
214     if (!visitStmt(Body))
215       return false;
216 
217   // Emit a guard return to protect against a code path missing one.
218   if (F->getReturnType()->isVoidType())
219     return this->emitRetVoid(SourceInfo{});
220   else
221     return this->emitNoRet(SourceInfo{});
222 }
223 
224 template <class Emitter>
225 bool ByteCodeStmtGen<Emitter>::visitStmt(const Stmt *S) {
226   switch (S->getStmtClass()) {
227   case Stmt::CompoundStmtClass:
228     return visitCompoundStmt(cast<CompoundStmt>(S));
229   case Stmt::DeclStmtClass:
230     return visitDeclStmt(cast<DeclStmt>(S));
231   case Stmt::ReturnStmtClass:
232     return visitReturnStmt(cast<ReturnStmt>(S));
233   case Stmt::IfStmtClass:
234     return visitIfStmt(cast<IfStmt>(S));
235   case Stmt::WhileStmtClass:
236     return visitWhileStmt(cast<WhileStmt>(S));
237   case Stmt::DoStmtClass:
238     return visitDoStmt(cast<DoStmt>(S));
239   case Stmt::ForStmtClass:
240     return visitForStmt(cast<ForStmt>(S));
241   case Stmt::CXXForRangeStmtClass:
242     return visitCXXForRangeStmt(cast<CXXForRangeStmt>(S));
243   case Stmt::BreakStmtClass:
244     return visitBreakStmt(cast<BreakStmt>(S));
245   case Stmt::ContinueStmtClass:
246     return visitContinueStmt(cast<ContinueStmt>(S));
247   case Stmt::SwitchStmtClass:
248     return visitSwitchStmt(cast<SwitchStmt>(S));
249   case Stmt::CaseStmtClass:
250     return visitCaseStmt(cast<CaseStmt>(S));
251   case Stmt::DefaultStmtClass:
252     return visitDefaultStmt(cast<DefaultStmt>(S));
253   case Stmt::GCCAsmStmtClass:
254   case Stmt::MSAsmStmtClass:
255     return visitAsmStmt(cast<AsmStmt>(S));
256   case Stmt::AttributedStmtClass:
257     return visitAttributedStmt(cast<AttributedStmt>(S));
258   case Stmt::CXXTryStmtClass:
259     return visitCXXTryStmt(cast<CXXTryStmt>(S));
260   case Stmt::NullStmtClass:
261     return true;
262   default: {
263     if (auto *Exp = dyn_cast<Expr>(S))
264       return this->discard(Exp);
265     return this->bail(S);
266   }
267   }
268 }
269 
270 /// Visits the given statment without creating a variable
271 /// scope for it in case it is a compound statement.
272 template <class Emitter>
273 bool ByteCodeStmtGen<Emitter>::visitLoopBody(const Stmt *S) {
274   if (isa<NullStmt>(S))
275     return true;
276 
277   if (const auto *CS = dyn_cast<CompoundStmt>(S)) {
278     for (auto *InnerStmt : CS->body())
279       if (!visitStmt(InnerStmt))
280         return false;
281     return true;
282   }
283 
284   return this->visitStmt(S);
285 }
286 
287 template <class Emitter>
288 bool ByteCodeStmtGen<Emitter>::visitCompoundStmt(
289     const CompoundStmt *CompoundStmt) {
290   BlockScope<Emitter> Scope(this);
291   for (auto *InnerStmt : CompoundStmt->body())
292     if (!visitStmt(InnerStmt))
293       return false;
294   return true;
295 }
296 
297 template <class Emitter>
298 bool ByteCodeStmtGen<Emitter>::visitDeclStmt(const DeclStmt *DS) {
299   for (auto *D : DS->decls()) {
300     if (isa<StaticAssertDecl, TagDecl, TypedefNameDecl>(D))
301       continue;
302 
303     const auto *VD = dyn_cast<VarDecl>(D);
304     if (!VD)
305       return false;
306     if (!this->visitVarDecl(VD))
307       return false;
308   }
309 
310   return true;
311 }
312 
313 template <class Emitter>
314 bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
315   if (const Expr *RE = RS->getRetValue()) {
316     ExprScope<Emitter> RetScope(this);
317     if (ReturnType) {
318       // Primitive types are simply returned.
319       if (!this->visit(RE))
320         return false;
321       this->emitCleanup();
322       return this->emitRet(*ReturnType, RS);
323     } else if (RE->getType()->isVoidType()) {
324       if (!this->visit(RE))
325         return false;
326     } else {
327       // RVO - construct the value in the return location.
328       if (!this->emitRVOPtr(RE))
329         return false;
330       if (!this->visitInitializer(RE))
331         return false;
332       if (!this->emitPopPtr(RE))
333         return false;
334 
335       this->emitCleanup();
336       return this->emitRetVoid(RS);
337     }
338   }
339 
340   // Void return.
341   this->emitCleanup();
342   return this->emitRetVoid(RS);
343 }
344 
345 template <class Emitter>
346 bool ByteCodeStmtGen<Emitter>::visitIfStmt(const IfStmt *IS) {
347   BlockScope<Emitter> IfScope(this);
348 
349   if (IS->isNonNegatedConsteval())
350     return visitStmt(IS->getThen());
351   if (IS->isNegatedConsteval())
352     return IS->getElse() ? visitStmt(IS->getElse()) : true;
353 
354   if (auto *CondInit = IS->getInit())
355     if (!visitStmt(CondInit))
356       return false;
357 
358   if (const DeclStmt *CondDecl = IS->getConditionVariableDeclStmt())
359     if (!visitDeclStmt(CondDecl))
360       return false;
361 
362   if (!this->visitBool(IS->getCond()))
363     return false;
364 
365   if (const Stmt *Else = IS->getElse()) {
366     LabelTy LabelElse = this->getLabel();
367     LabelTy LabelEnd = this->getLabel();
368     if (!this->jumpFalse(LabelElse))
369       return false;
370     if (!visitStmt(IS->getThen()))
371       return false;
372     if (!this->jump(LabelEnd))
373       return false;
374     this->emitLabel(LabelElse);
375     if (!visitStmt(Else))
376       return false;
377     this->emitLabel(LabelEnd);
378   } else {
379     LabelTy LabelEnd = this->getLabel();
380     if (!this->jumpFalse(LabelEnd))
381       return false;
382     if (!visitStmt(IS->getThen()))
383       return false;
384     this->emitLabel(LabelEnd);
385   }
386 
387   return true;
388 }
389 
390 template <class Emitter>
391 bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *S) {
392   const Expr *Cond = S->getCond();
393   const Stmt *Body = S->getBody();
394 
395   LabelTy CondLabel = this->getLabel(); // Label before the condition.
396   LabelTy EndLabel = this->getLabel();  // Label after the loop.
397   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
398 
399   this->emitLabel(CondLabel);
400   if (!this->visitBool(Cond))
401     return false;
402   if (!this->jumpFalse(EndLabel))
403     return false;
404 
405   LocalScope<Emitter> Scope(this);
406   {
407     DestructorScope<Emitter> DS(Scope);
408     if (!this->visitLoopBody(Body))
409       return false;
410   }
411 
412   if (!this->jump(CondLabel))
413     return false;
414   this->emitLabel(EndLabel);
415 
416   return true;
417 }
418 
419 template <class Emitter>
420 bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *S) {
421   const Expr *Cond = S->getCond();
422   const Stmt *Body = S->getBody();
423 
424   LabelTy StartLabel = this->getLabel();
425   LabelTy EndLabel = this->getLabel();
426   LabelTy CondLabel = this->getLabel();
427   LoopScope<Emitter> LS(this, EndLabel, CondLabel);
428   LocalScope<Emitter> Scope(this);
429 
430   this->emitLabel(StartLabel);
431   {
432     DestructorScope<Emitter> DS(Scope);
433 
434     if (!this->visitLoopBody(Body))
435       return false;
436     this->emitLabel(CondLabel);
437     if (!this->visitBool(Cond))
438       return false;
439   }
440   if (!this->jumpTrue(StartLabel))
441     return false;
442 
443   this->emitLabel(EndLabel);
444   return true;
445 }
446 
447 template <class Emitter>
448 bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *S) {
449   // for (Init; Cond; Inc) { Body }
450   const Stmt *Init = S->getInit();
451   const Expr *Cond = S->getCond();
452   const Expr *Inc = S->getInc();
453   const Stmt *Body = S->getBody();
454 
455   LabelTy EndLabel = this->getLabel();
456   LabelTy CondLabel = this->getLabel();
457   LabelTy IncLabel = this->getLabel();
458   LoopScope<Emitter> LS(this, EndLabel, IncLabel);
459   LocalScope<Emitter> Scope(this);
460 
461   if (Init && !this->visitStmt(Init))
462     return false;
463   this->emitLabel(CondLabel);
464   if (Cond) {
465     if (!this->visitBool(Cond))
466       return false;
467     if (!this->jumpFalse(EndLabel))
468       return false;
469   }
470 
471   {
472     DestructorScope<Emitter> DS(Scope);
473 
474     if (Body && !this->visitLoopBody(Body))
475       return false;
476     this->emitLabel(IncLabel);
477     if (Inc && !this->discard(Inc))
478       return false;
479   }
480 
481   if (!this->jump(CondLabel))
482     return false;
483   this->emitLabel(EndLabel);
484   return true;
485 }
486 
487 template <class Emitter>
488 bool ByteCodeStmtGen<Emitter>::visitCXXForRangeStmt(const CXXForRangeStmt *S) {
489   const Stmt *Init = S->getInit();
490   const Expr *Cond = S->getCond();
491   const Expr *Inc = S->getInc();
492   const Stmt *Body = S->getBody();
493   const Stmt *BeginStmt = S->getBeginStmt();
494   const Stmt *RangeStmt = S->getRangeStmt();
495   const Stmt *EndStmt = S->getEndStmt();
496   const VarDecl *LoopVar = S->getLoopVariable();
497 
498   LabelTy EndLabel = this->getLabel();
499   LabelTy CondLabel = this->getLabel();
500   LabelTy IncLabel = this->getLabel();
501   LoopScope<Emitter> LS(this, EndLabel, IncLabel);
502 
503   // Emit declarations needed in the loop.
504   if (Init && !this->visitStmt(Init))
505     return false;
506   if (!this->visitStmt(RangeStmt))
507     return false;
508   if (!this->visitStmt(BeginStmt))
509     return false;
510   if (!this->visitStmt(EndStmt))
511     return false;
512 
513   // Now the condition as well as the loop variable assignment.
514   this->emitLabel(CondLabel);
515   if (!this->visitBool(Cond))
516     return false;
517   if (!this->jumpFalse(EndLabel))
518     return false;
519 
520   if (!this->visitVarDecl(LoopVar))
521     return false;
522 
523   // Body.
524   LocalScope<Emitter> Scope(this);
525   {
526     DestructorScope<Emitter> DS(Scope);
527 
528     if (!this->visitLoopBody(Body))
529       return false;
530     this->emitLabel(IncLabel);
531     if (!this->discard(Inc))
532       return false;
533   }
534   if (!this->jump(CondLabel))
535     return false;
536 
537   this->emitLabel(EndLabel);
538   return true;
539 }
540 
541 template <class Emitter>
542 bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *S) {
543   if (!BreakLabel)
544     return false;
545 
546   this->VarScope->emitDestructors();
547   return this->jump(*BreakLabel);
548 }
549 
550 template <class Emitter>
551 bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *S) {
552   if (!ContinueLabel)
553     return false;
554 
555   this->VarScope->emitDestructors();
556   return this->jump(*ContinueLabel);
557 }
558 
559 template <class Emitter>
560 bool ByteCodeStmtGen<Emitter>::visitSwitchStmt(const SwitchStmt *S) {
561   const Expr *Cond = S->getCond();
562   PrimType CondT = this->classifyPrim(Cond->getType());
563 
564   LabelTy EndLabel = this->getLabel();
565   OptLabelTy DefaultLabel = std::nullopt;
566   unsigned CondVar = this->allocateLocalPrimitive(Cond, CondT, true, false);
567 
568   if (const auto *CondInit = S->getInit())
569     if (!visitStmt(CondInit))
570       return false;
571 
572   // Initialize condition variable.
573   if (!this->visit(Cond))
574     return false;
575   if (!this->emitSetLocal(CondT, CondVar, S))
576     return false;
577 
578   CaseMap CaseLabels;
579   // Create labels and comparison ops for all case statements.
580   for (const SwitchCase *SC = S->getSwitchCaseList(); SC;
581        SC = SC->getNextSwitchCase()) {
582     if (const auto *CS = dyn_cast<CaseStmt>(SC)) {
583       // FIXME: Implement ranges.
584       if (CS->caseStmtIsGNURange())
585         return false;
586       CaseLabels[SC] = this->getLabel();
587 
588       const Expr *Value = CS->getLHS();
589       PrimType ValueT = this->classifyPrim(Value->getType());
590 
591       // Compare the case statement's value to the switch condition.
592       if (!this->emitGetLocal(CondT, CondVar, CS))
593         return false;
594       if (!this->visit(Value))
595         return false;
596 
597       // Compare and jump to the case label.
598       if (!this->emitEQ(ValueT, S))
599         return false;
600       if (!this->jumpTrue(CaseLabels[CS]))
601         return false;
602     } else {
603       assert(!DefaultLabel);
604       DefaultLabel = this->getLabel();
605     }
606   }
607 
608   // If none of the conditions above were true, fall through to the default
609   // statement or jump after the switch statement.
610   if (DefaultLabel) {
611     if (!this->jump(*DefaultLabel))
612       return false;
613   } else {
614     if (!this->jump(EndLabel))
615       return false;
616   }
617 
618   SwitchScope<Emitter> SS(this, std::move(CaseLabels), EndLabel, DefaultLabel);
619   if (!this->visitStmt(S->getBody()))
620     return false;
621   this->emitLabel(EndLabel);
622   return true;
623 }
624 
625 template <class Emitter>
626 bool ByteCodeStmtGen<Emitter>::visitCaseStmt(const CaseStmt *S) {
627   this->emitLabel(CaseLabels[S]);
628   return this->visitStmt(S->getSubStmt());
629 }
630 
631 template <class Emitter>
632 bool ByteCodeStmtGen<Emitter>::visitDefaultStmt(const DefaultStmt *S) {
633   this->emitLabel(*DefaultLabel);
634   return this->visitStmt(S->getSubStmt());
635 }
636 
637 template <class Emitter>
638 bool ByteCodeStmtGen<Emitter>::visitAsmStmt(const AsmStmt *S) {
639   return this->emitInvalid(S);
640 }
641 
642 template <class Emitter>
643 bool ByteCodeStmtGen<Emitter>::visitAttributedStmt(const AttributedStmt *S) {
644   // Ignore all attributes.
645   return this->visitStmt(S->getSubStmt());
646 }
647 
648 template <class Emitter>
649 bool ByteCodeStmtGen<Emitter>::visitCXXTryStmt(const CXXTryStmt *S) {
650   // Ignore all handlers.
651   return this->visitStmt(S->getTryBlock());
652 }
653 
654 namespace clang {
655 namespace interp {
656 
657 template class ByteCodeStmtGen<ByteCodeEmitter>;
658 
659 } // namespace interp
660 } // namespace clang
661