1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  This file implements semantic analysis for C++ Coroutines.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/Sema/SemaInternal.h"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/ExprCXX.h"
17 #include "clang/AST/StmtCXX.h"
18 #include "clang/Lex/Preprocessor.h"
19 #include "clang/Sema/Overload.h"
20 using namespace clang;
21 using namespace sema;
22 
23 /// Look up the std::coroutine_traits<...>::promise_type for the given
24 /// function type.
25 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
26                                   SourceLocation Loc) {
27   // FIXME: Cache std::coroutine_traits once we've found it.
28   NamespaceDecl *Std = S.getStdNamespace();
29   if (!Std) {
30     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
31     return QualType();
32   }
33 
34   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
35                       Loc, Sema::LookupOrdinaryName);
36   if (!S.LookupQualifiedName(Result, Std)) {
37     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
38     return QualType();
39   }
40 
41   ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
42   if (!CoroTraits) {
43     Result.suppressDiagnostics();
44     // We found something weird. Complain about the first thing we found.
45     NamedDecl *Found = *Result.begin();
46     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
47     return QualType();
48   }
49 
50   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
51   TemplateArgumentListInfo Args(Loc, Loc);
52   Args.addArgument(TemplateArgumentLoc(
53       TemplateArgument(FnType->getReturnType()),
54       S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
55   for (QualType T : FnType->getParamTypes())
56     Args.addArgument(TemplateArgumentLoc(
57         TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
58 
59   // Build the template-id.
60   QualType CoroTrait =
61       S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
62   if (CoroTrait.isNull())
63     return QualType();
64   if (S.RequireCompleteType(Loc, CoroTrait,
65                             diag::err_coroutine_traits_missing_specialization))
66     return QualType();
67 
68   CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
69   assert(RD && "specialization of class template is not a class?");
70 
71   // Look up the ::promise_type member.
72   LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
73                  Sema::LookupOrdinaryName);
74   S.LookupQualifiedName(R, RD);
75   auto *Promise = R.getAsSingle<TypeDecl>();
76   if (!Promise) {
77     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
78       << RD;
79     return QualType();
80   }
81 
82   // The promise type is required to be a class type.
83   QualType PromiseType = S.Context.getTypeDeclType(Promise);
84   if (!PromiseType->getAsCXXRecordDecl()) {
85     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
86       << PromiseType;
87     return QualType();
88   }
89 
90   return PromiseType;
91 }
92 
93 /// Check that this is a context in which a coroutine suspension can appear.
94 static FunctionScopeInfo *
95 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
96   // 'co_await' and 'co_yield' are permitted in unevaluated operands.
97   // FIXME: Not in 'noexcept'.
98   if (S.isUnevaluatedContext())
99     return nullptr;
100 
101   // Any other usage must be within a function.
102   auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
103   if (!FD) {
104     S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
105                     ? diag::err_coroutine_objc_method
106                     : diag::err_coroutine_outside_function) << Keyword;
107   } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
108     // Coroutines TS [special]/6:
109     //   A special member function shall not be a coroutine.
110     //
111     // FIXME: We assume that this really means that a coroutine cannot
112     //        be a constructor or destructor.
113     S.Diag(Loc, diag::err_coroutine_ctor_dtor)
114       << isa<CXXDestructorDecl>(FD) << Keyword;
115   } else if (FD->isConstexpr()) {
116     S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
117   } else if (FD->isVariadic()) {
118     S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
119   } else {
120     auto *ScopeInfo = S.getCurFunction();
121     assert(ScopeInfo && "missing function scope for function");
122 
123     // If we don't have a promise variable, build one now.
124     if (!ScopeInfo->CoroutinePromise && !FD->getType()->isDependentType()) {
125       QualType T =
126           lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(), Loc);
127       if (T.isNull())
128         return nullptr;
129 
130       // Create and default-initialize the promise.
131       ScopeInfo->CoroutinePromise =
132           VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
133                           &S.PP.getIdentifierTable().get("__promise"), T,
134                           S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
135       S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
136       if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
137         S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
138     }
139 
140     return ScopeInfo;
141   }
142 
143   return nullptr;
144 }
145 
146 /// Build a call to 'operator co_await' if there is a suitable operator for
147 /// the given expression.
148 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
149                                            SourceLocation Loc, Expr *E) {
150   UnresolvedSet<16> Functions;
151   SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
152                                        Functions);
153   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
154 }
155 
156 struct ReadySuspendResumeResult {
157   bool IsInvalid;
158   Expr *Results[3];
159 };
160 
161 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
162 /// expression.
163 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
164                                                   Expr *E) {
165   // Assume invalid until we see otherwise.
166   ReadySuspendResumeResult Calls = {true, {}};
167 
168   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
169   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
170     DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Funcs[I]), Loc);
171 
172     Expr *Operand = new (S.Context) OpaqueValueExpr(
173         Loc, E->getType(), E->getValueKind(), E->getObjectKind(), E);
174 
175     // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
176     CXXScopeSpec SS;
177     ExprResult Result = S.BuildMemberReferenceExpr(
178         Operand, Operand->getType(), Loc, /*IsPtr=*/false, SS,
179         SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
180         /*Scope=*/nullptr);
181     if (Result.isInvalid())
182       return Calls;
183 
184     // FIXME: Pass coroutine handle to await_suspend.
185     Result = S.ActOnCallExpr(nullptr, Result.get(), Loc, None, Loc, nullptr);
186     if (Result.isInvalid())
187       return Calls;
188     Calls.Results[I] = Result.get();
189   }
190 
191   Calls.IsInvalid = false;
192   return Calls;
193 }
194 
195 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
196   ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
197   if (Awaitable.isInvalid())
198     return ExprError();
199   return BuildCoawaitExpr(Loc, Awaitable.get());
200 }
201 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
202   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
203 
204   if (E->getType()->isDependentType()) {
205     Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
206     if (Coroutine)
207       Coroutine->CoroutineStmts.push_back(Res);
208     return Res;
209   }
210 
211   if (E->getType()->isPlaceholderType()) {
212     ExprResult R = CheckPlaceholderExpr(E);
213     if (R.isInvalid()) return ExprError();
214     E = R.get();
215   }
216 
217   // FIXME: If E is a prvalue, create a temporary.
218   // FIXME: If E is an xvalue, convert to lvalue.
219 
220   // Build the await_ready, await_suspend, await_resume calls.
221   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
222   if (RSS.IsInvalid)
223     return ExprError();
224 
225   Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
226                                         RSS.Results[2]);
227   if (Coroutine)
228     Coroutine->CoroutineStmts.push_back(Res);
229   return Res;
230 }
231 
232 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
233   // FIXME: Build yield_value call.
234   ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
235   if (Awaitable.isInvalid())
236     return ExprError();
237   return BuildCoyieldExpr(Loc, Awaitable.get());
238 }
239 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
240   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
241 
242   // FIXME: Build await_* calls.
243   Expr *Res = new (Context) CoyieldExpr(Loc, Context.VoidTy, E);
244   if (Coroutine)
245     Coroutine->CoroutineStmts.push_back(Res);
246   return Res;
247 }
248 
249 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
250   return BuildCoreturnStmt(Loc, E);
251 }
252 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
253   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
254 
255   // FIXME: Build return_* calls.
256   Stmt *Res = new (Context) CoreturnStmt(Loc, E);
257   if (Coroutine)
258     Coroutine->CoroutineStmts.push_back(Res);
259   return Res;
260 }
261 
262 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *Body) {
263   FunctionScopeInfo *Fn = getCurFunction();
264   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
265 
266   // Coroutines [stmt.return]p1:
267   //   A return statement shall not appear in a coroutine.
268   if (Fn->FirstReturnLoc.isValid()) {
269     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
270     auto *First = Fn->CoroutineStmts[0];
271     Diag(First->getLocStart(), diag::note_declared_coroutine_here)
272       << (isa<CoawaitExpr>(First) ? 0 :
273           isa<CoyieldExpr>(First) ? 1 : 2);
274   }
275 
276   bool AnyCoawaits = false;
277   bool AnyCoyields = false;
278   for (auto *CoroutineStmt : Fn->CoroutineStmts) {
279     AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
280     AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
281   }
282 
283   if (!AnyCoawaits && !AnyCoyields)
284     Diag(Fn->CoroutineStmts.front()->getLocStart(),
285          diag::ext_coroutine_without_co_await_co_yield);
286 
287   // FIXME: Perform analysis of initial and final suspend,
288   // and set_exception call.
289 }
290