1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements semantic analysis for C++ Coroutines. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CoroutineStmtBuilder.h" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/ExprCXX.h" 17 #include "clang/AST/StmtCXX.h" 18 #include "clang/Lex/Preprocessor.h" 19 #include "clang/Sema/Initialization.h" 20 #include "clang/Sema/Overload.h" 21 #include "clang/Sema/SemaInternal.h" 22 23 using namespace clang; 24 using namespace sema; 25 26 static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, 27 SourceLocation Loc) { 28 DeclarationName DN = S.PP.getIdentifierInfo(Name); 29 LookupResult LR(S, DN, Loc, Sema::LookupMemberName); 30 // Suppress diagnostics when a private member is selected. The same warnings 31 // will be produced again when building the call. 32 LR.suppressDiagnostics(); 33 return S.LookupQualifiedName(LR, RD); 34 } 35 36 /// Look up the std::coroutine_traits<...>::promise_type for the given 37 /// function type. 38 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, 39 SourceLocation KwLoc, 40 SourceLocation FuncLoc) { 41 // FIXME: Cache std::coroutine_traits once we've found it. 42 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 43 if (!StdExp) { 44 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 45 << "std::experimental::coroutine_traits"; 46 return QualType(); 47 } 48 49 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), 50 FuncLoc, Sema::LookupOrdinaryName); 51 if (!S.LookupQualifiedName(Result, StdExp)) { 52 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 53 << "std::experimental::coroutine_traits"; 54 return QualType(); 55 } 56 57 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>(); 58 if (!CoroTraits) { 59 Result.suppressDiagnostics(); 60 // We found something weird. Complain about the first thing we found. 61 NamedDecl *Found = *Result.begin(); 62 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 63 return QualType(); 64 } 65 66 // Form template argument list for coroutine_traits<R, P1, P2, ...>. 67 TemplateArgumentListInfo Args(KwLoc, KwLoc); 68 Args.addArgument(TemplateArgumentLoc( 69 TemplateArgument(FnType->getReturnType()), 70 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc))); 71 // FIXME: If the function is a non-static member function, add the type 72 // of the implicit object parameter before the formal parameters. 73 for (QualType T : FnType->getParamTypes()) 74 Args.addArgument(TemplateArgumentLoc( 75 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); 76 77 // Build the template-id. 78 QualType CoroTrait = 79 S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); 80 if (CoroTrait.isNull()) 81 return QualType(); 82 if (S.RequireCompleteType(KwLoc, CoroTrait, 83 diag::err_coroutine_type_missing_specialization)) 84 return QualType(); 85 86 auto *RD = CoroTrait->getAsCXXRecordDecl(); 87 assert(RD && "specialization of class template is not a class?"); 88 89 // Look up the ::promise_type member. 90 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, 91 Sema::LookupOrdinaryName); 92 S.LookupQualifiedName(R, RD); 93 auto *Promise = R.getAsSingle<TypeDecl>(); 94 if (!Promise) { 95 S.Diag(FuncLoc, 96 diag::err_implied_std_coroutine_traits_promise_type_not_found) 97 << RD; 98 return QualType(); 99 } 100 // The promise type is required to be a class type. 101 QualType PromiseType = S.Context.getTypeDeclType(Promise); 102 103 auto buildElaboratedType = [&]() { 104 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); 105 NNS = NestedNameSpecifier::Create(S.Context, NNS, false, 106 CoroTrait.getTypePtr()); 107 return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); 108 }; 109 110 if (!PromiseType->getAsCXXRecordDecl()) { 111 S.Diag(FuncLoc, 112 diag::err_implied_std_coroutine_traits_promise_type_not_class) 113 << buildElaboratedType(); 114 return QualType(); 115 } 116 if (S.RequireCompleteType(FuncLoc, buildElaboratedType(), 117 diag::err_coroutine_promise_type_incomplete)) 118 return QualType(); 119 120 return PromiseType; 121 } 122 123 /// Look up the std::coroutine_traits<...>::promise_type for the given 124 /// function type. 125 static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, 126 SourceLocation Loc) { 127 if (PromiseType.isNull()) 128 return QualType(); 129 130 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 131 assert(StdExp && "Should already be diagnosed"); 132 133 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"), 134 Loc, Sema::LookupOrdinaryName); 135 if (!S.LookupQualifiedName(Result, StdExp)) { 136 S.Diag(Loc, diag::err_implied_coroutine_type_not_found) 137 << "std::experimental::coroutine_handle"; 138 return QualType(); 139 } 140 141 ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); 142 if (!CoroHandle) { 143 Result.suppressDiagnostics(); 144 // We found something weird. Complain about the first thing we found. 145 NamedDecl *Found = *Result.begin(); 146 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); 147 return QualType(); 148 } 149 150 // Form template argument list for coroutine_handle<Promise>. 151 TemplateArgumentListInfo Args(Loc, Loc); 152 Args.addArgument(TemplateArgumentLoc( 153 TemplateArgument(PromiseType), 154 S.Context.getTrivialTypeSourceInfo(PromiseType, Loc))); 155 156 // Build the template-id. 157 QualType CoroHandleType = 158 S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args); 159 if (CoroHandleType.isNull()) 160 return QualType(); 161 if (S.RequireCompleteType(Loc, CoroHandleType, 162 diag::err_coroutine_type_missing_specialization)) 163 return QualType(); 164 165 return CoroHandleType; 166 } 167 168 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, 169 StringRef Keyword) { 170 // 'co_await' and 'co_yield' are not permitted in unevaluated operands. 171 if (S.isUnevaluatedContext()) { 172 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword; 173 return false; 174 } 175 176 // Any other usage must be within a function. 177 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 178 if (!FD) { 179 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 180 ? diag::err_coroutine_objc_method 181 : diag::err_coroutine_outside_function) << Keyword; 182 return false; 183 } 184 185 // An enumeration for mapping the diagnostic type to the correct diagnostic 186 // selection index. 187 enum InvalidFuncDiag { 188 DiagCtor = 0, 189 DiagDtor, 190 DiagCopyAssign, 191 DiagMoveAssign, 192 DiagMain, 193 DiagConstexpr, 194 DiagAutoRet, 195 DiagVarargs, 196 }; 197 bool Diagnosed = false; 198 auto DiagInvalid = [&](InvalidFuncDiag ID) { 199 S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword; 200 Diagnosed = true; 201 return false; 202 }; 203 204 // Diagnose when a constructor, destructor, copy/move assignment operator, 205 // or the function 'main' are declared as a coroutine. 206 auto *MD = dyn_cast<CXXMethodDecl>(FD); 207 if (MD && isa<CXXConstructorDecl>(MD)) 208 return DiagInvalid(DiagCtor); 209 else if (MD && isa<CXXDestructorDecl>(MD)) 210 return DiagInvalid(DiagDtor); 211 else if (MD && MD->isCopyAssignmentOperator()) 212 return DiagInvalid(DiagCopyAssign); 213 else if (MD && MD->isMoveAssignmentOperator()) 214 return DiagInvalid(DiagMoveAssign); 215 else if (FD->isMain()) 216 return DiagInvalid(DiagMain); 217 218 // Emit a diagnostics for each of the following conditions which is not met. 219 if (FD->isConstexpr()) 220 DiagInvalid(DiagConstexpr); 221 if (FD->getReturnType()->isUndeducedType()) 222 DiagInvalid(DiagAutoRet); 223 if (FD->isVariadic()) 224 DiagInvalid(DiagVarargs); 225 226 return !Diagnosed; 227 } 228 229 static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, 230 SourceLocation Loc) { 231 DeclarationName OpName = 232 SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); 233 LookupResult Operators(SemaRef, OpName, SourceLocation(), 234 Sema::LookupOperatorName); 235 SemaRef.LookupName(Operators, S); 236 237 assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); 238 const auto &Functions = Operators.asUnresolvedSet(); 239 bool IsOverloaded = 240 Functions.size() > 1 || 241 (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); 242 Expr *CoawaitOp = UnresolvedLookupExpr::Create( 243 SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), 244 DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, 245 Functions.begin(), Functions.end()); 246 assert(CoawaitOp); 247 return CoawaitOp; 248 } 249 250 /// Build a call to 'operator co_await' if there is a suitable operator for 251 /// the given expression. 252 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, 253 Expr *E, 254 UnresolvedLookupExpr *Lookup) { 255 UnresolvedSet<16> Functions; 256 Functions.append(Lookup->decls_begin(), Lookup->decls_end()); 257 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 258 } 259 260 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 261 SourceLocation Loc, Expr *E) { 262 ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); 263 if (R.isInvalid()) 264 return ExprError(); 265 return buildOperatorCoawaitCall(SemaRef, Loc, E, 266 cast<UnresolvedLookupExpr>(R.get())); 267 } 268 269 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, 270 MultiExprArg CallArgs) { 271 StringRef Name = S.Context.BuiltinInfo.getName(Id); 272 LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); 273 S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); 274 275 auto *BuiltInDecl = R.getAsSingle<FunctionDecl>(); 276 assert(BuiltInDecl && "failed to find builtin declaration"); 277 278 ExprResult DeclRef = 279 S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc); 280 assert(DeclRef.isUsable() && "Builtin reference cannot fail"); 281 282 ExprResult Call = 283 S.ActOnCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc); 284 285 assert(!Call.isInvalid() && "Call to builtin cannot fail!"); 286 return Call.get(); 287 } 288 289 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType, 290 SourceLocation Loc) { 291 QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc); 292 if (CoroHandleType.isNull()) 293 return ExprError(); 294 295 DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType); 296 LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc, 297 Sema::LookupOrdinaryName); 298 if (!S.LookupQualifiedName(Found, LookupCtx)) { 299 S.Diag(Loc, diag::err_coroutine_handle_missing_member) 300 << "from_address"; 301 return ExprError(); 302 } 303 304 Expr *FramePtr = 305 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 306 307 CXXScopeSpec SS; 308 ExprResult FromAddr = 309 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 310 if (FromAddr.isInvalid()) 311 return ExprError(); 312 313 return S.ActOnCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc); 314 } 315 316 struct ReadySuspendResumeResult { 317 Expr *Results[3]; 318 OpaqueValueExpr *OpaqueValue; 319 bool IsInvalid; 320 }; 321 322 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, 323 StringRef Name, MultiExprArg Args) { 324 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); 325 326 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 327 CXXScopeSpec SS; 328 ExprResult Result = S.BuildMemberReferenceExpr( 329 Base, Base->getType(), Loc, /*IsPtr=*/false, SS, 330 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 331 /*Scope=*/nullptr); 332 if (Result.isInvalid()) 333 return ExprError(); 334 335 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); 336 } 337 338 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 339 /// expression. 340 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, 341 SourceLocation Loc, Expr *E) { 342 OpaqueValueExpr *Operand = new (S.Context) 343 OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E); 344 345 // Assume invalid until we see otherwise. 346 ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true}; 347 348 ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc); 349 if (CoroHandleRes.isInvalid()) 350 return Calls; 351 Expr *CoroHandle = CoroHandleRes.get(); 352 353 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 354 MultiExprArg Args[] = {None, CoroHandle, None}; 355 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 356 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]); 357 if (Result.isInvalid()) 358 return Calls; 359 Calls.Results[I] = Result.get(); 360 } 361 362 Calls.IsInvalid = false; 363 return Calls; 364 } 365 366 static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, 367 SourceLocation Loc, StringRef Name, 368 MultiExprArg Args) { 369 370 // Form a reference to the promise. 371 ExprResult PromiseRef = S.BuildDeclRefExpr( 372 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); 373 if (PromiseRef.isInvalid()) 374 return ExprError(); 375 376 // Call 'yield_value', passing in E. 377 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); 378 } 379 380 VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { 381 assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); 382 auto *FD = cast<FunctionDecl>(CurContext); 383 384 QualType T = 385 FD->getType()->isDependentType() 386 ? Context.DependentTy 387 : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(), 388 Loc, FD->getLocation()); 389 if (T.isNull()) 390 return nullptr; 391 392 auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), 393 &PP.getIdentifierTable().get("__promise"), T, 394 Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 395 CheckVariableDeclarationType(VD); 396 if (VD->isInvalidDecl()) 397 return nullptr; 398 ActOnUninitializedDecl(VD); 399 assert(!VD->isInvalidDecl()); 400 return VD; 401 } 402 403 /// Check that this is a context in which a coroutine suspension can appear. 404 static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, 405 StringRef Keyword, 406 bool IsImplicit = false) { 407 if (!isValidCoroutineContext(S, Loc, Keyword)) 408 return nullptr; 409 410 assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); 411 412 auto *ScopeInfo = S.getCurFunction(); 413 assert(ScopeInfo && "missing function scope for function"); 414 415 if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit) 416 ScopeInfo->setFirstCoroutineStmt(Loc, Keyword); 417 418 if (ScopeInfo->CoroutinePromise) 419 return ScopeInfo; 420 421 ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); 422 if (!ScopeInfo->CoroutinePromise) 423 return nullptr; 424 425 return ScopeInfo; 426 } 427 428 static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, 429 StringRef Keyword) { 430 if (!checkCoroutineContext(S, KWLoc, Keyword)) 431 return false; 432 auto *ScopeInfo = S.getCurFunction(); 433 assert(ScopeInfo->CoroutinePromise); 434 435 // If we have existing coroutine statements then we have already built 436 // the initial and final suspend points. 437 if (!ScopeInfo->NeedsCoroutineSuspends) 438 return true; 439 440 ScopeInfo->setNeedsCoroutineSuspends(false); 441 442 auto *Fn = cast<FunctionDecl>(S.CurContext); 443 SourceLocation Loc = Fn->getLocation(); 444 // Build the initial suspend point 445 auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { 446 ExprResult Suspend = 447 buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); 448 if (Suspend.isInvalid()) 449 return StmtError(); 450 Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); 451 if (Suspend.isInvalid()) 452 return StmtError(); 453 Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(), 454 /*IsImplicit*/ true); 455 Suspend = S.ActOnFinishFullExpr(Suspend.get()); 456 if (Suspend.isInvalid()) { 457 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 458 << ((Name == "initial_suspend") ? 0 : 1); 459 S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; 460 return StmtError(); 461 } 462 return cast<Stmt>(Suspend.get()); 463 }; 464 465 StmtResult InitSuspend = buildSuspends("initial_suspend"); 466 if (InitSuspend.isInvalid()) 467 return true; 468 469 StmtResult FinalSuspend = buildSuspends("final_suspend"); 470 if (FinalSuspend.isInvalid()) 471 return true; 472 473 ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); 474 475 return true; 476 } 477 478 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 479 if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { 480 CorrectDelayedTyposInExpr(E); 481 return ExprError(); 482 } 483 484 if (E->getType()->isPlaceholderType()) { 485 ExprResult R = CheckPlaceholderExpr(E); 486 if (R.isInvalid()) return ExprError(); 487 E = R.get(); 488 } 489 ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); 490 if (Lookup.isInvalid()) 491 return ExprError(); 492 return BuildUnresolvedCoawaitExpr(Loc, E, 493 cast<UnresolvedLookupExpr>(Lookup.get())); 494 } 495 496 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, 497 UnresolvedLookupExpr *Lookup) { 498 auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); 499 if (!FSI) 500 return ExprError(); 501 502 if (E->getType()->isPlaceholderType()) { 503 ExprResult R = CheckPlaceholderExpr(E); 504 if (R.isInvalid()) 505 return ExprError(); 506 E = R.get(); 507 } 508 509 auto *Promise = FSI->CoroutinePromise; 510 if (Promise->getType()->isDependentType()) { 511 Expr *Res = 512 new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); 513 return Res; 514 } 515 516 auto *RD = Promise->getType()->getAsCXXRecordDecl(); 517 if (lookupMember(*this, "await_transform", RD, Loc)) { 518 ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); 519 if (R.isInvalid()) { 520 Diag(Loc, 521 diag::note_coroutine_promise_implicit_await_transform_required_here) 522 << E->getSourceRange(); 523 return ExprError(); 524 } 525 E = R.get(); 526 } 527 ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); 528 if (Awaitable.isInvalid()) 529 return ExprError(); 530 531 return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); 532 } 533 534 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, 535 bool IsImplicit) { 536 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit); 537 if (!Coroutine) 538 return ExprError(); 539 540 if (E->getType()->isPlaceholderType()) { 541 ExprResult R = CheckPlaceholderExpr(E); 542 if (R.isInvalid()) return ExprError(); 543 E = R.get(); 544 } 545 546 if (E->getType()->isDependentType()) { 547 Expr *Res = new (Context) 548 CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); 549 return Res; 550 } 551 552 // If the expression is a temporary, materialize it as an lvalue so that we 553 // can use it multiple times. 554 if (E->getValueKind() == VK_RValue) 555 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 556 557 // Build the await_ready, await_suspend, await_resume calls. 558 ReadySuspendResumeResult RSS = 559 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 560 if (RSS.IsInvalid) 561 return ExprError(); 562 563 Expr *Res = 564 new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 565 RSS.Results[2], RSS.OpaqueValue, IsImplicit); 566 567 return Res; 568 } 569 570 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 571 if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { 572 CorrectDelayedTyposInExpr(E); 573 return ExprError(); 574 } 575 576 // Build yield_value call. 577 ExprResult Awaitable = buildPromiseCall( 578 *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); 579 if (Awaitable.isInvalid()) 580 return ExprError(); 581 582 // Build 'operator co_await' call. 583 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get()); 584 if (Awaitable.isInvalid()) 585 return ExprError(); 586 587 return BuildCoyieldExpr(Loc, Awaitable.get()); 588 } 589 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 590 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 591 if (!Coroutine) 592 return ExprError(); 593 594 if (E->getType()->isPlaceholderType()) { 595 ExprResult R = CheckPlaceholderExpr(E); 596 if (R.isInvalid()) return ExprError(); 597 E = R.get(); 598 } 599 600 if (E->getType()->isDependentType()) { 601 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); 602 return Res; 603 } 604 605 // If the expression is a temporary, materialize it as an lvalue so that we 606 // can use it multiple times. 607 if (E->getValueKind() == VK_RValue) 608 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 609 610 // Build the await_ready, await_suspend, await_resume calls. 611 ReadySuspendResumeResult RSS = 612 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 613 if (RSS.IsInvalid) 614 return ExprError(); 615 616 Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], 617 RSS.Results[2], RSS.OpaqueValue); 618 619 return Res; 620 } 621 622 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { 623 if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { 624 CorrectDelayedTyposInExpr(E); 625 return StmtError(); 626 } 627 return BuildCoreturnStmt(Loc, E); 628 } 629 630 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, 631 bool IsImplicit) { 632 auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit); 633 if (!FSI) 634 return StmtError(); 635 636 if (E && E->getType()->isPlaceholderType() && 637 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { 638 ExprResult R = CheckPlaceholderExpr(E); 639 if (R.isInvalid()) return StmtError(); 640 E = R.get(); 641 } 642 643 // FIXME: If the operand is a reference to a variable that's about to go out 644 // of scope, we should treat the operand as an xvalue for this overload 645 // resolution. 646 VarDecl *Promise = FSI->CoroutinePromise; 647 ExprResult PC; 648 if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { 649 PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); 650 } else { 651 E = MakeFullDiscardedValueExpr(E).get(); 652 PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); 653 } 654 if (PC.isInvalid()) 655 return StmtError(); 656 657 Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); 658 659 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); 660 return Res; 661 } 662 663 // Find an appropriate delete for the promise. 664 static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, 665 QualType PromiseType) { 666 FunctionDecl *OperatorDelete = nullptr; 667 668 DeclarationName DeleteName = 669 S.Context.DeclarationNames.getCXXOperatorName(OO_Delete); 670 671 auto *PointeeRD = PromiseType->getAsCXXRecordDecl(); 672 assert(PointeeRD && "PromiseType must be a CxxRecordDecl type"); 673 674 if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete)) 675 return nullptr; 676 677 if (!OperatorDelete) { 678 // Look for a global declaration. 679 const bool CanProvideSize = S.isCompleteType(Loc, PromiseType); 680 const bool Overaligned = false; 681 OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize, 682 Overaligned, DeleteName); 683 } 684 S.MarkFunctionReferenced(Loc, OperatorDelete); 685 return OperatorDelete; 686 } 687 688 689 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { 690 FunctionScopeInfo *Fn = getCurFunction(); 691 assert(Fn && Fn->CoroutinePromise && "not a coroutine"); 692 693 if (!Body) { 694 assert(FD->isInvalidDecl() && 695 "a null body is only allowed for invalid declarations"); 696 return; 697 } 698 699 if (isa<CoroutineBodyStmt>(Body)) { 700 // FIXME(EricWF): Nothing todo. the body is already a transformed coroutine 701 // body statement. 702 return; 703 } 704 705 // Coroutines [stmt.return]p1: 706 // A return statement shall not appear in a coroutine. 707 if (Fn->FirstReturnLoc.isValid()) { 708 assert(Fn->FirstCoroutineStmtLoc.isValid() && 709 "first coroutine location not set"); 710 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 711 Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 712 << Fn->getFirstCoroutineStmtKeyword(); 713 } 714 CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body); 715 if (Builder.isInvalid() || !Builder.buildStatements()) 716 return FD->setInvalidDecl(); 717 718 // Build body for the coroutine wrapper statement. 719 Body = CoroutineBodyStmt::Create(Context, Builder); 720 } 721 722 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, 723 sema::FunctionScopeInfo &Fn, 724 Stmt *Body) 725 : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), 726 IsPromiseDependentType( 727 !Fn.CoroutinePromise || 728 Fn.CoroutinePromise->getType()->isDependentType()) { 729 this->Body = Body; 730 if (!IsPromiseDependentType) { 731 PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); 732 assert(PromiseRecordDecl && "Type should have already been checked"); 733 } 734 this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend(); 735 } 736 737 bool CoroutineStmtBuilder::buildStatements() { 738 assert(this->IsValid && "coroutine already invalid"); 739 this->IsValid = makeReturnObject() && makeParamMoves(); 740 if (this->IsValid && !IsPromiseDependentType) 741 buildDependentStatements(); 742 return this->IsValid; 743 } 744 745 bool CoroutineStmtBuilder::buildDependentStatements() { 746 assert(this->IsValid && "coroutine already invalid"); 747 assert(!this->IsPromiseDependentType && 748 "coroutine cannot have a dependent promise type"); 749 this->IsValid = makeOnException() && makeOnFallthrough() && 750 makeReturnOnAllocFailure() && makeNewAndDeleteExpr(); 751 return this->IsValid; 752 } 753 754 bool CoroutineStmtBuilder::makePromiseStmt() { 755 // Form a declaration statement for the promise declaration, so that AST 756 // visitors can more easily find it. 757 StmtResult PromiseStmt = 758 S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); 759 if (PromiseStmt.isInvalid()) 760 return false; 761 762 this->Promise = PromiseStmt.get(); 763 return true; 764 } 765 766 bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() { 767 if (Fn.hasInvalidCoroutineSuspends()) 768 return false; 769 this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first); 770 this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second); 771 return true; 772 } 773 774 static bool diagReturnOnAllocFailure(Sema &S, Expr *E, 775 CXXRecordDecl *PromiseRecordDecl, 776 FunctionScopeInfo &Fn) { 777 auto Loc = E->getExprLoc(); 778 if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) { 779 auto *Decl = DeclRef->getDecl(); 780 if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) { 781 if (Method->isStatic()) 782 return true; 783 else 784 Loc = Decl->getLocation(); 785 } 786 } 787 788 S.Diag( 789 Loc, 790 diag::err_coroutine_promise_get_return_object_on_allocation_failure) 791 << PromiseRecordDecl; 792 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 793 << Fn.getFirstCoroutineStmtKeyword(); 794 return false; 795 } 796 797 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() { 798 assert(!IsPromiseDependentType && 799 "cannot make statement while the promise type is dependent"); 800 801 // [dcl.fct.def.coroutine]/8 802 // The unqualified-id get_return_object_on_allocation_failure is looked up in 803 // the scope of class P by class member access lookup (3.4.5). ... 804 // If an allocation function returns nullptr, ... the coroutine return value 805 // is obtained by a call to ... get_return_object_on_allocation_failure(). 806 807 DeclarationName DN = 808 S.PP.getIdentifierInfo("get_return_object_on_allocation_failure"); 809 LookupResult Found(S, DN, Loc, Sema::LookupMemberName); 810 if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) 811 return true; 812 813 CXXScopeSpec SS; 814 ExprResult DeclNameExpr = 815 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 816 if (DeclNameExpr.isInvalid()) 817 return false; 818 819 if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)) 820 return false; 821 822 ExprResult ReturnObjectOnAllocationFailure = 823 S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc); 824 if (ReturnObjectOnAllocationFailure.isInvalid()) 825 return false; 826 827 // FIXME: ActOnReturnStmt expects a scope that is inside of the function, due 828 // to CheckJumpOutOfSEHFinally(*this, ReturnLoc, *CurScope->getFnParent()); 829 // S.getCurScope()->getFnParent() == nullptr at ActOnFinishFunctionBody when 830 // CoroutineBodyStmt is built. Figure it out and fix it. 831 // Use BuildReturnStmt here to unbreak sanitized tests. (Gor:3/27/2017) 832 StmtResult ReturnStmt = 833 S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get()); 834 if (ReturnStmt.isInvalid()) 835 return false; 836 837 this->ReturnStmtOnAllocFailure = ReturnStmt.get(); 838 return true; 839 } 840 841 bool CoroutineStmtBuilder::makeNewAndDeleteExpr() { 842 // Form and check allocation and deallocation calls. 843 assert(!IsPromiseDependentType && 844 "cannot make statement while the promise type is dependent"); 845 QualType PromiseType = Fn.CoroutinePromise->getType(); 846 847 if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) 848 return false; 849 850 // FIXME: Add nothrow_t placement arg for global alloc 851 // if ReturnStmtOnAllocFailure != nullptr. 852 // FIXME: Add support for stateful allocators. 853 854 FunctionDecl *OperatorNew = nullptr; 855 FunctionDecl *OperatorDelete = nullptr; 856 FunctionDecl *UnusedResult = nullptr; 857 bool PassAlignment = false; 858 859 S.FindAllocationFunctions(Loc, SourceRange(), 860 /*UseGlobal*/ false, PromiseType, 861 /*isArray*/ false, PassAlignment, 862 /*PlacementArgs*/ None, OperatorNew, UnusedResult); 863 864 OperatorDelete = findDeleteForPromise(S, Loc, PromiseType); 865 866 if (!OperatorDelete || !OperatorNew) 867 return false; 868 869 Expr *FramePtr = 870 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 871 872 Expr *FrameSize = 873 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); 874 875 // Make new call. 876 877 ExprResult NewRef = 878 S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); 879 if (NewRef.isInvalid()) 880 return false; 881 882 ExprResult NewExpr = 883 S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, FrameSize, Loc); 884 if (NewExpr.isInvalid()) 885 return false; 886 887 // Make delete call. 888 889 QualType OpDeleteQualType = OperatorDelete->getType(); 890 891 ExprResult DeleteRef = 892 S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc); 893 if (DeleteRef.isInvalid()) 894 return false; 895 896 Expr *CoroFree = 897 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); 898 899 SmallVector<Expr *, 2> DeleteArgs{CoroFree}; 900 901 // Check if we need to pass the size. 902 const auto *OpDeleteType = 903 OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>(); 904 if (OpDeleteType->getNumParams() > 1) 905 DeleteArgs.push_back(FrameSize); 906 907 ExprResult DeleteExpr = 908 S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc); 909 if (DeleteExpr.isInvalid()) 910 return false; 911 912 this->Allocate = NewExpr.get(); 913 this->Deallocate = DeleteExpr.get(); 914 915 return true; 916 } 917 918 bool CoroutineStmtBuilder::makeOnFallthrough() { 919 assert(!IsPromiseDependentType && 920 "cannot make statement while the promise type is dependent"); 921 922 // [dcl.fct.def.coroutine]/4 923 // The unqualified-ids 'return_void' and 'return_value' are looked up in 924 // the scope of class P. If both are found, the program is ill-formed. 925 const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc); 926 const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc); 927 928 StmtResult Fallthrough; 929 if (HasRVoid && HasRValue) { 930 // FIXME Improve this diagnostic 931 S.Diag(FD.getLocation(), diag::err_coroutine_promise_return_ill_formed) 932 << PromiseRecordDecl; 933 return false; 934 } else if (HasRVoid) { 935 // If the unqualified-id return_void is found, flowing off the end of a 936 // coroutine is equivalent to a co_return with no operand. Otherwise, 937 // flowing off the end of a coroutine results in undefined behavior. 938 Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, 939 /*IsImplicit*/false); 940 Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); 941 if (Fallthrough.isInvalid()) 942 return false; 943 } 944 945 this->OnFallthrough = Fallthrough.get(); 946 return true; 947 } 948 949 bool CoroutineStmtBuilder::makeOnException() { 950 // Try to form 'p.unhandled_exception();' 951 assert(!IsPromiseDependentType && 952 "cannot make statement while the promise type is dependent"); 953 954 const bool RequireUnhandledException = S.getLangOpts().CXXExceptions; 955 956 if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) { 957 auto DiagID = 958 RequireUnhandledException 959 ? diag::err_coroutine_promise_unhandled_exception_required 960 : diag:: 961 warn_coroutine_promise_unhandled_exception_required_with_exceptions; 962 S.Diag(Loc, DiagID) << PromiseRecordDecl; 963 return !RequireUnhandledException; 964 } 965 966 // If exceptions are disabled, don't try to build OnException. 967 if (!S.getLangOpts().CXXExceptions) 968 return true; 969 970 ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, 971 "unhandled_exception", None); 972 UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc); 973 if (UnhandledException.isInvalid()) 974 return false; 975 976 this->OnException = UnhandledException.get(); 977 return true; 978 } 979 980 bool CoroutineStmtBuilder::makeReturnObject() { 981 982 // Build implicit 'p.get_return_object()' expression and form initialization 983 // of return type from it. 984 ExprResult ReturnObject = 985 buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); 986 if (ReturnObject.isInvalid()) 987 return false; 988 QualType RetType = FD.getReturnType(); 989 if (!RetType->isDependentType()) { 990 InitializedEntity Entity = 991 InitializedEntity::InitializeResult(Loc, RetType, false); 992 ReturnObject = S.PerformMoveOrCopyInitialization(Entity, nullptr, RetType, 993 ReturnObject.get()); 994 if (ReturnObject.isInvalid()) 995 return false; 996 } 997 ReturnObject = S.ActOnFinishFullExpr(ReturnObject.get(), Loc); 998 if (ReturnObject.isInvalid()) 999 return false; 1000 1001 this->ReturnValue = ReturnObject.get(); 1002 return true; 1003 } 1004 1005 bool CoroutineStmtBuilder::makeParamMoves() { 1006 // FIXME: Perform move-initialization of parameters into frame-local copies. 1007 return true; 1008 } 1009 1010 StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { 1011 CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); 1012 if (!Res) 1013 return StmtError(); 1014 return Res; 1015 } 1016