1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements semantic analysis for C++ Coroutines. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CoroutineStmtBuilder.h" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/ExprCXX.h" 17 #include "clang/AST/StmtCXX.h" 18 #include "clang/Lex/Preprocessor.h" 19 #include "clang/Sema/Initialization.h" 20 #include "clang/Sema/Overload.h" 21 #include "clang/Sema/SemaInternal.h" 22 23 using namespace clang; 24 using namespace sema; 25 26 static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, 27 SourceLocation Loc) { 28 DeclarationName DN = S.PP.getIdentifierInfo(Name); 29 LookupResult LR(S, DN, Loc, Sema::LookupMemberName); 30 // Suppress diagnostics when a private member is selected. The same warnings 31 // will be produced again when building the call. 32 LR.suppressDiagnostics(); 33 return S.LookupQualifiedName(LR, RD); 34 } 35 36 /// Look up the std::coroutine_traits<...>::promise_type for the given 37 /// function type. 38 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, 39 SourceLocation KwLoc, 40 SourceLocation FuncLoc) { 41 // FIXME: Cache std::coroutine_traits once we've found it. 42 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 43 if (!StdExp) { 44 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 45 << "std::experimental::coroutine_traits"; 46 return QualType(); 47 } 48 49 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), 50 FuncLoc, Sema::LookupOrdinaryName); 51 if (!S.LookupQualifiedName(Result, StdExp)) { 52 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 53 << "std::experimental::coroutine_traits"; 54 return QualType(); 55 } 56 57 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>(); 58 if (!CoroTraits) { 59 Result.suppressDiagnostics(); 60 // We found something weird. Complain about the first thing we found. 61 NamedDecl *Found = *Result.begin(); 62 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 63 return QualType(); 64 } 65 66 // Form template argument list for coroutine_traits<R, P1, P2, ...>. 67 TemplateArgumentListInfo Args(KwLoc, KwLoc); 68 Args.addArgument(TemplateArgumentLoc( 69 TemplateArgument(FnType->getReturnType()), 70 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc))); 71 // FIXME: If the function is a non-static member function, add the type 72 // of the implicit object parameter before the formal parameters. 73 for (QualType T : FnType->getParamTypes()) 74 Args.addArgument(TemplateArgumentLoc( 75 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); 76 77 // Build the template-id. 78 QualType CoroTrait = 79 S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); 80 if (CoroTrait.isNull()) 81 return QualType(); 82 if (S.RequireCompleteType(KwLoc, CoroTrait, 83 diag::err_coroutine_type_missing_specialization)) 84 return QualType(); 85 86 auto *RD = CoroTrait->getAsCXXRecordDecl(); 87 assert(RD && "specialization of class template is not a class?"); 88 89 // Look up the ::promise_type member. 90 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, 91 Sema::LookupOrdinaryName); 92 S.LookupQualifiedName(R, RD); 93 auto *Promise = R.getAsSingle<TypeDecl>(); 94 if (!Promise) { 95 S.Diag(FuncLoc, 96 diag::err_implied_std_coroutine_traits_promise_type_not_found) 97 << RD; 98 return QualType(); 99 } 100 // The promise type is required to be a class type. 101 QualType PromiseType = S.Context.getTypeDeclType(Promise); 102 103 auto buildElaboratedType = [&]() { 104 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); 105 NNS = NestedNameSpecifier::Create(S.Context, NNS, false, 106 CoroTrait.getTypePtr()); 107 return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); 108 }; 109 110 if (!PromiseType->getAsCXXRecordDecl()) { 111 S.Diag(FuncLoc, 112 diag::err_implied_std_coroutine_traits_promise_type_not_class) 113 << buildElaboratedType(); 114 return QualType(); 115 } 116 if (S.RequireCompleteType(FuncLoc, buildElaboratedType(), 117 diag::err_coroutine_promise_type_incomplete)) 118 return QualType(); 119 120 return PromiseType; 121 } 122 123 /// Look up the std::coroutine_traits<...>::promise_type for the given 124 /// function type. 125 static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, 126 SourceLocation Loc) { 127 if (PromiseType.isNull()) 128 return QualType(); 129 130 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 131 assert(StdExp && "Should already be diagnosed"); 132 133 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"), 134 Loc, Sema::LookupOrdinaryName); 135 if (!S.LookupQualifiedName(Result, StdExp)) { 136 S.Diag(Loc, diag::err_implied_coroutine_type_not_found) 137 << "std::experimental::coroutine_handle"; 138 return QualType(); 139 } 140 141 ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); 142 if (!CoroHandle) { 143 Result.suppressDiagnostics(); 144 // We found something weird. Complain about the first thing we found. 145 NamedDecl *Found = *Result.begin(); 146 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); 147 return QualType(); 148 } 149 150 // Form template argument list for coroutine_handle<Promise>. 151 TemplateArgumentListInfo Args(Loc, Loc); 152 Args.addArgument(TemplateArgumentLoc( 153 TemplateArgument(PromiseType), 154 S.Context.getTrivialTypeSourceInfo(PromiseType, Loc))); 155 156 // Build the template-id. 157 QualType CoroHandleType = 158 S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args); 159 if (CoroHandleType.isNull()) 160 return QualType(); 161 if (S.RequireCompleteType(Loc, CoroHandleType, 162 diag::err_coroutine_type_missing_specialization)) 163 return QualType(); 164 165 return CoroHandleType; 166 } 167 168 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, 169 StringRef Keyword) { 170 // 'co_await' and 'co_yield' are not permitted in unevaluated operands. 171 if (S.isUnevaluatedContext()) { 172 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword; 173 return false; 174 } 175 176 // Any other usage must be within a function. 177 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 178 if (!FD) { 179 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 180 ? diag::err_coroutine_objc_method 181 : diag::err_coroutine_outside_function) << Keyword; 182 return false; 183 } 184 185 // An enumeration for mapping the diagnostic type to the correct diagnostic 186 // selection index. 187 enum InvalidFuncDiag { 188 DiagCtor = 0, 189 DiagDtor, 190 DiagCopyAssign, 191 DiagMoveAssign, 192 DiagMain, 193 DiagConstexpr, 194 DiagAutoRet, 195 DiagVarargs, 196 }; 197 bool Diagnosed = false; 198 auto DiagInvalid = [&](InvalidFuncDiag ID) { 199 S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword; 200 Diagnosed = true; 201 return false; 202 }; 203 204 // Diagnose when a constructor, destructor, copy/move assignment operator, 205 // or the function 'main' are declared as a coroutine. 206 auto *MD = dyn_cast<CXXMethodDecl>(FD); 207 if (MD && isa<CXXConstructorDecl>(MD)) 208 return DiagInvalid(DiagCtor); 209 else if (MD && isa<CXXDestructorDecl>(MD)) 210 return DiagInvalid(DiagDtor); 211 else if (MD && MD->isCopyAssignmentOperator()) 212 return DiagInvalid(DiagCopyAssign); 213 else if (MD && MD->isMoveAssignmentOperator()) 214 return DiagInvalid(DiagMoveAssign); 215 else if (FD->isMain()) 216 return DiagInvalid(DiagMain); 217 218 // Emit a diagnostics for each of the following conditions which is not met. 219 if (FD->isConstexpr()) 220 DiagInvalid(DiagConstexpr); 221 if (FD->getReturnType()->isUndeducedType()) 222 DiagInvalid(DiagAutoRet); 223 if (FD->isVariadic()) 224 DiagInvalid(DiagVarargs); 225 226 return !Diagnosed; 227 } 228 229 static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, 230 SourceLocation Loc) { 231 DeclarationName OpName = 232 SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); 233 LookupResult Operators(SemaRef, OpName, SourceLocation(), 234 Sema::LookupOperatorName); 235 SemaRef.LookupName(Operators, S); 236 237 assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); 238 const auto &Functions = Operators.asUnresolvedSet(); 239 bool IsOverloaded = 240 Functions.size() > 1 || 241 (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); 242 Expr *CoawaitOp = UnresolvedLookupExpr::Create( 243 SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), 244 DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, 245 Functions.begin(), Functions.end()); 246 assert(CoawaitOp); 247 return CoawaitOp; 248 } 249 250 /// Build a call to 'operator co_await' if there is a suitable operator for 251 /// the given expression. 252 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, 253 Expr *E, 254 UnresolvedLookupExpr *Lookup) { 255 UnresolvedSet<16> Functions; 256 Functions.append(Lookup->decls_begin(), Lookup->decls_end()); 257 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 258 } 259 260 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 261 SourceLocation Loc, Expr *E) { 262 ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); 263 if (R.isInvalid()) 264 return ExprError(); 265 return buildOperatorCoawaitCall(SemaRef, Loc, E, 266 cast<UnresolvedLookupExpr>(R.get())); 267 } 268 269 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, 270 MultiExprArg CallArgs) { 271 StringRef Name = S.Context.BuiltinInfo.getName(Id); 272 LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); 273 S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); 274 275 auto *BuiltInDecl = R.getAsSingle<FunctionDecl>(); 276 assert(BuiltInDecl && "failed to find builtin declaration"); 277 278 ExprResult DeclRef = 279 S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc); 280 assert(DeclRef.isUsable() && "Builtin reference cannot fail"); 281 282 ExprResult Call = 283 S.ActOnCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc); 284 285 assert(!Call.isInvalid() && "Call to builtin cannot fail!"); 286 return Call.get(); 287 } 288 289 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType, 290 SourceLocation Loc) { 291 QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc); 292 if (CoroHandleType.isNull()) 293 return ExprError(); 294 295 DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType); 296 LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc, 297 Sema::LookupOrdinaryName); 298 if (!S.LookupQualifiedName(Found, LookupCtx)) { 299 S.Diag(Loc, diag::err_coroutine_handle_missing_member) 300 << "from_address"; 301 return ExprError(); 302 } 303 304 Expr *FramePtr = 305 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 306 307 CXXScopeSpec SS; 308 ExprResult FromAddr = 309 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 310 if (FromAddr.isInvalid()) 311 return ExprError(); 312 313 return S.ActOnCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc); 314 } 315 316 struct ReadySuspendResumeResult { 317 Expr *Results[3]; 318 OpaqueValueExpr *OpaqueValue; 319 bool IsInvalid; 320 }; 321 322 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, 323 StringRef Name, MultiExprArg Args) { 324 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); 325 326 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 327 CXXScopeSpec SS; 328 ExprResult Result = S.BuildMemberReferenceExpr( 329 Base, Base->getType(), Loc, /*IsPtr=*/false, SS, 330 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 331 /*Scope=*/nullptr); 332 if (Result.isInvalid()) 333 return ExprError(); 334 335 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); 336 } 337 338 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 339 /// expression. 340 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, 341 SourceLocation Loc, Expr *E) { 342 OpaqueValueExpr *Operand = new (S.Context) 343 OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E); 344 345 // Assume invalid until we see otherwise. 346 ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true}; 347 348 ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc); 349 if (CoroHandleRes.isInvalid()) 350 return Calls; 351 Expr *CoroHandle = CoroHandleRes.get(); 352 353 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 354 MultiExprArg Args[] = {None, CoroHandle, None}; 355 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 356 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]); 357 if (Result.isInvalid()) 358 return Calls; 359 Calls.Results[I] = Result.get(); 360 } 361 362 Calls.IsInvalid = false; 363 return Calls; 364 } 365 366 static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, 367 SourceLocation Loc, StringRef Name, 368 MultiExprArg Args) { 369 370 // Form a reference to the promise. 371 ExprResult PromiseRef = S.BuildDeclRefExpr( 372 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); 373 if (PromiseRef.isInvalid()) 374 return ExprError(); 375 376 // Call 'yield_value', passing in E. 377 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); 378 } 379 380 VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { 381 assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); 382 auto *FD = cast<FunctionDecl>(CurContext); 383 384 QualType T = 385 FD->getType()->isDependentType() 386 ? Context.DependentTy 387 : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(), 388 Loc, FD->getLocation()); 389 if (T.isNull()) 390 return nullptr; 391 392 auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), 393 &PP.getIdentifierTable().get("__promise"), T, 394 Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 395 CheckVariableDeclarationType(VD); 396 if (VD->isInvalidDecl()) 397 return nullptr; 398 ActOnUninitializedDecl(VD); 399 assert(!VD->isInvalidDecl()); 400 return VD; 401 } 402 403 /// Check that this is a context in which a coroutine suspension can appear. 404 static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, 405 StringRef Keyword, 406 bool IsImplicit = false) { 407 if (!isValidCoroutineContext(S, Loc, Keyword)) 408 return nullptr; 409 410 assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); 411 412 auto *ScopeInfo = S.getCurFunction(); 413 assert(ScopeInfo && "missing function scope for function"); 414 415 if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit) 416 ScopeInfo->setFirstCoroutineStmt(Loc, Keyword); 417 418 if (ScopeInfo->CoroutinePromise) 419 return ScopeInfo; 420 421 ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); 422 if (!ScopeInfo->CoroutinePromise) 423 return nullptr; 424 425 return ScopeInfo; 426 } 427 428 static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, 429 StringRef Keyword) { 430 if (!checkCoroutineContext(S, KWLoc, Keyword)) 431 return false; 432 auto *ScopeInfo = S.getCurFunction(); 433 assert(ScopeInfo->CoroutinePromise); 434 435 // If we have existing coroutine statements then we have already built 436 // the initial and final suspend points. 437 if (!ScopeInfo->NeedsCoroutineSuspends) 438 return true; 439 440 ScopeInfo->setNeedsCoroutineSuspends(false); 441 442 auto *Fn = cast<FunctionDecl>(S.CurContext); 443 SourceLocation Loc = Fn->getLocation(); 444 // Build the initial suspend point 445 auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { 446 ExprResult Suspend = 447 buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); 448 if (Suspend.isInvalid()) 449 return StmtError(); 450 Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); 451 if (Suspend.isInvalid()) 452 return StmtError(); 453 Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(), 454 /*IsImplicit*/ true); 455 Suspend = S.ActOnFinishFullExpr(Suspend.get()); 456 if (Suspend.isInvalid()) { 457 S.Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required) 458 << ((Name == "initial_suspend") ? 0 : 1); 459 S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; 460 return StmtError(); 461 } 462 return cast<Stmt>(Suspend.get()); 463 }; 464 465 StmtResult InitSuspend = buildSuspends("initial_suspend"); 466 if (InitSuspend.isInvalid()) 467 return true; 468 469 StmtResult FinalSuspend = buildSuspends("final_suspend"); 470 if (FinalSuspend.isInvalid()) 471 return true; 472 473 ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); 474 475 return true; 476 } 477 478 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 479 if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { 480 CorrectDelayedTyposInExpr(E); 481 return ExprError(); 482 } 483 484 if (E->getType()->isPlaceholderType()) { 485 ExprResult R = CheckPlaceholderExpr(E); 486 if (R.isInvalid()) return ExprError(); 487 E = R.get(); 488 } 489 ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); 490 if (Lookup.isInvalid()) 491 return ExprError(); 492 return BuildUnresolvedCoawaitExpr(Loc, E, 493 cast<UnresolvedLookupExpr>(Lookup.get())); 494 } 495 496 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, 497 UnresolvedLookupExpr *Lookup) { 498 auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); 499 if (!FSI) 500 return ExprError(); 501 502 if (E->getType()->isPlaceholderType()) { 503 ExprResult R = CheckPlaceholderExpr(E); 504 if (R.isInvalid()) 505 return ExprError(); 506 E = R.get(); 507 } 508 509 auto *Promise = FSI->CoroutinePromise; 510 if (Promise->getType()->isDependentType()) { 511 Expr *Res = 512 new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); 513 return Res; 514 } 515 516 auto *RD = Promise->getType()->getAsCXXRecordDecl(); 517 if (lookupMember(*this, "await_transform", RD, Loc)) { 518 ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); 519 if (R.isInvalid()) { 520 Diag(Loc, 521 diag::note_coroutine_promise_implicit_await_transform_required_here) 522 << E->getSourceRange(); 523 return ExprError(); 524 } 525 E = R.get(); 526 } 527 ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); 528 if (Awaitable.isInvalid()) 529 return ExprError(); 530 531 return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); 532 } 533 534 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, 535 bool IsImplicit) { 536 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit); 537 if (!Coroutine) 538 return ExprError(); 539 540 if (E->getType()->isPlaceholderType()) { 541 ExprResult R = CheckPlaceholderExpr(E); 542 if (R.isInvalid()) return ExprError(); 543 E = R.get(); 544 } 545 546 if (E->getType()->isDependentType()) { 547 Expr *Res = new (Context) 548 CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); 549 return Res; 550 } 551 552 // If the expression is a temporary, materialize it as an lvalue so that we 553 // can use it multiple times. 554 if (E->getValueKind() == VK_RValue) 555 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 556 557 // Build the await_ready, await_suspend, await_resume calls. 558 ReadySuspendResumeResult RSS = 559 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 560 if (RSS.IsInvalid) 561 return ExprError(); 562 563 Expr *Res = 564 new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 565 RSS.Results[2], RSS.OpaqueValue, IsImplicit); 566 567 return Res; 568 } 569 570 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 571 if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { 572 CorrectDelayedTyposInExpr(E); 573 return ExprError(); 574 } 575 576 // Build yield_value call. 577 ExprResult Awaitable = buildPromiseCall( 578 *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); 579 if (Awaitable.isInvalid()) 580 return ExprError(); 581 582 // Build 'operator co_await' call. 583 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get()); 584 if (Awaitable.isInvalid()) 585 return ExprError(); 586 587 return BuildCoyieldExpr(Loc, Awaitable.get()); 588 } 589 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 590 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 591 if (!Coroutine) 592 return ExprError(); 593 594 if (E->getType()->isPlaceholderType()) { 595 ExprResult R = CheckPlaceholderExpr(E); 596 if (R.isInvalid()) return ExprError(); 597 E = R.get(); 598 } 599 600 if (E->getType()->isDependentType()) { 601 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); 602 return Res; 603 } 604 605 // If the expression is a temporary, materialize it as an lvalue so that we 606 // can use it multiple times. 607 if (E->getValueKind() == VK_RValue) 608 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 609 610 // Build the await_ready, await_suspend, await_resume calls. 611 ReadySuspendResumeResult RSS = 612 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 613 if (RSS.IsInvalid) 614 return ExprError(); 615 616 Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], 617 RSS.Results[2], RSS.OpaqueValue); 618 619 return Res; 620 } 621 622 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { 623 if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { 624 CorrectDelayedTyposInExpr(E); 625 return StmtError(); 626 } 627 return BuildCoreturnStmt(Loc, E); 628 } 629 630 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, 631 bool IsImplicit) { 632 auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit); 633 if (!FSI) 634 return StmtError(); 635 636 if (E && E->getType()->isPlaceholderType() && 637 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { 638 ExprResult R = CheckPlaceholderExpr(E); 639 if (R.isInvalid()) return StmtError(); 640 E = R.get(); 641 } 642 643 // FIXME: If the operand is a reference to a variable that's about to go out 644 // of scope, we should treat the operand as an xvalue for this overload 645 // resolution. 646 VarDecl *Promise = FSI->CoroutinePromise; 647 ExprResult PC; 648 if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { 649 PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); 650 } else { 651 E = MakeFullDiscardedValueExpr(E).get(); 652 PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); 653 } 654 if (PC.isInvalid()) 655 return StmtError(); 656 657 Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); 658 659 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); 660 return Res; 661 } 662 663 /// Look up the std::nothrow object. 664 static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) { 665 NamespaceDecl *Std = S.getStdNamespace(); 666 assert(Std && "Should already be diagnosed"); 667 668 LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc, 669 Sema::LookupOrdinaryName); 670 if (!S.LookupQualifiedName(Result, Std)) { 671 // FIXME: <experimental/coroutine> should have been included already. 672 // If we require it to include <new> then this diagnostic is no longer 673 // needed. 674 S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found); 675 return nullptr; 676 } 677 678 // FIXME: Mark the variable as ODR used. This currently does not work 679 // likely due to the scope at in which this function is called. 680 auto *VD = Result.getAsSingle<VarDecl>(); 681 if (!VD) { 682 Result.suppressDiagnostics(); 683 // We found something weird. Complain about the first thing we found. 684 NamedDecl *Found = *Result.begin(); 685 S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow); 686 return nullptr; 687 } 688 689 ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc); 690 if (DR.isInvalid()) 691 return nullptr; 692 693 return DR.get(); 694 } 695 696 // Find an appropriate delete for the promise. 697 static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, 698 QualType PromiseType) { 699 FunctionDecl *OperatorDelete = nullptr; 700 701 DeclarationName DeleteName = 702 S.Context.DeclarationNames.getCXXOperatorName(OO_Delete); 703 704 auto *PointeeRD = PromiseType->getAsCXXRecordDecl(); 705 assert(PointeeRD && "PromiseType must be a CxxRecordDecl type"); 706 707 if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete)) 708 return nullptr; 709 710 if (!OperatorDelete) { 711 // Look for a global declaration. 712 const bool CanProvideSize = S.isCompleteType(Loc, PromiseType); 713 const bool Overaligned = false; 714 OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize, 715 Overaligned, DeleteName); 716 } 717 S.MarkFunctionReferenced(Loc, OperatorDelete); 718 return OperatorDelete; 719 } 720 721 722 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { 723 FunctionScopeInfo *Fn = getCurFunction(); 724 assert(Fn && Fn->CoroutinePromise && "not a coroutine"); 725 726 if (!Body) { 727 assert(FD->isInvalidDecl() && 728 "a null body is only allowed for invalid declarations"); 729 return; 730 } 731 732 if (isa<CoroutineBodyStmt>(Body)) { 733 // FIXME(EricWF): Nothing todo. the body is already a transformed coroutine 734 // body statement. 735 return; 736 } 737 738 // Coroutines [stmt.return]p1: 739 // A return statement shall not appear in a coroutine. 740 if (Fn->FirstReturnLoc.isValid()) { 741 assert(Fn->FirstCoroutineStmtLoc.isValid() && 742 "first coroutine location not set"); 743 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 744 Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 745 << Fn->getFirstCoroutineStmtKeyword(); 746 } 747 CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body); 748 if (Builder.isInvalid() || !Builder.buildStatements()) 749 return FD->setInvalidDecl(); 750 751 // Build body for the coroutine wrapper statement. 752 Body = CoroutineBodyStmt::Create(Context, Builder); 753 } 754 755 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, 756 sema::FunctionScopeInfo &Fn, 757 Stmt *Body) 758 : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), 759 IsPromiseDependentType( 760 !Fn.CoroutinePromise || 761 Fn.CoroutinePromise->getType()->isDependentType()) { 762 this->Body = Body; 763 if (!IsPromiseDependentType) { 764 PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); 765 assert(PromiseRecordDecl && "Type should have already been checked"); 766 } 767 this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend(); 768 } 769 770 bool CoroutineStmtBuilder::buildStatements() { 771 assert(this->IsValid && "coroutine already invalid"); 772 this->IsValid = makeReturnObject() && makeParamMoves(); 773 if (this->IsValid && !IsPromiseDependentType) 774 buildDependentStatements(); 775 return this->IsValid; 776 } 777 778 bool CoroutineStmtBuilder::buildDependentStatements() { 779 assert(this->IsValid && "coroutine already invalid"); 780 assert(!this->IsPromiseDependentType && 781 "coroutine cannot have a dependent promise type"); 782 this->IsValid = makeOnException() && makeOnFallthrough() && 783 makeReturnOnAllocFailure() && makeNewAndDeleteExpr(); 784 return this->IsValid; 785 } 786 787 bool CoroutineStmtBuilder::makePromiseStmt() { 788 // Form a declaration statement for the promise declaration, so that AST 789 // visitors can more easily find it. 790 StmtResult PromiseStmt = 791 S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); 792 if (PromiseStmt.isInvalid()) 793 return false; 794 795 this->Promise = PromiseStmt.get(); 796 return true; 797 } 798 799 bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() { 800 if (Fn.hasInvalidCoroutineSuspends()) 801 return false; 802 this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first); 803 this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second); 804 return true; 805 } 806 807 static bool diagReturnOnAllocFailure(Sema &S, Expr *E, 808 CXXRecordDecl *PromiseRecordDecl, 809 FunctionScopeInfo &Fn) { 810 auto Loc = E->getExprLoc(); 811 if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) { 812 auto *Decl = DeclRef->getDecl(); 813 if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) { 814 if (Method->isStatic()) 815 return true; 816 else 817 Loc = Decl->getLocation(); 818 } 819 } 820 821 S.Diag( 822 Loc, 823 diag::err_coroutine_promise_get_return_object_on_allocation_failure) 824 << PromiseRecordDecl; 825 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 826 << Fn.getFirstCoroutineStmtKeyword(); 827 return false; 828 } 829 830 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() { 831 assert(!IsPromiseDependentType && 832 "cannot make statement while the promise type is dependent"); 833 834 // [dcl.fct.def.coroutine]/8 835 // The unqualified-id get_return_object_on_allocation_failure is looked up in 836 // the scope of class P by class member access lookup (3.4.5). ... 837 // If an allocation function returns nullptr, ... the coroutine return value 838 // is obtained by a call to ... get_return_object_on_allocation_failure(). 839 840 DeclarationName DN = 841 S.PP.getIdentifierInfo("get_return_object_on_allocation_failure"); 842 LookupResult Found(S, DN, Loc, Sema::LookupMemberName); 843 if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) 844 return true; 845 846 CXXScopeSpec SS; 847 ExprResult DeclNameExpr = 848 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 849 if (DeclNameExpr.isInvalid()) 850 return false; 851 852 if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)) 853 return false; 854 855 ExprResult ReturnObjectOnAllocationFailure = 856 S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc); 857 if (ReturnObjectOnAllocationFailure.isInvalid()) 858 return false; 859 860 // FIXME: ActOnReturnStmt expects a scope that is inside of the function, due 861 // to CheckJumpOutOfSEHFinally(*this, ReturnLoc, *CurScope->getFnParent()); 862 // S.getCurScope()->getFnParent() == nullptr at ActOnFinishFunctionBody when 863 // CoroutineBodyStmt is built. Figure it out and fix it. 864 // Use BuildReturnStmt here to unbreak sanitized tests. (Gor:3/27/2017) 865 StmtResult ReturnStmt = 866 S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get()); 867 if (ReturnStmt.isInvalid()) 868 return false; 869 870 this->ReturnStmtOnAllocFailure = ReturnStmt.get(); 871 return true; 872 } 873 874 bool CoroutineStmtBuilder::makeNewAndDeleteExpr() { 875 // Form and check allocation and deallocation calls. 876 assert(!IsPromiseDependentType && 877 "cannot make statement while the promise type is dependent"); 878 QualType PromiseType = Fn.CoroutinePromise->getType(); 879 880 if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) 881 return false; 882 883 const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr; 884 885 // FIXME: Add support for stateful allocators. 886 887 FunctionDecl *OperatorNew = nullptr; 888 FunctionDecl *OperatorDelete = nullptr; 889 FunctionDecl *UnusedResult = nullptr; 890 bool PassAlignment = false; 891 SmallVector<Expr *, 1> PlacementArgs; 892 893 S.FindAllocationFunctions(Loc, SourceRange(), 894 /*UseGlobal*/ false, PromiseType, 895 /*isArray*/ false, PassAlignment, PlacementArgs, 896 OperatorNew, UnusedResult); 897 898 bool IsGlobalOverload = 899 OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext()); 900 // If we didn't find a class-local new declaration and non-throwing new 901 // was is required then we need to lookup the non-throwing global operator 902 // instead. 903 if (RequiresNoThrowAlloc && (!OperatorNew || IsGlobalOverload)) { 904 auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc); 905 if (!StdNoThrow) 906 return false; 907 PlacementArgs = {StdNoThrow}; 908 OperatorNew = nullptr; 909 S.FindAllocationFunctions(Loc, SourceRange(), 910 /*UseGlobal*/ true, PromiseType, 911 /*isArray*/ false, PassAlignment, PlacementArgs, 912 OperatorNew, UnusedResult); 913 } 914 915 assert(OperatorNew && "expected definition of operator new to be found"); 916 917 if (RequiresNoThrowAlloc) { 918 const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>(); 919 if (!FT->isNothrow(S.Context, /*ResultIfDependent*/ false)) { 920 S.Diag(OperatorNew->getLocation(), 921 diag::err_coroutine_promise_new_requires_nothrow) 922 << OperatorNew; 923 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 924 << OperatorNew; 925 return false; 926 } 927 } 928 929 if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr) 930 return false; 931 932 Expr *FramePtr = 933 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 934 935 Expr *FrameSize = 936 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); 937 938 // Make new call. 939 940 ExprResult NewRef = 941 S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); 942 if (NewRef.isInvalid()) 943 return false; 944 945 SmallVector<Expr *, 2> NewArgs(1, FrameSize); 946 for (auto Arg : PlacementArgs) 947 NewArgs.push_back(Arg); 948 949 ExprResult NewExpr = 950 S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc); 951 NewExpr = S.ActOnFinishFullExpr(NewExpr.get()); 952 if (NewExpr.isInvalid()) 953 return false; 954 955 // Make delete call. 956 957 QualType OpDeleteQualType = OperatorDelete->getType(); 958 959 ExprResult DeleteRef = 960 S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc); 961 if (DeleteRef.isInvalid()) 962 return false; 963 964 Expr *CoroFree = 965 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); 966 967 SmallVector<Expr *, 2> DeleteArgs{CoroFree}; 968 969 // Check if we need to pass the size. 970 const auto *OpDeleteType = 971 OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>(); 972 if (OpDeleteType->getNumParams() > 1) 973 DeleteArgs.push_back(FrameSize); 974 975 ExprResult DeleteExpr = 976 S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc); 977 DeleteExpr = S.ActOnFinishFullExpr(DeleteExpr.get()); 978 if (DeleteExpr.isInvalid()) 979 return false; 980 981 this->Allocate = NewExpr.get(); 982 this->Deallocate = DeleteExpr.get(); 983 984 return true; 985 } 986 987 bool CoroutineStmtBuilder::makeOnFallthrough() { 988 assert(!IsPromiseDependentType && 989 "cannot make statement while the promise type is dependent"); 990 991 // [dcl.fct.def.coroutine]/4 992 // The unqualified-ids 'return_void' and 'return_value' are looked up in 993 // the scope of class P. If both are found, the program is ill-formed. 994 const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc); 995 const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc); 996 997 StmtResult Fallthrough; 998 if (HasRVoid && HasRValue) { 999 // FIXME Improve this diagnostic 1000 S.Diag(FD.getLocation(), diag::err_coroutine_promise_return_ill_formed) 1001 << PromiseRecordDecl; 1002 return false; 1003 } else if (HasRVoid) { 1004 // If the unqualified-id return_void is found, flowing off the end of a 1005 // coroutine is equivalent to a co_return with no operand. Otherwise, 1006 // flowing off the end of a coroutine results in undefined behavior. 1007 Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, 1008 /*IsImplicit*/false); 1009 Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); 1010 if (Fallthrough.isInvalid()) 1011 return false; 1012 } 1013 1014 this->OnFallthrough = Fallthrough.get(); 1015 return true; 1016 } 1017 1018 bool CoroutineStmtBuilder::makeOnException() { 1019 // Try to form 'p.unhandled_exception();' 1020 assert(!IsPromiseDependentType && 1021 "cannot make statement while the promise type is dependent"); 1022 1023 const bool RequireUnhandledException = S.getLangOpts().CXXExceptions; 1024 1025 if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) { 1026 auto DiagID = 1027 RequireUnhandledException 1028 ? diag::err_coroutine_promise_unhandled_exception_required 1029 : diag:: 1030 warn_coroutine_promise_unhandled_exception_required_with_exceptions; 1031 S.Diag(Loc, DiagID) << PromiseRecordDecl; 1032 return !RequireUnhandledException; 1033 } 1034 1035 // If exceptions are disabled, don't try to build OnException. 1036 if (!S.getLangOpts().CXXExceptions) 1037 return true; 1038 1039 ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, 1040 "unhandled_exception", None); 1041 UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc); 1042 if (UnhandledException.isInvalid()) 1043 return false; 1044 1045 this->OnException = UnhandledException.get(); 1046 return true; 1047 } 1048 1049 bool CoroutineStmtBuilder::makeReturnObject() { 1050 1051 // Build implicit 'p.get_return_object()' expression and form initialization 1052 // of return type from it. 1053 ExprResult ReturnObject = 1054 buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); 1055 if (ReturnObject.isInvalid()) 1056 return false; 1057 QualType RetType = FD.getReturnType(); 1058 if (!RetType->isDependentType()) { 1059 InitializedEntity Entity = 1060 InitializedEntity::InitializeResult(Loc, RetType, false); 1061 ReturnObject = S.PerformMoveOrCopyInitialization(Entity, nullptr, RetType, 1062 ReturnObject.get()); 1063 if (ReturnObject.isInvalid()) 1064 return false; 1065 } 1066 ReturnObject = S.ActOnFinishFullExpr(ReturnObject.get(), Loc); 1067 if (ReturnObject.isInvalid()) 1068 return false; 1069 1070 this->ReturnValue = ReturnObject.get(); 1071 return true; 1072 } 1073 1074 bool CoroutineStmtBuilder::makeParamMoves() { 1075 // FIXME: Perform move-initialization of parameters into frame-local copies. 1076 return true; 1077 } 1078 1079 StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { 1080 CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); 1081 if (!Res) 1082 return StmtError(); 1083 return Res; 1084 } 1085