1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements semantic analysis for C++ Coroutines. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CoroutineStmtBuilder.h" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/ExprCXX.h" 17 #include "clang/AST/StmtCXX.h" 18 #include "clang/Lex/Preprocessor.h" 19 #include "clang/Sema/Initialization.h" 20 #include "clang/Sema/Overload.h" 21 #include "clang/Sema/SemaInternal.h" 22 23 using namespace clang; 24 using namespace sema; 25 26 static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, 27 SourceLocation Loc, bool &Res) { 28 DeclarationName DN = S.PP.getIdentifierInfo(Name); 29 LookupResult LR(S, DN, Loc, Sema::LookupMemberName); 30 // Suppress diagnostics when a private member is selected. The same warnings 31 // will be produced again when building the call. 32 LR.suppressDiagnostics(); 33 Res = S.LookupQualifiedName(LR, RD); 34 return LR; 35 } 36 37 static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, 38 SourceLocation Loc) { 39 bool Res; 40 lookupMember(S, Name, RD, Loc, Res); 41 return Res; 42 } 43 44 /// Look up the std::coroutine_traits<...>::promise_type for the given 45 /// function type. 46 static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD, 47 SourceLocation KwLoc) { 48 const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>(); 49 const SourceLocation FuncLoc = FD->getLocation(); 50 // FIXME: Cache std::coroutine_traits once we've found it. 51 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 52 if (!StdExp) { 53 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 54 << "std::experimental::coroutine_traits"; 55 return QualType(); 56 } 57 58 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), 59 FuncLoc, Sema::LookupOrdinaryName); 60 if (!S.LookupQualifiedName(Result, StdExp)) { 61 S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) 62 << "std::experimental::coroutine_traits"; 63 return QualType(); 64 } 65 66 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>(); 67 if (!CoroTraits) { 68 Result.suppressDiagnostics(); 69 // We found something weird. Complain about the first thing we found. 70 NamedDecl *Found = *Result.begin(); 71 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 72 return QualType(); 73 } 74 75 // Form template argument list for coroutine_traits<R, P1, P2, ...> according 76 // to [dcl.fct.def.coroutine]3 77 TemplateArgumentListInfo Args(KwLoc, KwLoc); 78 auto AddArg = [&](QualType T) { 79 Args.addArgument(TemplateArgumentLoc( 80 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); 81 }; 82 AddArg(FnType->getReturnType()); 83 // If the function is a non-static member function, add the type 84 // of the implicit object parameter before the formal parameters. 85 if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) { 86 if (MD->isInstance()) { 87 // [over.match.funcs]4 88 // For non-static member functions, the type of the implicit object 89 // parameter is 90 // -- "lvalue reference to cv X" for functions declared without a 91 // ref-qualifier or with the & ref-qualifier 92 // -- "rvalue reference to cv X" for functions declared with the && 93 // ref-qualifier 94 QualType T = 95 MD->getThisType(S.Context)->getAs<PointerType>()->getPointeeType(); 96 T = FnType->getRefQualifier() == RQ_RValue 97 ? S.Context.getRValueReferenceType(T) 98 : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true); 99 AddArg(T); 100 } 101 } 102 for (QualType T : FnType->getParamTypes()) 103 AddArg(T); 104 105 // Build the template-id. 106 QualType CoroTrait = 107 S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); 108 if (CoroTrait.isNull()) 109 return QualType(); 110 if (S.RequireCompleteType(KwLoc, CoroTrait, 111 diag::err_coroutine_type_missing_specialization)) 112 return QualType(); 113 114 auto *RD = CoroTrait->getAsCXXRecordDecl(); 115 assert(RD && "specialization of class template is not a class?"); 116 117 // Look up the ::promise_type member. 118 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, 119 Sema::LookupOrdinaryName); 120 S.LookupQualifiedName(R, RD); 121 auto *Promise = R.getAsSingle<TypeDecl>(); 122 if (!Promise) { 123 S.Diag(FuncLoc, 124 diag::err_implied_std_coroutine_traits_promise_type_not_found) 125 << RD; 126 return QualType(); 127 } 128 // The promise type is required to be a class type. 129 QualType PromiseType = S.Context.getTypeDeclType(Promise); 130 131 auto buildElaboratedType = [&]() { 132 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); 133 NNS = NestedNameSpecifier::Create(S.Context, NNS, false, 134 CoroTrait.getTypePtr()); 135 return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); 136 }; 137 138 if (!PromiseType->getAsCXXRecordDecl()) { 139 S.Diag(FuncLoc, 140 diag::err_implied_std_coroutine_traits_promise_type_not_class) 141 << buildElaboratedType(); 142 return QualType(); 143 } 144 if (S.RequireCompleteType(FuncLoc, buildElaboratedType(), 145 diag::err_coroutine_promise_type_incomplete)) 146 return QualType(); 147 148 return PromiseType; 149 } 150 151 /// Look up the std::experimental::coroutine_handle<PromiseType>. 152 static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, 153 SourceLocation Loc) { 154 if (PromiseType.isNull()) 155 return QualType(); 156 157 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 158 assert(StdExp && "Should already be diagnosed"); 159 160 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"), 161 Loc, Sema::LookupOrdinaryName); 162 if (!S.LookupQualifiedName(Result, StdExp)) { 163 S.Diag(Loc, diag::err_implied_coroutine_type_not_found) 164 << "std::experimental::coroutine_handle"; 165 return QualType(); 166 } 167 168 ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); 169 if (!CoroHandle) { 170 Result.suppressDiagnostics(); 171 // We found something weird. Complain about the first thing we found. 172 NamedDecl *Found = *Result.begin(); 173 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); 174 return QualType(); 175 } 176 177 // Form template argument list for coroutine_handle<Promise>. 178 TemplateArgumentListInfo Args(Loc, Loc); 179 Args.addArgument(TemplateArgumentLoc( 180 TemplateArgument(PromiseType), 181 S.Context.getTrivialTypeSourceInfo(PromiseType, Loc))); 182 183 // Build the template-id. 184 QualType CoroHandleType = 185 S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args); 186 if (CoroHandleType.isNull()) 187 return QualType(); 188 if (S.RequireCompleteType(Loc, CoroHandleType, 189 diag::err_coroutine_type_missing_specialization)) 190 return QualType(); 191 192 return CoroHandleType; 193 } 194 195 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, 196 StringRef Keyword) { 197 // 'co_await' and 'co_yield' are not permitted in unevaluated operands. 198 if (S.isUnevaluatedContext()) { 199 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword; 200 return false; 201 } 202 203 // Any other usage must be within a function. 204 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 205 if (!FD) { 206 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 207 ? diag::err_coroutine_objc_method 208 : diag::err_coroutine_outside_function) << Keyword; 209 return false; 210 } 211 212 // An enumeration for mapping the diagnostic type to the correct diagnostic 213 // selection index. 214 enum InvalidFuncDiag { 215 DiagCtor = 0, 216 DiagDtor, 217 DiagCopyAssign, 218 DiagMoveAssign, 219 DiagMain, 220 DiagConstexpr, 221 DiagAutoRet, 222 DiagVarargs, 223 }; 224 bool Diagnosed = false; 225 auto DiagInvalid = [&](InvalidFuncDiag ID) { 226 S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword; 227 Diagnosed = true; 228 return false; 229 }; 230 231 // Diagnose when a constructor, destructor, copy/move assignment operator, 232 // or the function 'main' are declared as a coroutine. 233 auto *MD = dyn_cast<CXXMethodDecl>(FD); 234 if (MD && isa<CXXConstructorDecl>(MD)) 235 return DiagInvalid(DiagCtor); 236 else if (MD && isa<CXXDestructorDecl>(MD)) 237 return DiagInvalid(DiagDtor); 238 else if (MD && MD->isCopyAssignmentOperator()) 239 return DiagInvalid(DiagCopyAssign); 240 else if (MD && MD->isMoveAssignmentOperator()) 241 return DiagInvalid(DiagMoveAssign); 242 else if (FD->isMain()) 243 return DiagInvalid(DiagMain); 244 245 // Emit a diagnostics for each of the following conditions which is not met. 246 if (FD->isConstexpr()) 247 DiagInvalid(DiagConstexpr); 248 if (FD->getReturnType()->isUndeducedType()) 249 DiagInvalid(DiagAutoRet); 250 if (FD->isVariadic()) 251 DiagInvalid(DiagVarargs); 252 253 return !Diagnosed; 254 } 255 256 static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, 257 SourceLocation Loc) { 258 DeclarationName OpName = 259 SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); 260 LookupResult Operators(SemaRef, OpName, SourceLocation(), 261 Sema::LookupOperatorName); 262 SemaRef.LookupName(Operators, S); 263 264 assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); 265 const auto &Functions = Operators.asUnresolvedSet(); 266 bool IsOverloaded = 267 Functions.size() > 1 || 268 (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); 269 Expr *CoawaitOp = UnresolvedLookupExpr::Create( 270 SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), 271 DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, 272 Functions.begin(), Functions.end()); 273 assert(CoawaitOp); 274 return CoawaitOp; 275 } 276 277 /// Build a call to 'operator co_await' if there is a suitable operator for 278 /// the given expression. 279 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, 280 Expr *E, 281 UnresolvedLookupExpr *Lookup) { 282 UnresolvedSet<16> Functions; 283 Functions.append(Lookup->decls_begin(), Lookup->decls_end()); 284 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 285 } 286 287 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 288 SourceLocation Loc, Expr *E) { 289 ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); 290 if (R.isInvalid()) 291 return ExprError(); 292 return buildOperatorCoawaitCall(SemaRef, Loc, E, 293 cast<UnresolvedLookupExpr>(R.get())); 294 } 295 296 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, 297 MultiExprArg CallArgs) { 298 StringRef Name = S.Context.BuiltinInfo.getName(Id); 299 LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); 300 S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); 301 302 auto *BuiltInDecl = R.getAsSingle<FunctionDecl>(); 303 assert(BuiltInDecl && "failed to find builtin declaration"); 304 305 ExprResult DeclRef = 306 S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc); 307 assert(DeclRef.isUsable() && "Builtin reference cannot fail"); 308 309 ExprResult Call = 310 S.ActOnCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc); 311 312 assert(!Call.isInvalid() && "Call to builtin cannot fail!"); 313 return Call.get(); 314 } 315 316 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType, 317 SourceLocation Loc) { 318 QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc); 319 if (CoroHandleType.isNull()) 320 return ExprError(); 321 322 DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType); 323 LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc, 324 Sema::LookupOrdinaryName); 325 if (!S.LookupQualifiedName(Found, LookupCtx)) { 326 S.Diag(Loc, diag::err_coroutine_handle_missing_member) 327 << "from_address"; 328 return ExprError(); 329 } 330 331 Expr *FramePtr = 332 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 333 334 CXXScopeSpec SS; 335 ExprResult FromAddr = 336 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 337 if (FromAddr.isInvalid()) 338 return ExprError(); 339 340 return S.ActOnCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc); 341 } 342 343 struct ReadySuspendResumeResult { 344 enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume }; 345 Expr *Results[3]; 346 OpaqueValueExpr *OpaqueValue; 347 bool IsInvalid; 348 }; 349 350 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, 351 StringRef Name, MultiExprArg Args) { 352 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); 353 354 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 355 CXXScopeSpec SS; 356 ExprResult Result = S.BuildMemberReferenceExpr( 357 Base, Base->getType(), Loc, /*IsPtr=*/false, SS, 358 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 359 /*Scope=*/nullptr); 360 if (Result.isInvalid()) 361 return ExprError(); 362 363 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); 364 } 365 366 // See if return type is coroutine-handle and if so, invoke builtin coro-resume 367 // on its address. This is to enable experimental support for coroutine-handle 368 // returning await_suspend that results in a guranteed tail call to the target 369 // coroutine. 370 static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, 371 SourceLocation Loc) { 372 if (RetType->isReferenceType()) 373 return nullptr; 374 Type const *T = RetType.getTypePtr(); 375 if (!T->isClassType() && !T->isStructureType()) 376 return nullptr; 377 378 // FIXME: Add convertability check to coroutine_handle<>. Possibly via 379 // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment 380 // a private function in SemaExprCXX.cpp 381 382 ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None); 383 if (AddressExpr.isInvalid()) 384 return nullptr; 385 386 Expr *JustAddress = AddressExpr.get(); 387 // FIXME: Check that the type of AddressExpr is void* 388 return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume, 389 JustAddress); 390 } 391 392 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 393 /// expression. 394 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, 395 SourceLocation Loc, Expr *E) { 396 OpaqueValueExpr *Operand = new (S.Context) 397 OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E); 398 399 // Assume invalid until we see otherwise. 400 ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true}; 401 402 ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc); 403 if (CoroHandleRes.isInvalid()) 404 return Calls; 405 Expr *CoroHandle = CoroHandleRes.get(); 406 407 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 408 MultiExprArg Args[] = {None, CoroHandle, None}; 409 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 410 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]); 411 if (Result.isInvalid()) 412 return Calls; 413 Calls.Results[I] = Result.get(); 414 } 415 416 // Assume the calls are valid; all further checking should make them invalid. 417 Calls.IsInvalid = false; 418 419 using ACT = ReadySuspendResumeResult::AwaitCallType; 420 CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]); 421 if (!AwaitReady->getType()->isDependentType()) { 422 // [expr.await]p3 [...] 423 // — await-ready is the expression e.await_ready(), contextually converted 424 // to bool. 425 ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady); 426 if (Conv.isInvalid()) { 427 S.Diag(AwaitReady->getDirectCallee()->getLocStart(), 428 diag::note_await_ready_no_bool_conversion); 429 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 430 << AwaitReady->getDirectCallee() << E->getSourceRange(); 431 Calls.IsInvalid = true; 432 } 433 Calls.Results[ACT::ACT_Ready] = Conv.get(); 434 } 435 CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]); 436 if (!AwaitSuspend->getType()->isDependentType()) { 437 // [expr.await]p3 [...] 438 // - await-suspend is the expression e.await_suspend(h), which shall be 439 // a prvalue of type void or bool. 440 QualType RetType = AwaitSuspend->getCallReturnType(S.Context); 441 442 // Experimental support for coroutine_handle returning await_suspend. 443 if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc)) 444 Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; 445 else { 446 // non-class prvalues always have cv-unqualified types 447 if (RetType->isReferenceType() || 448 (!RetType->isBooleanType() && !RetType->isVoidType())) { 449 S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), 450 diag::err_await_suspend_invalid_return_type) 451 << RetType; 452 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 453 << AwaitSuspend->getDirectCallee(); 454 Calls.IsInvalid = true; 455 } 456 } 457 } 458 459 return Calls; 460 } 461 462 static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, 463 SourceLocation Loc, StringRef Name, 464 MultiExprArg Args) { 465 466 // Form a reference to the promise. 467 ExprResult PromiseRef = S.BuildDeclRefExpr( 468 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); 469 if (PromiseRef.isInvalid()) 470 return ExprError(); 471 472 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); 473 } 474 475 VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { 476 assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); 477 auto *FD = cast<FunctionDecl>(CurContext); 478 bool IsThisDependentType = [&] { 479 if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD)) 480 return MD->isInstance() && MD->getThisType(Context)->isDependentType(); 481 else 482 return false; 483 }(); 484 485 QualType T = FD->getType()->isDependentType() || IsThisDependentType 486 ? Context.DependentTy 487 : lookupPromiseType(*this, FD, Loc); 488 if (T.isNull()) 489 return nullptr; 490 491 auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), 492 &PP.getIdentifierTable().get("__promise"), T, 493 Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 494 CheckVariableDeclarationType(VD); 495 if (VD->isInvalidDecl()) 496 return nullptr; 497 498 auto *ScopeInfo = getCurFunction(); 499 // Build a list of arguments, based on the coroutine functions arguments, 500 // that will be passed to the promise type's constructor. 501 llvm::SmallVector<Expr *, 4> CtorArgExprs; 502 auto &Moves = ScopeInfo->CoroutineParameterMoves; 503 for (auto *PD : FD->parameters()) { 504 if (PD->getType()->isDependentType()) 505 continue; 506 507 auto RefExpr = ExprEmpty(); 508 auto Move = Moves.find(PD); 509 if (Move != Moves.end()) { 510 // If a reference to the function parameter exists in the coroutine 511 // frame, use that reference. 512 auto *MoveDecl = 513 cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl()); 514 RefExpr = BuildDeclRefExpr(MoveDecl, MoveDecl->getType(), 515 ExprValueKind::VK_LValue, FD->getLocation()); 516 } else { 517 // If the function parameter doesn't exist in the coroutine frame, it 518 // must be a scalar value. Use it directly. 519 assert(!PD->getType()->getAsCXXRecordDecl() && 520 "Non-scalar types should have been moved and inserted into the " 521 "parameter moves map"); 522 RefExpr = 523 BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), 524 ExprValueKind::VK_LValue, FD->getLocation()); 525 } 526 527 if (RefExpr.isInvalid()) 528 return nullptr; 529 CtorArgExprs.push_back(RefExpr.get()); 530 } 531 532 // Create an initialization sequence for the promise type using the 533 // constructor arguments, wrapped in a parenthesized list expression. 534 Expr *PLE = new (Context) ParenListExpr(Context, FD->getLocation(), 535 CtorArgExprs, FD->getLocation()); 536 InitializedEntity Entity = InitializedEntity::InitializeVariable(VD); 537 InitializationKind Kind = InitializationKind::CreateForInit( 538 VD->getLocation(), /*DirectInit=*/true, PLE); 539 InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs, 540 /*TopLevelOfInitList=*/false, 541 /*TreatUnavailableAsInvalid=*/false); 542 543 // Attempt to initialize the promise type with the arguments. 544 // If that fails, fall back to the promise type's default constructor. 545 if (InitSeq) { 546 ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs); 547 if (Result.isInvalid()) { 548 VD->setInvalidDecl(); 549 } else if (Result.get()) { 550 VD->setInit(MaybeCreateExprWithCleanups(Result.get())); 551 VD->setInitStyle(VarDecl::CallInit); 552 CheckCompleteVariableDeclaration(VD); 553 } 554 } else 555 ActOnUninitializedDecl(VD); 556 557 FD->addDecl(VD); 558 return VD; 559 } 560 561 /// Check that this is a context in which a coroutine suspension can appear. 562 static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, 563 StringRef Keyword, 564 bool IsImplicit = false) { 565 if (!isValidCoroutineContext(S, Loc, Keyword)) 566 return nullptr; 567 568 assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); 569 570 auto *ScopeInfo = S.getCurFunction(); 571 assert(ScopeInfo && "missing function scope for function"); 572 573 if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit) 574 ScopeInfo->setFirstCoroutineStmt(Loc, Keyword); 575 576 if (ScopeInfo->CoroutinePromise) 577 return ScopeInfo; 578 579 if (!S.buildCoroutineParameterMoves(Loc)) 580 return nullptr; 581 582 ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); 583 if (!ScopeInfo->CoroutinePromise) 584 return nullptr; 585 586 return ScopeInfo; 587 } 588 589 bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc, 590 StringRef Keyword) { 591 if (!checkCoroutineContext(*this, KWLoc, Keyword)) 592 return false; 593 auto *ScopeInfo = getCurFunction(); 594 assert(ScopeInfo->CoroutinePromise); 595 596 // If we have existing coroutine statements then we have already built 597 // the initial and final suspend points. 598 if (!ScopeInfo->NeedsCoroutineSuspends) 599 return true; 600 601 ScopeInfo->setNeedsCoroutineSuspends(false); 602 603 auto *Fn = cast<FunctionDecl>(CurContext); 604 SourceLocation Loc = Fn->getLocation(); 605 // Build the initial suspend point 606 auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { 607 ExprResult Suspend = 608 buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None); 609 if (Suspend.isInvalid()) 610 return StmtError(); 611 Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get()); 612 if (Suspend.isInvalid()) 613 return StmtError(); 614 Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(), 615 /*IsImplicit*/ true); 616 Suspend = ActOnFinishFullExpr(Suspend.get()); 617 if (Suspend.isInvalid()) { 618 Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required) 619 << ((Name == "initial_suspend") ? 0 : 1); 620 Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; 621 return StmtError(); 622 } 623 return cast<Stmt>(Suspend.get()); 624 }; 625 626 StmtResult InitSuspend = buildSuspends("initial_suspend"); 627 if (InitSuspend.isInvalid()) 628 return true; 629 630 StmtResult FinalSuspend = buildSuspends("final_suspend"); 631 if (FinalSuspend.isInvalid()) 632 return true; 633 634 ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); 635 636 return true; 637 } 638 639 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 640 if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) { 641 CorrectDelayedTyposInExpr(E); 642 return ExprError(); 643 } 644 645 if (E->getType()->isPlaceholderType()) { 646 ExprResult R = CheckPlaceholderExpr(E); 647 if (R.isInvalid()) return ExprError(); 648 E = R.get(); 649 } 650 ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); 651 if (Lookup.isInvalid()) 652 return ExprError(); 653 return BuildUnresolvedCoawaitExpr(Loc, E, 654 cast<UnresolvedLookupExpr>(Lookup.get())); 655 } 656 657 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, 658 UnresolvedLookupExpr *Lookup) { 659 auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); 660 if (!FSI) 661 return ExprError(); 662 663 if (E->getType()->isPlaceholderType()) { 664 ExprResult R = CheckPlaceholderExpr(E); 665 if (R.isInvalid()) 666 return ExprError(); 667 E = R.get(); 668 } 669 670 auto *Promise = FSI->CoroutinePromise; 671 if (Promise->getType()->isDependentType()) { 672 Expr *Res = 673 new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); 674 return Res; 675 } 676 677 auto *RD = Promise->getType()->getAsCXXRecordDecl(); 678 if (lookupMember(*this, "await_transform", RD, Loc)) { 679 ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); 680 if (R.isInvalid()) { 681 Diag(Loc, 682 diag::note_coroutine_promise_implicit_await_transform_required_here) 683 << E->getSourceRange(); 684 return ExprError(); 685 } 686 E = R.get(); 687 } 688 ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); 689 if (Awaitable.isInvalid()) 690 return ExprError(); 691 692 return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); 693 } 694 695 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, 696 bool IsImplicit) { 697 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit); 698 if (!Coroutine) 699 return ExprError(); 700 701 if (E->getType()->isPlaceholderType()) { 702 ExprResult R = CheckPlaceholderExpr(E); 703 if (R.isInvalid()) return ExprError(); 704 E = R.get(); 705 } 706 707 if (E->getType()->isDependentType()) { 708 Expr *Res = new (Context) 709 CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); 710 return Res; 711 } 712 713 // If the expression is a temporary, materialize it as an lvalue so that we 714 // can use it multiple times. 715 if (E->getValueKind() == VK_RValue) 716 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 717 718 // Build the await_ready, await_suspend, await_resume calls. 719 ReadySuspendResumeResult RSS = 720 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 721 if (RSS.IsInvalid) 722 return ExprError(); 723 724 Expr *Res = 725 new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 726 RSS.Results[2], RSS.OpaqueValue, IsImplicit); 727 728 return Res; 729 } 730 731 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 732 if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) { 733 CorrectDelayedTyposInExpr(E); 734 return ExprError(); 735 } 736 737 // Build yield_value call. 738 ExprResult Awaitable = buildPromiseCall( 739 *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); 740 if (Awaitable.isInvalid()) 741 return ExprError(); 742 743 // Build 'operator co_await' call. 744 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get()); 745 if (Awaitable.isInvalid()) 746 return ExprError(); 747 748 return BuildCoyieldExpr(Loc, Awaitable.get()); 749 } 750 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 751 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 752 if (!Coroutine) 753 return ExprError(); 754 755 if (E->getType()->isPlaceholderType()) { 756 ExprResult R = CheckPlaceholderExpr(E); 757 if (R.isInvalid()) return ExprError(); 758 E = R.get(); 759 } 760 761 if (E->getType()->isDependentType()) { 762 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); 763 return Res; 764 } 765 766 // If the expression is a temporary, materialize it as an lvalue so that we 767 // can use it multiple times. 768 if (E->getValueKind() == VK_RValue) 769 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 770 771 // Build the await_ready, await_suspend, await_resume calls. 772 ReadySuspendResumeResult RSS = 773 buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); 774 if (RSS.IsInvalid) 775 return ExprError(); 776 777 Expr *Res = 778 new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], 779 RSS.Results[2], RSS.OpaqueValue); 780 781 return Res; 782 } 783 784 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { 785 if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) { 786 CorrectDelayedTyposInExpr(E); 787 return StmtError(); 788 } 789 return BuildCoreturnStmt(Loc, E); 790 } 791 792 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, 793 bool IsImplicit) { 794 auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit); 795 if (!FSI) 796 return StmtError(); 797 798 if (E && E->getType()->isPlaceholderType() && 799 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { 800 ExprResult R = CheckPlaceholderExpr(E); 801 if (R.isInvalid()) return StmtError(); 802 E = R.get(); 803 } 804 805 // FIXME: If the operand is a reference to a variable that's about to go out 806 // of scope, we should treat the operand as an xvalue for this overload 807 // resolution. 808 VarDecl *Promise = FSI->CoroutinePromise; 809 ExprResult PC; 810 if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { 811 PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); 812 } else { 813 E = MakeFullDiscardedValueExpr(E).get(); 814 PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); 815 } 816 if (PC.isInvalid()) 817 return StmtError(); 818 819 Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); 820 821 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); 822 return Res; 823 } 824 825 /// Look up the std::nothrow object. 826 static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) { 827 NamespaceDecl *Std = S.getStdNamespace(); 828 assert(Std && "Should already be diagnosed"); 829 830 LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc, 831 Sema::LookupOrdinaryName); 832 if (!S.LookupQualifiedName(Result, Std)) { 833 // FIXME: <experimental/coroutine> should have been included already. 834 // If we require it to include <new> then this diagnostic is no longer 835 // needed. 836 S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found); 837 return nullptr; 838 } 839 840 auto *VD = Result.getAsSingle<VarDecl>(); 841 if (!VD) { 842 Result.suppressDiagnostics(); 843 // We found something weird. Complain about the first thing we found. 844 NamedDecl *Found = *Result.begin(); 845 S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow); 846 return nullptr; 847 } 848 849 ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc); 850 if (DR.isInvalid()) 851 return nullptr; 852 853 return DR.get(); 854 } 855 856 // Find an appropriate delete for the promise. 857 static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, 858 QualType PromiseType) { 859 FunctionDecl *OperatorDelete = nullptr; 860 861 DeclarationName DeleteName = 862 S.Context.DeclarationNames.getCXXOperatorName(OO_Delete); 863 864 auto *PointeeRD = PromiseType->getAsCXXRecordDecl(); 865 assert(PointeeRD && "PromiseType must be a CxxRecordDecl type"); 866 867 if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete)) 868 return nullptr; 869 870 if (!OperatorDelete) { 871 // Look for a global declaration. 872 const bool CanProvideSize = S.isCompleteType(Loc, PromiseType); 873 const bool Overaligned = false; 874 OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize, 875 Overaligned, DeleteName); 876 } 877 S.MarkFunctionReferenced(Loc, OperatorDelete); 878 return OperatorDelete; 879 } 880 881 882 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { 883 FunctionScopeInfo *Fn = getCurFunction(); 884 assert(Fn && Fn->isCoroutine() && "not a coroutine"); 885 if (!Body) { 886 assert(FD->isInvalidDecl() && 887 "a null body is only allowed for invalid declarations"); 888 return; 889 } 890 // We have a function that uses coroutine keywords, but we failed to build 891 // the promise type. 892 if (!Fn->CoroutinePromise) 893 return FD->setInvalidDecl(); 894 895 if (isa<CoroutineBodyStmt>(Body)) { 896 // Nothing todo. the body is already a transformed coroutine body statement. 897 return; 898 } 899 900 // Coroutines [stmt.return]p1: 901 // A return statement shall not appear in a coroutine. 902 if (Fn->FirstReturnLoc.isValid()) { 903 assert(Fn->FirstCoroutineStmtLoc.isValid() && 904 "first coroutine location not set"); 905 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 906 Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 907 << Fn->getFirstCoroutineStmtKeyword(); 908 } 909 CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body); 910 if (Builder.isInvalid() || !Builder.buildStatements()) 911 return FD->setInvalidDecl(); 912 913 // Build body for the coroutine wrapper statement. 914 Body = CoroutineBodyStmt::Create(Context, Builder); 915 } 916 917 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, 918 sema::FunctionScopeInfo &Fn, 919 Stmt *Body) 920 : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), 921 IsPromiseDependentType( 922 !Fn.CoroutinePromise || 923 Fn.CoroutinePromise->getType()->isDependentType()) { 924 this->Body = Body; 925 926 for (auto KV : Fn.CoroutineParameterMoves) 927 this->ParamMovesVector.push_back(KV.second); 928 this->ParamMoves = this->ParamMovesVector; 929 930 if (!IsPromiseDependentType) { 931 PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); 932 assert(PromiseRecordDecl && "Type should have already been checked"); 933 } 934 this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend(); 935 } 936 937 bool CoroutineStmtBuilder::buildStatements() { 938 assert(this->IsValid && "coroutine already invalid"); 939 this->IsValid = makeReturnObject(); 940 if (this->IsValid && !IsPromiseDependentType) 941 buildDependentStatements(); 942 return this->IsValid; 943 } 944 945 bool CoroutineStmtBuilder::buildDependentStatements() { 946 assert(this->IsValid && "coroutine already invalid"); 947 assert(!this->IsPromiseDependentType && 948 "coroutine cannot have a dependent promise type"); 949 this->IsValid = makeOnException() && makeOnFallthrough() && 950 makeGroDeclAndReturnStmt() && makeReturnOnAllocFailure() && 951 makeNewAndDeleteExpr(); 952 return this->IsValid; 953 } 954 955 bool CoroutineStmtBuilder::makePromiseStmt() { 956 // Form a declaration statement for the promise declaration, so that AST 957 // visitors can more easily find it. 958 StmtResult PromiseStmt = 959 S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); 960 if (PromiseStmt.isInvalid()) 961 return false; 962 963 this->Promise = PromiseStmt.get(); 964 return true; 965 } 966 967 bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() { 968 if (Fn.hasInvalidCoroutineSuspends()) 969 return false; 970 this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first); 971 this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second); 972 return true; 973 } 974 975 static bool diagReturnOnAllocFailure(Sema &S, Expr *E, 976 CXXRecordDecl *PromiseRecordDecl, 977 FunctionScopeInfo &Fn) { 978 auto Loc = E->getExprLoc(); 979 if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) { 980 auto *Decl = DeclRef->getDecl(); 981 if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) { 982 if (Method->isStatic()) 983 return true; 984 else 985 Loc = Decl->getLocation(); 986 } 987 } 988 989 S.Diag( 990 Loc, 991 diag::err_coroutine_promise_get_return_object_on_allocation_failure) 992 << PromiseRecordDecl; 993 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 994 << Fn.getFirstCoroutineStmtKeyword(); 995 return false; 996 } 997 998 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() { 999 assert(!IsPromiseDependentType && 1000 "cannot make statement while the promise type is dependent"); 1001 1002 // [dcl.fct.def.coroutine]/8 1003 // The unqualified-id get_return_object_on_allocation_failure is looked up in 1004 // the scope of class P by class member access lookup (3.4.5). ... 1005 // If an allocation function returns nullptr, ... the coroutine return value 1006 // is obtained by a call to ... get_return_object_on_allocation_failure(). 1007 1008 DeclarationName DN = 1009 S.PP.getIdentifierInfo("get_return_object_on_allocation_failure"); 1010 LookupResult Found(S, DN, Loc, Sema::LookupMemberName); 1011 if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) 1012 return true; 1013 1014 CXXScopeSpec SS; 1015 ExprResult DeclNameExpr = 1016 S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); 1017 if (DeclNameExpr.isInvalid()) 1018 return false; 1019 1020 if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)) 1021 return false; 1022 1023 ExprResult ReturnObjectOnAllocationFailure = 1024 S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc); 1025 if (ReturnObjectOnAllocationFailure.isInvalid()) 1026 return false; 1027 1028 StmtResult ReturnStmt = 1029 S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get()); 1030 if (ReturnStmt.isInvalid()) { 1031 S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here) 1032 << DN; 1033 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1034 << Fn.getFirstCoroutineStmtKeyword(); 1035 return false; 1036 } 1037 1038 this->ReturnStmtOnAllocFailure = ReturnStmt.get(); 1039 return true; 1040 } 1041 1042 bool CoroutineStmtBuilder::makeNewAndDeleteExpr() { 1043 // Form and check allocation and deallocation calls. 1044 assert(!IsPromiseDependentType && 1045 "cannot make statement while the promise type is dependent"); 1046 QualType PromiseType = Fn.CoroutinePromise->getType(); 1047 1048 if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) 1049 return false; 1050 1051 const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr; 1052 1053 // FIXME: Add support for stateful allocators. 1054 1055 FunctionDecl *OperatorNew = nullptr; 1056 FunctionDecl *OperatorDelete = nullptr; 1057 FunctionDecl *UnusedResult = nullptr; 1058 bool PassAlignment = false; 1059 SmallVector<Expr *, 1> PlacementArgs; 1060 1061 S.FindAllocationFunctions(Loc, SourceRange(), 1062 /*UseGlobal*/ false, PromiseType, 1063 /*isArray*/ false, PassAlignment, PlacementArgs, 1064 OperatorNew, UnusedResult); 1065 1066 bool IsGlobalOverload = 1067 OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext()); 1068 // If we didn't find a class-local new declaration and non-throwing new 1069 // was is required then we need to lookup the non-throwing global operator 1070 // instead. 1071 if (RequiresNoThrowAlloc && (!OperatorNew || IsGlobalOverload)) { 1072 auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc); 1073 if (!StdNoThrow) 1074 return false; 1075 PlacementArgs = {StdNoThrow}; 1076 OperatorNew = nullptr; 1077 S.FindAllocationFunctions(Loc, SourceRange(), 1078 /*UseGlobal*/ true, PromiseType, 1079 /*isArray*/ false, PassAlignment, PlacementArgs, 1080 OperatorNew, UnusedResult); 1081 } 1082 1083 assert(OperatorNew && "expected definition of operator new to be found"); 1084 1085 if (RequiresNoThrowAlloc) { 1086 const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>(); 1087 if (!FT->isNothrow(S.Context, /*ResultIfDependent*/ false)) { 1088 S.Diag(OperatorNew->getLocation(), 1089 diag::err_coroutine_promise_new_requires_nothrow) 1090 << OperatorNew; 1091 S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) 1092 << OperatorNew; 1093 return false; 1094 } 1095 } 1096 1097 if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr) 1098 return false; 1099 1100 Expr *FramePtr = 1101 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); 1102 1103 Expr *FrameSize = 1104 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); 1105 1106 // Make new call. 1107 1108 ExprResult NewRef = 1109 S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); 1110 if (NewRef.isInvalid()) 1111 return false; 1112 1113 SmallVector<Expr *, 2> NewArgs(1, FrameSize); 1114 for (auto Arg : PlacementArgs) 1115 NewArgs.push_back(Arg); 1116 1117 ExprResult NewExpr = 1118 S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc); 1119 NewExpr = S.ActOnFinishFullExpr(NewExpr.get()); 1120 if (NewExpr.isInvalid()) 1121 return false; 1122 1123 // Make delete call. 1124 1125 QualType OpDeleteQualType = OperatorDelete->getType(); 1126 1127 ExprResult DeleteRef = 1128 S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc); 1129 if (DeleteRef.isInvalid()) 1130 return false; 1131 1132 Expr *CoroFree = 1133 buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); 1134 1135 SmallVector<Expr *, 2> DeleteArgs{CoroFree}; 1136 1137 // Check if we need to pass the size. 1138 const auto *OpDeleteType = 1139 OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>(); 1140 if (OpDeleteType->getNumParams() > 1) 1141 DeleteArgs.push_back(FrameSize); 1142 1143 ExprResult DeleteExpr = 1144 S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc); 1145 DeleteExpr = S.ActOnFinishFullExpr(DeleteExpr.get()); 1146 if (DeleteExpr.isInvalid()) 1147 return false; 1148 1149 this->Allocate = NewExpr.get(); 1150 this->Deallocate = DeleteExpr.get(); 1151 1152 return true; 1153 } 1154 1155 bool CoroutineStmtBuilder::makeOnFallthrough() { 1156 assert(!IsPromiseDependentType && 1157 "cannot make statement while the promise type is dependent"); 1158 1159 // [dcl.fct.def.coroutine]/4 1160 // The unqualified-ids 'return_void' and 'return_value' are looked up in 1161 // the scope of class P. If both are found, the program is ill-formed. 1162 bool HasRVoid, HasRValue; 1163 LookupResult LRVoid = 1164 lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid); 1165 LookupResult LRValue = 1166 lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue); 1167 1168 StmtResult Fallthrough; 1169 if (HasRVoid && HasRValue) { 1170 // FIXME Improve this diagnostic 1171 S.Diag(FD.getLocation(), 1172 diag::err_coroutine_promise_incompatible_return_functions) 1173 << PromiseRecordDecl; 1174 S.Diag(LRVoid.getRepresentativeDecl()->getLocation(), 1175 diag::note_member_first_declared_here) 1176 << LRVoid.getLookupName(); 1177 S.Diag(LRValue.getRepresentativeDecl()->getLocation(), 1178 diag::note_member_first_declared_here) 1179 << LRValue.getLookupName(); 1180 return false; 1181 } else if (!HasRVoid && !HasRValue) { 1182 // FIXME: The PDTS currently specifies this case as UB, not ill-formed. 1183 // However we still diagnose this as an error since until the PDTS is fixed. 1184 S.Diag(FD.getLocation(), 1185 diag::err_coroutine_promise_requires_return_function) 1186 << PromiseRecordDecl; 1187 S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here) 1188 << PromiseRecordDecl; 1189 return false; 1190 } else if (HasRVoid) { 1191 // If the unqualified-id return_void is found, flowing off the end of a 1192 // coroutine is equivalent to a co_return with no operand. Otherwise, 1193 // flowing off the end of a coroutine results in undefined behavior. 1194 Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, 1195 /*IsImplicit*/false); 1196 Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); 1197 if (Fallthrough.isInvalid()) 1198 return false; 1199 } 1200 1201 this->OnFallthrough = Fallthrough.get(); 1202 return true; 1203 } 1204 1205 bool CoroutineStmtBuilder::makeOnException() { 1206 // Try to form 'p.unhandled_exception();' 1207 assert(!IsPromiseDependentType && 1208 "cannot make statement while the promise type is dependent"); 1209 1210 const bool RequireUnhandledException = S.getLangOpts().CXXExceptions; 1211 1212 if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) { 1213 auto DiagID = 1214 RequireUnhandledException 1215 ? diag::err_coroutine_promise_unhandled_exception_required 1216 : diag:: 1217 warn_coroutine_promise_unhandled_exception_required_with_exceptions; 1218 S.Diag(Loc, DiagID) << PromiseRecordDecl; 1219 S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here) 1220 << PromiseRecordDecl; 1221 return !RequireUnhandledException; 1222 } 1223 1224 // If exceptions are disabled, don't try to build OnException. 1225 if (!S.getLangOpts().CXXExceptions) 1226 return true; 1227 1228 ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, 1229 "unhandled_exception", None); 1230 UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc); 1231 if (UnhandledException.isInvalid()) 1232 return false; 1233 1234 // Since the body of the coroutine will be wrapped in try-catch, it will 1235 // be incompatible with SEH __try if present in a function. 1236 if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) { 1237 S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions); 1238 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1239 << Fn.getFirstCoroutineStmtKeyword(); 1240 return false; 1241 } 1242 1243 this->OnException = UnhandledException.get(); 1244 return true; 1245 } 1246 1247 bool CoroutineStmtBuilder::makeReturnObject() { 1248 // Build implicit 'p.get_return_object()' expression and form initialization 1249 // of return type from it. 1250 ExprResult ReturnObject = 1251 buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); 1252 if (ReturnObject.isInvalid()) 1253 return false; 1254 1255 this->ReturnValue = ReturnObject.get(); 1256 return true; 1257 } 1258 1259 static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) { 1260 if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) { 1261 auto *MethodDecl = MbrRef->getMethodDecl(); 1262 S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here) 1263 << MethodDecl; 1264 } 1265 S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) 1266 << Fn.getFirstCoroutineStmtKeyword(); 1267 } 1268 1269 bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() { 1270 assert(!IsPromiseDependentType && 1271 "cannot make statement while the promise type is dependent"); 1272 assert(this->ReturnValue && "ReturnValue must be already formed"); 1273 1274 QualType const GroType = this->ReturnValue->getType(); 1275 assert(!GroType->isDependentType() && 1276 "get_return_object type must no longer be dependent"); 1277 1278 QualType const FnRetType = FD.getReturnType(); 1279 assert(!FnRetType->isDependentType() && 1280 "get_return_object type must no longer be dependent"); 1281 1282 if (FnRetType->isVoidType()) { 1283 ExprResult Res = S.ActOnFinishFullExpr(this->ReturnValue, Loc); 1284 if (Res.isInvalid()) 1285 return false; 1286 1287 this->ResultDecl = Res.get(); 1288 return true; 1289 } 1290 1291 if (GroType->isVoidType()) { 1292 // Trigger a nice error message. 1293 InitializedEntity Entity = 1294 InitializedEntity::InitializeResult(Loc, FnRetType, false); 1295 S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue); 1296 noteMemberDeclaredHere(S, ReturnValue, Fn); 1297 return false; 1298 } 1299 1300 auto *GroDecl = VarDecl::Create( 1301 S.Context, &FD, FD.getLocation(), FD.getLocation(), 1302 &S.PP.getIdentifierTable().get("__coro_gro"), GroType, 1303 S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None); 1304 1305 S.CheckVariableDeclarationType(GroDecl); 1306 if (GroDecl->isInvalidDecl()) 1307 return false; 1308 1309 InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl); 1310 ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType, 1311 this->ReturnValue); 1312 if (Res.isInvalid()) 1313 return false; 1314 1315 Res = S.ActOnFinishFullExpr(Res.get()); 1316 if (Res.isInvalid()) 1317 return false; 1318 1319 if (GroType == FnRetType) { 1320 GroDecl->setNRVOVariable(true); 1321 } 1322 1323 S.AddInitializerToDecl(GroDecl, Res.get(), 1324 /*DirectInit=*/false); 1325 1326 S.FinalizeDeclaration(GroDecl); 1327 1328 // Form a declaration statement for the return declaration, so that AST 1329 // visitors can more easily find it. 1330 StmtResult GroDeclStmt = 1331 S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc); 1332 if (GroDeclStmt.isInvalid()) 1333 return false; 1334 1335 this->ResultDecl = GroDeclStmt.get(); 1336 1337 ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc); 1338 if (declRef.isInvalid()) 1339 return false; 1340 1341 StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get()); 1342 if (ReturnStmt.isInvalid()) { 1343 noteMemberDeclaredHere(S, ReturnValue, Fn); 1344 return false; 1345 } 1346 1347 this->ReturnStmt = ReturnStmt.get(); 1348 return true; 1349 } 1350 1351 // Create a static_cast\<T&&>(expr). 1352 static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) { 1353 if (T.isNull()) 1354 T = E->getType(); 1355 QualType TargetType = S.BuildReferenceType( 1356 T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName()); 1357 SourceLocation ExprLoc = E->getLocStart(); 1358 TypeSourceInfo *TargetLoc = 1359 S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc); 1360 1361 return S 1362 .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E, 1363 SourceRange(ExprLoc, ExprLoc), E->getSourceRange()) 1364 .get(); 1365 } 1366 1367 /// \brief Build a variable declaration for move parameter. 1368 static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type, 1369 IdentifierInfo *II) { 1370 TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc); 1371 VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, 1372 TInfo, SC_None); 1373 Decl->setImplicit(); 1374 return Decl; 1375 } 1376 1377 // Build statements that move coroutine function parameters to the coroutine 1378 // frame, and store them on the function scope info. 1379 bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) { 1380 assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); 1381 auto *FD = cast<FunctionDecl>(CurContext); 1382 1383 auto *ScopeInfo = getCurFunction(); 1384 assert(ScopeInfo->CoroutineParameterMoves.empty() && 1385 "Should not build parameter moves twice"); 1386 1387 for (auto *PD : FD->parameters()) { 1388 if (PD->getType()->isDependentType()) 1389 continue; 1390 1391 // No need to copy scalars, LLVM will take care of them. 1392 if (PD->getType()->getAsCXXRecordDecl()) { 1393 ExprResult PDRefExpr = BuildDeclRefExpr( 1394 PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope? 1395 if (PDRefExpr.isInvalid()) 1396 return false; 1397 1398 Expr *CExpr = castForMoving(*this, PDRefExpr.get()); 1399 1400 auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier()); 1401 AddInitializerToDecl(D, CExpr, /*DirectInit=*/true); 1402 1403 // Convert decl to a statement. 1404 StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc); 1405 if (Stmt.isInvalid()) 1406 return false; 1407 1408 ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get())); 1409 } 1410 } 1411 return true; 1412 } 1413 1414 StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { 1415 CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); 1416 if (!Res) 1417 return StmtError(); 1418 return Res; 1419 } 1420