1 //===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements semantic analysis for C++ Coroutines. 10 // 11 // This file contains references to sections of the Coroutines TS, which 12 // can be found at http://wg21.link/coroutines. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "CoroutineStmtBuilder.h" 17 #include "clang/AST/ASTLambda.h" 18 #include "clang/AST/Decl.h" 19 #include "clang/AST/ExprCXX.h" 20 #include "clang/AST/StmtCXX.h" 21 #include "clang/Lex/Preprocessor.h" 22 #include "clang/Sema/Initialization.h" 23 #include "clang/Sema/Overload.h" 24 #include "clang/Sema/ScopeInfo.h" 25 #include "clang/Sema/SemaInternal.h" 26 27 using namespace clang; 28 using namespace sema; 29 30 static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, 31 SourceLocation Loc, bool &Res) { 32 DeclarationName DN = S.PP.getIdentifierInfo(Name); 33 LookupResult LR(S, DN, Loc, Sema::LookupMemberName); 34 // Suppress diagnostics when a private member is selected. The same warnings 35 // will be produced again when building the call. 36 LR.suppressDiagnostics(); 37 Res = S.LookupQualifiedName(LR, RD); 38 return LR; 39 } 40 41 static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, 42 SourceLocation Loc) { 43 bool Res; 44 lookupMember(S, Name, RD, Loc, Res); 45 return Res; 46 } 47 48 /// Look up the std::coroutine_traits<...>::promise_type for the given 49 /// function type. 50 static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD, 51 SourceLocation KwLoc) { 52 const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>(); 53 const SourceLocation FuncLoc = FD->getLocation(); 54 // FIXME: Cache std::coroutine_traits once we've found it. 55 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 56 if (!StdExp) { 57 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 58 << "std::experimental::coroutine_traits"; 59 return QualType(); 60 } 61 62 ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc); 63 if (!CoroTraits) { 64 return QualType(); 65 } 66 67 // Form template argument list for coroutine_traits<R, P1, P2, ...> according 68 // to [dcl.fct.def.coroutine]3 69 TemplateArgumentListInfo Args(KwLoc, KwLoc); 70 auto AddArg = [&](QualType T) { 71 Args.addArgument(TemplateArgumentLoc( 72 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); 73 }; 74 AddArg(FnType->getReturnType()); 75 // If the function is a non-static member function, add the type 76 // of the implicit object parameter before the formal parameters. 77 if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) { 78 if (MD->isInstance()) { 79 // [over.match.funcs]4 80 // For non-static member functions, the type of the implicit object 81 // parameter is 82 // -- "lvalue reference to cv X" for functions declared without a 83 // ref-qualifier or with the & ref-qualifier 84 // -- "rvalue reference to cv X" for functions declared with the && 85 // ref-qualifier 86 QualType T = MD->getThisType()->getAs<PointerType>()->getPointeeType(); 87 T = FnType->getRefQualifier() == RQ_RValue 88 ? S.Context.getRValueReferenceType(T) 89 : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true); 90 AddArg(T); 91 } 92 } 93 for (QualType T : FnType->getParamTypes()) 94 AddArg(T); 95 96 // Build the template-id. 97 QualType CoroTrait = 98 S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); 99 if (CoroTrait.isNull()) 100 return QualType(); 101 if (S.RequireCompleteType(KwLoc, CoroTrait, 102 diag::err_coroutine_type_missing_specialization)) 103 return QualType(); 104 105 auto *RD = CoroTrait->getAsCXXRecordDecl(); 106 assert(RD && "specialization of class template is not a class?"); 107 108 // Look up the ::promise_type member. 109 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, 110 Sema::LookupOrdinaryName); 111 S.LookupQualifiedName(R, RD); 112 auto *Promise = R.getAsSingle<TypeDecl>(); 113 if (!Promise) { 114 S.Diag(FuncLoc, 115 diag::err_implied_std_coroutine_traits_promise_type_not_found) 116 << RD; 117 return QualType(); 118 } 119 // The promise type is required to be a class type. 120 QualType PromiseType = S.Context.getTypeDeclType(Promise); 121 122 auto buildElaboratedType = [&]() { 123 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); 124 NNS = NestedNameSpecifier::Create(S.Context, NNS, false, 125 CoroTrait.getTypePtr()); 126 return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); 127 }; 128 129 if (!PromiseType->getAsCXXRecordDecl()) { 130 S.Diag(FuncLoc, 131 diag::err_implied_std_coroutine_traits_promise_type_not_class) 132 << buildElaboratedType(); 133 return QualType(); 134 } 135 if (S.RequireCompleteType(FuncLoc, buildElaboratedType(), 136 diag::err_coroutine_promise_type_incomplete)) 137 return QualType(); 138 139 return PromiseType; 140 } 141 142 /// Look up the std::experimental::coroutine_handle<PromiseType>. 143 static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, 144 SourceLocation Loc) { 145 if (PromiseType.isNull()) 146 return QualType(); 147 148 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 149 assert(StdExp && "Should already be diagnosed"); 150 151 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"), 152 Loc, Sema::LookupOrdinaryName); 153 if (!S.LookupQualifiedName(Result, StdExp)) { 154 S.Diag(Loc, diag::err_implied_coroutine_type_not_found) 155 << "std::experimental::coroutine_handle"; 156 return QualType(); 157 } 158 159 ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); 160 if (!CoroHandle) { 161 Result.suppressDiagnostics(); 162 // We found something weird. Complain about the first thing we found. 163 NamedDecl *Found = *Result.begin(); 164 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); 165 return QualType(); 166 } 167 168 // Form template argument list for coroutine_handle<Promise>. 169 TemplateArgumentListInfo Args(Loc, Loc); 170 Args.addArgument(TemplateArgumentLoc( 171 TemplateArgument(PromiseType), 172 S.Context.getTrivialTypeSourceInfo(PromiseType, Loc))); 173 174 // Build the template-id. 175 QualType CoroHandleType = 176 S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args); 177 if (CoroHandleType.isNull()) 178 return QualType(); 179 if (S.RequireCompleteType(Loc, CoroHandleType, 180 diag::err_coroutine_type_missing_specialization)) 181 return QualType(); 182 183 return CoroHandleType; 184 } 185 186 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, 187 StringRef Keyword) { 188 // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within 189 // a function body. 190 // FIXME: This also covers [expr.await]p2: "An await-expression shall not 191 // appear in a default argument." But the diagnostic QoI here could be 192 // improved to inform the user that default arguments specifically are not 193 // allowed. 194 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 195 if (!FD) { 196 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 197 ? diag::err_coroutine_objc_method 198 : diag::err_coroutine_outside_function) << Keyword; 199 return false; 200 } 201 202 // An enumeration for mapping the diagnostic type to the correct diagnostic 203 // selection index. 204 enum InvalidFuncDiag { 205 DiagCtor = 0, 206 DiagDtor, 207 DiagCopyAssign, 208 DiagMoveAssign, 209 DiagMain, 210 DiagConstexpr, 211 DiagAutoRet, 212 DiagVarargs, 213 }; 214 bool Diagnosed = false; 215 auto DiagInvalid = [&](InvalidFuncDiag ID) { 216 S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword; 217 Diagnosed = true; 218 return false; 219 }; 220 221 // Diagnose when a constructor, destructor, copy/move assignment operator, 222 // or the function 'main' are declared as a coroutine. 223 auto *MD = dyn_cast<CXXMethodDecl>(FD); 224 // [class.ctor]p6: "A constructor shall not be a coroutine." 225 if (MD && isa<CXXConstructorDecl>(MD)) 226 return DiagInvalid(DiagCtor); 227 // [class.dtor]p17: "A destructor shall not be a coroutine." 228 else if (MD && isa<CXXDestructorDecl>(MD)) 229 return DiagInvalid(DiagDtor); 230 // N4499 [special]p6: "A special member function shall not be a coroutine." 231 // Per C++ [special]p1, special member functions are the "default constructor, 232 // copy constructor and copy assignment operator, move constructor and move 233 // assignment operator, and destructor." 234 else if (MD && MD->isCopyAssignmentOperator()) 235 return DiagInvalid(DiagCopyAssign); 236 else if (MD && MD->isMoveAssignmentOperator()) 237 return DiagInvalid(DiagMoveAssign); 238 // [basic.start.main]p3: "The function main shall not be a coroutine." 239 else if (FD->isMain()) 240 return DiagInvalid(DiagMain); 241 242 // Emit a diagnostics for each of the following conditions which is not met. 243 // [expr.const]p2: "An expression e is a core constant expression unless the 244 // evaluation of e [...] would evaluate one of the following expressions: 245 // [...] an await-expression [...] a yield-expression." 246 if (FD->isConstexpr()) 247 DiagInvalid(DiagConstexpr); 248 // [dcl.spec.auto]p15: "A function declared with a return type that uses a 249 // placeholder type shall not be a coroutine." 250 if (FD->getReturnType()->isUndeducedType()) 251 DiagInvalid(DiagAutoRet); 252 // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the 253 // coroutine shall not terminate with an ellipsis that is not part of a 254 // parameter-declaration." 255 if (FD->isVariadic()) 256 DiagInvalid(DiagVarargs); 257 258 return !Diagnosed; 259 } 260 261 static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, 262 SourceLocation Loc) { 263 DeclarationName OpName = 264 SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); 265 LookupResult Operators(SemaRef, OpName, SourceLocation(), 266 Sema::LookupOperatorName); 267 SemaRef.LookupName(Operators, S); 268 269 assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); 270 const auto &Functions = Operators.asUnresolvedSet(); 271 bool IsOverloaded = 272 Functions.size() > 1 || 273 (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); 274 Expr *CoawaitOp = UnresolvedLookupExpr::Create( 275 SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), 276 DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, 277 Functions.begin(), Functions.end()); 278 assert(CoawaitOp); 279 return CoawaitOp; 280 } 281 282 /// Build a call to 'operator co_await' if there is a suitable operator for 283 /// the given expression. 284 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, 285 Expr *E, 286 UnresolvedLookupExpr *Lookup) { 287 UnresolvedSet<16> Functions; 288 Functions.append(Lookup->decls_begin(), Lookup->decls_end()); 289 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 290 } 291 292 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 293 SourceLocation Loc, Expr *E) { 294 ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); 295 if (R.isInvalid()) 296 return ExprError(); 297 return buildOperatorCoawaitCall(SemaRef, Loc, E, 298 cast<UnresolvedLookupExpr>(R.get())); 299 } 300 301 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, 302 MultiExprArg CallArgs) { 303 StringRef Name = S.Context.BuiltinInfo.getName(Id); 304 LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); 305 S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); 306 307 auto *BuiltInDecl = R.getAsSingle<FunctionDecl>(); 308 assert(BuiltInDecl && "failed to find builtin declaration"); 309 310 ExprResult DeclRef = 311 S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc); 312 assert(DeclRef.isUsable() && "Builtin reference cannot fail"); 313 314 ExprResult Call = 315 S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc); 316 317 assert(!Call.isInvalid() && "Call to builtin cannot fail!"); 318 return Call.get(); 319 } 320 321 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType, 322 SourceLocation Loc) { 323 QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc); 324 if (CoroHandleType.isNull()) 325 return ExprError(); 326 327 DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType); 328 LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc, 329 Sema::LookupOrdinaryName); 330 if (!S.LookupQualifiedName(Found, LookupCtx)) { 331 S.Diag(Loc, diag::err_coroutine_handle_missing_member) 332 << "from_address"; 333 return ExprError(); 334 } 335 336 Expr *FramePtr = 337 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 338 339 CXXScopeSpec SS; 340 ExprResult FromAddr = 341 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 342 if (FromAddr.isInvalid()) 343 return ExprError(); 344 345 return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc); 346 } 347 348 struct ReadySuspendResumeResult { 349 enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume }; 350 Expr *Results[3]; 351 OpaqueValueExpr *OpaqueValue; 352 bool IsInvalid; 353 }; 354 355 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, 356 StringRef Name, MultiExprArg Args) { 357 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); 358 359 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 360 CXXScopeSpec SS; 361 ExprResult Result = S.BuildMemberReferenceExpr( 362 Base, Base->getType(), Loc, /*IsPtr=*/false, SS, 363 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 364 /*Scope=*/nullptr); 365 if (Result.isInvalid()) 366 return ExprError(); 367 368 // We meant exactly what we asked for. No need for typo correction. 369 if (auto *TE = dyn_cast<TypoExpr>(Result.get())) { 370 S.clearDelayedTypo(TE); 371 S.Diag(Loc, diag::err_no_member) 372 << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl() 373 << Base->getSourceRange(); 374 return ExprError(); 375 } 376 377 return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); 378 } 379 380 // See if return type is coroutine-handle and if so, invoke builtin coro-resume 381 // on its address. This is to enable experimental support for coroutine-handle 382 // returning await_suspend that results in a guaranteed tail call to the target 383 // coroutine. 384 static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, 385 SourceLocation Loc) { 386 if (RetType->isReferenceType()) 387 return nullptr; 388 Type const *T = RetType.getTypePtr(); 389 if (!T->isClassType() && !T->isStructureType()) 390 return nullptr; 391 392 // FIXME: Add convertability check to coroutine_handle<>. Possibly via 393 // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment 394 // a private function in SemaExprCXX.cpp 395 396 ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None); 397 if (AddressExpr.isInvalid()) 398 return nullptr; 399 400 Expr *JustAddress = AddressExpr.get(); 401 // FIXME: Check that the type of AddressExpr is void* 402 return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume, 403 JustAddress); 404 } 405 406 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 407 /// expression. 408 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, 409 SourceLocation Loc, Expr *E) { 410 OpaqueValueExpr *Operand = new (S.Context) 411 OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E); 412 413 // Assume invalid until we see otherwise. 414 ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true}; 415 416 ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc); 417 if (CoroHandleRes.isInvalid()) 418 return Calls; 419 Expr *CoroHandle = CoroHandleRes.get(); 420 421 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 422 MultiExprArg Args[] = {None, CoroHandle, None}; 423 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 424 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]); 425 if (Result.isInvalid()) 426 return Calls; 427 Calls.Results[I] = Result.get(); 428 } 429 430 // Assume the calls are valid; all further checking should make them invalid. 431 Calls.IsInvalid = false; 432 433 using ACT = ReadySuspendResumeResult::AwaitCallType; 434 CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]); 435 if (!AwaitReady->getType()->isDependentType()) { 436 // [expr.await]p3 [...] 437 // — await-ready is the expression e.await_ready(), contextually converted 438 // to bool. 439 ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady); 440 if (Conv.isInvalid()) { 441 S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(), 442 diag::note_await_ready_no_bool_conversion); 443 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 444 << AwaitReady->getDirectCallee() << E->getSourceRange(); 445 Calls.IsInvalid = true; 446 } 447 Calls.Results[ACT::ACT_Ready] = Conv.get(); 448 } 449 CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]); 450 if (!AwaitSuspend->getType()->isDependentType()) { 451 // [expr.await]p3 [...] 452 // - await-suspend is the expression e.await_suspend(h), which shall be 453 // a prvalue of type void or bool. 454 QualType RetType = AwaitSuspend->getCallReturnType(S.Context); 455 456 // Experimental support for coroutine_handle returning await_suspend. 457 if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc)) 458 Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; 459 else { 460 // non-class prvalues always have cv-unqualified types 461 if (RetType->isReferenceType() || 462 (!RetType->isBooleanType() && !RetType->isVoidType())) { 463 S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), 464 diag::err_await_suspend_invalid_return_type) 465 << RetType; 466 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 467 << AwaitSuspend->getDirectCallee(); 468 Calls.IsInvalid = true; 469 } 470 } 471 } 472 473 return Calls; 474 } 475 476 static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, 477 SourceLocation Loc, StringRef Name, 478 MultiExprArg Args) { 479 480 // Form a reference to the promise. 481 ExprResult PromiseRef = S.BuildDeclRefExpr( 482 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); 483 if (PromiseRef.isInvalid()) 484 return ExprError(); 485 486 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); 487 } 488 489 VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { 490 assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); 491 auto *FD = cast<FunctionDecl>(CurContext); 492 bool IsThisDependentType = [&] { 493 if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD)) 494 return MD->isInstance() && MD->getThisType()->isDependentType(); 495 else 496 return false; 497 }(); 498 499 QualType T = FD->getType()->isDependentType() || IsThisDependentType 500 ? Context.DependentTy 501 : lookupPromiseType(*this, FD, Loc); 502 if (T.isNull()) 503 return nullptr; 504 505 auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), 506 &PP.getIdentifierTable().get("__promise"), T, 507 Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 508 CheckVariableDeclarationType(VD); 509 if (VD->isInvalidDecl()) 510 return nullptr; 511 512 auto *ScopeInfo = getCurFunction(); 513 // Build a list of arguments, based on the coroutine functions arguments, 514 // that will be passed to the promise type's constructor. 515 llvm::SmallVector<Expr *, 4> CtorArgExprs; 516 517 // Add implicit object parameter. 518 if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) { 519 if (MD->isInstance() && !isLambdaCallOperator(MD)) { 520 ExprResult ThisExpr = ActOnCXXThis(Loc); 521 if (ThisExpr.isInvalid()) 522 return nullptr; 523 ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get()); 524 if (ThisExpr.isInvalid()) 525 return nullptr; 526 CtorArgExprs.push_back(ThisExpr.get()); 527 } 528 } 529 530 auto &Moves = ScopeInfo->CoroutineParameterMoves; 531 for (auto *PD : FD->parameters()) { 532 if (PD->getType()->isDependentType()) 533 continue; 534 535 auto RefExpr = ExprEmpty(); 536 auto Move = Moves.find(PD); 537 assert(Move != Moves.end() && 538 "Coroutine function parameter not inserted into move map"); 539 // If a reference to the function parameter exists in the coroutine 540 // frame, use that reference. 541 auto *MoveDecl = 542 cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl()); 543 RefExpr = 544 BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(), 545 ExprValueKind::VK_LValue, FD->getLocation()); 546 if (RefExpr.isInvalid()) 547 return nullptr; 548 CtorArgExprs.push_back(RefExpr.get()); 549 } 550 551 // Create an initialization sequence for the promise type using the 552 // constructor arguments, wrapped in a parenthesized list expression. 553 Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(), 554 CtorArgExprs, FD->getLocation()); 555 InitializedEntity Entity = InitializedEntity::InitializeVariable(VD); 556 InitializationKind Kind = InitializationKind::CreateForInit( 557 VD->getLocation(), /*DirectInit=*/true, PLE); 558 InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs, 559 /*TopLevelOfInitList=*/false, 560 /*TreatUnavailableAsInvalid=*/false); 561 562 // Attempt to initialize the promise type with the arguments. 563 // If that fails, fall back to the promise type's default constructor. 564 if (InitSeq) { 565 ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs); 566 if (Result.isInvalid()) { 567 VD->setInvalidDecl(); 568 } else if (Result.get()) { 569 VD->setInit(MaybeCreateExprWithCleanups(Result.get())); 570 VD->setInitStyle(VarDecl::CallInit); 571 CheckCompleteVariableDeclaration(VD); 572 } 573 } else 574 ActOnUninitializedDecl(VD); 575 576 FD->addDecl(VD); 577 return VD; 578 } 579 580 /// Check that this is a context in which a coroutine suspension can appear. 581 static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, 582 StringRef Keyword, 583 bool IsImplicit = false) { 584 if (!isValidCoroutineContext(S, Loc, Keyword)) 585 return nullptr; 586 587 assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); 588 589 auto *ScopeInfo = S.getCurFunction(); 590 assert(ScopeInfo && "missing function scope for function"); 591 592 if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit) 593 ScopeInfo->setFirstCoroutineStmt(Loc, Keyword); 594 595 if (ScopeInfo->CoroutinePromise) 596 return ScopeInfo; 597 598 if (!S.buildCoroutineParameterMoves(Loc)) 599 return nullptr; 600 601 ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); 602 if (!ScopeInfo->CoroutinePromise) 603 return nullptr; 604 605 return ScopeInfo; 606 } 607 608 bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc, 609 StringRef Keyword) { 610 if (!checkCoroutineContext(*this, KWLoc, Keyword)) 611 return false; 612 auto *ScopeInfo = getCurFunction(); 613 assert(ScopeInfo->CoroutinePromise); 614 615 // If we have existing coroutine statements then we have already built 616 // the initial and final suspend points. 617 if (!ScopeInfo->NeedsCoroutineSuspends) 618 return true; 619 620 ScopeInfo->setNeedsCoroutineSuspends(false); 621 622 auto *Fn = cast<FunctionDecl>(CurContext); 623 SourceLocation Loc = Fn->getLocation(); 624 // Build the initial suspend point 625 auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { 626 ExprResult Suspend = 627 buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None); 628 if (Suspend.isInvalid()) 629 return StmtError(); 630 Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get()); 631 if (Suspend.isInvalid()) 632 return StmtError(); 633 Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(), 634 /*IsImplicit*/ true); 635 Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false); 636 if (Suspend.isInvalid()) { 637 Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required) 638 << ((Name == "initial_suspend") ? 0 : 1); 639 Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; 640 return StmtError(); 641 } 642 return cast<Stmt>(Suspend.get()); 643 }; 644 645 StmtResult InitSuspend = buildSuspends("initial_suspend"); 646 if (InitSuspend.isInvalid()) 647 return true; 648 649 StmtResult FinalSuspend = buildSuspends("final_suspend"); 650 if (FinalSuspend.isInvalid()) 651 return true; 652 653 ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); 654 655 return true; 656 } 657 658 // Recursively walks up the scope hierarchy until either a 'catch' or a function 659 // scope is found, whichever comes first. 660 static bool isWithinCatchScope(Scope *S) { 661 // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but 662 // lambdas that use 'co_await' are allowed. The loop below ends when a 663 // function scope is found in order to ensure the following behavior: 664 // 665 // void foo() { // <- function scope 666 // try { // 667 // co_await x; // <- 'co_await' is OK within a function scope 668 // } catch { // <- catch scope 669 // co_await x; // <- 'co_await' is not OK within a catch scope 670 // []() { // <- function scope 671 // co_await x; // <- 'co_await' is OK within a function scope 672 // }(); 673 // } 674 // } 675 while (S && !(S->getFlags() & Scope::FnScope)) { 676 if (S->getFlags() & Scope::CatchScope) 677 return true; 678 S = S->getParent(); 679 } 680 return false; 681 } 682 683 // [expr.await]p2, emphasis added: "An await-expression shall appear only in 684 // a *potentially evaluated* expression within the compound-statement of a 685 // function-body *outside of a handler* [...] A context within a function 686 // where an await-expression can appear is called a suspension context of the 687 // function." 688 static void checkSuspensionContext(Sema &S, SourceLocation Loc, 689 StringRef Keyword) { 690 // First emphasis of [expr.await]p2: must be a potentially evaluated context. 691 // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of 692 // \c sizeof. 693 if (S.isUnevaluatedContext()) 694 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword; 695 696 // Second emphasis of [expr.await]p2: must be outside of an exception handler. 697 if (isWithinCatchScope(S.getCurScope())) 698 S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword; 699 } 700 701 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 702 if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) { 703 CorrectDelayedTyposInExpr(E); 704 return ExprError(); 705 } 706 707 checkSuspensionContext(*this, Loc, "co_await"); 708 709 if (E->getType()->isPlaceholderType()) { 710 ExprResult R = CheckPlaceholderExpr(E); 711 if (R.isInvalid()) return ExprError(); 712 E = R.get(); 713 } 714 ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); 715 if (Lookup.isInvalid()) 716 return ExprError(); 717 return BuildUnresolvedCoawaitExpr(Loc, E, 718 cast<UnresolvedLookupExpr>(Lookup.get())); 719 } 720 721 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, 722 UnresolvedLookupExpr *Lookup) { 723 auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); 724 if (!FSI) 725 return ExprError(); 726 727 if (E->getType()->isPlaceholderType()) { 728 ExprResult R = CheckPlaceholderExpr(E); 729 if (R.isInvalid()) 730 return ExprError(); 731 E = R.get(); 732 } 733 734 auto *Promise = FSI->CoroutinePromise; 735 if (Promise->getType()->isDependentType()) { 736 Expr *Res = 737 new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); 738 return Res; 739 } 740 741 auto *RD = Promise->getType()->getAsCXXRecordDecl(); 742 if (lookupMember(*this, "await_transform", RD, Loc)) { 743 ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); 744 if (R.isInvalid()) { 745 Diag(Loc, 746 diag::note_coroutine_promise_implicit_await_transform_required_here) 747 << E->getSourceRange(); 748 return ExprError(); 749 } 750 E = R.get(); 751 } 752 ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); 753 if (Awaitable.isInvalid()) 754 return ExprError(); 755 756 return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); 757 } 758 759 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, 760 bool IsImplicit) { 761 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit); 762 if (!Coroutine) 763 return ExprError(); 764 765 if (E->getType()->isPlaceholderType()) { 766 ExprResult R = CheckPlaceholderExpr(E); 767 if (R.isInvalid()) return ExprError(); 768 E = R.get(); 769 } 770 771 if (E->getType()->isDependentType()) { 772 Expr *Res = new (Context) 773 CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); 774 return Res; 775 } 776 777 // If the expression is a temporary, materialize it as an lvalue so that we 778 // can use it multiple times. 779 if (E->getValueKind() == VK_RValue) 780 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 781 782 // The location of the `co_await` token cannot be used when constructing 783 // the member call expressions since it's before the location of `Expr`, which 784 // is used as the start of the member call expression. 785 SourceLocation CallLoc = E->getExprLoc(); 786 787 // Build the await_ready, await_suspend, await_resume calls. 788 ReadySuspendResumeResult RSS = 789 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E); 790 if (RSS.IsInvalid) 791 return ExprError(); 792 793 Expr *Res = 794 new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 795 RSS.Results[2], RSS.OpaqueValue, IsImplicit); 796 797 return Res; 798 } 799 800 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 801 if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) { 802 CorrectDelayedTyposInExpr(E); 803 return ExprError(); 804 } 805 806 checkSuspensionContext(*this, Loc, "co_yield"); 807 808 // Build yield_value call. 809 ExprResult Awaitable = buildPromiseCall( 810 *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); 811 if (Awaitable.isInvalid()) 812 return ExprError(); 813 814 // Build 'operator co_await' call. 815 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get()); 816 if (Awaitable.isInvalid()) 817 return ExprError(); 818 819 return BuildCoyieldExpr(Loc, Awaitable.get()); 820 } 821 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 822 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 823 if (!Coroutine) 824 return ExprError(); 825 826 if (E->getType()->isPlaceholderType()) { 827 ExprResult R = CheckPlaceholderExpr(E); 828 if (R.isInvalid()) return ExprError(); 829 E = R.get(); 830 } 831 832 if (E->getType()->isDependentType()) { 833 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); 834 return Res; 835 } 836 837 // If the expression is a temporary, materialize it as an lvalue so that we 838 // can use it multiple times. 839 if (E->getValueKind() == VK_RValue) 840 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 841 842 // Build the await_ready, await_suspend, await_resume calls. 843 ReadySuspendResumeResult RSS = 844 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 845 if (RSS.IsInvalid) 846 return ExprError(); 847 848 Expr *Res = 849 new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], 850 RSS.Results[2], RSS.OpaqueValue); 851 852 return Res; 853 } 854 855 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { 856 if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) { 857 CorrectDelayedTyposInExpr(E); 858 return StmtError(); 859 } 860 return BuildCoreturnStmt(Loc, E); 861 } 862 863 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, 864 bool IsImplicit) { 865 auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit); 866 if (!FSI) 867 return StmtError(); 868 869 if (E && E->getType()->isPlaceholderType() && 870 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { 871 ExprResult R = CheckPlaceholderExpr(E); 872 if (R.isInvalid()) return StmtError(); 873 E = R.get(); 874 } 875 876 // Move the return value if we can 877 if (E) { 878 auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove); 879 if (NRVOCandidate) { 880 InitializedEntity Entity = 881 InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate); 882 ExprResult MoveResult = this->PerformMoveOrCopyInitialization( 883 Entity, NRVOCandidate, E->getType(), E); 884 if (MoveResult.get()) 885 E = MoveResult.get(); 886 } 887 } 888 889 // FIXME: If the operand is a reference to a variable that's about to go out 890 // of scope, we should treat the operand as an xvalue for this overload 891 // resolution. 892 VarDecl *Promise = FSI->CoroutinePromise; 893 ExprResult PC; 894 if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { 895 PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); 896 } else { 897 E = MakeFullDiscardedValueExpr(E).get(); 898 PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); 899 } 900 if (PC.isInvalid()) 901 return StmtError(); 902 903 Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get(); 904 905 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); 906 return Res; 907 } 908 909 /// Look up the std::nothrow object. 910 static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) { 911 NamespaceDecl *Std = S.getStdNamespace(); 912 assert(Std && "Should already be diagnosed"); 913 914 LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc, 915 Sema::LookupOrdinaryName); 916 if (!S.LookupQualifiedName(Result, Std)) { 917 // FIXME: <experimental/coroutine> should have been included already. 918 // If we require it to include <new> then this diagnostic is no longer 919 // needed. 920 S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found); 921 return nullptr; 922 } 923 924 auto *VD = Result.getAsSingle<VarDecl>(); 925 if (!VD) { 926 Result.suppressDiagnostics(); 927 // We found something weird. Complain about the first thing we found. 928 NamedDecl *Found = *Result.begin(); 929 S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow); 930 return nullptr; 931 } 932 933 ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc); 934 if (DR.isInvalid()) 935 return nullptr; 936 937 return DR.get(); 938 } 939 940 // Find an appropriate delete for the promise. 941 static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, 942 QualType PromiseType) { 943 FunctionDecl *OperatorDelete = nullptr; 944 945 DeclarationName DeleteName = 946 S.Context.DeclarationNames.getCXXOperatorName(OO_Delete); 947 948 auto *PointeeRD = PromiseType->getAsCXXRecordDecl(); 949 assert(PointeeRD && "PromiseType must be a CxxRecordDecl type"); 950 951 if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete)) 952 return nullptr; 953 954 if (!OperatorDelete) { 955 // Look for a global declaration. 956 const bool CanProvideSize = S.isCompleteType(Loc, PromiseType); 957 const bool Overaligned = false; 958 OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize, 959 Overaligned, DeleteName); 960 } 961 S.MarkFunctionReferenced(Loc, OperatorDelete); 962 return OperatorDelete; 963 } 964 965 966 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { 967 FunctionScopeInfo *Fn = getCurFunction(); 968 assert(Fn && Fn->isCoroutine() && "not a coroutine"); 969 if (!Body) { 970 assert(FD->isInvalidDecl() && 971 "a null body is only allowed for invalid declarations"); 972 return; 973 } 974 // We have a function that uses coroutine keywords, but we failed to build 975 // the promise type. 976 if (!Fn->CoroutinePromise) 977 return FD->setInvalidDecl(); 978 979 if (isa<CoroutineBodyStmt>(Body)) { 980 // Nothing todo. the body is already a transformed coroutine body statement. 981 return; 982 } 983 984 // Coroutines [stmt.return]p1: 985 // A return statement shall not appear in a coroutine. 986 if (Fn->FirstReturnLoc.isValid()) { 987 assert(Fn->FirstCoroutineStmtLoc.isValid() && 988 "first coroutine location not set"); 989 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 990 Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 991 << Fn->getFirstCoroutineStmtKeyword(); 992 } 993 CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body); 994 if (Builder.isInvalid() || !Builder.buildStatements()) 995 return FD->setInvalidDecl(); 996 997 // Build body for the coroutine wrapper statement. 998 Body = CoroutineBodyStmt::Create(Context, Builder); 999 } 1000 1001 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, 1002 sema::FunctionScopeInfo &Fn, 1003 Stmt *Body) 1004 : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), 1005 IsPromiseDependentType( 1006 !Fn.CoroutinePromise || 1007 Fn.CoroutinePromise->getType()->isDependentType()) { 1008 this->Body = Body; 1009 1010 for (auto KV : Fn.CoroutineParameterMoves) 1011 this->ParamMovesVector.push_back(KV.second); 1012 this->ParamMoves = this->ParamMovesVector; 1013 1014 if (!IsPromiseDependentType) { 1015 PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); 1016 assert(PromiseRecordDecl && "Type should have already been checked"); 1017 } 1018 this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend(); 1019 } 1020 1021 bool CoroutineStmtBuilder::buildStatements() { 1022 assert(this->IsValid && "coroutine already invalid"); 1023 this->IsValid = makeReturnObject(); 1024 if (this->IsValid && !IsPromiseDependentType) 1025 buildDependentStatements(); 1026 return this->IsValid; 1027 } 1028 1029 bool CoroutineStmtBuilder::buildDependentStatements() { 1030 assert(this->IsValid && "coroutine already invalid"); 1031 assert(!this->IsPromiseDependentType && 1032 "coroutine cannot have a dependent promise type"); 1033 this->IsValid = makeOnException() && makeOnFallthrough() && 1034 makeGroDeclAndReturnStmt() && makeReturnOnAllocFailure() && 1035 makeNewAndDeleteExpr(); 1036 return this->IsValid; 1037 } 1038 1039 bool CoroutineStmtBuilder::makePromiseStmt() { 1040 // Form a declaration statement for the promise declaration, so that AST 1041 // visitors can more easily find it. 1042 StmtResult PromiseStmt = 1043 S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); 1044 if (PromiseStmt.isInvalid()) 1045 return false; 1046 1047 this->Promise = PromiseStmt.get(); 1048 return true; 1049 } 1050 1051 bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() { 1052 if (Fn.hasInvalidCoroutineSuspends()) 1053 return false; 1054 this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first); 1055 this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second); 1056 return true; 1057 } 1058 1059 static bool diagReturnOnAllocFailure(Sema &S, Expr *E, 1060 CXXRecordDecl *PromiseRecordDecl, 1061 FunctionScopeInfo &Fn) { 1062 auto Loc = E->getExprLoc(); 1063 if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) { 1064 auto *Decl = DeclRef->getDecl(); 1065 if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) { 1066 if (Method->isStatic()) 1067 return true; 1068 else 1069 Loc = Decl->getLocation(); 1070 } 1071 } 1072 1073 S.Diag( 1074 Loc, 1075 diag::err_coroutine_promise_get_return_object_on_allocation_failure) 1076 << PromiseRecordDecl; 1077 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1078 << Fn.getFirstCoroutineStmtKeyword(); 1079 return false; 1080 } 1081 1082 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() { 1083 assert(!IsPromiseDependentType && 1084 "cannot make statement while the promise type is dependent"); 1085 1086 // [dcl.fct.def.coroutine]/8 1087 // The unqualified-id get_return_object_on_allocation_failure is looked up in 1088 // the scope of class P by class member access lookup (3.4.5). ... 1089 // If an allocation function returns nullptr, ... the coroutine return value 1090 // is obtained by a call to ... get_return_object_on_allocation_failure(). 1091 1092 DeclarationName DN = 1093 S.PP.getIdentifierInfo("get_return_object_on_allocation_failure"); 1094 LookupResult Found(S, DN, Loc, Sema::LookupMemberName); 1095 if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) 1096 return true; 1097 1098 CXXScopeSpec SS; 1099 ExprResult DeclNameExpr = 1100 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 1101 if (DeclNameExpr.isInvalid()) 1102 return false; 1103 1104 if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)) 1105 return false; 1106 1107 ExprResult ReturnObjectOnAllocationFailure = 1108 S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc); 1109 if (ReturnObjectOnAllocationFailure.isInvalid()) 1110 return false; 1111 1112 StmtResult ReturnStmt = 1113 S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get()); 1114 if (ReturnStmt.isInvalid()) { 1115 S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here) 1116 << DN; 1117 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1118 << Fn.getFirstCoroutineStmtKeyword(); 1119 return false; 1120 } 1121 1122 this->ReturnStmtOnAllocFailure = ReturnStmt.get(); 1123 return true; 1124 } 1125 1126 bool CoroutineStmtBuilder::makeNewAndDeleteExpr() { 1127 // Form and check allocation and deallocation calls. 1128 assert(!IsPromiseDependentType && 1129 "cannot make statement while the promise type is dependent"); 1130 QualType PromiseType = Fn.CoroutinePromise->getType(); 1131 1132 if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) 1133 return false; 1134 1135 const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr; 1136 1137 // [dcl.fct.def.coroutine]/7 1138 // Lookup allocation functions using a parameter list composed of the 1139 // requested size of the coroutine state being allocated, followed by 1140 // the coroutine function's arguments. If a matching allocation function 1141 // exists, use it. Otherwise, use an allocation function that just takes 1142 // the requested size. 1143 1144 FunctionDecl *OperatorNew = nullptr; 1145 FunctionDecl *OperatorDelete = nullptr; 1146 FunctionDecl *UnusedResult = nullptr; 1147 bool PassAlignment = false; 1148 SmallVector<Expr *, 1> PlacementArgs; 1149 1150 // [dcl.fct.def.coroutine]/7 1151 // "The allocation function’s name is looked up in the scope of P. 1152 // [...] If the lookup finds an allocation function in the scope of P, 1153 // overload resolution is performed on a function call created by assembling 1154 // an argument list. The first argument is the amount of space requested, 1155 // and has type std::size_t. The lvalues p1 ... pn are the succeeding 1156 // arguments." 1157 // 1158 // ...where "p1 ... pn" are defined earlier as: 1159 // 1160 // [dcl.fct.def.coroutine]/3 1161 // "For a coroutine f that is a non-static member function, let P1 denote the 1162 // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types 1163 // of the function parameters; otherwise let P1 ... Pn be the types of the 1164 // function parameters. Let p1 ... pn be lvalues denoting those objects." 1165 if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) { 1166 if (MD->isInstance() && !isLambdaCallOperator(MD)) { 1167 ExprResult ThisExpr = S.ActOnCXXThis(Loc); 1168 if (ThisExpr.isInvalid()) 1169 return false; 1170 ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get()); 1171 if (ThisExpr.isInvalid()) 1172 return false; 1173 PlacementArgs.push_back(ThisExpr.get()); 1174 } 1175 } 1176 for (auto *PD : FD.parameters()) { 1177 if (PD->getType()->isDependentType()) 1178 continue; 1179 1180 // Build a reference to the parameter. 1181 auto PDLoc = PD->getLocation(); 1182 ExprResult PDRefExpr = 1183 S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), 1184 ExprValueKind::VK_LValue, PDLoc); 1185 if (PDRefExpr.isInvalid()) 1186 return false; 1187 1188 PlacementArgs.push_back(PDRefExpr.get()); 1189 } 1190 S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class, 1191 /*DeleteScope*/ Sema::AFS_Both, PromiseType, 1192 /*isArray*/ false, PassAlignment, PlacementArgs, 1193 OperatorNew, UnusedResult, /*Diagnose*/ false); 1194 1195 // [dcl.fct.def.coroutine]/7 1196 // "If no matching function is found, overload resolution is performed again 1197 // on a function call created by passing just the amount of space required as 1198 // an argument of type std::size_t." 1199 if (!OperatorNew && !PlacementArgs.empty()) { 1200 PlacementArgs.clear(); 1201 S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class, 1202 /*DeleteScope*/ Sema::AFS_Both, PromiseType, 1203 /*isArray*/ false, PassAlignment, PlacementArgs, 1204 OperatorNew, UnusedResult, /*Diagnose*/ false); 1205 } 1206 1207 // [dcl.fct.def.coroutine]/7 1208 // "The allocation function’s name is looked up in the scope of P. If this 1209 // lookup fails, the allocation function’s name is looked up in the global 1210 // scope." 1211 if (!OperatorNew) { 1212 S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global, 1213 /*DeleteScope*/ Sema::AFS_Both, PromiseType, 1214 /*isArray*/ false, PassAlignment, PlacementArgs, 1215 OperatorNew, UnusedResult); 1216 } 1217 1218 bool IsGlobalOverload = 1219 OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext()); 1220 // If we didn't find a class-local new declaration and non-throwing new 1221 // was is required then we need to lookup the non-throwing global operator 1222 // instead. 1223 if (RequiresNoThrowAlloc && (!OperatorNew || IsGlobalOverload)) { 1224 auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc); 1225 if (!StdNoThrow) 1226 return false; 1227 PlacementArgs = {StdNoThrow}; 1228 OperatorNew = nullptr; 1229 S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both, 1230 /*DeleteScope*/ Sema::AFS_Both, PromiseType, 1231 /*isArray*/ false, PassAlignment, PlacementArgs, 1232 OperatorNew, UnusedResult); 1233 } 1234 1235 if (!OperatorNew) 1236 return false; 1237 1238 if (RequiresNoThrowAlloc) { 1239 const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>(); 1240 if (!FT->isNothrow(/*ResultIfDependent*/ false)) { 1241 S.Diag(OperatorNew->getLocation(), 1242 diag::err_coroutine_promise_new_requires_nothrow) 1243 << OperatorNew; 1244 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 1245 << OperatorNew; 1246 return false; 1247 } 1248 } 1249 1250 if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr) 1251 return false; 1252 1253 Expr *FramePtr = 1254 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 1255 1256 Expr *FrameSize = 1257 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); 1258 1259 // Make new call. 1260 1261 ExprResult NewRef = 1262 S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); 1263 if (NewRef.isInvalid()) 1264 return false; 1265 1266 SmallVector<Expr *, 2> NewArgs(1, FrameSize); 1267 for (auto Arg : PlacementArgs) 1268 NewArgs.push_back(Arg); 1269 1270 ExprResult NewExpr = 1271 S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc); 1272 NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false); 1273 if (NewExpr.isInvalid()) 1274 return false; 1275 1276 // Make delete call. 1277 1278 QualType OpDeleteQualType = OperatorDelete->getType(); 1279 1280 ExprResult DeleteRef = 1281 S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc); 1282 if (DeleteRef.isInvalid()) 1283 return false; 1284 1285 Expr *CoroFree = 1286 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); 1287 1288 SmallVector<Expr *, 2> DeleteArgs{CoroFree}; 1289 1290 // Check if we need to pass the size. 1291 const auto *OpDeleteType = 1292 OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>(); 1293 if (OpDeleteType->getNumParams() > 1) 1294 DeleteArgs.push_back(FrameSize); 1295 1296 ExprResult DeleteExpr = 1297 S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc); 1298 DeleteExpr = 1299 S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false); 1300 if (DeleteExpr.isInvalid()) 1301 return false; 1302 1303 this->Allocate = NewExpr.get(); 1304 this->Deallocate = DeleteExpr.get(); 1305 1306 return true; 1307 } 1308 1309 bool CoroutineStmtBuilder::makeOnFallthrough() { 1310 assert(!IsPromiseDependentType && 1311 "cannot make statement while the promise type is dependent"); 1312 1313 // [dcl.fct.def.coroutine]/4 1314 // The unqualified-ids 'return_void' and 'return_value' are looked up in 1315 // the scope of class P. If both are found, the program is ill-formed. 1316 bool HasRVoid, HasRValue; 1317 LookupResult LRVoid = 1318 lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid); 1319 LookupResult LRValue = 1320 lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue); 1321 1322 StmtResult Fallthrough; 1323 if (HasRVoid && HasRValue) { 1324 // FIXME Improve this diagnostic 1325 S.Diag(FD.getLocation(), 1326 diag::err_coroutine_promise_incompatible_return_functions) 1327 << PromiseRecordDecl; 1328 S.Diag(LRVoid.getRepresentativeDecl()->getLocation(), 1329 diag::note_member_first_declared_here) 1330 << LRVoid.getLookupName(); 1331 S.Diag(LRValue.getRepresentativeDecl()->getLocation(), 1332 diag::note_member_first_declared_here) 1333 << LRValue.getLookupName(); 1334 return false; 1335 } else if (!HasRVoid && !HasRValue) { 1336 // FIXME: The PDTS currently specifies this case as UB, not ill-formed. 1337 // However we still diagnose this as an error since until the PDTS is fixed. 1338 S.Diag(FD.getLocation(), 1339 diag::err_coroutine_promise_requires_return_function) 1340 << PromiseRecordDecl; 1341 S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here) 1342 << PromiseRecordDecl; 1343 return false; 1344 } else if (HasRVoid) { 1345 // If the unqualified-id return_void is found, flowing off the end of a 1346 // coroutine is equivalent to a co_return with no operand. Otherwise, 1347 // flowing off the end of a coroutine results in undefined behavior. 1348 Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, 1349 /*IsImplicit*/false); 1350 Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); 1351 if (Fallthrough.isInvalid()) 1352 return false; 1353 } 1354 1355 this->OnFallthrough = Fallthrough.get(); 1356 return true; 1357 } 1358 1359 bool CoroutineStmtBuilder::makeOnException() { 1360 // Try to form 'p.unhandled_exception();' 1361 assert(!IsPromiseDependentType && 1362 "cannot make statement while the promise type is dependent"); 1363 1364 const bool RequireUnhandledException = S.getLangOpts().CXXExceptions; 1365 1366 if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) { 1367 auto DiagID = 1368 RequireUnhandledException 1369 ? diag::err_coroutine_promise_unhandled_exception_required 1370 : diag:: 1371 warn_coroutine_promise_unhandled_exception_required_with_exceptions; 1372 S.Diag(Loc, DiagID) << PromiseRecordDecl; 1373 S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here) 1374 << PromiseRecordDecl; 1375 return !RequireUnhandledException; 1376 } 1377 1378 // If exceptions are disabled, don't try to build OnException. 1379 if (!S.getLangOpts().CXXExceptions) 1380 return true; 1381 1382 ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, 1383 "unhandled_exception", None); 1384 UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc, 1385 /*DiscardedValue*/ false); 1386 if (UnhandledException.isInvalid()) 1387 return false; 1388 1389 // Since the body of the coroutine will be wrapped in try-catch, it will 1390 // be incompatible with SEH __try if present in a function. 1391 if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) { 1392 S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions); 1393 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1394 << Fn.getFirstCoroutineStmtKeyword(); 1395 return false; 1396 } 1397 1398 this->OnException = UnhandledException.get(); 1399 return true; 1400 } 1401 1402 bool CoroutineStmtBuilder::makeReturnObject() { 1403 // Build implicit 'p.get_return_object()' expression and form initialization 1404 // of return type from it. 1405 ExprResult ReturnObject = 1406 buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); 1407 if (ReturnObject.isInvalid()) 1408 return false; 1409 1410 this->ReturnValue = ReturnObject.get(); 1411 return true; 1412 } 1413 1414 static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) { 1415 if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) { 1416 auto *MethodDecl = MbrRef->getMethodDecl(); 1417 S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here) 1418 << MethodDecl; 1419 } 1420 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1421 << Fn.getFirstCoroutineStmtKeyword(); 1422 } 1423 1424 bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() { 1425 assert(!IsPromiseDependentType && 1426 "cannot make statement while the promise type is dependent"); 1427 assert(this->ReturnValue && "ReturnValue must be already formed"); 1428 1429 QualType const GroType = this->ReturnValue->getType(); 1430 assert(!GroType->isDependentType() && 1431 "get_return_object type must no longer be dependent"); 1432 1433 QualType const FnRetType = FD.getReturnType(); 1434 assert(!FnRetType->isDependentType() && 1435 "get_return_object type must no longer be dependent"); 1436 1437 if (FnRetType->isVoidType()) { 1438 ExprResult Res = 1439 S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false); 1440 if (Res.isInvalid()) 1441 return false; 1442 1443 this->ResultDecl = Res.get(); 1444 return true; 1445 } 1446 1447 if (GroType->isVoidType()) { 1448 // Trigger a nice error message. 1449 InitializedEntity Entity = 1450 InitializedEntity::InitializeResult(Loc, FnRetType, false); 1451 S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue); 1452 noteMemberDeclaredHere(S, ReturnValue, Fn); 1453 return false; 1454 } 1455 1456 auto *GroDecl = VarDecl::Create( 1457 S.Context, &FD, FD.getLocation(), FD.getLocation(), 1458 &S.PP.getIdentifierTable().get("__coro_gro"), GroType, 1459 S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None); 1460 1461 S.CheckVariableDeclarationType(GroDecl); 1462 if (GroDecl->isInvalidDecl()) 1463 return false; 1464 1465 InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl); 1466 ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType, 1467 this->ReturnValue); 1468 if (Res.isInvalid()) 1469 return false; 1470 1471 Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false); 1472 if (Res.isInvalid()) 1473 return false; 1474 1475 S.AddInitializerToDecl(GroDecl, Res.get(), 1476 /*DirectInit=*/false); 1477 1478 S.FinalizeDeclaration(GroDecl); 1479 1480 // Form a declaration statement for the return declaration, so that AST 1481 // visitors can more easily find it. 1482 StmtResult GroDeclStmt = 1483 S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc); 1484 if (GroDeclStmt.isInvalid()) 1485 return false; 1486 1487 this->ResultDecl = GroDeclStmt.get(); 1488 1489 ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc); 1490 if (declRef.isInvalid()) 1491 return false; 1492 1493 StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get()); 1494 if (ReturnStmt.isInvalid()) { 1495 noteMemberDeclaredHere(S, ReturnValue, Fn); 1496 return false; 1497 } 1498 if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl) 1499 GroDecl->setNRVOVariable(true); 1500 1501 this->ReturnStmt = ReturnStmt.get(); 1502 return true; 1503 } 1504 1505 // Create a static_cast\<T&&>(expr). 1506 static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) { 1507 if (T.isNull()) 1508 T = E->getType(); 1509 QualType TargetType = S.BuildReferenceType( 1510 T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName()); 1511 SourceLocation ExprLoc = E->getBeginLoc(); 1512 TypeSourceInfo *TargetLoc = 1513 S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc); 1514 1515 return S 1516 .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E, 1517 SourceRange(ExprLoc, ExprLoc), E->getSourceRange()) 1518 .get(); 1519 } 1520 1521 /// Build a variable declaration for move parameter. 1522 static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type, 1523 IdentifierInfo *II) { 1524 TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc); 1525 VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, 1526 TInfo, SC_None); 1527 Decl->setImplicit(); 1528 return Decl; 1529 } 1530 1531 // Build statements that move coroutine function parameters to the coroutine 1532 // frame, and store them on the function scope info. 1533 bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) { 1534 assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); 1535 auto *FD = cast<FunctionDecl>(CurContext); 1536 1537 auto *ScopeInfo = getCurFunction(); 1538 assert(ScopeInfo->CoroutineParameterMoves.empty() && 1539 "Should not build parameter moves twice"); 1540 1541 for (auto *PD : FD->parameters()) { 1542 if (PD->getType()->isDependentType()) 1543 continue; 1544 1545 ExprResult PDRefExpr = 1546 BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(), 1547 ExprValueKind::VK_LValue, Loc); // FIXME: scope? 1548 if (PDRefExpr.isInvalid()) 1549 return false; 1550 1551 Expr *CExpr = nullptr; 1552 if (PD->getType()->getAsCXXRecordDecl() || 1553 PD->getType()->isRValueReferenceType()) 1554 CExpr = castForMoving(*this, PDRefExpr.get()); 1555 else 1556 CExpr = PDRefExpr.get(); 1557 1558 auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier()); 1559 AddInitializerToDecl(D, CExpr, /*DirectInit=*/true); 1560 1561 // Convert decl to a statement. 1562 StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc); 1563 if (Stmt.isInvalid()) 1564 return false; 1565 1566 ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get())); 1567 } 1568 return true; 1569 } 1570 1571 StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { 1572 CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); 1573 if (!Res) 1574 return StmtError(); 1575 return Res; 1576 } 1577 1578 ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc, 1579 SourceLocation FuncLoc) { 1580 if (!StdCoroutineTraitsCache) { 1581 if (auto StdExp = lookupStdExperimentalNamespace()) { 1582 LookupResult Result(*this, 1583 &PP.getIdentifierTable().get("coroutine_traits"), 1584 FuncLoc, LookupOrdinaryName); 1585 if (!LookupQualifiedName(Result, StdExp)) { 1586 Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 1587 << "std::experimental::coroutine_traits"; 1588 return nullptr; 1589 } 1590 if (!(StdCoroutineTraitsCache = 1591 Result.getAsSingle<ClassTemplateDecl>())) { 1592 Result.suppressDiagnostics(); 1593 NamedDecl *Found = *Result.begin(); 1594 Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 1595 return nullptr; 1596 } 1597 } 1598 } 1599 return StdCoroutineTraitsCache; 1600 } 1601