1 //===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file implements semantic analysis for C++ Coroutines.
10 //
11 //  This file contains references to sections of the Coroutines TS, which
12 //  can be found at http://wg21.link/coroutines.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "CoroutineStmtBuilder.h"
17 #include "clang/AST/ASTLambda.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/StmtCXX.h"
21 #include "clang/Basic/Builtins.h"
22 #include "clang/Lex/Preprocessor.h"
23 #include "clang/Sema/Initialization.h"
24 #include "clang/Sema/Overload.h"
25 #include "clang/Sema/ScopeInfo.h"
26 #include "clang/Sema/SemaInternal.h"
27 
28 using namespace clang;
29 using namespace sema;
30 
31 static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
32                                  SourceLocation Loc, bool &Res) {
33   DeclarationName DN = S.PP.getIdentifierInfo(Name);
34   LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
35   // Suppress diagnostics when a private member is selected. The same warnings
36   // will be produced again when building the call.
37   LR.suppressDiagnostics();
38   Res = S.LookupQualifiedName(LR, RD);
39   return LR;
40 }
41 
42 static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
43                          SourceLocation Loc) {
44   bool Res;
45   lookupMember(S, Name, RD, Loc, Res);
46   return Res;
47 }
48 
49 /// Look up the std::coroutine_traits<...>::promise_type for the given
50 /// function type.
51 static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
52                                   SourceLocation KwLoc) {
53   const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
54   const SourceLocation FuncLoc = FD->getLocation();
55   // FIXME: Cache std::coroutine_traits once we've found it.
56   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
57   if (!StdExp) {
58     S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
59         << "std::experimental::coroutine_traits";
60     return QualType();
61   }
62 
63   ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc);
64   if (!CoroTraits) {
65     return QualType();
66   }
67 
68   // Form template argument list for coroutine_traits<R, P1, P2, ...> according
69   // to [dcl.fct.def.coroutine]3
70   TemplateArgumentListInfo Args(KwLoc, KwLoc);
71   auto AddArg = [&](QualType T) {
72     Args.addArgument(TemplateArgumentLoc(
73         TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
74   };
75   AddArg(FnType->getReturnType());
76   // If the function is a non-static member function, add the type
77   // of the implicit object parameter before the formal parameters.
78   if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
79     if (MD->isInstance()) {
80       // [over.match.funcs]4
81       // For non-static member functions, the type of the implicit object
82       // parameter is
83       //  -- "lvalue reference to cv X" for functions declared without a
84       //      ref-qualifier or with the & ref-qualifier
85       //  -- "rvalue reference to cv X" for functions declared with the &&
86       //      ref-qualifier
87       QualType T = MD->getThisType()->castAs<PointerType>()->getPointeeType();
88       T = FnType->getRefQualifier() == RQ_RValue
89               ? S.Context.getRValueReferenceType(T)
90               : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true);
91       AddArg(T);
92     }
93   }
94   for (QualType T : FnType->getParamTypes())
95     AddArg(T);
96 
97   // Build the template-id.
98   QualType CoroTrait =
99       S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
100   if (CoroTrait.isNull())
101     return QualType();
102   if (S.RequireCompleteType(KwLoc, CoroTrait,
103                             diag::err_coroutine_type_missing_specialization))
104     return QualType();
105 
106   auto *RD = CoroTrait->getAsCXXRecordDecl();
107   assert(RD && "specialization of class template is not a class?");
108 
109   // Look up the ::promise_type member.
110   LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
111                  Sema::LookupOrdinaryName);
112   S.LookupQualifiedName(R, RD);
113   auto *Promise = R.getAsSingle<TypeDecl>();
114   if (!Promise) {
115     S.Diag(FuncLoc,
116            diag::err_implied_std_coroutine_traits_promise_type_not_found)
117         << RD;
118     return QualType();
119   }
120   // The promise type is required to be a class type.
121   QualType PromiseType = S.Context.getTypeDeclType(Promise);
122 
123   auto buildElaboratedType = [&]() {
124     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
125     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
126                                       CoroTrait.getTypePtr());
127     return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
128   };
129 
130   if (!PromiseType->getAsCXXRecordDecl()) {
131     S.Diag(FuncLoc,
132            diag::err_implied_std_coroutine_traits_promise_type_not_class)
133         << buildElaboratedType();
134     return QualType();
135   }
136   if (S.RequireCompleteType(FuncLoc, buildElaboratedType(),
137                             diag::err_coroutine_promise_type_incomplete))
138     return QualType();
139 
140   return PromiseType;
141 }
142 
143 /// Look up the std::experimental::coroutine_handle<PromiseType>.
144 static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
145                                           SourceLocation Loc) {
146   if (PromiseType.isNull())
147     return QualType();
148 
149   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
150   assert(StdExp && "Should already be diagnosed");
151 
152   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
153                       Loc, Sema::LookupOrdinaryName);
154   if (!S.LookupQualifiedName(Result, StdExp)) {
155     S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
156         << "std::experimental::coroutine_handle";
157     return QualType();
158   }
159 
160   ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
161   if (!CoroHandle) {
162     Result.suppressDiagnostics();
163     // We found something weird. Complain about the first thing we found.
164     NamedDecl *Found = *Result.begin();
165     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
166     return QualType();
167   }
168 
169   // Form template argument list for coroutine_handle<Promise>.
170   TemplateArgumentListInfo Args(Loc, Loc);
171   Args.addArgument(TemplateArgumentLoc(
172       TemplateArgument(PromiseType),
173       S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
174 
175   // Build the template-id.
176   QualType CoroHandleType =
177       S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
178   if (CoroHandleType.isNull())
179     return QualType();
180   if (S.RequireCompleteType(Loc, CoroHandleType,
181                             diag::err_coroutine_type_missing_specialization))
182     return QualType();
183 
184   return CoroHandleType;
185 }
186 
187 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
188                                     StringRef Keyword) {
189   // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within
190   // a function body.
191   // FIXME: This also covers [expr.await]p2: "An await-expression shall not
192   // appear in a default argument." But the diagnostic QoI here could be
193   // improved to inform the user that default arguments specifically are not
194   // allowed.
195   auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
196   if (!FD) {
197     S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
198                     ? diag::err_coroutine_objc_method
199                     : diag::err_coroutine_outside_function) << Keyword;
200     return false;
201   }
202 
203   // An enumeration for mapping the diagnostic type to the correct diagnostic
204   // selection index.
205   enum InvalidFuncDiag {
206     DiagCtor = 0,
207     DiagDtor,
208     DiagMain,
209     DiagConstexpr,
210     DiagAutoRet,
211     DiagVarargs,
212     DiagConsteval,
213   };
214   bool Diagnosed = false;
215   auto DiagInvalid = [&](InvalidFuncDiag ID) {
216     S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
217     Diagnosed = true;
218     return false;
219   };
220 
221   // Diagnose when a constructor, destructor
222   // or the function 'main' are declared as a coroutine.
223   auto *MD = dyn_cast<CXXMethodDecl>(FD);
224   // [class.ctor]p11: "A constructor shall not be a coroutine."
225   if (MD && isa<CXXConstructorDecl>(MD))
226     return DiagInvalid(DiagCtor);
227   // [class.dtor]p17: "A destructor shall not be a coroutine."
228   else if (MD && isa<CXXDestructorDecl>(MD))
229     return DiagInvalid(DiagDtor);
230   // [basic.start.main]p3: "The function main shall not be a coroutine."
231   else if (FD->isMain())
232     return DiagInvalid(DiagMain);
233 
234   // Emit a diagnostics for each of the following conditions which is not met.
235   // [expr.const]p2: "An expression e is a core constant expression unless the
236   // evaluation of e [...] would evaluate one of the following expressions:
237   // [...] an await-expression [...] a yield-expression."
238   if (FD->isConstexpr())
239     DiagInvalid(FD->isConsteval() ? DiagConsteval : DiagConstexpr);
240   // [dcl.spec.auto]p15: "A function declared with a return type that uses a
241   // placeholder type shall not be a coroutine."
242   if (FD->getReturnType()->isUndeducedType())
243     DiagInvalid(DiagAutoRet);
244   // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the
245   // coroutine shall not terminate with an ellipsis that is not part of a
246   // parameter-declaration."
247   if (FD->isVariadic())
248     DiagInvalid(DiagVarargs);
249 
250   return !Diagnosed;
251 }
252 
253 static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
254                                                  SourceLocation Loc) {
255   DeclarationName OpName =
256       SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
257   LookupResult Operators(SemaRef, OpName, SourceLocation(),
258                          Sema::LookupOperatorName);
259   SemaRef.LookupName(Operators, S);
260 
261   assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
262   const auto &Functions = Operators.asUnresolvedSet();
263   bool IsOverloaded =
264       Functions.size() > 1 ||
265       (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
266   Expr *CoawaitOp = UnresolvedLookupExpr::Create(
267       SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
268       DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
269       Functions.begin(), Functions.end());
270   assert(CoawaitOp);
271   return CoawaitOp;
272 }
273 
274 /// Build a call to 'operator co_await' if there is a suitable operator for
275 /// the given expression.
276 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
277                                            Expr *E,
278                                            UnresolvedLookupExpr *Lookup) {
279   UnresolvedSet<16> Functions;
280   Functions.append(Lookup->decls_begin(), Lookup->decls_end());
281   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
282 }
283 
284 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
285                                            SourceLocation Loc, Expr *E) {
286   ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
287   if (R.isInvalid())
288     return ExprError();
289   return buildOperatorCoawaitCall(SemaRef, Loc, E,
290                                   cast<UnresolvedLookupExpr>(R.get()));
291 }
292 
293 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
294                               MultiExprArg CallArgs) {
295   StringRef Name = S.Context.BuiltinInfo.getName(Id);
296   LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
297   S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
298 
299   auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
300   assert(BuiltInDecl && "failed to find builtin declaration");
301 
302   ExprResult DeclRef =
303       S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
304   assert(DeclRef.isUsable() && "Builtin reference cannot fail");
305 
306   ExprResult Call =
307       S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
308 
309   assert(!Call.isInvalid() && "Call to builtin cannot fail!");
310   return Call.get();
311 }
312 
313 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
314                                        SourceLocation Loc) {
315   QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
316   if (CoroHandleType.isNull())
317     return ExprError();
318 
319   DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
320   LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
321                      Sema::LookupOrdinaryName);
322   if (!S.LookupQualifiedName(Found, LookupCtx)) {
323     S.Diag(Loc, diag::err_coroutine_handle_missing_member)
324         << "from_address";
325     return ExprError();
326   }
327 
328   Expr *FramePtr =
329       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
330 
331   CXXScopeSpec SS;
332   ExprResult FromAddr =
333       S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
334   if (FromAddr.isInvalid())
335     return ExprError();
336 
337   return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
338 }
339 
340 struct ReadySuspendResumeResult {
341   enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
342   Expr *Results[3];
343   OpaqueValueExpr *OpaqueValue;
344   bool IsInvalid;
345 };
346 
347 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
348                                   StringRef Name, MultiExprArg Args) {
349   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
350 
351   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
352   CXXScopeSpec SS;
353   ExprResult Result = S.BuildMemberReferenceExpr(
354       Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
355       SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
356       /*Scope=*/nullptr);
357   if (Result.isInvalid())
358     return ExprError();
359 
360   // We meant exactly what we asked for. No need for typo correction.
361   if (auto *TE = dyn_cast<TypoExpr>(Result.get())) {
362     S.clearDelayedTypo(TE);
363     S.Diag(Loc, diag::err_no_member)
364         << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl()
365         << Base->getSourceRange();
366     return ExprError();
367   }
368 
369   return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
370 }
371 
372 // See if return type is coroutine-handle and if so, invoke builtin coro-resume
373 // on its address. This is to enable experimental support for coroutine-handle
374 // returning await_suspend that results in a guaranteed tail call to the target
375 // coroutine.
376 static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
377                            SourceLocation Loc) {
378   if (RetType->isReferenceType())
379     return nullptr;
380   Type const *T = RetType.getTypePtr();
381   if (!T->isClassType() && !T->isStructureType())
382     return nullptr;
383 
384   // FIXME: Add convertability check to coroutine_handle<>. Possibly via
385   // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
386   // a private function in SemaExprCXX.cpp
387 
388   ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
389   if (AddressExpr.isInvalid())
390     return nullptr;
391 
392   Expr *JustAddress = AddressExpr.get();
393   // FIXME: Check that the type of AddressExpr is void*
394   return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
395                           JustAddress);
396 }
397 
398 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
399 /// expression.
400 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
401                                                   SourceLocation Loc, Expr *E) {
402   OpaqueValueExpr *Operand = new (S.Context)
403       OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
404 
405   // Assume invalid until we see otherwise.
406   ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
407 
408   ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
409   if (CoroHandleRes.isInvalid())
410     return Calls;
411   Expr *CoroHandle = CoroHandleRes.get();
412 
413   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
414   MultiExprArg Args[] = {None, CoroHandle, None};
415   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
416     ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
417     if (Result.isInvalid())
418       return Calls;
419     Calls.Results[I] = Result.get();
420   }
421 
422   // Assume the calls are valid; all further checking should make them invalid.
423   Calls.IsInvalid = false;
424 
425   using ACT = ReadySuspendResumeResult::AwaitCallType;
426   CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
427   if (!AwaitReady->getType()->isDependentType()) {
428     // [expr.await]p3 [...]
429     // — await-ready is the expression e.await_ready(), contextually converted
430     // to bool.
431     ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
432     if (Conv.isInvalid()) {
433       S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(),
434              diag::note_await_ready_no_bool_conversion);
435       S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
436           << AwaitReady->getDirectCallee() << E->getSourceRange();
437       Calls.IsInvalid = true;
438     }
439     Calls.Results[ACT::ACT_Ready] = Conv.get();
440   }
441   CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
442   if (!AwaitSuspend->getType()->isDependentType()) {
443     // [expr.await]p3 [...]
444     //   - await-suspend is the expression e.await_suspend(h), which shall be
445     //     a prvalue of type void or bool.
446     QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
447 
448     // Experimental support for coroutine_handle returning await_suspend.
449     if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
450       Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
451     else {
452       // non-class prvalues always have cv-unqualified types
453       if (RetType->isReferenceType() ||
454           (!RetType->isBooleanType() && !RetType->isVoidType())) {
455         S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
456                diag::err_await_suspend_invalid_return_type)
457             << RetType;
458         S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
459             << AwaitSuspend->getDirectCallee();
460         Calls.IsInvalid = true;
461       }
462     }
463   }
464 
465   return Calls;
466 }
467 
468 static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
469                                    SourceLocation Loc, StringRef Name,
470                                    MultiExprArg Args) {
471 
472   // Form a reference to the promise.
473   ExprResult PromiseRef = S.BuildDeclRefExpr(
474       Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
475   if (PromiseRef.isInvalid())
476     return ExprError();
477 
478   return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
479 }
480 
481 VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
482   assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
483   auto *FD = cast<FunctionDecl>(CurContext);
484   bool IsThisDependentType = [&] {
485     if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
486       return MD->isInstance() && MD->getThisType()->isDependentType();
487     else
488       return false;
489   }();
490 
491   QualType T = FD->getType()->isDependentType() || IsThisDependentType
492                    ? Context.DependentTy
493                    : lookupPromiseType(*this, FD, Loc);
494   if (T.isNull())
495     return nullptr;
496 
497   auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
498                              &PP.getIdentifierTable().get("__promise"), T,
499                              Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
500   CheckVariableDeclarationType(VD);
501   if (VD->isInvalidDecl())
502     return nullptr;
503 
504   auto *ScopeInfo = getCurFunction();
505 
506   // Build a list of arguments, based on the coroutine function's arguments,
507   // that if present will be passed to the promise type's constructor.
508   llvm::SmallVector<Expr *, 4> CtorArgExprs;
509 
510   // Add implicit object parameter.
511   if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
512     if (MD->isInstance() && !isLambdaCallOperator(MD)) {
513       ExprResult ThisExpr = ActOnCXXThis(Loc);
514       if (ThisExpr.isInvalid())
515         return nullptr;
516       ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
517       if (ThisExpr.isInvalid())
518         return nullptr;
519       CtorArgExprs.push_back(ThisExpr.get());
520     }
521   }
522 
523   // Add the coroutine function's parameters.
524   auto &Moves = ScopeInfo->CoroutineParameterMoves;
525   for (auto *PD : FD->parameters()) {
526     if (PD->getType()->isDependentType())
527       continue;
528 
529     auto RefExpr = ExprEmpty();
530     auto Move = Moves.find(PD);
531     assert(Move != Moves.end() &&
532            "Coroutine function parameter not inserted into move map");
533     // If a reference to the function parameter exists in the coroutine
534     // frame, use that reference.
535     auto *MoveDecl =
536         cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
537     RefExpr =
538         BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
539                          ExprValueKind::VK_LValue, FD->getLocation());
540     if (RefExpr.isInvalid())
541       return nullptr;
542     CtorArgExprs.push_back(RefExpr.get());
543   }
544 
545   // If we have a non-zero number of constructor arguments, try to use them.
546   // Otherwise, fall back to the promise type's default constructor.
547   if (!CtorArgExprs.empty()) {
548     // Create an initialization sequence for the promise type using the
549     // constructor arguments, wrapped in a parenthesized list expression.
550     Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(),
551                                       CtorArgExprs, FD->getLocation());
552     InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
553     InitializationKind Kind = InitializationKind::CreateForInit(
554         VD->getLocation(), /*DirectInit=*/true, PLE);
555     InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
556                                    /*TopLevelOfInitList=*/false,
557                                    /*TreatUnavailableAsInvalid=*/false);
558 
559     // Attempt to initialize the promise type with the arguments.
560     // If that fails, fall back to the promise type's default constructor.
561     if (InitSeq) {
562       ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
563       if (Result.isInvalid()) {
564         VD->setInvalidDecl();
565       } else if (Result.get()) {
566         VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
567         VD->setInitStyle(VarDecl::CallInit);
568         CheckCompleteVariableDeclaration(VD);
569       }
570     } else
571       ActOnUninitializedDecl(VD);
572   } else
573     ActOnUninitializedDecl(VD);
574 
575   FD->addDecl(VD);
576   return VD;
577 }
578 
579 /// Check that this is a context in which a coroutine suspension can appear.
580 static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
581                                                 StringRef Keyword,
582                                                 bool IsImplicit = false) {
583   if (!isValidCoroutineContext(S, Loc, Keyword))
584     return nullptr;
585 
586   assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
587 
588   auto *ScopeInfo = S.getCurFunction();
589   assert(ScopeInfo && "missing function scope for function");
590 
591   if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit)
592     ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
593 
594   if (ScopeInfo->CoroutinePromise)
595     return ScopeInfo;
596 
597   if (!S.buildCoroutineParameterMoves(Loc))
598     return nullptr;
599 
600   ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
601   if (!ScopeInfo->CoroutinePromise)
602     return nullptr;
603 
604   return ScopeInfo;
605 }
606 
607 bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
608                                    StringRef Keyword) {
609   if (!checkCoroutineContext(*this, KWLoc, Keyword))
610     return false;
611   auto *ScopeInfo = getCurFunction();
612   assert(ScopeInfo->CoroutinePromise);
613 
614   // If we have existing coroutine statements then we have already built
615   // the initial and final suspend points.
616   if (!ScopeInfo->NeedsCoroutineSuspends)
617     return true;
618 
619   ScopeInfo->setNeedsCoroutineSuspends(false);
620 
621   auto *Fn = cast<FunctionDecl>(CurContext);
622   SourceLocation Loc = Fn->getLocation();
623   // Build the initial suspend point
624   auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
625     ExprResult Suspend =
626         buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
627     if (Suspend.isInvalid())
628       return StmtError();
629     Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
630     if (Suspend.isInvalid())
631       return StmtError();
632     Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
633                                        /*IsImplicit*/ true);
634     Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
635     if (Suspend.isInvalid()) {
636       Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
637           << ((Name == "initial_suspend") ? 0 : 1);
638       Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
639       return StmtError();
640     }
641     return cast<Stmt>(Suspend.get());
642   };
643 
644   StmtResult InitSuspend = buildSuspends("initial_suspend");
645   if (InitSuspend.isInvalid())
646     return true;
647 
648   StmtResult FinalSuspend = buildSuspends("final_suspend");
649   if (FinalSuspend.isInvalid())
650     return true;
651 
652   ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
653 
654   return true;
655 }
656 
657 // Recursively walks up the scope hierarchy until either a 'catch' or a function
658 // scope is found, whichever comes first.
659 static bool isWithinCatchScope(Scope *S) {
660   // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but
661   // lambdas that use 'co_await' are allowed. The loop below ends when a
662   // function scope is found in order to ensure the following behavior:
663   //
664   // void foo() {      // <- function scope
665   //   try {           //
666   //     co_await x;   // <- 'co_await' is OK within a function scope
667   //   } catch {       // <- catch scope
668   //     co_await x;   // <- 'co_await' is not OK within a catch scope
669   //     []() {        // <- function scope
670   //       co_await x; // <- 'co_await' is OK within a function scope
671   //     }();
672   //   }
673   // }
674   while (S && !(S->getFlags() & Scope::FnScope)) {
675     if (S->getFlags() & Scope::CatchScope)
676       return true;
677     S = S->getParent();
678   }
679   return false;
680 }
681 
682 // [expr.await]p2, emphasis added: "An await-expression shall appear only in
683 // a *potentially evaluated* expression within the compound-statement of a
684 // function-body *outside of a handler* [...] A context within a function
685 // where an await-expression can appear is called a suspension context of the
686 // function."
687 static void checkSuspensionContext(Sema &S, SourceLocation Loc,
688                                    StringRef Keyword) {
689   // First emphasis of [expr.await]p2: must be a potentially evaluated context.
690   // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of
691   // \c sizeof.
692   if (S.isUnevaluatedContext())
693     S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
694 
695   // Second emphasis of [expr.await]p2: must be outside of an exception handler.
696   if (isWithinCatchScope(S.getCurScope()))
697     S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword;
698 }
699 
700 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
701   if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
702     CorrectDelayedTyposInExpr(E);
703     return ExprError();
704   }
705 
706   checkSuspensionContext(*this, Loc, "co_await");
707 
708   if (E->getType()->isPlaceholderType()) {
709     ExprResult R = CheckPlaceholderExpr(E);
710     if (R.isInvalid()) return ExprError();
711     E = R.get();
712   }
713   ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
714   if (Lookup.isInvalid())
715     return ExprError();
716   return BuildUnresolvedCoawaitExpr(Loc, E,
717                                    cast<UnresolvedLookupExpr>(Lookup.get()));
718 }
719 
720 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
721                                             UnresolvedLookupExpr *Lookup) {
722   auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
723   if (!FSI)
724     return ExprError();
725 
726   if (E->getType()->isPlaceholderType()) {
727     ExprResult R = CheckPlaceholderExpr(E);
728     if (R.isInvalid())
729       return ExprError();
730     E = R.get();
731   }
732 
733   auto *Promise = FSI->CoroutinePromise;
734   if (Promise->getType()->isDependentType()) {
735     Expr *Res =
736         new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
737     return Res;
738   }
739 
740   auto *RD = Promise->getType()->getAsCXXRecordDecl();
741   if (lookupMember(*this, "await_transform", RD, Loc)) {
742     ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
743     if (R.isInvalid()) {
744       Diag(Loc,
745            diag::note_coroutine_promise_implicit_await_transform_required_here)
746           << E->getSourceRange();
747       return ExprError();
748     }
749     E = R.get();
750   }
751   ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
752   if (Awaitable.isInvalid())
753     return ExprError();
754 
755   return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
756 }
757 
758 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
759                                   bool IsImplicit) {
760   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
761   if (!Coroutine)
762     return ExprError();
763 
764   if (E->getType()->isPlaceholderType()) {
765     ExprResult R = CheckPlaceholderExpr(E);
766     if (R.isInvalid()) return ExprError();
767     E = R.get();
768   }
769 
770   if (E->getType()->isDependentType()) {
771     Expr *Res = new (Context)
772         CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
773     return Res;
774   }
775 
776   // If the expression is a temporary, materialize it as an lvalue so that we
777   // can use it multiple times.
778   if (E->getValueKind() == VK_RValue)
779     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
780 
781   // The location of the `co_await` token cannot be used when constructing
782   // the member call expressions since it's before the location of `Expr`, which
783   // is used as the start of the member call expression.
784   SourceLocation CallLoc = E->getExprLoc();
785 
786   // Build the await_ready, await_suspend, await_resume calls.
787   ReadySuspendResumeResult RSS =
788       buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E);
789   if (RSS.IsInvalid)
790     return ExprError();
791 
792   Expr *Res =
793       new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
794                                 RSS.Results[2], RSS.OpaqueValue, IsImplicit);
795 
796   return Res;
797 }
798 
799 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
800   if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
801     CorrectDelayedTyposInExpr(E);
802     return ExprError();
803   }
804 
805   checkSuspensionContext(*this, Loc, "co_yield");
806 
807   // Build yield_value call.
808   ExprResult Awaitable = buildPromiseCall(
809       *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
810   if (Awaitable.isInvalid())
811     return ExprError();
812 
813   // Build 'operator co_await' call.
814   Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
815   if (Awaitable.isInvalid())
816     return ExprError();
817 
818   return BuildCoyieldExpr(Loc, Awaitable.get());
819 }
820 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
821   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
822   if (!Coroutine)
823     return ExprError();
824 
825   if (E->getType()->isPlaceholderType()) {
826     ExprResult R = CheckPlaceholderExpr(E);
827     if (R.isInvalid()) return ExprError();
828     E = R.get();
829   }
830 
831   if (E->getType()->isDependentType()) {
832     Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
833     return Res;
834   }
835 
836   // If the expression is a temporary, materialize it as an lvalue so that we
837   // can use it multiple times.
838   if (E->getValueKind() == VK_RValue)
839     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
840 
841   // Build the await_ready, await_suspend, await_resume calls.
842   ReadySuspendResumeResult RSS =
843       buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
844   if (RSS.IsInvalid)
845     return ExprError();
846 
847   Expr *Res =
848       new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
849                                 RSS.Results[2], RSS.OpaqueValue);
850 
851   return Res;
852 }
853 
854 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
855   if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
856     CorrectDelayedTyposInExpr(E);
857     return StmtError();
858   }
859   return BuildCoreturnStmt(Loc, E);
860 }
861 
862 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
863                                    bool IsImplicit) {
864   auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
865   if (!FSI)
866     return StmtError();
867 
868   if (E && E->getType()->isPlaceholderType() &&
869       !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
870     ExprResult R = CheckPlaceholderExpr(E);
871     if (R.isInvalid()) return StmtError();
872     E = R.get();
873   }
874 
875   // Move the return value if we can
876   if (E) {
877     auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove);
878     if (NRVOCandidate) {
879       InitializedEntity Entity =
880           InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate);
881       ExprResult MoveResult = this->PerformMoveOrCopyInitialization(
882           Entity, NRVOCandidate, E->getType(), E);
883       if (MoveResult.get())
884         E = MoveResult.get();
885     }
886   }
887 
888   // FIXME: If the operand is a reference to a variable that's about to go out
889   // of scope, we should treat the operand as an xvalue for this overload
890   // resolution.
891   VarDecl *Promise = FSI->CoroutinePromise;
892   ExprResult PC;
893   if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
894     PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
895   } else {
896     E = MakeFullDiscardedValueExpr(E).get();
897     PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
898   }
899   if (PC.isInvalid())
900     return StmtError();
901 
902   Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get();
903 
904   Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
905   return Res;
906 }
907 
908 /// Look up the std::nothrow object.
909 static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
910   NamespaceDecl *Std = S.getStdNamespace();
911   assert(Std && "Should already be diagnosed");
912 
913   LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
914                       Sema::LookupOrdinaryName);
915   if (!S.LookupQualifiedName(Result, Std)) {
916     // FIXME: <experimental/coroutine> should have been included already.
917     // If we require it to include <new> then this diagnostic is no longer
918     // needed.
919     S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
920     return nullptr;
921   }
922 
923   auto *VD = Result.getAsSingle<VarDecl>();
924   if (!VD) {
925     Result.suppressDiagnostics();
926     // We found something weird. Complain about the first thing we found.
927     NamedDecl *Found = *Result.begin();
928     S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
929     return nullptr;
930   }
931 
932   ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
933   if (DR.isInvalid())
934     return nullptr;
935 
936   return DR.get();
937 }
938 
939 // Find an appropriate delete for the promise.
940 static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
941                                           QualType PromiseType) {
942   FunctionDecl *OperatorDelete = nullptr;
943 
944   DeclarationName DeleteName =
945       S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
946 
947   auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
948   assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
949 
950   if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
951     return nullptr;
952 
953   if (!OperatorDelete) {
954     // Look for a global declaration.
955     const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
956     const bool Overaligned = false;
957     OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
958                                                      Overaligned, DeleteName);
959   }
960   S.MarkFunctionReferenced(Loc, OperatorDelete);
961   return OperatorDelete;
962 }
963 
964 
965 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
966   FunctionScopeInfo *Fn = getCurFunction();
967   assert(Fn && Fn->isCoroutine() && "not a coroutine");
968   if (!Body) {
969     assert(FD->isInvalidDecl() &&
970            "a null body is only allowed for invalid declarations");
971     return;
972   }
973   // We have a function that uses coroutine keywords, but we failed to build
974   // the promise type.
975   if (!Fn->CoroutinePromise)
976     return FD->setInvalidDecl();
977 
978   if (isa<CoroutineBodyStmt>(Body)) {
979     // Nothing todo. the body is already a transformed coroutine body statement.
980     return;
981   }
982 
983   // Coroutines [stmt.return]p1:
984   //   A return statement shall not appear in a coroutine.
985   if (Fn->FirstReturnLoc.isValid()) {
986     assert(Fn->FirstCoroutineStmtLoc.isValid() &&
987                    "first coroutine location not set");
988     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
989     Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
990             << Fn->getFirstCoroutineStmtKeyword();
991   }
992   CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
993   if (Builder.isInvalid() || !Builder.buildStatements())
994     return FD->setInvalidDecl();
995 
996   // Build body for the coroutine wrapper statement.
997   Body = CoroutineBodyStmt::Create(Context, Builder);
998 }
999 
1000 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
1001                                            sema::FunctionScopeInfo &Fn,
1002                                            Stmt *Body)
1003     : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
1004       IsPromiseDependentType(
1005           !Fn.CoroutinePromise ||
1006           Fn.CoroutinePromise->getType()->isDependentType()) {
1007   this->Body = Body;
1008 
1009   for (auto KV : Fn.CoroutineParameterMoves)
1010     this->ParamMovesVector.push_back(KV.second);
1011   this->ParamMoves = this->ParamMovesVector;
1012 
1013   if (!IsPromiseDependentType) {
1014     PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
1015     assert(PromiseRecordDecl && "Type should have already been checked");
1016   }
1017   this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
1018 }
1019 
1020 bool CoroutineStmtBuilder::buildStatements() {
1021   assert(this->IsValid && "coroutine already invalid");
1022   this->IsValid = makeReturnObject();
1023   if (this->IsValid && !IsPromiseDependentType)
1024     buildDependentStatements();
1025   return this->IsValid;
1026 }
1027 
1028 bool CoroutineStmtBuilder::buildDependentStatements() {
1029   assert(this->IsValid && "coroutine already invalid");
1030   assert(!this->IsPromiseDependentType &&
1031          "coroutine cannot have a dependent promise type");
1032   this->IsValid = makeOnException() && makeOnFallthrough() &&
1033                   makeGroDeclAndReturnStmt() && makeReturnOnAllocFailure() &&
1034                   makeNewAndDeleteExpr();
1035   return this->IsValid;
1036 }
1037 
1038 bool CoroutineStmtBuilder::makePromiseStmt() {
1039   // Form a declaration statement for the promise declaration, so that AST
1040   // visitors can more easily find it.
1041   StmtResult PromiseStmt =
1042       S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
1043   if (PromiseStmt.isInvalid())
1044     return false;
1045 
1046   this->Promise = PromiseStmt.get();
1047   return true;
1048 }
1049 
1050 bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
1051   if (Fn.hasInvalidCoroutineSuspends())
1052     return false;
1053   this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
1054   this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
1055   return true;
1056 }
1057 
1058 static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
1059                                      CXXRecordDecl *PromiseRecordDecl,
1060                                      FunctionScopeInfo &Fn) {
1061   auto Loc = E->getExprLoc();
1062   if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
1063     auto *Decl = DeclRef->getDecl();
1064     if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
1065       if (Method->isStatic())
1066         return true;
1067       else
1068         Loc = Decl->getLocation();
1069     }
1070   }
1071 
1072   S.Diag(
1073       Loc,
1074       diag::err_coroutine_promise_get_return_object_on_allocation_failure)
1075       << PromiseRecordDecl;
1076   S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1077       << Fn.getFirstCoroutineStmtKeyword();
1078   return false;
1079 }
1080 
1081 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
1082   assert(!IsPromiseDependentType &&
1083          "cannot make statement while the promise type is dependent");
1084 
1085   // [dcl.fct.def.coroutine]/8
1086   // The unqualified-id get_return_object_on_allocation_failure is looked up in
1087   // the scope of class P by class member access lookup (3.4.5). ...
1088   // If an allocation function returns nullptr, ... the coroutine return value
1089   // is obtained by a call to ... get_return_object_on_allocation_failure().
1090 
1091   DeclarationName DN =
1092       S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
1093   LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
1094   if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
1095     return true;
1096 
1097   CXXScopeSpec SS;
1098   ExprResult DeclNameExpr =
1099       S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
1100   if (DeclNameExpr.isInvalid())
1101     return false;
1102 
1103   if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
1104     return false;
1105 
1106   ExprResult ReturnObjectOnAllocationFailure =
1107       S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
1108   if (ReturnObjectOnAllocationFailure.isInvalid())
1109     return false;
1110 
1111   StmtResult ReturnStmt =
1112       S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
1113   if (ReturnStmt.isInvalid()) {
1114     S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
1115         << DN;
1116     S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1117         << Fn.getFirstCoroutineStmtKeyword();
1118     return false;
1119   }
1120 
1121   this->ReturnStmtOnAllocFailure = ReturnStmt.get();
1122   return true;
1123 }
1124 
1125 bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
1126   // Form and check allocation and deallocation calls.
1127   assert(!IsPromiseDependentType &&
1128          "cannot make statement while the promise type is dependent");
1129   QualType PromiseType = Fn.CoroutinePromise->getType();
1130 
1131   if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
1132     return false;
1133 
1134   const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
1135 
1136   // [dcl.fct.def.coroutine]/7
1137   // Lookup allocation functions using a parameter list composed of the
1138   // requested size of the coroutine state being allocated, followed by
1139   // the coroutine function's arguments. If a matching allocation function
1140   // exists, use it. Otherwise, use an allocation function that just takes
1141   // the requested size.
1142 
1143   FunctionDecl *OperatorNew = nullptr;
1144   FunctionDecl *OperatorDelete = nullptr;
1145   FunctionDecl *UnusedResult = nullptr;
1146   bool PassAlignment = false;
1147   SmallVector<Expr *, 1> PlacementArgs;
1148 
1149   // [dcl.fct.def.coroutine]/7
1150   // "The allocation function’s name is looked up in the scope of P.
1151   // [...] If the lookup finds an allocation function in the scope of P,
1152   // overload resolution is performed on a function call created by assembling
1153   // an argument list. The first argument is the amount of space requested,
1154   // and has type std::size_t. The lvalues p1 ... pn are the succeeding
1155   // arguments."
1156   //
1157   // ...where "p1 ... pn" are defined earlier as:
1158   //
1159   // [dcl.fct.def.coroutine]/3
1160   // "For a coroutine f that is a non-static member function, let P1 denote the
1161   // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
1162   // of the function parameters; otherwise let P1 ... Pn be the types of the
1163   // function parameters. Let p1 ... pn be lvalues denoting those objects."
1164   if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
1165     if (MD->isInstance() && !isLambdaCallOperator(MD)) {
1166       ExprResult ThisExpr = S.ActOnCXXThis(Loc);
1167       if (ThisExpr.isInvalid())
1168         return false;
1169       ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
1170       if (ThisExpr.isInvalid())
1171         return false;
1172       PlacementArgs.push_back(ThisExpr.get());
1173     }
1174   }
1175   for (auto *PD : FD.parameters()) {
1176     if (PD->getType()->isDependentType())
1177       continue;
1178 
1179     // Build a reference to the parameter.
1180     auto PDLoc = PD->getLocation();
1181     ExprResult PDRefExpr =
1182         S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
1183                            ExprValueKind::VK_LValue, PDLoc);
1184     if (PDRefExpr.isInvalid())
1185       return false;
1186 
1187     PlacementArgs.push_back(PDRefExpr.get());
1188   }
1189   S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1190                             /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1191                             /*isArray*/ false, PassAlignment, PlacementArgs,
1192                             OperatorNew, UnusedResult, /*Diagnose*/ false);
1193 
1194   // [dcl.fct.def.coroutine]/7
1195   // "If no matching function is found, overload resolution is performed again
1196   // on a function call created by passing just the amount of space required as
1197   // an argument of type std::size_t."
1198   if (!OperatorNew && !PlacementArgs.empty()) {
1199     PlacementArgs.clear();
1200     S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1201                               /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1202                               /*isArray*/ false, PassAlignment, PlacementArgs,
1203                               OperatorNew, UnusedResult, /*Diagnose*/ false);
1204   }
1205 
1206   // [dcl.fct.def.coroutine]/7
1207   // "The allocation function’s name is looked up in the scope of P. If this
1208   // lookup fails, the allocation function’s name is looked up in the global
1209   // scope."
1210   if (!OperatorNew) {
1211     S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global,
1212                               /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1213                               /*isArray*/ false, PassAlignment, PlacementArgs,
1214                               OperatorNew, UnusedResult);
1215   }
1216 
1217   bool IsGlobalOverload =
1218       OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1219   // If we didn't find a class-local new declaration and non-throwing new
1220   // was is required then we need to lookup the non-throwing global operator
1221   // instead.
1222   if (RequiresNoThrowAlloc && (!OperatorNew || IsGlobalOverload)) {
1223     auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1224     if (!StdNoThrow)
1225       return false;
1226     PlacementArgs = {StdNoThrow};
1227     OperatorNew = nullptr;
1228     S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both,
1229                               /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1230                               /*isArray*/ false, PassAlignment, PlacementArgs,
1231                               OperatorNew, UnusedResult);
1232   }
1233 
1234   if (!OperatorNew)
1235     return false;
1236 
1237   if (RequiresNoThrowAlloc) {
1238     const auto *FT = OperatorNew->getType()->castAs<FunctionProtoType>();
1239     if (!FT->isNothrow(/*ResultIfDependent*/ false)) {
1240       S.Diag(OperatorNew->getLocation(),
1241              diag::err_coroutine_promise_new_requires_nothrow)
1242           << OperatorNew;
1243       S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1244           << OperatorNew;
1245       return false;
1246     }
1247   }
1248 
1249   if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr)
1250     return false;
1251 
1252   Expr *FramePtr =
1253       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1254 
1255   Expr *FrameSize =
1256       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1257 
1258   // Make new call.
1259 
1260   ExprResult NewRef =
1261       S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1262   if (NewRef.isInvalid())
1263     return false;
1264 
1265   SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1266   for (auto Arg : PlacementArgs)
1267     NewArgs.push_back(Arg);
1268 
1269   ExprResult NewExpr =
1270       S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1271   NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false);
1272   if (NewExpr.isInvalid())
1273     return false;
1274 
1275   // Make delete call.
1276 
1277   QualType OpDeleteQualType = OperatorDelete->getType();
1278 
1279   ExprResult DeleteRef =
1280       S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1281   if (DeleteRef.isInvalid())
1282     return false;
1283 
1284   Expr *CoroFree =
1285       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1286 
1287   SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1288 
1289   // Check if we need to pass the size.
1290   const auto *OpDeleteType =
1291       OpDeleteQualType.getTypePtr()->castAs<FunctionProtoType>();
1292   if (OpDeleteType->getNumParams() > 1)
1293     DeleteArgs.push_back(FrameSize);
1294 
1295   ExprResult DeleteExpr =
1296       S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1297   DeleteExpr =
1298       S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false);
1299   if (DeleteExpr.isInvalid())
1300     return false;
1301 
1302   this->Allocate = NewExpr.get();
1303   this->Deallocate = DeleteExpr.get();
1304 
1305   return true;
1306 }
1307 
1308 bool CoroutineStmtBuilder::makeOnFallthrough() {
1309   assert(!IsPromiseDependentType &&
1310          "cannot make statement while the promise type is dependent");
1311 
1312   // [dcl.fct.def.coroutine]/4
1313   // The unqualified-ids 'return_void' and 'return_value' are looked up in
1314   // the scope of class P. If both are found, the program is ill-formed.
1315   bool HasRVoid, HasRValue;
1316   LookupResult LRVoid =
1317       lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1318   LookupResult LRValue =
1319       lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1320 
1321   StmtResult Fallthrough;
1322   if (HasRVoid && HasRValue) {
1323     // FIXME Improve this diagnostic
1324     S.Diag(FD.getLocation(),
1325            diag::err_coroutine_promise_incompatible_return_functions)
1326         << PromiseRecordDecl;
1327     S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1328            diag::note_member_first_declared_here)
1329         << LRVoid.getLookupName();
1330     S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1331            diag::note_member_first_declared_here)
1332         << LRValue.getLookupName();
1333     return false;
1334   } else if (!HasRVoid && !HasRValue) {
1335     // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1336     // However we still diagnose this as an error since until the PDTS is fixed.
1337     S.Diag(FD.getLocation(),
1338            diag::err_coroutine_promise_requires_return_function)
1339         << PromiseRecordDecl;
1340     S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1341         << PromiseRecordDecl;
1342     return false;
1343   } else if (HasRVoid) {
1344     // If the unqualified-id return_void is found, flowing off the end of a
1345     // coroutine is equivalent to a co_return with no operand. Otherwise,
1346     // flowing off the end of a coroutine results in undefined behavior.
1347     Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1348                                       /*IsImplicit*/false);
1349     Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1350     if (Fallthrough.isInvalid())
1351       return false;
1352   }
1353 
1354   this->OnFallthrough = Fallthrough.get();
1355   return true;
1356 }
1357 
1358 bool CoroutineStmtBuilder::makeOnException() {
1359   // Try to form 'p.unhandled_exception();'
1360   assert(!IsPromiseDependentType &&
1361          "cannot make statement while the promise type is dependent");
1362 
1363   const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1364 
1365   if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) {
1366     auto DiagID =
1367         RequireUnhandledException
1368             ? diag::err_coroutine_promise_unhandled_exception_required
1369             : diag::
1370                   warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1371     S.Diag(Loc, DiagID) << PromiseRecordDecl;
1372     S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1373         << PromiseRecordDecl;
1374     return !RequireUnhandledException;
1375   }
1376 
1377   // If exceptions are disabled, don't try to build OnException.
1378   if (!S.getLangOpts().CXXExceptions)
1379     return true;
1380 
1381   ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1382                                                    "unhandled_exception", None);
1383   UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc,
1384                                              /*DiscardedValue*/ false);
1385   if (UnhandledException.isInvalid())
1386     return false;
1387 
1388   // Since the body of the coroutine will be wrapped in try-catch, it will
1389   // be incompatible with SEH __try if present in a function.
1390   if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) {
1391     S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1392     S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1393         << Fn.getFirstCoroutineStmtKeyword();
1394     return false;
1395   }
1396 
1397   this->OnException = UnhandledException.get();
1398   return true;
1399 }
1400 
1401 bool CoroutineStmtBuilder::makeReturnObject() {
1402   // Build implicit 'p.get_return_object()' expression and form initialization
1403   // of return type from it.
1404   ExprResult ReturnObject =
1405       buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1406   if (ReturnObject.isInvalid())
1407     return false;
1408 
1409   this->ReturnValue = ReturnObject.get();
1410   return true;
1411 }
1412 
1413 static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1414   if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) {
1415     auto *MethodDecl = MbrRef->getMethodDecl();
1416     S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1417         << MethodDecl;
1418   }
1419   S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1420       << Fn.getFirstCoroutineStmtKeyword();
1421 }
1422 
1423 bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1424   assert(!IsPromiseDependentType &&
1425          "cannot make statement while the promise type is dependent");
1426   assert(this->ReturnValue && "ReturnValue must be already formed");
1427 
1428   QualType const GroType = this->ReturnValue->getType();
1429   assert(!GroType->isDependentType() &&
1430          "get_return_object type must no longer be dependent");
1431 
1432   QualType const FnRetType = FD.getReturnType();
1433   assert(!FnRetType->isDependentType() &&
1434          "get_return_object type must no longer be dependent");
1435 
1436   if (FnRetType->isVoidType()) {
1437     ExprResult Res =
1438         S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
1439     if (Res.isInvalid())
1440       return false;
1441 
1442     this->ResultDecl = Res.get();
1443     return true;
1444   }
1445 
1446   if (GroType->isVoidType()) {
1447     // Trigger a nice error message.
1448     InitializedEntity Entity =
1449         InitializedEntity::InitializeResult(Loc, FnRetType, false);
1450     S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1451     noteMemberDeclaredHere(S, ReturnValue, Fn);
1452     return false;
1453   }
1454 
1455   auto *GroDecl = VarDecl::Create(
1456       S.Context, &FD, FD.getLocation(), FD.getLocation(),
1457       &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1458       S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1459 
1460   S.CheckVariableDeclarationType(GroDecl);
1461   if (GroDecl->isInvalidDecl())
1462     return false;
1463 
1464   InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1465   ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1466                                                      this->ReturnValue);
1467   if (Res.isInvalid())
1468     return false;
1469 
1470   Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
1471   if (Res.isInvalid())
1472     return false;
1473 
1474   S.AddInitializerToDecl(GroDecl, Res.get(),
1475                          /*DirectInit=*/false);
1476 
1477   S.FinalizeDeclaration(GroDecl);
1478 
1479   // Form a declaration statement for the return declaration, so that AST
1480   // visitors can more easily find it.
1481   StmtResult GroDeclStmt =
1482       S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1483   if (GroDeclStmt.isInvalid())
1484     return false;
1485 
1486   this->ResultDecl = GroDeclStmt.get();
1487 
1488   ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1489   if (declRef.isInvalid())
1490     return false;
1491 
1492   StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1493   if (ReturnStmt.isInvalid()) {
1494     noteMemberDeclaredHere(S, ReturnValue, Fn);
1495     return false;
1496   }
1497   if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
1498     GroDecl->setNRVOVariable(true);
1499 
1500   this->ReturnStmt = ReturnStmt.get();
1501   return true;
1502 }
1503 
1504 // Create a static_cast\<T&&>(expr).
1505 static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1506   if (T.isNull())
1507     T = E->getType();
1508   QualType TargetType = S.BuildReferenceType(
1509       T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1510   SourceLocation ExprLoc = E->getBeginLoc();
1511   TypeSourceInfo *TargetLoc =
1512       S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1513 
1514   return S
1515       .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1516                          SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1517       .get();
1518 }
1519 
1520 /// Build a variable declaration for move parameter.
1521 static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1522                              IdentifierInfo *II) {
1523   TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1524   VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
1525                                   TInfo, SC_None);
1526   Decl->setImplicit();
1527   return Decl;
1528 }
1529 
1530 // Build statements that move coroutine function parameters to the coroutine
1531 // frame, and store them on the function scope info.
1532 bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
1533   assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
1534   auto *FD = cast<FunctionDecl>(CurContext);
1535 
1536   auto *ScopeInfo = getCurFunction();
1537   if (!ScopeInfo->CoroutineParameterMoves.empty())
1538     return false;
1539 
1540   for (auto *PD : FD->parameters()) {
1541     if (PD->getType()->isDependentType())
1542       continue;
1543 
1544     ExprResult PDRefExpr =
1545         BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
1546                          ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1547     if (PDRefExpr.isInvalid())
1548       return false;
1549 
1550     Expr *CExpr = nullptr;
1551     if (PD->getType()->getAsCXXRecordDecl() ||
1552         PD->getType()->isRValueReferenceType())
1553       CExpr = castForMoving(*this, PDRefExpr.get());
1554     else
1555       CExpr = PDRefExpr.get();
1556 
1557     auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
1558     AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
1559 
1560     // Convert decl to a statement.
1561     StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
1562     if (Stmt.isInvalid())
1563       return false;
1564 
1565     ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
1566   }
1567   return true;
1568 }
1569 
1570 StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1571   CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1572   if (!Res)
1573     return StmtError();
1574   return Res;
1575 }
1576 
1577 ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc,
1578                                                SourceLocation FuncLoc) {
1579   if (!StdCoroutineTraitsCache) {
1580     if (auto StdExp = lookupStdExperimentalNamespace()) {
1581       LookupResult Result(*this,
1582                           &PP.getIdentifierTable().get("coroutine_traits"),
1583                           FuncLoc, LookupOrdinaryName);
1584       if (!LookupQualifiedName(Result, StdExp)) {
1585         Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
1586             << "std::experimental::coroutine_traits";
1587         return nullptr;
1588       }
1589       if (!(StdCoroutineTraitsCache =
1590                 Result.getAsSingle<ClassTemplateDecl>())) {
1591         Result.suppressDiagnostics();
1592         NamedDecl *Found = *Result.begin();
1593         Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
1594         return nullptr;
1595       }
1596     }
1597   }
1598   return StdCoroutineTraitsCache;
1599 }
1600