1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements semantic analysis for C++ Coroutines. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "clang/Sema/SemaInternal.h" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/ExprCXX.h" 17 #include "clang/AST/StmtCXX.h" 18 #include "clang/Lex/Preprocessor.h" 19 #include "clang/Sema/Initialization.h" 20 #include "clang/Sema/Overload.h" 21 using namespace clang; 22 using namespace sema; 23 24 /// Look up the std::coroutine_traits<...>::promise_type for the given 25 /// function type. 26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, 27 SourceLocation Loc) { 28 // FIXME: Cache std::coroutine_traits once we've found it. 29 NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); 30 if (!StdExp) { 31 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); 32 return QualType(); 33 } 34 35 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), 36 Loc, Sema::LookupOrdinaryName); 37 if (!S.LookupQualifiedName(Result, StdExp)) { 38 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); 39 return QualType(); 40 } 41 42 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>(); 43 if (!CoroTraits) { 44 Result.suppressDiagnostics(); 45 // We found something weird. Complain about the first thing we found. 46 NamedDecl *Found = *Result.begin(); 47 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 48 return QualType(); 49 } 50 51 // Form template argument list for coroutine_traits<R, P1, P2, ...>. 52 TemplateArgumentListInfo Args(Loc, Loc); 53 Args.addArgument(TemplateArgumentLoc( 54 TemplateArgument(FnType->getReturnType()), 55 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); 56 // FIXME: If the function is a non-static member function, add the type 57 // of the implicit object parameter before the formal parameters. 58 for (QualType T : FnType->getParamTypes()) 59 Args.addArgument(TemplateArgumentLoc( 60 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); 61 62 // Build the template-id. 63 QualType CoroTrait = 64 S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); 65 if (CoroTrait.isNull()) 66 return QualType(); 67 if (S.RequireCompleteType(Loc, CoroTrait, 68 diag::err_coroutine_traits_missing_specialization)) 69 return QualType(); 70 71 CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl(); 72 assert(RD && "specialization of class template is not a class?"); 73 74 // Look up the ::promise_type member. 75 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, 76 Sema::LookupOrdinaryName); 77 S.LookupQualifiedName(R, RD); 78 auto *Promise = R.getAsSingle<TypeDecl>(); 79 if (!Promise) { 80 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) 81 << RD; 82 return QualType(); 83 } 84 85 // The promise type is required to be a class type. 86 QualType PromiseType = S.Context.getTypeDeclType(Promise); 87 if (!PromiseType->getAsCXXRecordDecl()) { 88 // Use the fully-qualified name of the type. 89 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); 90 NNS = NestedNameSpecifier::Create(S.Context, NNS, false, 91 CoroTrait.getTypePtr()); 92 PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType); 93 94 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) 95 << PromiseType; 96 return QualType(); 97 } 98 99 return PromiseType; 100 } 101 102 /// Check that this is a context in which a coroutine suspension can appear. 103 static FunctionScopeInfo * 104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { 105 // 'co_await' and 'co_yield' are not permitted in unevaluated operands. 106 if (S.isUnevaluatedContext()) { 107 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword; 108 return nullptr; 109 } 110 111 // Any other usage must be within a function. 112 // FIXME: Reject a coroutine with a deduced return type. 113 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 114 if (!FD) { 115 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 116 ? diag::err_coroutine_objc_method 117 : diag::err_coroutine_outside_function) << Keyword; 118 } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) { 119 // Coroutines TS [special]/6: 120 // A special member function shall not be a coroutine. 121 // 122 // FIXME: We assume that this really means that a coroutine cannot 123 // be a constructor or destructor. 124 S.Diag(Loc, diag::err_coroutine_ctor_dtor) 125 << isa<CXXDestructorDecl>(FD) << Keyword; 126 } else if (FD->isConstexpr()) { 127 S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword; 128 } else if (FD->isVariadic()) { 129 S.Diag(Loc, diag::err_coroutine_varargs) << Keyword; 130 } else if (FD->isMain()) { 131 S.Diag(FD->getLocStart(), diag::err_coroutine_main); 132 S.Diag(Loc, diag::note_declared_coroutine_here) 133 << (Keyword == "co_await" ? 0 : 134 Keyword == "co_yield" ? 1 : 2); 135 } else { 136 auto *ScopeInfo = S.getCurFunction(); 137 assert(ScopeInfo && "missing function scope for function"); 138 139 // If we don't have a promise variable, build one now. 140 if (!ScopeInfo->CoroutinePromise) { 141 QualType T = 142 FD->getType()->isDependentType() 143 ? S.Context.DependentTy 144 : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(), 145 Loc); 146 if (T.isNull()) 147 return nullptr; 148 149 // Create and default-initialize the promise. 150 ScopeInfo->CoroutinePromise = 151 VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), 152 &S.PP.getIdentifierTable().get("__promise"), T, 153 S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 154 S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); 155 if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) 156 S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false); 157 } 158 159 return ScopeInfo; 160 } 161 162 return nullptr; 163 } 164 165 /// Build a call to 'operator co_await' if there is a suitable operator for 166 /// the given expression. 167 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 168 SourceLocation Loc, Expr *E) { 169 UnresolvedSet<16> Functions; 170 SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), 171 Functions); 172 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 173 } 174 175 struct ReadySuspendResumeResult { 176 bool IsInvalid; 177 Expr *Results[3]; 178 }; 179 180 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, 181 StringRef Name, 182 MutableArrayRef<Expr *> Args) { 183 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); 184 185 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 186 CXXScopeSpec SS; 187 ExprResult Result = S.BuildMemberReferenceExpr( 188 Base, Base->getType(), Loc, /*IsPtr=*/false, SS, 189 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 190 /*Scope=*/nullptr); 191 if (Result.isInvalid()) 192 return ExprError(); 193 194 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); 195 } 196 197 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 198 /// expression. 199 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, 200 Expr *E) { 201 // Assume invalid until we see otherwise. 202 ReadySuspendResumeResult Calls = {true, {}}; 203 204 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 205 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 206 Expr *Operand = new (S.Context) OpaqueValueExpr( 207 Loc, E->getType(), VK_LValue, E->getObjectKind(), E); 208 209 // FIXME: Pass coroutine handle to await_suspend. 210 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None); 211 if (Result.isInvalid()) 212 return Calls; 213 Calls.Results[I] = Result.get(); 214 } 215 216 Calls.IsInvalid = false; 217 return Calls; 218 } 219 220 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 221 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); 222 if (!Coroutine) { 223 CorrectDelayedTyposInExpr(E); 224 return ExprError(); 225 } 226 if (E->getType()->isPlaceholderType()) { 227 ExprResult R = CheckPlaceholderExpr(E); 228 if (R.isInvalid()) return ExprError(); 229 E = R.get(); 230 } 231 232 ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); 233 if (Awaitable.isInvalid()) 234 return ExprError(); 235 236 return BuildCoawaitExpr(Loc, Awaitable.get()); 237 } 238 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { 239 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); 240 if (!Coroutine) 241 return ExprError(); 242 243 if (E->getType()->isPlaceholderType()) { 244 ExprResult R = CheckPlaceholderExpr(E); 245 if (R.isInvalid()) return ExprError(); 246 E = R.get(); 247 } 248 249 if (E->getType()->isDependentType()) { 250 Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); 251 Coroutine->CoroutineStmts.push_back(Res); 252 return Res; 253 } 254 255 // If the expression is a temporary, materialize it as an lvalue so that we 256 // can use it multiple times. 257 if (E->getValueKind() == VK_RValue) 258 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 259 260 // Build the await_ready, await_suspend, await_resume calls. 261 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); 262 if (RSS.IsInvalid) 263 return ExprError(); 264 265 Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 266 RSS.Results[2]); 267 Coroutine->CoroutineStmts.push_back(Res); 268 return Res; 269 } 270 271 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, 272 SourceLocation Loc, StringRef Name, 273 MutableArrayRef<Expr *> Args) { 274 assert(Coroutine->CoroutinePromise && "no promise for coroutine"); 275 276 // Form a reference to the promise. 277 auto *Promise = Coroutine->CoroutinePromise; 278 ExprResult PromiseRef = S.BuildDeclRefExpr( 279 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); 280 if (PromiseRef.isInvalid()) 281 return ExprError(); 282 283 // Call 'yield_value', passing in E. 284 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); 285 } 286 287 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 288 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 289 if (!Coroutine) { 290 CorrectDelayedTyposInExpr(E); 291 return ExprError(); 292 } 293 294 // Build yield_value call. 295 ExprResult Awaitable = 296 buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); 297 if (Awaitable.isInvalid()) 298 return ExprError(); 299 300 // Build 'operator co_await' call. 301 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get()); 302 if (Awaitable.isInvalid()) 303 return ExprError(); 304 305 return BuildCoyieldExpr(Loc, Awaitable.get()); 306 } 307 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 308 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 309 if (!Coroutine) 310 return ExprError(); 311 312 if (E->getType()->isPlaceholderType()) { 313 ExprResult R = CheckPlaceholderExpr(E); 314 if (R.isInvalid()) return ExprError(); 315 E = R.get(); 316 } 317 318 if (E->getType()->isDependentType()) { 319 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); 320 Coroutine->CoroutineStmts.push_back(Res); 321 return Res; 322 } 323 324 // If the expression is a temporary, materialize it as an lvalue so that we 325 // can use it multiple times. 326 if (E->getValueKind() == VK_RValue) 327 E = CreateMaterializeTemporaryExpr(E->getType(), E, true); 328 329 // Build the await_ready, await_suspend, await_resume calls. 330 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); 331 if (RSS.IsInvalid) 332 return ExprError(); 333 334 Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], 335 RSS.Results[2]); 336 Coroutine->CoroutineStmts.push_back(Res); 337 return Res; 338 } 339 340 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { 341 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); 342 if (!Coroutine) { 343 CorrectDelayedTyposInExpr(E); 344 return StmtError(); 345 } 346 return BuildCoreturnStmt(Loc, E); 347 } 348 349 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { 350 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); 351 if (!Coroutine) 352 return StmtError(); 353 354 if (E && E->getType()->isPlaceholderType() && 355 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { 356 ExprResult R = CheckPlaceholderExpr(E); 357 if (R.isInvalid()) return StmtError(); 358 E = R.get(); 359 } 360 361 // FIXME: If the operand is a reference to a variable that's about to go out 362 // of scope, we should treat the operand as an xvalue for this overload 363 // resolution. 364 ExprResult PC; 365 if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { 366 PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); 367 } else { 368 E = MakeFullDiscardedValueExpr(E).get(); 369 PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); 370 } 371 if (PC.isInvalid()) 372 return StmtError(); 373 374 Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); 375 376 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); 377 Coroutine->CoroutineStmts.push_back(Res); 378 return Res; 379 } 380 381 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { 382 FunctionScopeInfo *Fn = getCurFunction(); 383 assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); 384 385 // Coroutines [stmt.return]p1: 386 // A return statement shall not appear in a coroutine. 387 if (Fn->FirstReturnLoc.isValid()) { 388 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 389 auto *First = Fn->CoroutineStmts[0]; 390 Diag(First->getLocStart(), diag::note_declared_coroutine_here) 391 << (isa<CoawaitExpr>(First) ? 0 : 392 isa<CoyieldExpr>(First) ? 1 : 2); 393 } 394 395 bool AnyCoawaits = false; 396 bool AnyCoyields = false; 397 for (auto *CoroutineStmt : Fn->CoroutineStmts) { 398 AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt); 399 AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt); 400 } 401 402 if (!AnyCoawaits && !AnyCoyields) 403 Diag(Fn->CoroutineStmts.front()->getLocStart(), 404 diag::ext_coroutine_without_co_await_co_yield); 405 406 SourceLocation Loc = FD->getLocation(); 407 408 // Form a declaration statement for the promise declaration, so that AST 409 // visitors can more easily find it. 410 StmtResult PromiseStmt = 411 ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); 412 if (PromiseStmt.isInvalid()) 413 return FD->setInvalidDecl(); 414 415 // Form and check implicit 'co_await p.initial_suspend();' statement. 416 ExprResult InitialSuspend = 417 buildPromiseCall(*this, Fn, Loc, "initial_suspend", None); 418 // FIXME: Support operator co_await here. 419 if (!InitialSuspend.isInvalid()) 420 InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); 421 InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); 422 if (InitialSuspend.isInvalid()) 423 return FD->setInvalidDecl(); 424 425 // Form and check implicit 'co_await p.final_suspend();' statement. 426 ExprResult FinalSuspend = 427 buildPromiseCall(*this, Fn, Loc, "final_suspend", None); 428 // FIXME: Support operator co_await here. 429 if (!FinalSuspend.isInvalid()) 430 FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); 431 FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); 432 if (FinalSuspend.isInvalid()) 433 return FD->setInvalidDecl(); 434 435 // FIXME: Perform analysis of set_exception call. 436 437 // FIXME: Try to form 'p.return_void();' expression statement to handle 438 // control flowing off the end of the coroutine. 439 440 // Build implicit 'p.get_return_object()' expression and form initialization 441 // of return type from it. 442 ExprResult ReturnObject = 443 buildPromiseCall(*this, Fn, Loc, "get_return_object", None); 444 if (ReturnObject.isInvalid()) 445 return FD->setInvalidDecl(); 446 QualType RetType = FD->getReturnType(); 447 if (!RetType->isDependentType()) { 448 InitializedEntity Entity = 449 InitializedEntity::InitializeResult(Loc, RetType, false); 450 ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, 451 ReturnObject.get()); 452 if (ReturnObject.isInvalid()) 453 return FD->setInvalidDecl(); 454 } 455 ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); 456 if (ReturnObject.isInvalid()) 457 return FD->setInvalidDecl(); 458 459 // FIXME: Perform move-initialization of parameters into frame-local copies. 460 SmallVector<Expr*, 16> ParamMoves; 461 462 // Build body for the coroutine wrapper statement. 463 Body = new (Context) CoroutineBodyStmt( 464 Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), 465 /*SetException*/nullptr, /*Fallthrough*/nullptr, 466 ReturnObject.get(), ParamMoves); 467 } 468