1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements semantic analysis for C++ Coroutines. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "clang/Sema/SemaInternal.h" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/ExprCXX.h" 17 #include "clang/AST/StmtCXX.h" 18 #include "clang/Lex/Preprocessor.h" 19 #include "clang/Sema/Overload.h" 20 using namespace clang; 21 using namespace sema; 22 23 /// Look up the std::coroutine_traits<...>::promise_type for the given 24 /// function type. 25 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, 26 SourceLocation Loc) { 27 // FIXME: Cache std::coroutine_traits once we've found it. 28 NamespaceDecl *Std = S.getStdNamespace(); 29 if (!Std) { 30 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); 31 return QualType(); 32 } 33 34 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), 35 Loc, Sema::LookupOrdinaryName); 36 if (!S.LookupQualifiedName(Result, Std)) { 37 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); 38 return QualType(); 39 } 40 41 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>(); 42 if (!CoroTraits) { 43 Result.suppressDiagnostics(); 44 // We found something weird. Complain about the first thing we found. 45 NamedDecl *Found = *Result.begin(); 46 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 47 return QualType(); 48 } 49 50 // Form template argument list for coroutine_traits<R, P1, P2, ...>. 51 TemplateArgumentListInfo Args(Loc, Loc); 52 Args.addArgument(TemplateArgumentLoc( 53 TemplateArgument(FnType->getReturnType()), 54 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); 55 for (QualType T : FnType->getParamTypes()) 56 Args.addArgument(TemplateArgumentLoc( 57 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); 58 59 // Build the template-id. 60 QualType CoroTrait = 61 S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); 62 if (CoroTrait.isNull()) 63 return QualType(); 64 if (S.RequireCompleteType(Loc, CoroTrait, 65 diag::err_coroutine_traits_missing_specialization)) 66 return QualType(); 67 68 CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl(); 69 assert(RD && "specialization of class template is not a class?"); 70 71 // Look up the ::promise_type member. 72 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, 73 Sema::LookupOrdinaryName); 74 S.LookupQualifiedName(R, RD); 75 auto *Promise = R.getAsSingle<TypeDecl>(); 76 if (!Promise) { 77 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) 78 << RD; 79 return QualType(); 80 } 81 82 // The promise type is required to be a class type. 83 QualType PromiseType = S.Context.getTypeDeclType(Promise); 84 if (!PromiseType->getAsCXXRecordDecl()) { 85 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) 86 << PromiseType; 87 return QualType(); 88 } 89 90 return PromiseType; 91 } 92 93 /// Check that this is a context in which a coroutine suspension can appear. 94 static FunctionScopeInfo * 95 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { 96 // 'co_await' and 'co_yield' are permitted in unevaluated operands. 97 // FIXME: Not in 'noexcept'. 98 if (S.isUnevaluatedContext()) 99 return nullptr; 100 101 // Any other usage must be within a function. 102 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 103 if (!FD) { 104 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 105 ? diag::err_coroutine_objc_method 106 : diag::err_coroutine_outside_function) << Keyword; 107 } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) { 108 // Coroutines TS [special]/6: 109 // A special member function shall not be a coroutine. 110 // 111 // FIXME: We assume that this really means that a coroutine cannot 112 // be a constructor or destructor. 113 S.Diag(Loc, diag::err_coroutine_ctor_dtor) 114 << isa<CXXDestructorDecl>(FD) << Keyword; 115 } else if (FD->isConstexpr()) { 116 S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword; 117 } else if (FD->isVariadic()) { 118 S.Diag(Loc, diag::err_coroutine_varargs) << Keyword; 119 } else { 120 auto *ScopeInfo = S.getCurFunction(); 121 assert(ScopeInfo && "missing function scope for function"); 122 123 // If we don't have a promise variable, build one now. 124 if (!ScopeInfo->CoroutinePromise && !FD->getType()->isDependentType()) { 125 QualType T = 126 lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(), Loc); 127 if (T.isNull()) 128 return nullptr; 129 130 // Create and default-initialize the promise. 131 ScopeInfo->CoroutinePromise = 132 VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), 133 &S.PP.getIdentifierTable().get("__promise"), T, 134 S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 135 S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); 136 if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) 137 S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false); 138 } 139 140 return ScopeInfo; 141 } 142 143 return nullptr; 144 } 145 146 /// Build a call to 'operator co_await' if there is a suitable operator for 147 /// the given expression. 148 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 149 SourceLocation Loc, Expr *E) { 150 UnresolvedSet<16> Functions; 151 SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), 152 Functions); 153 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 154 } 155 156 struct ReadySuspendResumeResult { 157 bool IsInvalid; 158 Expr *Results[3]; 159 }; 160 161 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 162 /// expression. 163 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, 164 Expr *E) { 165 // Assume invalid until we see otherwise. 166 ReadySuspendResumeResult Calls = {true, {}}; 167 168 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 169 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 170 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Funcs[I]), Loc); 171 172 Expr *Operand = new (S.Context) OpaqueValueExpr( 173 Loc, E->getType(), E->getValueKind(), E->getObjectKind(), E); 174 175 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 176 CXXScopeSpec SS; 177 ExprResult Result = S.BuildMemberReferenceExpr( 178 Operand, Operand->getType(), Loc, /*IsPtr=*/false, SS, 179 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 180 /*Scope=*/nullptr); 181 if (Result.isInvalid()) 182 return Calls; 183 184 // FIXME: Pass coroutine handle to await_suspend. 185 Result = S.ActOnCallExpr(nullptr, Result.get(), Loc, None, Loc, nullptr); 186 if (Result.isInvalid()) 187 return Calls; 188 Calls.Results[I] = Result.get(); 189 } 190 191 Calls.IsInvalid = false; 192 return Calls; 193 } 194 195 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 196 ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); 197 if (Awaitable.isInvalid()) 198 return ExprError(); 199 return BuildCoawaitExpr(Loc, Awaitable.get()); 200 } 201 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { 202 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); 203 204 if (E->getType()->isDependentType()) { 205 Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); 206 if (Coroutine) 207 Coroutine->CoroutineStmts.push_back(Res); 208 return Res; 209 } 210 211 if (E->getType()->isPlaceholderType()) { 212 ExprResult R = CheckPlaceholderExpr(E); 213 if (R.isInvalid()) return ExprError(); 214 E = R.get(); 215 } 216 217 // FIXME: If E is a prvalue, create a temporary. 218 // FIXME: If E is an xvalue, convert to lvalue. 219 220 // Build the await_ready, await_suspend, await_resume calls. 221 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); 222 if (RSS.IsInvalid) 223 return ExprError(); 224 225 Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 226 RSS.Results[2]); 227 if (Coroutine) 228 Coroutine->CoroutineStmts.push_back(Res); 229 return Res; 230 } 231 232 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 233 // FIXME: Build yield_value call. 234 ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); 235 if (Awaitable.isInvalid()) 236 return ExprError(); 237 return BuildCoyieldExpr(Loc, Awaitable.get()); 238 } 239 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 240 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 241 242 // FIXME: Build await_* calls. 243 Expr *Res = new (Context) CoyieldExpr(Loc, Context.VoidTy, E); 244 if (Coroutine) 245 Coroutine->CoroutineStmts.push_back(Res); 246 return Res; 247 } 248 249 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { 250 return BuildCoreturnStmt(Loc, E); 251 } 252 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { 253 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); 254 255 // FIXME: Build return_* calls. 256 Stmt *Res = new (Context) CoreturnStmt(Loc, E); 257 if (Coroutine) 258 Coroutine->CoroutineStmts.push_back(Res); 259 return Res; 260 } 261 262 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *Body) { 263 FunctionScopeInfo *Fn = getCurFunction(); 264 assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); 265 266 // Coroutines [stmt.return]p1: 267 // A return statement shall not appear in a coroutine. 268 if (Fn->FirstReturnLoc.isValid()) { 269 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 270 auto *First = Fn->CoroutineStmts[0]; 271 Diag(First->getLocStart(), diag::note_declared_coroutine_here) 272 << (isa<CoawaitExpr>(First) ? 0 : 273 isa<CoyieldExpr>(First) ? 1 : 2); 274 } 275 276 bool AnyCoawaits = false; 277 bool AnyCoyields = false; 278 for (auto *CoroutineStmt : Fn->CoroutineStmts) { 279 AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt); 280 AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt); 281 } 282 283 if (!AnyCoawaits && !AnyCoyields) 284 Diag(Fn->CoroutineStmts.front()->getLocStart(), 285 diag::ext_coroutine_without_co_await_co_yield); 286 287 // FIXME: Perform analysis of initial and final suspend, 288 // and set_exception call. 289 } 290