1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  This file implements semantic analysis for C++ Coroutines.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/Sema/SemaInternal.h"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/ExprCXX.h"
17 #include "clang/AST/StmtCXX.h"
18 #include "clang/Lex/Preprocessor.h"
19 #include "clang/Sema/Initialization.h"
20 #include "clang/Sema/Overload.h"
21 using namespace clang;
22 using namespace sema;
23 
24 /// Look up the std::coroutine_traits<...>::promise_type for the given
25 /// function type.
26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
27                                   SourceLocation Loc) {
28   // FIXME: Cache std::coroutine_traits once we've found it.
29   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
30   if (!StdExp) {
31     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
32     return QualType();
33   }
34 
35   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
36                       Loc, Sema::LookupOrdinaryName);
37   if (!S.LookupQualifiedName(Result, StdExp)) {
38     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
39     return QualType();
40   }
41 
42   ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
43   if (!CoroTraits) {
44     Result.suppressDiagnostics();
45     // We found something weird. Complain about the first thing we found.
46     NamedDecl *Found = *Result.begin();
47     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
48     return QualType();
49   }
50 
51   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
52   TemplateArgumentListInfo Args(Loc, Loc);
53   Args.addArgument(TemplateArgumentLoc(
54       TemplateArgument(FnType->getReturnType()),
55       S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
56   // FIXME: If the function is a non-static member function, add the type
57   // of the implicit object parameter before the formal parameters.
58   for (QualType T : FnType->getParamTypes())
59     Args.addArgument(TemplateArgumentLoc(
60         TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
61 
62   // Build the template-id.
63   QualType CoroTrait =
64       S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
65   if (CoroTrait.isNull())
66     return QualType();
67   if (S.RequireCompleteType(Loc, CoroTrait,
68                             diag::err_coroutine_traits_missing_specialization))
69     return QualType();
70 
71   CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
72   assert(RD && "specialization of class template is not a class?");
73 
74   // Look up the ::promise_type member.
75   LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
76                  Sema::LookupOrdinaryName);
77   S.LookupQualifiedName(R, RD);
78   auto *Promise = R.getAsSingle<TypeDecl>();
79   if (!Promise) {
80     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
81       << RD;
82     return QualType();
83   }
84 
85   // The promise type is required to be a class type.
86   QualType PromiseType = S.Context.getTypeDeclType(Promise);
87   if (!PromiseType->getAsCXXRecordDecl()) {
88     // Use the fully-qualified name of the type.
89     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
90     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
91                                       CoroTrait.getTypePtr());
92     PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
93 
94     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
95       << PromiseType;
96     return QualType();
97   }
98 
99   return PromiseType;
100 }
101 
102 /// Check that this is a context in which a coroutine suspension can appear.
103 static FunctionScopeInfo *
104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
105   // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
106   if (S.isUnevaluatedContext()) {
107     S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
108     return nullptr;
109   }
110 
111   // Any other usage must be within a function.
112   // FIXME: Reject a coroutine with a deduced return type.
113   auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
114   if (!FD) {
115     S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
116                     ? diag::err_coroutine_objc_method
117                     : diag::err_coroutine_outside_function) << Keyword;
118   } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
119     // Coroutines TS [special]/6:
120     //   A special member function shall not be a coroutine.
121     //
122     // FIXME: We assume that this really means that a coroutine cannot
123     //        be a constructor or destructor.
124     S.Diag(Loc, diag::err_coroutine_ctor_dtor)
125       << isa<CXXDestructorDecl>(FD) << Keyword;
126   } else if (FD->isConstexpr()) {
127     S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
128   } else if (FD->isVariadic()) {
129     S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
130   } else if (FD->isMain()) {
131     S.Diag(FD->getLocStart(), diag::err_coroutine_main);
132     S.Diag(Loc, diag::note_declared_coroutine_here)
133       << (Keyword == "co_await" ? 0 :
134           Keyword == "co_yield" ? 1 : 2);
135   } else {
136     auto *ScopeInfo = S.getCurFunction();
137     assert(ScopeInfo && "missing function scope for function");
138 
139     // If we don't have a promise variable, build one now.
140     if (!ScopeInfo->CoroutinePromise) {
141       QualType T =
142           FD->getType()->isDependentType()
143               ? S.Context.DependentTy
144               : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
145                                   Loc);
146       if (T.isNull())
147         return nullptr;
148 
149       // Create and default-initialize the promise.
150       ScopeInfo->CoroutinePromise =
151           VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
152                           &S.PP.getIdentifierTable().get("__promise"), T,
153                           S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
154       S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
155       if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
156         S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
157     }
158 
159     return ScopeInfo;
160   }
161 
162   return nullptr;
163 }
164 
165 /// Build a call to 'operator co_await' if there is a suitable operator for
166 /// the given expression.
167 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
168                                            SourceLocation Loc, Expr *E) {
169   UnresolvedSet<16> Functions;
170   SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
171                                        Functions);
172   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
173 }
174 
175 struct ReadySuspendResumeResult {
176   bool IsInvalid;
177   Expr *Results[3];
178 };
179 
180 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
181                                   StringRef Name,
182                                   MutableArrayRef<Expr *> Args) {
183   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
184 
185   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
186   CXXScopeSpec SS;
187   ExprResult Result = S.BuildMemberReferenceExpr(
188       Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
189       SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
190       /*Scope=*/nullptr);
191   if (Result.isInvalid())
192     return ExprError();
193 
194   return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
195 }
196 
197 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
198 /// expression.
199 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
200                                                   Expr *E) {
201   // Assume invalid until we see otherwise.
202   ReadySuspendResumeResult Calls = {true, {}};
203 
204   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
205   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
206     Expr *Operand = new (S.Context) OpaqueValueExpr(
207         Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
208 
209     // FIXME: Pass coroutine handle to await_suspend.
210     ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
211     if (Result.isInvalid())
212       return Calls;
213     Calls.Results[I] = Result.get();
214   }
215 
216   Calls.IsInvalid = false;
217   return Calls;
218 }
219 
220 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
221   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
222   if (!Coroutine) {
223     CorrectDelayedTyposInExpr(E);
224     return ExprError();
225   }
226   if (E->getType()->isPlaceholderType()) {
227     ExprResult R = CheckPlaceholderExpr(E);
228     if (R.isInvalid()) return ExprError();
229     E = R.get();
230   }
231 
232   ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
233   if (Awaitable.isInvalid())
234     return ExprError();
235 
236   return BuildCoawaitExpr(Loc, Awaitable.get());
237 }
238 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
239   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
240   if (!Coroutine)
241     return ExprError();
242 
243   if (E->getType()->isPlaceholderType()) {
244     ExprResult R = CheckPlaceholderExpr(E);
245     if (R.isInvalid()) return ExprError();
246     E = R.get();
247   }
248 
249   if (E->getType()->isDependentType()) {
250     Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
251     Coroutine->CoroutineStmts.push_back(Res);
252     return Res;
253   }
254 
255   // If the expression is a temporary, materialize it as an lvalue so that we
256   // can use it multiple times.
257   if (E->getValueKind() == VK_RValue)
258     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
259 
260   // Build the await_ready, await_suspend, await_resume calls.
261   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
262   if (RSS.IsInvalid)
263     return ExprError();
264 
265   Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
266                                         RSS.Results[2]);
267   Coroutine->CoroutineStmts.push_back(Res);
268   return Res;
269 }
270 
271 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
272                                    SourceLocation Loc, StringRef Name,
273                                    MutableArrayRef<Expr *> Args) {
274   assert(Coroutine->CoroutinePromise && "no promise for coroutine");
275 
276   // Form a reference to the promise.
277   auto *Promise = Coroutine->CoroutinePromise;
278   ExprResult PromiseRef = S.BuildDeclRefExpr(
279       Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
280   if (PromiseRef.isInvalid())
281     return ExprError();
282 
283   // Call 'yield_value', passing in E.
284   return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
285 }
286 
287 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
288   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
289   if (!Coroutine) {
290     CorrectDelayedTyposInExpr(E);
291     return ExprError();
292   }
293 
294   // Build yield_value call.
295   ExprResult Awaitable =
296       buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
297   if (Awaitable.isInvalid())
298     return ExprError();
299 
300   // Build 'operator co_await' call.
301   Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
302   if (Awaitable.isInvalid())
303     return ExprError();
304 
305   return BuildCoyieldExpr(Loc, Awaitable.get());
306 }
307 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
308   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
309   if (!Coroutine)
310     return ExprError();
311 
312   if (E->getType()->isPlaceholderType()) {
313     ExprResult R = CheckPlaceholderExpr(E);
314     if (R.isInvalid()) return ExprError();
315     E = R.get();
316   }
317 
318   if (E->getType()->isDependentType()) {
319     Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
320     Coroutine->CoroutineStmts.push_back(Res);
321     return Res;
322   }
323 
324   // If the expression is a temporary, materialize it as an lvalue so that we
325   // can use it multiple times.
326   if (E->getValueKind() == VK_RValue)
327     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
328 
329   // Build the await_ready, await_suspend, await_resume calls.
330   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
331   if (RSS.IsInvalid)
332     return ExprError();
333 
334   Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
335                                         RSS.Results[2]);
336   Coroutine->CoroutineStmts.push_back(Res);
337   return Res;
338 }
339 
340 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
341   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
342   if (!Coroutine) {
343     CorrectDelayedTyposInExpr(E);
344     return StmtError();
345   }
346   return BuildCoreturnStmt(Loc, E);
347 }
348 
349 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
350   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
351   if (!Coroutine)
352     return StmtError();
353 
354   if (E && E->getType()->isPlaceholderType() &&
355       !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
356     ExprResult R = CheckPlaceholderExpr(E);
357     if (R.isInvalid()) return StmtError();
358     E = R.get();
359   }
360 
361   // FIXME: If the operand is a reference to a variable that's about to go out
362   // of scope, we should treat the operand as an xvalue for this overload
363   // resolution.
364   ExprResult PC;
365   if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
366     PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
367   } else {
368     E = MakeFullDiscardedValueExpr(E).get();
369     PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
370   }
371   if (PC.isInvalid())
372     return StmtError();
373 
374   Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
375 
376   Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
377   Coroutine->CoroutineStmts.push_back(Res);
378   return Res;
379 }
380 
381 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
382   FunctionScopeInfo *Fn = getCurFunction();
383   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
384 
385   // Coroutines [stmt.return]p1:
386   //   A return statement shall not appear in a coroutine.
387   if (Fn->FirstReturnLoc.isValid()) {
388     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
389     auto *First = Fn->CoroutineStmts[0];
390     Diag(First->getLocStart(), diag::note_declared_coroutine_here)
391       << (isa<CoawaitExpr>(First) ? 0 :
392           isa<CoyieldExpr>(First) ? 1 : 2);
393   }
394 
395   bool AnyCoawaits = false;
396   bool AnyCoyields = false;
397   for (auto *CoroutineStmt : Fn->CoroutineStmts) {
398     AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
399     AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
400   }
401 
402   if (!AnyCoawaits && !AnyCoyields)
403     Diag(Fn->CoroutineStmts.front()->getLocStart(),
404          diag::ext_coroutine_without_co_await_co_yield);
405 
406   SourceLocation Loc = FD->getLocation();
407 
408   // Form a declaration statement for the promise declaration, so that AST
409   // visitors can more easily find it.
410   StmtResult PromiseStmt =
411       ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
412   if (PromiseStmt.isInvalid())
413     return FD->setInvalidDecl();
414 
415   // Form and check implicit 'co_await p.initial_suspend();' statement.
416   ExprResult InitialSuspend =
417       buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
418   // FIXME: Support operator co_await here.
419   if (!InitialSuspend.isInvalid())
420     InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
421   InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
422   if (InitialSuspend.isInvalid())
423     return FD->setInvalidDecl();
424 
425   // Form and check implicit 'co_await p.final_suspend();' statement.
426   ExprResult FinalSuspend =
427       buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
428   // FIXME: Support operator co_await here.
429   if (!FinalSuspend.isInvalid())
430     FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
431   FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
432   if (FinalSuspend.isInvalid())
433     return FD->setInvalidDecl();
434 
435   // FIXME: Perform analysis of set_exception call.
436 
437   // FIXME: Try to form 'p.return_void();' expression statement to handle
438   // control flowing off the end of the coroutine.
439 
440   // Build implicit 'p.get_return_object()' expression and form initialization
441   // of return type from it.
442   ExprResult ReturnObject =
443     buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
444   if (ReturnObject.isInvalid())
445     return FD->setInvalidDecl();
446   QualType RetType = FD->getReturnType();
447   if (!RetType->isDependentType()) {
448     InitializedEntity Entity =
449         InitializedEntity::InitializeResult(Loc, RetType, false);
450     ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
451                                                    ReturnObject.get());
452     if (ReturnObject.isInvalid())
453       return FD->setInvalidDecl();
454   }
455   ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
456   if (ReturnObject.isInvalid())
457     return FD->setInvalidDecl();
458 
459   // FIXME: Perform move-initialization of parameters into frame-local copies.
460   SmallVector<Expr*, 16> ParamMoves;
461 
462   // Build body for the coroutine wrapper statement.
463   Body = new (Context) CoroutineBodyStmt(
464       Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
465       /*SetException*/nullptr, /*Fallthrough*/nullptr,
466       ReturnObject.get(), ParamMoves);
467 }
468