1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/Tools/PDLL/AST/Context.h" 13 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 14 #include "mlir/Tools/PDLL/AST/Nodes.h" 15 #include "mlir/Tools/PDLL/AST/Types.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/SaveAndRestore.h" 20 #include "llvm/Support/ScopedPrinter.h" 21 #include <string> 22 23 using namespace mlir; 24 using namespace mlir::pdll; 25 26 //===----------------------------------------------------------------------===// 27 // Parser 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 class Parser { 32 public: 33 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 34 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 35 curToken(lexer.lexToken()), curDeclScope(nullptr), 36 valueTy(ast::ValueType::get(ctx)), 37 valueRangeTy(ast::ValueRangeType::get(ctx)), 38 typeTy(ast::TypeType::get(ctx)), 39 typeRangeTy(ast::TypeRangeType::get(ctx)) {} 40 41 /// Try to parse a new module. Returns nullptr in the case of failure. 42 FailureOr<ast::Module *> parseModule(); 43 44 private: 45 /// The current context of the parser. It allows for the parser to know a bit 46 /// about the construct it is nested within during parsing. This is used 47 /// specifically to provide additional verification during parsing, e.g. to 48 /// prevent using rewrites within a match context, matcher constraints within 49 /// a rewrite section, etc. 50 enum class ParserContext { 51 /// The parser is in the global context. 52 Global, 53 /// The parser is currently within a Constraint, which disallows all types 54 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 55 Constraint, 56 /// The parser is currently within the matcher portion of a Pattern, which 57 /// is allows a terminal operation rewrite statement but no other rewrite 58 /// transformations. 59 PatternMatch, 60 /// The parser is currently within a Rewrite, which disallows calls to 61 /// constraints, requires operation expressions to have names, etc. 62 Rewrite, 63 }; 64 65 //===--------------------------------------------------------------------===// 66 // Parsing 67 //===--------------------------------------------------------------------===// 68 69 /// Push a new decl scope onto the lexer. 70 ast::DeclScope *pushDeclScope() { 71 ast::DeclScope *newScope = 72 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 73 return (curDeclScope = newScope); 74 } 75 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 76 77 /// Pop the last decl scope from the lexer. 78 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 79 80 /// Parse the body of an AST module. 81 LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls); 82 83 /// Try to convert the given expression to `type`. Returns failure and emits 84 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 85 /// invoked to attach notes to the emitted error diagnostic. On success, 86 /// `expr` is updated to the expression used to convert to `type`. 87 LogicalResult convertExpressionTo( 88 ast::Expr *&expr, ast::Type type, 89 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 90 91 /// Given an operation expression, convert it to a Value or ValueRange 92 /// typed expression. 93 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 94 95 //===--------------------------------------------------------------------===// 96 // Directives 97 98 LogicalResult parseDirective(SmallVector<ast::Decl *> &decls); 99 LogicalResult parseInclude(SmallVector<ast::Decl *> &decls); 100 101 //===--------------------------------------------------------------------===// 102 // Decls 103 104 /// This structure contains the set of pattern metadata that may be parsed. 105 struct ParsedPatternMetadata { 106 Optional<uint16_t> benefit; 107 bool hasBoundedRecursion = false; 108 }; 109 110 FailureOr<ast::Decl *> parseTopLevelDecl(); 111 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 112 113 /// Parse an argument variable as part of the signature of a 114 /// UserConstraintDecl or UserRewriteDecl. 115 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 116 117 /// Parse a result variable as part of the signature of a UserConstraintDecl 118 /// or UserRewriteDecl. 119 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 120 121 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 122 /// defined in a non-global context. 123 FailureOr<ast::UserConstraintDecl *> 124 parseUserConstraintDecl(bool isInline = false); 125 126 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 127 /// non-global context, such as within a Pattern/Constraint/etc. 128 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 129 130 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 131 /// PDLL constructs. 132 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 133 const ast::Name &name, bool isInline, 134 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 135 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 136 137 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 138 /// defined in a non-global context. 139 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 140 141 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 142 /// non-global context, such as within a Pattern/Rewrite/etc. 143 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 144 145 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 146 /// PDLL constructs. 147 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 148 const ast::Name &name, bool isInline, 149 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 150 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 151 152 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 153 /// effectively the same syntax, and only differ on slight semantics (given 154 /// the different parsing contexts). 155 template <typename T, typename ParseUserPDLLDeclFnT> 156 FailureOr<T *> parseUserConstraintOrRewriteDecl( 157 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 158 StringRef anonymousNamePrefix, bool isInline); 159 160 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 161 /// These decls have effectively the same syntax. 162 template <typename T> 163 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 164 const ast::Name &name, bool isInline, 165 ArrayRef<ast::VariableDecl *> arguments, 166 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 167 168 /// Parse the functional signature (i.e. the arguments and results) of a 169 /// UserConstraintDecl or UserRewriteDecl. 170 LogicalResult parseUserConstraintOrRewriteSignature( 171 SmallVectorImpl<ast::VariableDecl *> &arguments, 172 SmallVectorImpl<ast::VariableDecl *> &results, 173 ast::DeclScope *&argumentScope, ast::Type &resultType); 174 175 /// Validate the return (which if present is specified by bodyIt) of a 176 /// UserConstraintDecl or UserRewriteDecl. 177 LogicalResult validateUserConstraintOrRewriteReturn( 178 StringRef declType, ast::CompoundStmt *body, 179 ArrayRef<ast::Stmt *>::iterator bodyIt, 180 ArrayRef<ast::Stmt *>::iterator bodyE, 181 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 182 183 FailureOr<ast::CompoundStmt *> 184 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 185 bool expectTerminalSemicolon = true); 186 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 187 FailureOr<ast::Decl *> parsePatternDecl(); 188 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 189 190 /// Check to see if a decl has already been defined with the given name, if 191 /// one has emit and error and return failure. Returns success otherwise. 192 LogicalResult checkDefineNamedDecl(const ast::Name &name); 193 194 /// Try to define a variable decl with the given components, returns the 195 /// variable on success. 196 FailureOr<ast::VariableDecl *> 197 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 198 ast::Expr *initExpr, 199 ArrayRef<ast::ConstraintRef> constraints); 200 FailureOr<ast::VariableDecl *> 201 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 202 ArrayRef<ast::ConstraintRef> constraints); 203 204 /// Parse the constraint reference list for a variable decl. 205 LogicalResult parseVariableDeclConstraintList( 206 SmallVectorImpl<ast::ConstraintRef> &constraints); 207 208 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 209 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 210 211 /// Try to parse a single reference to a constraint. `typeConstraint` is the 212 /// location of a previously parsed type constraint for the entity that will 213 /// be constrained by the parsed constraint. `existingConstraints` are any 214 /// existing constraints that have already been parsed for the same entity 215 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 216 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 217 FailureOr<ast::ConstraintRef> 218 parseConstraint(Optional<SMRange> &typeConstraint, 219 ArrayRef<ast::ConstraintRef> existingConstraints, 220 bool allowInlineTypeConstraints); 221 222 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 223 /// argument or result variable. The constraints for these variables do not 224 /// allow inline type constraints, and only permit a single constraint. 225 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 226 227 //===--------------------------------------------------------------------===// 228 // Exprs 229 230 FailureOr<ast::Expr *> parseExpr(); 231 232 /// Identifier expressions. 233 FailureOr<ast::Expr *> parseAttributeExpr(); 234 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 235 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 236 FailureOr<ast::Expr *> parseIdentifierExpr(); 237 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 238 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 239 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 240 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 241 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 242 FailureOr<ast::Expr *> parseOperationExpr(); 243 FailureOr<ast::Expr *> parseTupleExpr(); 244 FailureOr<ast::Expr *> parseTypeExpr(); 245 FailureOr<ast::Expr *> parseUnderscoreExpr(); 246 247 //===--------------------------------------------------------------------===// 248 // Stmts 249 250 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 251 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 252 FailureOr<ast::EraseStmt *> parseEraseStmt(); 253 FailureOr<ast::LetStmt *> parseLetStmt(); 254 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 255 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 256 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 257 258 //===--------------------------------------------------------------------===// 259 // Creation+Analysis 260 //===--------------------------------------------------------------------===// 261 262 //===--------------------------------------------------------------------===// 263 // Decls 264 265 /// Try to extract a callable from the given AST node. Returns nullptr on 266 /// failure. 267 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 268 269 /// Try to create a pattern decl with the given components, returning the 270 /// Pattern on success. 271 FailureOr<ast::PatternDecl *> 272 createPatternDecl(SMRange loc, const ast::Name *name, 273 const ParsedPatternMetadata &metadata, 274 ast::CompoundStmt *body); 275 276 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 277 /// of results, defined as part of the signature. 278 ast::Type 279 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 280 281 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 282 template <typename T> 283 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 284 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 285 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 286 ast::CompoundStmt *body); 287 288 /// Try to create a variable decl with the given components, returning the 289 /// Variable on success. 290 FailureOr<ast::VariableDecl *> 291 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 292 ArrayRef<ast::ConstraintRef> constraints); 293 294 /// Create a variable for an argument or result defined as part of the 295 /// signature of a UserConstraintDecl/UserRewriteDecl. 296 FailureOr<ast::VariableDecl *> 297 createArgOrResultVariableDecl(StringRef name, SMRange loc, 298 const ast::ConstraintRef &constraint); 299 300 /// Validate the constraints used to constraint a variable decl. 301 /// `inferredType` is the type of the variable inferred by the constraints 302 /// within the list, and is updated to the most refined type as determined by 303 /// the constraints. Returns success if the constraint list is valid, failure 304 /// otherwise. 305 LogicalResult 306 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 307 ast::Type &inferredType); 308 /// Validate a single reference to a constraint. `inferredType` contains the 309 /// currently inferred variabled type and is refined within the type defined 310 /// by the constraint. Returns success if the constraint is valid, failure 311 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 312 /// defined constraints) may be used with the variable. 313 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 314 ast::Type &inferredType, 315 bool allowNonCoreConstraints = true); 316 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 317 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 318 319 //===--------------------------------------------------------------------===// 320 // Exprs 321 322 FailureOr<ast::CallExpr *> 323 createCallExpr(SMRange loc, ast::Expr *parentExpr, 324 MutableArrayRef<ast::Expr *> arguments); 325 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 326 FailureOr<ast::DeclRefExpr *> 327 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 328 ArrayRef<ast::ConstraintRef> constraints); 329 FailureOr<ast::MemberAccessExpr *> 330 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 331 332 /// Validate the member access `name` into the given parent expression. On 333 /// success, this also returns the type of the member accessed. 334 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 335 StringRef name, SMRange loc); 336 FailureOr<ast::OperationExpr *> 337 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 338 MutableArrayRef<ast::Expr *> operands, 339 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 340 MutableArrayRef<ast::Expr *> results); 341 LogicalResult 342 validateOperationOperands(SMRange loc, Optional<StringRef> name, 343 MutableArrayRef<ast::Expr *> operands); 344 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 345 MutableArrayRef<ast::Expr *> results); 346 LogicalResult 347 validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name, 348 MutableArrayRef<ast::Expr *> values, 349 ast::Type singleTy, ast::Type rangeTy); 350 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 351 ArrayRef<ast::Expr *> elements, 352 ArrayRef<StringRef> elementNames); 353 354 //===--------------------------------------------------------------------===// 355 // Stmts 356 357 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 358 FailureOr<ast::ReplaceStmt *> 359 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 360 MutableArrayRef<ast::Expr *> replValues); 361 FailureOr<ast::RewriteStmt *> 362 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 363 ast::CompoundStmt *rewriteBody); 364 365 //===--------------------------------------------------------------------===// 366 // Lexer Utilities 367 //===--------------------------------------------------------------------===// 368 369 /// If the current token has the specified kind, consume it and return true. 370 /// If not, return false. 371 bool consumeIf(Token::Kind kind) { 372 if (curToken.isNot(kind)) 373 return false; 374 consumeToken(kind); 375 return true; 376 } 377 378 /// Advance the current lexer onto the next token. 379 void consumeToken() { 380 assert(curToken.isNot(Token::eof, Token::error) && 381 "shouldn't advance past EOF or errors"); 382 curToken = lexer.lexToken(); 383 } 384 385 /// Advance the current lexer onto the next token, asserting what the expected 386 /// current token is. This is preferred to the above method because it leads 387 /// to more self-documenting code with better checking. 388 void consumeToken(Token::Kind kind) { 389 assert(curToken.is(kind) && "consumed an unexpected token"); 390 consumeToken(); 391 } 392 393 /// Reset the lexer to the location at the given position. 394 void resetToken(SMRange tokLoc) { 395 lexer.resetPointer(tokLoc.Start.getPointer()); 396 curToken = lexer.lexToken(); 397 } 398 399 /// Consume the specified token if present and return success. On failure, 400 /// output a diagnostic and return failure. 401 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 402 if (curToken.getKind() != kind) 403 return emitError(curToken.getLoc(), msg); 404 consumeToken(); 405 return success(); 406 } 407 LogicalResult emitError(SMRange loc, const Twine &msg) { 408 lexer.emitError(loc, msg); 409 return failure(); 410 } 411 LogicalResult emitError(const Twine &msg) { 412 return emitError(curToken.getLoc(), msg); 413 } 414 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 415 const Twine ¬e) { 416 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 417 return failure(); 418 } 419 420 //===--------------------------------------------------------------------===// 421 // Fields 422 //===--------------------------------------------------------------------===// 423 424 /// The owning AST context. 425 ast::Context &ctx; 426 427 /// The lexer of this parser. 428 Lexer lexer; 429 430 /// The current token within the lexer. 431 Token curToken; 432 433 /// The most recently defined decl scope. 434 ast::DeclScope *curDeclScope; 435 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 436 437 /// The current context of the parser. 438 ParserContext parserContext = ParserContext::Global; 439 440 /// Cached types to simplify verification and expression creation. 441 ast::Type valueTy, valueRangeTy; 442 ast::Type typeTy, typeRangeTy; 443 444 /// A counter used when naming anonymous constraints and rewrites. 445 unsigned anonymousDeclNameCounter = 0; 446 }; 447 } // namespace 448 449 FailureOr<ast::Module *> Parser::parseModule() { 450 SMLoc moduleLoc = curToken.getStartLoc(); 451 pushDeclScope(); 452 453 // Parse the top-level decls of the module. 454 SmallVector<ast::Decl *> decls; 455 if (failed(parseModuleBody(decls))) 456 return popDeclScope(), failure(); 457 458 popDeclScope(); 459 return ast::Module::create(ctx, moduleLoc, decls); 460 } 461 462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) { 463 while (curToken.isNot(Token::eof)) { 464 if (curToken.is(Token::directive)) { 465 if (failed(parseDirective(decls))) 466 return failure(); 467 continue; 468 } 469 470 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 471 if (failed(decl)) 472 return failure(); 473 decls.push_back(*decl); 474 } 475 return success(); 476 } 477 478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 479 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 480 valueRangeTy); 481 } 482 483 LogicalResult Parser::convertExpressionTo( 484 ast::Expr *&expr, ast::Type type, 485 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 486 ast::Type exprType = expr->getType(); 487 if (exprType == type) 488 return success(); 489 490 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 491 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 492 expr->getLoc(), llvm::formatv("unable to convert expression of type " 493 "`{0}` to the expected type of " 494 "`{1}`", 495 exprType, type)); 496 if (noteAttachFn) 497 noteAttachFn(*diag); 498 return diag; 499 }; 500 501 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 502 // Two operation types are compatible if they have the same name, or if the 503 // expected type is more general. 504 if (auto opType = type.dyn_cast<ast::OperationType>()) { 505 if (opType.getName()) 506 return emitConvertError(); 507 return success(); 508 } 509 510 // An operation can always convert to a ValueRange. 511 if (type == valueRangeTy) { 512 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 513 valueRangeTy); 514 return success(); 515 } 516 517 // Allow conversion to a single value by constraining the result range. 518 if (type == valueTy) { 519 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 520 valueTy); 521 return success(); 522 } 523 return emitConvertError(); 524 } 525 526 // FIXME: Decide how to allow/support converting a single result to multiple, 527 // and multiple to a single result. For now, we just allow Single->Range, 528 // but this isn't something really supported in the PDL dialect. We should 529 // figure out some way to support both. 530 if ((exprType == valueTy || exprType == valueRangeTy) && 531 (type == valueTy || type == valueRangeTy)) 532 return success(); 533 if ((exprType == typeTy || exprType == typeRangeTy) && 534 (type == typeTy || type == typeRangeTy)) 535 return success(); 536 537 // Handle tuple types. 538 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 539 auto tupleType = type.dyn_cast<ast::TupleType>(); 540 if (!tupleType || tupleType.size() != exprTupleType.size()) 541 return emitConvertError(); 542 543 // Build a new tuple expression using each of the elements of the current 544 // tuple. 545 SmallVector<ast::Expr *> newExprs; 546 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 547 newExprs.push_back(ast::MemberAccessExpr::create( 548 ctx, expr->getLoc(), expr, llvm::to_string(i), 549 exprTupleType.getElementTypes()[i])); 550 551 auto diagFn = [&](ast::Diagnostic &diag) { 552 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 553 i, exprTupleType)); 554 if (noteAttachFn) 555 noteAttachFn(diag); 556 }; 557 if (failed(convertExpressionTo(newExprs.back(), 558 tupleType.getElementTypes()[i], diagFn))) 559 return failure(); 560 } 561 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 562 tupleType.getElementNames()); 563 return success(); 564 } 565 566 return emitConvertError(); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // Directives 571 572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) { 573 StringRef directive = curToken.getSpelling(); 574 if (directive == "#include") 575 return parseInclude(decls); 576 577 return emitError("unknown directive `" + directive + "`"); 578 } 579 580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) { 581 SMRange loc = curToken.getLoc(); 582 consumeToken(Token::directive); 583 584 // Parse the file being included. 585 if (!curToken.isString()) 586 return emitError(loc, 587 "expected string file name after `include` directive"); 588 SMRange fileLoc = curToken.getLoc(); 589 std::string filenameStr = curToken.getStringValue(); 590 StringRef filename = filenameStr; 591 consumeToken(); 592 593 // Check the type of include. If ending with `.pdll`, this is another pdl file 594 // to be parsed along with the current module. 595 if (filename.endswith(".pdll")) { 596 if (failed(lexer.pushInclude(filename))) 597 return emitError(fileLoc, 598 "unable to open include file `" + filename + "`"); 599 600 // If we added the include successfully, parse it into the current module. 601 // Make sure to save the current token so that we can restore it when we 602 // finish parsing the nested file. 603 Token oldToken = curToken; 604 curToken = lexer.lexToken(); 605 LogicalResult result = parseModuleBody(decls); 606 curToken = oldToken; 607 return result; 608 } 609 610 return emitError(fileLoc, "expected include filename to end with `.pdll`"); 611 } 612 613 //===----------------------------------------------------------------------===// 614 // Decls 615 616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 617 FailureOr<ast::Decl *> decl; 618 switch (curToken.getKind()) { 619 case Token::kw_Constraint: 620 decl = parseUserConstraintDecl(); 621 break; 622 case Token::kw_Pattern: 623 decl = parsePatternDecl(); 624 break; 625 case Token::kw_Rewrite: 626 decl = parseUserRewriteDecl(); 627 break; 628 default: 629 return emitError("expected top-level declaration, such as a `Pattern`"); 630 } 631 if (failed(decl)) 632 return failure(); 633 634 // If the decl has a name, add it to the current scope. 635 if (const ast::Name *name = (*decl)->getName()) { 636 if (failed(checkDefineNamedDecl(*name))) 637 return failure(); 638 curDeclScope->add(*decl); 639 } 640 return decl; 641 } 642 643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 644 std::string attrNameStr; 645 if (curToken.isString()) 646 attrNameStr = curToken.getStringValue(); 647 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 648 attrNameStr = curToken.getSpelling().str(); 649 else 650 return emitError("expected identifier or string attribute name"); 651 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 652 consumeToken(); 653 654 // Check for a value of the attribute. 655 ast::Expr *attrValue = nullptr; 656 if (consumeIf(Token::equal)) { 657 FailureOr<ast::Expr *> attrExpr = parseExpr(); 658 if (failed(attrExpr)) 659 return failure(); 660 attrValue = *attrExpr; 661 } else { 662 // If there isn't a concrete value, create an expression representing a 663 // UnitAttr. 664 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 665 } 666 667 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 668 } 669 670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 671 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 672 bool expectTerminalSemicolon) { 673 consumeToken(Token::equal_arrow); 674 675 // Parse the single statement of the lambda body. 676 SMLoc bodyStartLoc = curToken.getStartLoc(); 677 pushDeclScope(); 678 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 679 bool failedToParse = 680 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 681 popDeclScope(); 682 if (failedToParse) 683 return failure(); 684 685 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 686 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 687 } 688 689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 690 // Ensure that the argument is named. 691 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 692 return emitError("expected identifier argument name"); 693 694 // Parse the argument similarly to a normal variable. 695 StringRef name = curToken.getSpelling(); 696 SMRange nameLoc = curToken.getLoc(); 697 consumeToken(); 698 699 if (failed( 700 parseToken(Token::colon, "expected `:` before argument constraint"))) 701 return failure(); 702 703 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 704 if (failed(cst)) 705 return failure(); 706 707 return createArgOrResultVariableDecl(name, nameLoc, *cst); 708 } 709 710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 711 // Check to see if this result is named. 712 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 713 // Check to see if this name actually refers to a Constraint. 714 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 715 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 716 // If yes, and this is a Rewrite, give a nice error message as non-Core 717 // constraints are not supported on Rewrite results. 718 if (parserContext == ParserContext::Rewrite) { 719 return emitError( 720 "`Rewrite` results are only permitted to use core constraints, " 721 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 722 } 723 724 // Otherwise, parse this as an unnamed result variable. 725 } else { 726 // If it wasn't a constraint, parse the result similarly to a variable. If 727 // there is already an existing decl, we will emit an error when defining 728 // this variable later. 729 StringRef name = curToken.getSpelling(); 730 SMRange nameLoc = curToken.getLoc(); 731 consumeToken(); 732 733 if (failed(parseToken(Token::colon, 734 "expected `:` before result constraint"))) 735 return failure(); 736 737 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 738 if (failed(cst)) 739 return failure(); 740 741 return createArgOrResultVariableDecl(name, nameLoc, *cst); 742 } 743 } 744 745 // If it isn't named, we parse the constraint directly and create an unnamed 746 // result variable. 747 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 748 if (failed(cst)) 749 return failure(); 750 751 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 752 } 753 754 FailureOr<ast::UserConstraintDecl *> 755 Parser::parseUserConstraintDecl(bool isInline) { 756 // Constraints and rewrites have very similar formats, dispatch to a shared 757 // interface for parsing. 758 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 759 [&](auto &&...args) { 760 return this->parseUserPDLLConstraintDecl(args...); 761 }, 762 ParserContext::Constraint, "constraint", isInline); 763 } 764 765 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 766 FailureOr<ast::UserConstraintDecl *> decl = 767 parseUserConstraintDecl(/*isInline=*/true); 768 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 769 return failure(); 770 771 curDeclScope->add(*decl); 772 return decl; 773 } 774 775 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 776 const ast::Name &name, bool isInline, 777 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 778 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 779 // Push the argument scope back onto the list, so that the body can 780 // reference arguments. 781 pushDeclScope(argumentScope); 782 783 // Parse the body of the constraint. The body is either defined as a compound 784 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 785 ast::CompoundStmt *body; 786 if (curToken.is(Token::equal_arrow)) { 787 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 788 [&](ast::Stmt *&stmt) -> LogicalResult { 789 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 790 if (!stmtExpr) { 791 return emitError(stmt->getLoc(), 792 "expected `Constraint` lambda body to contain a " 793 "single expression"); 794 } 795 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 796 return success(); 797 }, 798 /*expectTerminalSemicolon=*/!isInline); 799 if (failed(bodyResult)) 800 return failure(); 801 body = *bodyResult; 802 } else { 803 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 804 if (failed(bodyResult)) 805 return failure(); 806 body = *bodyResult; 807 808 // Verify the structure of the body. 809 auto bodyIt = body->begin(), bodyE = body->end(); 810 for (; bodyIt != bodyE; ++bodyIt) 811 if (isa<ast::ReturnStmt>(*bodyIt)) 812 break; 813 if (failed(validateUserConstraintOrRewriteReturn( 814 "Constraint", body, bodyIt, bodyE, results, resultType))) 815 return failure(); 816 } 817 popDeclScope(); 818 819 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 820 name, arguments, results, resultType, body); 821 } 822 823 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 824 // Constraints and rewrites have very similar formats, dispatch to a shared 825 // interface for parsing. 826 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 827 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 828 ParserContext::Rewrite, "rewrite", isInline); 829 } 830 831 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 832 FailureOr<ast::UserRewriteDecl *> decl = 833 parseUserRewriteDecl(/*isInline=*/true); 834 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 835 return failure(); 836 837 curDeclScope->add(*decl); 838 return decl; 839 } 840 841 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 842 const ast::Name &name, bool isInline, 843 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 844 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 845 // Push the argument scope back onto the list, so that the body can 846 // reference arguments. 847 curDeclScope = argumentScope; 848 ast::CompoundStmt *body; 849 if (curToken.is(Token::equal_arrow)) { 850 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 851 [&](ast::Stmt *&statement) -> LogicalResult { 852 if (isa<ast::OpRewriteStmt>(statement)) 853 return success(); 854 855 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 856 if (!statementExpr) { 857 return emitError( 858 statement->getLoc(), 859 "expected `Rewrite` lambda body to contain a single expression " 860 "or an operation rewrite statement; such as `erase`, " 861 "`replace`, or `rewrite`"); 862 } 863 statement = 864 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 865 return success(); 866 }, 867 /*expectTerminalSemicolon=*/!isInline); 868 if (failed(bodyResult)) 869 return failure(); 870 body = *bodyResult; 871 } else { 872 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 873 if (failed(bodyResult)) 874 return failure(); 875 body = *bodyResult; 876 } 877 popDeclScope(); 878 879 // Verify the structure of the body. 880 auto bodyIt = body->begin(), bodyE = body->end(); 881 for (; bodyIt != bodyE; ++bodyIt) 882 if (isa<ast::ReturnStmt>(*bodyIt)) 883 break; 884 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 885 bodyE, results, resultType))) 886 return failure(); 887 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 888 name, arguments, results, resultType, body); 889 } 890 891 template <typename T, typename ParseUserPDLLDeclFnT> 892 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 893 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 894 StringRef anonymousNamePrefix, bool isInline) { 895 SMRange loc = curToken.getLoc(); 896 consumeToken(); 897 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 898 899 // Parse the name of the decl. 900 const ast::Name *name = nullptr; 901 if (curToken.isNot(Token::identifier)) { 902 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 903 // in C++, so being unnamed is fine. 904 if (!isInline) 905 return emitError("expected identifier name"); 906 907 // Create a unique anonymous name to use, as the name for this decl is not 908 // important. 909 std::string anonName = 910 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 911 anonymousDeclNameCounter++) 912 .str(); 913 name = &ast::Name::create(ctx, anonName, loc); 914 } else { 915 // If a name was provided, we can use it directly. 916 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 917 consumeToken(Token::identifier); 918 } 919 920 // Parse the functional signature of the decl. 921 SmallVector<ast::VariableDecl *> arguments, results; 922 ast::DeclScope *argumentScope; 923 ast::Type resultType; 924 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 925 argumentScope, resultType))) 926 return failure(); 927 928 // Check to see which type of constraint this is. If the constraint contains a 929 // compound body, this is a PDLL decl. 930 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 931 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 932 resultType); 933 934 // Otherwise, this is a native decl. 935 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 936 results, resultType); 937 } 938 939 template <typename T> 940 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 941 const ast::Name &name, bool isInline, 942 ArrayRef<ast::VariableDecl *> arguments, 943 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 944 // If followed by a string, the native code body has also been specified. 945 std::string codeStrStorage; 946 Optional<StringRef> optCodeStr; 947 if (curToken.isString()) { 948 codeStrStorage = curToken.getStringValue(); 949 optCodeStr = codeStrStorage; 950 consumeToken(); 951 } else if (isInline) { 952 return emitError(name.getLoc(), 953 "external declarations must be declared in global scope"); 954 } 955 if (failed(parseToken(Token::semicolon, 956 "expected `;` after native declaration"))) 957 return failure(); 958 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 959 } 960 961 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 962 SmallVectorImpl<ast::VariableDecl *> &arguments, 963 SmallVectorImpl<ast::VariableDecl *> &results, 964 ast::DeclScope *&argumentScope, ast::Type &resultType) { 965 // Parse the argument list of the decl. 966 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 967 return failure(); 968 969 argumentScope = pushDeclScope(); 970 if (curToken.isNot(Token::r_paren)) { 971 do { 972 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 973 if (failed(argument)) 974 return failure(); 975 arguments.emplace_back(*argument); 976 } while (consumeIf(Token::comma)); 977 } 978 popDeclScope(); 979 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 980 return failure(); 981 982 // Parse the results of the decl. 983 pushDeclScope(); 984 if (consumeIf(Token::arrow)) { 985 auto parseResultFn = [&]() -> LogicalResult { 986 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 987 if (failed(result)) 988 return failure(); 989 results.emplace_back(*result); 990 return success(); 991 }; 992 993 // Check for a list of results. 994 if (consumeIf(Token::l_paren)) { 995 do { 996 if (failed(parseResultFn())) 997 return failure(); 998 } while (consumeIf(Token::comma)); 999 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1000 return failure(); 1001 1002 // Otherwise, there is only one result. 1003 } else if (failed(parseResultFn())) { 1004 return failure(); 1005 } 1006 } 1007 popDeclScope(); 1008 1009 // Compute the result type of the decl. 1010 resultType = createUserConstraintRewriteResultType(results); 1011 1012 // Verify that results are only named if there are more than one. 1013 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1014 return emitError( 1015 results.front()->getLoc(), 1016 "cannot create a single-element tuple with an element label"); 1017 } 1018 return success(); 1019 } 1020 1021 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1022 StringRef declType, ast::CompoundStmt *body, 1023 ArrayRef<ast::Stmt *>::iterator bodyIt, 1024 ArrayRef<ast::Stmt *>::iterator bodyE, 1025 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1026 // Handle if a `return` was provided. 1027 if (bodyIt != bodyE) { 1028 // Emit an error if we have trailing statements after the return. 1029 if (std::next(bodyIt) != bodyE) { 1030 return emitError( 1031 (*std::next(bodyIt))->getLoc(), 1032 llvm::formatv("`return` terminated the `{0}` body, but found " 1033 "trailing statements afterwards", 1034 declType)); 1035 } 1036 1037 // Otherwise if a return wasn't provided, check that no results are 1038 // expected. 1039 } else if (!results.empty()) { 1040 return emitError( 1041 {body->getLoc().End, body->getLoc().End}, 1042 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1043 declType, resultType)); 1044 } 1045 return success(); 1046 } 1047 1048 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1049 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1050 if (isa<ast::OpRewriteStmt>(statement)) 1051 return success(); 1052 return emitError( 1053 statement->getLoc(), 1054 "expected Pattern lambda body to contain a single operation " 1055 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1056 }); 1057 } 1058 1059 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1060 SMRange loc = curToken.getLoc(); 1061 consumeToken(Token::kw_Pattern); 1062 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1063 ParserContext::PatternMatch); 1064 1065 // Check for an optional identifier for the pattern name. 1066 const ast::Name *name = nullptr; 1067 if (curToken.is(Token::identifier)) { 1068 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1069 consumeToken(Token::identifier); 1070 } 1071 1072 // Parse any pattern metadata. 1073 ParsedPatternMetadata metadata; 1074 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1075 return failure(); 1076 1077 // Parse the pattern body. 1078 ast::CompoundStmt *body; 1079 1080 // Handle a lambda body. 1081 if (curToken.is(Token::equal_arrow)) { 1082 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1083 if (failed(bodyResult)) 1084 return failure(); 1085 body = *bodyResult; 1086 } else { 1087 if (curToken.isNot(Token::l_brace)) 1088 return emitError("expected `{` or `=>` to start pattern body"); 1089 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1090 if (failed(bodyResult)) 1091 return failure(); 1092 body = *bodyResult; 1093 1094 // Verify the body of the pattern. 1095 auto bodyIt = body->begin(), bodyE = body->end(); 1096 for (; bodyIt != bodyE; ++bodyIt) { 1097 if (isa<ast::ReturnStmt>(*bodyIt)) { 1098 return emitError((*bodyIt)->getLoc(), 1099 "`return` statements are only permitted within a " 1100 "`Constraint` or `Rewrite` body"); 1101 } 1102 // Break when we've found the rewrite statement. 1103 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1104 break; 1105 } 1106 if (bodyIt == bodyE) { 1107 return emitError(loc, 1108 "expected Pattern body to terminate with an operation " 1109 "rewrite statement, such as `erase`"); 1110 } 1111 if (std::next(bodyIt) != bodyE) { 1112 return emitError((*std::next(bodyIt))->getLoc(), 1113 "Pattern body was terminated by an operation " 1114 "rewrite statement, but found trailing statements"); 1115 } 1116 } 1117 1118 return createPatternDecl(loc, name, metadata, body); 1119 } 1120 1121 LogicalResult 1122 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1123 Optional<SMRange> benefitLoc; 1124 Optional<SMRange> hasBoundedRecursionLoc; 1125 1126 do { 1127 if (curToken.isNot(Token::identifier)) 1128 return emitError("expected pattern metadata identifier"); 1129 StringRef metadataStr = curToken.getSpelling(); 1130 SMRange metadataLoc = curToken.getLoc(); 1131 consumeToken(Token::identifier); 1132 1133 // Parse the benefit metadata: benefit(<integer-value>) 1134 if (metadataStr == "benefit") { 1135 if (benefitLoc) { 1136 return emitErrorAndNote(metadataLoc, 1137 "pattern benefit has already been specified", 1138 *benefitLoc, "see previous definition here"); 1139 } 1140 if (failed(parseToken(Token::l_paren, 1141 "expected `(` before pattern benefit"))) 1142 return failure(); 1143 1144 uint16_t benefitValue = 0; 1145 if (curToken.isNot(Token::integer)) 1146 return emitError("expected integral pattern benefit"); 1147 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1148 return emitError( 1149 "expected pattern benefit to fit within a 16-bit integer"); 1150 consumeToken(Token::integer); 1151 1152 metadata.benefit = benefitValue; 1153 benefitLoc = metadataLoc; 1154 1155 if (failed( 1156 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1157 return failure(); 1158 continue; 1159 } 1160 1161 // Parse the bounded recursion metadata: recursion 1162 if (metadataStr == "recursion") { 1163 if (hasBoundedRecursionLoc) { 1164 return emitErrorAndNote( 1165 metadataLoc, 1166 "pattern recursion metadata has already been specified", 1167 *hasBoundedRecursionLoc, "see previous definition here"); 1168 } 1169 metadata.hasBoundedRecursion = true; 1170 hasBoundedRecursionLoc = metadataLoc; 1171 continue; 1172 } 1173 1174 return emitError(metadataLoc, "unknown pattern metadata"); 1175 } while (consumeIf(Token::comma)); 1176 1177 return success(); 1178 } 1179 1180 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1181 consumeToken(Token::less); 1182 1183 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1184 if (failed(typeExpr) || 1185 failed(parseToken(Token::greater, 1186 "expected `>` after variable type constraint"))) 1187 return failure(); 1188 return typeExpr; 1189 } 1190 1191 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1192 assert(curDeclScope && "defining decl outside of a decl scope"); 1193 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1194 return emitErrorAndNote( 1195 name.getLoc(), "`" + name.getName() + "` has already been defined", 1196 lastDecl->getName()->getLoc(), "see previous definition here"); 1197 } 1198 return success(); 1199 } 1200 1201 FailureOr<ast::VariableDecl *> 1202 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1203 ast::Expr *initExpr, 1204 ArrayRef<ast::ConstraintRef> constraints) { 1205 assert(curDeclScope && "defining variable outside of decl scope"); 1206 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1207 1208 // If the name of the variable indicates a special variable, we don't add it 1209 // to the scope. This variable is local to the definition point. 1210 if (name.empty() || name == "_") { 1211 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1212 constraints); 1213 } 1214 if (failed(checkDefineNamedDecl(nameDecl))) 1215 return failure(); 1216 1217 auto *varDecl = 1218 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1219 curDeclScope->add(varDecl); 1220 return varDecl; 1221 } 1222 1223 FailureOr<ast::VariableDecl *> 1224 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1225 ArrayRef<ast::ConstraintRef> constraints) { 1226 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1227 constraints); 1228 } 1229 1230 LogicalResult Parser::parseVariableDeclConstraintList( 1231 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1232 Optional<SMRange> typeConstraint; 1233 auto parseSingleConstraint = [&] { 1234 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1235 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); 1236 if (failed(constraint)) 1237 return failure(); 1238 constraints.push_back(*constraint); 1239 return success(); 1240 }; 1241 1242 // Check to see if this is a single constraint, or a list. 1243 if (!consumeIf(Token::l_square)) 1244 return parseSingleConstraint(); 1245 1246 do { 1247 if (failed(parseSingleConstraint())) 1248 return failure(); 1249 } while (consumeIf(Token::comma)); 1250 return parseToken(Token::r_square, "expected `]` after constraint list"); 1251 } 1252 1253 FailureOr<ast::ConstraintRef> 1254 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1255 ArrayRef<ast::ConstraintRef> existingConstraints, 1256 bool allowInlineTypeConstraints) { 1257 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1258 if (!allowInlineTypeConstraints) { 1259 return emitError( 1260 curToken.getLoc(), 1261 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1262 "permitted on arguments or results"); 1263 } 1264 if (typeConstraint) 1265 return emitErrorAndNote( 1266 curToken.getLoc(), 1267 "the type of this variable has already been constrained", 1268 *typeConstraint, "see previous constraint location here"); 1269 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1270 if (failed(constraintExpr)) 1271 return failure(); 1272 typeExpr = *constraintExpr; 1273 typeConstraint = typeExpr->getLoc(); 1274 return success(); 1275 }; 1276 1277 SMRange loc = curToken.getLoc(); 1278 switch (curToken.getKind()) { 1279 case Token::kw_Attr: { 1280 consumeToken(Token::kw_Attr); 1281 1282 // Check for a type constraint. 1283 ast::Expr *typeExpr = nullptr; 1284 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1285 return failure(); 1286 return ast::ConstraintRef( 1287 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1288 } 1289 case Token::kw_Op: { 1290 consumeToken(Token::kw_Op); 1291 1292 // Parse an optional operation name. If the name isn't provided, this refers 1293 // to "any" operation. 1294 FailureOr<ast::OpNameDecl *> opName = 1295 parseWrappedOperationName(/*allowEmptyName=*/true); 1296 if (failed(opName)) 1297 return failure(); 1298 1299 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1300 loc); 1301 } 1302 case Token::kw_Type: 1303 consumeToken(Token::kw_Type); 1304 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1305 case Token::kw_TypeRange: 1306 consumeToken(Token::kw_TypeRange); 1307 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1308 loc); 1309 case Token::kw_Value: { 1310 consumeToken(Token::kw_Value); 1311 1312 // Check for a type constraint. 1313 ast::Expr *typeExpr = nullptr; 1314 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1315 return failure(); 1316 1317 return ast::ConstraintRef( 1318 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1319 } 1320 case Token::kw_ValueRange: { 1321 consumeToken(Token::kw_ValueRange); 1322 1323 // Check for a type constraint. 1324 ast::Expr *typeExpr = nullptr; 1325 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1326 return failure(); 1327 1328 return ast::ConstraintRef( 1329 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1330 } 1331 1332 case Token::kw_Constraint: { 1333 // Handle an inline constraint. 1334 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1335 if (failed(decl)) 1336 return failure(); 1337 return ast::ConstraintRef(*decl, loc); 1338 } 1339 case Token::identifier: { 1340 StringRef constraintName = curToken.getSpelling(); 1341 consumeToken(Token::identifier); 1342 1343 // Lookup the referenced constraint. 1344 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1345 if (!cstDecl) { 1346 return emitError(loc, "unknown reference to constraint `" + 1347 constraintName + "`"); 1348 } 1349 1350 // Handle a reference to a proper constraint. 1351 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1352 return ast::ConstraintRef(cst, loc); 1353 1354 return emitErrorAndNote( 1355 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1356 "see the definition of `" + constraintName + "` here"); 1357 } 1358 default: 1359 break; 1360 } 1361 return emitError(loc, "expected identifier constraint"); 1362 } 1363 1364 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1365 Optional<SMRange> typeConstraint; 1366 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1367 /*allowInlineTypeConstraints=*/false); 1368 } 1369 1370 //===----------------------------------------------------------------------===// 1371 // Exprs 1372 1373 FailureOr<ast::Expr *> Parser::parseExpr() { 1374 if (curToken.is(Token::underscore)) 1375 return parseUnderscoreExpr(); 1376 1377 // Parse the LHS expression. 1378 FailureOr<ast::Expr *> lhsExpr; 1379 switch (curToken.getKind()) { 1380 case Token::kw_attr: 1381 lhsExpr = parseAttributeExpr(); 1382 break; 1383 case Token::kw_Constraint: 1384 lhsExpr = parseInlineConstraintLambdaExpr(); 1385 break; 1386 case Token::identifier: 1387 lhsExpr = parseIdentifierExpr(); 1388 break; 1389 case Token::kw_op: 1390 lhsExpr = parseOperationExpr(); 1391 break; 1392 case Token::kw_Rewrite: 1393 lhsExpr = parseInlineRewriteLambdaExpr(); 1394 break; 1395 case Token::kw_type: 1396 lhsExpr = parseTypeExpr(); 1397 break; 1398 case Token::l_paren: 1399 lhsExpr = parseTupleExpr(); 1400 break; 1401 default: 1402 return emitError("expected expression"); 1403 } 1404 if (failed(lhsExpr)) 1405 return failure(); 1406 1407 // Check for an operator expression. 1408 while (true) { 1409 switch (curToken.getKind()) { 1410 case Token::dot: 1411 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1412 break; 1413 case Token::l_paren: 1414 lhsExpr = parseCallExpr(*lhsExpr); 1415 break; 1416 default: 1417 return lhsExpr; 1418 } 1419 if (failed(lhsExpr)) 1420 return failure(); 1421 } 1422 } 1423 1424 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1425 SMRange loc = curToken.getLoc(); 1426 consumeToken(Token::kw_attr); 1427 1428 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1429 // identifier. 1430 if (!consumeIf(Token::less)) { 1431 resetToken(loc); 1432 return parseIdentifierExpr(); 1433 } 1434 1435 if (!curToken.isString()) 1436 return emitError("expected string literal containing MLIR attribute"); 1437 std::string attrExpr = curToken.getStringValue(); 1438 consumeToken(); 1439 1440 if (failed( 1441 parseToken(Token::greater, "expected `>` after attribute literal"))) 1442 return failure(); 1443 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1444 } 1445 1446 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1447 SMRange loc = curToken.getLoc(); 1448 consumeToken(Token::l_paren); 1449 1450 // Parse the arguments of the call. 1451 SmallVector<ast::Expr *> arguments; 1452 if (curToken.isNot(Token::r_paren)) { 1453 do { 1454 FailureOr<ast::Expr *> argument = parseExpr(); 1455 if (failed(argument)) 1456 return failure(); 1457 arguments.push_back(*argument); 1458 } while (consumeIf(Token::comma)); 1459 } 1460 loc.End = curToken.getEndLoc(); 1461 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1462 return failure(); 1463 1464 return createCallExpr(loc, parentExpr, arguments); 1465 } 1466 1467 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1468 ast::Decl *decl = curDeclScope->lookup(name); 1469 if (!decl) 1470 return emitError(loc, "undefined reference to `" + name + "`"); 1471 1472 return createDeclRefExpr(loc, decl); 1473 } 1474 1475 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1476 StringRef name = curToken.getSpelling(); 1477 SMRange nameLoc = curToken.getLoc(); 1478 consumeToken(); 1479 1480 // Check to see if this is a decl ref expression that defines a variable 1481 // inline. 1482 if (consumeIf(Token::colon)) { 1483 SmallVector<ast::ConstraintRef> constraints; 1484 if (failed(parseVariableDeclConstraintList(constraints))) 1485 return failure(); 1486 ast::Type type; 1487 if (failed(validateVariableConstraints(constraints, type))) 1488 return failure(); 1489 return createInlineVariableExpr(type, name, nameLoc, constraints); 1490 } 1491 1492 return parseDeclRefExpr(name, nameLoc); 1493 } 1494 1495 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1496 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1497 if (failed(decl)) 1498 return failure(); 1499 1500 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1501 ast::ConstraintType::get(ctx)); 1502 } 1503 1504 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1505 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1506 if (failed(decl)) 1507 return failure(); 1508 1509 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1510 ast::RewriteType::get(ctx)); 1511 } 1512 1513 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1514 SMRange loc = curToken.getLoc(); 1515 consumeToken(Token::dot); 1516 1517 // Parse the member name. 1518 Token memberNameTok = curToken; 1519 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1520 !memberNameTok.isKeyword()) 1521 return emitError(loc, "expected identifier or numeric member name"); 1522 StringRef memberName = memberNameTok.getSpelling(); 1523 consumeToken(); 1524 1525 return createMemberAccessExpr(parentExpr, memberName, loc); 1526 } 1527 1528 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1529 SMRange loc = curToken.getLoc(); 1530 1531 // Handle the case of an no operation name. 1532 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1533 if (allowEmptyName) 1534 return ast::OpNameDecl::create(ctx, SMRange()); 1535 return emitError("expected dialect namespace"); 1536 } 1537 StringRef name = curToken.getSpelling(); 1538 consumeToken(); 1539 1540 // Otherwise, this is a literal operation name. 1541 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1542 return failure(); 1543 1544 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1545 return emitError("expected operation name after dialect namespace"); 1546 1547 name = StringRef(name.data(), name.size() + 1); 1548 do { 1549 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1550 loc.End = curToken.getEndLoc(); 1551 consumeToken(); 1552 } while (curToken.isAny(Token::identifier, Token::dot) || 1553 curToken.isKeyword()); 1554 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1555 } 1556 1557 FailureOr<ast::OpNameDecl *> 1558 Parser::parseWrappedOperationName(bool allowEmptyName) { 1559 if (!consumeIf(Token::less)) 1560 return ast::OpNameDecl::create(ctx, SMRange()); 1561 1562 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1563 if (failed(opNameDecl)) 1564 return failure(); 1565 1566 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1567 return failure(); 1568 return opNameDecl; 1569 } 1570 1571 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1572 SMRange loc = curToken.getLoc(); 1573 consumeToken(Token::kw_op); 1574 1575 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1576 // identifier. 1577 if (curToken.isNot(Token::less)) { 1578 resetToken(loc); 1579 return parseIdentifierExpr(); 1580 } 1581 1582 // Parse the operation name. The name may be elided, in which case the 1583 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1584 // `Operation*`). Operation names within a rewrite context must be named. 1585 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1586 FailureOr<ast::OpNameDecl *> opNameDecl = 1587 parseWrappedOperationName(allowEmptyName); 1588 if (failed(opNameDecl)) 1589 return failure(); 1590 1591 // Check for the optional list of operands. 1592 SmallVector<ast::Expr *> operands; 1593 if (consumeIf(Token::l_paren)) { 1594 do { 1595 FailureOr<ast::Expr *> operand = parseExpr(); 1596 if (failed(operand)) 1597 return failure(); 1598 operands.push_back(*operand); 1599 } while (consumeIf(Token::comma)); 1600 1601 if (failed(parseToken(Token::r_paren, 1602 "expected `)` after operation operand list"))) 1603 return failure(); 1604 } 1605 1606 // Check for the optional list of attributes. 1607 SmallVector<ast::NamedAttributeDecl *> attributes; 1608 if (consumeIf(Token::l_brace)) { 1609 do { 1610 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1611 if (failed(decl)) 1612 return failure(); 1613 attributes.emplace_back(*decl); 1614 } while (consumeIf(Token::comma)); 1615 1616 if (failed(parseToken(Token::r_brace, 1617 "expected `}` after operation attribute list"))) 1618 return failure(); 1619 } 1620 1621 // Check for the optional list of result types. 1622 SmallVector<ast::Expr *> resultTypes; 1623 if (consumeIf(Token::arrow)) { 1624 if (failed(parseToken(Token::l_paren, 1625 "expected `(` before operation result type list"))) 1626 return failure(); 1627 1628 do { 1629 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1630 if (failed(resultTypeExpr)) 1631 return failure(); 1632 resultTypes.push_back(*resultTypeExpr); 1633 } while (consumeIf(Token::comma)); 1634 1635 if (failed(parseToken(Token::r_paren, 1636 "expected `)` after operation result type list"))) 1637 return failure(); 1638 } 1639 1640 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1641 resultTypes); 1642 } 1643 1644 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1645 SMRange loc = curToken.getLoc(); 1646 consumeToken(Token::l_paren); 1647 1648 DenseMap<StringRef, SMRange> usedNames; 1649 SmallVector<StringRef> elementNames; 1650 SmallVector<ast::Expr *> elements; 1651 if (curToken.isNot(Token::r_paren)) { 1652 do { 1653 // Check for the optional element name assignment before the value. 1654 StringRef elementName; 1655 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1656 Token elementNameTok = curToken; 1657 consumeToken(); 1658 1659 // The element name is only present if followed by an `=`. 1660 if (consumeIf(Token::equal)) { 1661 elementName = elementNameTok.getSpelling(); 1662 1663 // Check to see if this name is already used. 1664 auto elementNameIt = 1665 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1666 if (!elementNameIt.second) { 1667 return emitErrorAndNote( 1668 elementNameTok.getLoc(), 1669 llvm::formatv("duplicate tuple element label `{0}`", 1670 elementName), 1671 elementNameIt.first->getSecond(), 1672 "see previous label use here"); 1673 } 1674 } else { 1675 // Otherwise, we treat this as part of an expression so reset the 1676 // lexer. 1677 resetToken(elementNameTok.getLoc()); 1678 } 1679 } 1680 elementNames.push_back(elementName); 1681 1682 // Parse the tuple element value. 1683 FailureOr<ast::Expr *> element = parseExpr(); 1684 if (failed(element)) 1685 return failure(); 1686 elements.push_back(*element); 1687 } while (consumeIf(Token::comma)); 1688 } 1689 loc.End = curToken.getEndLoc(); 1690 if (failed( 1691 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1692 return failure(); 1693 return createTupleExpr(loc, elements, elementNames); 1694 } 1695 1696 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1697 SMRange loc = curToken.getLoc(); 1698 consumeToken(Token::kw_type); 1699 1700 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1701 // identifier. 1702 if (!consumeIf(Token::less)) { 1703 resetToken(loc); 1704 return parseIdentifierExpr(); 1705 } 1706 1707 if (!curToken.isString()) 1708 return emitError("expected string literal containing MLIR type"); 1709 std::string attrExpr = curToken.getStringValue(); 1710 consumeToken(); 1711 1712 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1713 return failure(); 1714 return ast::TypeExpr::create(ctx, loc, attrExpr); 1715 } 1716 1717 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 1718 StringRef name = curToken.getSpelling(); 1719 SMRange nameLoc = curToken.getLoc(); 1720 consumeToken(Token::underscore); 1721 1722 // Underscore expressions require a constraint list. 1723 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 1724 return failure(); 1725 1726 // Parse the constraints for the expression. 1727 SmallVector<ast::ConstraintRef> constraints; 1728 if (failed(parseVariableDeclConstraintList(constraints))) 1729 return failure(); 1730 1731 ast::Type type; 1732 if (failed(validateVariableConstraints(constraints, type))) 1733 return failure(); 1734 return createInlineVariableExpr(type, name, nameLoc, constraints); 1735 } 1736 1737 //===----------------------------------------------------------------------===// 1738 // Stmts 1739 1740 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 1741 FailureOr<ast::Stmt *> stmt; 1742 switch (curToken.getKind()) { 1743 case Token::kw_erase: 1744 stmt = parseEraseStmt(); 1745 break; 1746 case Token::kw_let: 1747 stmt = parseLetStmt(); 1748 break; 1749 case Token::kw_replace: 1750 stmt = parseReplaceStmt(); 1751 break; 1752 case Token::kw_return: 1753 stmt = parseReturnStmt(); 1754 break; 1755 case Token::kw_rewrite: 1756 stmt = parseRewriteStmt(); 1757 break; 1758 default: 1759 stmt = parseExpr(); 1760 break; 1761 } 1762 if (failed(stmt) || 1763 (expectTerminalSemicolon && 1764 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 1765 return failure(); 1766 return stmt; 1767 } 1768 1769 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 1770 SMLoc startLoc = curToken.getStartLoc(); 1771 consumeToken(Token::l_brace); 1772 1773 // Push a new block scope and parse any nested statements. 1774 pushDeclScope(); 1775 SmallVector<ast::Stmt *> statements; 1776 while (curToken.isNot(Token::r_brace)) { 1777 FailureOr<ast::Stmt *> statement = parseStmt(); 1778 if (failed(statement)) 1779 return popDeclScope(), failure(); 1780 statements.push_back(*statement); 1781 } 1782 popDeclScope(); 1783 1784 // Consume the end brace. 1785 SMRange location(startLoc, curToken.getEndLoc()); 1786 consumeToken(Token::r_brace); 1787 1788 return ast::CompoundStmt::create(ctx, location, statements); 1789 } 1790 1791 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 1792 if (parserContext == ParserContext::Constraint) 1793 return emitError("`erase` cannot be used within a Constraint"); 1794 SMRange loc = curToken.getLoc(); 1795 consumeToken(Token::kw_erase); 1796 1797 // Parse the root operation expression. 1798 FailureOr<ast::Expr *> rootOp = parseExpr(); 1799 if (failed(rootOp)) 1800 return failure(); 1801 1802 return createEraseStmt(loc, *rootOp); 1803 } 1804 1805 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 1806 SMRange loc = curToken.getLoc(); 1807 consumeToken(Token::kw_let); 1808 1809 // Parse the name of the new variable. 1810 SMRange varLoc = curToken.getLoc(); 1811 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 1812 // `_` is a reserved variable name. 1813 if (curToken.is(Token::underscore)) { 1814 return emitError(varLoc, 1815 "`_` may only be used to define \"inline\" variables"); 1816 } 1817 return emitError(varLoc, 1818 "expected identifier after `let` to name a new variable"); 1819 } 1820 StringRef varName = curToken.getSpelling(); 1821 consumeToken(); 1822 1823 // Parse the optional set of constraints. 1824 SmallVector<ast::ConstraintRef> constraints; 1825 if (consumeIf(Token::colon) && 1826 failed(parseVariableDeclConstraintList(constraints))) 1827 return failure(); 1828 1829 // Parse the optional initializer expression. 1830 ast::Expr *initializer = nullptr; 1831 if (consumeIf(Token::equal)) { 1832 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 1833 if (failed(initOrFailure)) 1834 return failure(); 1835 initializer = *initOrFailure; 1836 1837 // Check that the constraints are compatible with having an initializer, 1838 // e.g. type constraints cannot be used with initializers. 1839 for (ast::ConstraintRef constraint : constraints) { 1840 LogicalResult result = 1841 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 1842 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 1843 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 1844 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 1845 return this->emitError( 1846 constraint.referenceLoc, 1847 "type constraints are not permitted on variables with " 1848 "initializers"); 1849 } 1850 return success(); 1851 }) 1852 .Default(success()); 1853 if (failed(result)) 1854 return failure(); 1855 } 1856 } 1857 1858 FailureOr<ast::VariableDecl *> varDecl = 1859 createVariableDecl(varName, varLoc, initializer, constraints); 1860 if (failed(varDecl)) 1861 return failure(); 1862 return ast::LetStmt::create(ctx, loc, *varDecl); 1863 } 1864 1865 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 1866 if (parserContext == ParserContext::Constraint) 1867 return emitError("`replace` cannot be used within a Constraint"); 1868 SMRange loc = curToken.getLoc(); 1869 consumeToken(Token::kw_replace); 1870 1871 // Parse the root operation expression. 1872 FailureOr<ast::Expr *> rootOp = parseExpr(); 1873 if (failed(rootOp)) 1874 return failure(); 1875 1876 if (failed( 1877 parseToken(Token::kw_with, "expected `with` after root operation"))) 1878 return failure(); 1879 1880 // The replacement portion of this statement is within a rewrite context. 1881 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1882 ParserContext::Rewrite); 1883 1884 // Parse the replacement values. 1885 SmallVector<ast::Expr *> replValues; 1886 if (consumeIf(Token::l_paren)) { 1887 if (consumeIf(Token::r_paren)) { 1888 return emitError( 1889 loc, "expected at least one replacement value, consider using " 1890 "`erase` if no replacement values are desired"); 1891 } 1892 1893 do { 1894 FailureOr<ast::Expr *> replExpr = parseExpr(); 1895 if (failed(replExpr)) 1896 return failure(); 1897 replValues.emplace_back(*replExpr); 1898 } while (consumeIf(Token::comma)); 1899 1900 if (failed(parseToken(Token::r_paren, 1901 "expected `)` after replacement values"))) 1902 return failure(); 1903 } else { 1904 FailureOr<ast::Expr *> replExpr = parseExpr(); 1905 if (failed(replExpr)) 1906 return failure(); 1907 replValues.emplace_back(*replExpr); 1908 } 1909 1910 return createReplaceStmt(loc, *rootOp, replValues); 1911 } 1912 1913 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 1914 SMRange loc = curToken.getLoc(); 1915 consumeToken(Token::kw_return); 1916 1917 // Parse the result value. 1918 FailureOr<ast::Expr *> resultExpr = parseExpr(); 1919 if (failed(resultExpr)) 1920 return failure(); 1921 1922 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 1923 } 1924 1925 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 1926 if (parserContext == ParserContext::Constraint) 1927 return emitError("`rewrite` cannot be used within a Constraint"); 1928 SMRange loc = curToken.getLoc(); 1929 consumeToken(Token::kw_rewrite); 1930 1931 // Parse the root operation. 1932 FailureOr<ast::Expr *> rootOp = parseExpr(); 1933 if (failed(rootOp)) 1934 return failure(); 1935 1936 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 1937 return failure(); 1938 1939 if (curToken.isNot(Token::l_brace)) 1940 return emitError("expected `{` to start rewrite body"); 1941 1942 // The rewrite body of this statement is within a rewrite context. 1943 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1944 ParserContext::Rewrite); 1945 1946 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 1947 if (failed(rewriteBody)) 1948 return failure(); 1949 1950 // Verify the rewrite body. 1951 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 1952 if (isa<ast::ReturnStmt>(stmt)) { 1953 return emitError(stmt->getLoc(), 1954 "`return` statements are only permitted within a " 1955 "`Constraint` or `Rewrite` body"); 1956 } 1957 } 1958 1959 return createRewriteStmt(loc, *rootOp, *rewriteBody); 1960 } 1961 1962 //===----------------------------------------------------------------------===// 1963 // Creation+Analysis 1964 //===----------------------------------------------------------------------===// 1965 1966 //===----------------------------------------------------------------------===// 1967 // Decls 1968 1969 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 1970 // Unwrap reference expressions. 1971 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 1972 node = init->getDecl(); 1973 return dyn_cast<ast::CallableDecl>(node); 1974 } 1975 1976 FailureOr<ast::PatternDecl *> 1977 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 1978 const ParsedPatternMetadata &metadata, 1979 ast::CompoundStmt *body) { 1980 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 1981 metadata.hasBoundedRecursion, body); 1982 } 1983 1984 ast::Type Parser::createUserConstraintRewriteResultType( 1985 ArrayRef<ast::VariableDecl *> results) { 1986 // Single result decls use the type of the single result. 1987 if (results.size() == 1) 1988 return results[0]->getType(); 1989 1990 // Multiple results use a tuple type, with the types and names grabbed from 1991 // the result variable decls. 1992 auto resultTypes = llvm::map_range( 1993 results, [&](const auto *result) { return result->getType(); }); 1994 auto resultNames = llvm::map_range( 1995 results, [&](const auto *result) { return result->getName().getName(); }); 1996 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 1997 llvm::to_vector(resultNames)); 1998 } 1999 2000 template <typename T> 2001 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2002 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2003 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2004 ast::CompoundStmt *body) { 2005 if (!body->getChildren().empty()) { 2006 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2007 ast::Expr *resultExpr = retStmt->getResultExpr(); 2008 2009 // Process the result of the decl. If no explicit signature results 2010 // were provided, check for return type inference. Otherwise, check that 2011 // the return expression can be converted to the expected type. 2012 if (results.empty()) 2013 resultType = resultExpr->getType(); 2014 else if (failed(convertExpressionTo(resultExpr, resultType))) 2015 return failure(); 2016 else 2017 retStmt->setResultExpr(resultExpr); 2018 } 2019 } 2020 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2021 } 2022 2023 FailureOr<ast::VariableDecl *> 2024 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2025 ArrayRef<ast::ConstraintRef> constraints) { 2026 // The type of the variable, which is expected to be inferred by either a 2027 // constraint or an initializer expression. 2028 ast::Type type; 2029 if (failed(validateVariableConstraints(constraints, type))) 2030 return failure(); 2031 2032 if (initializer) { 2033 // Update the variable type based on the initializer, or try to convert the 2034 // initializer to the existing type. 2035 if (!type) 2036 type = initializer->getType(); 2037 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2038 type = mergedType; 2039 else if (failed(convertExpressionTo(initializer, type))) 2040 return failure(); 2041 2042 // Otherwise, if there is no initializer check that the type has already 2043 // been resolved from the constraint list. 2044 } else if (!type) { 2045 return emitErrorAndNote( 2046 loc, "unable to infer type for variable `" + name + "`", loc, 2047 "the type of a variable must be inferable from the constraint " 2048 "list or the initializer"); 2049 } 2050 2051 // Constraint types cannot be used when defining variables. 2052 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2053 return emitError( 2054 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2055 } 2056 2057 // Try to define a variable with the given name. 2058 FailureOr<ast::VariableDecl *> varDecl = 2059 defineVariableDecl(name, loc, type, initializer, constraints); 2060 if (failed(varDecl)) 2061 return failure(); 2062 2063 return *varDecl; 2064 } 2065 2066 FailureOr<ast::VariableDecl *> 2067 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2068 const ast::ConstraintRef &constraint) { 2069 // Constraint arguments may apply more complex constraints via the arguments. 2070 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2071 ast::Type argType; 2072 if (failed(validateVariableConstraint(constraint, argType, 2073 allowNonCoreConstraints))) 2074 return failure(); 2075 return defineVariableDecl(name, loc, argType, constraint); 2076 } 2077 2078 LogicalResult 2079 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2080 ast::Type &inferredType) { 2081 for (const ast::ConstraintRef &ref : constraints) 2082 if (failed(validateVariableConstraint(ref, inferredType))) 2083 return failure(); 2084 return success(); 2085 } 2086 2087 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2088 ast::Type &inferredType, 2089 bool allowNonCoreConstraints) { 2090 ast::Type constraintType; 2091 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2092 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2093 if (failed(validateTypeConstraintExpr(typeExpr))) 2094 return failure(); 2095 } 2096 constraintType = ast::AttributeType::get(ctx); 2097 } else if (const auto *cst = 2098 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2099 constraintType = ast::OperationType::get(ctx, cst->getName()); 2100 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2101 constraintType = typeTy; 2102 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2103 constraintType = typeRangeTy; 2104 } else if (const auto *cst = 2105 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2106 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2107 if (failed(validateTypeConstraintExpr(typeExpr))) 2108 return failure(); 2109 } 2110 constraintType = valueTy; 2111 } else if (const auto *cst = 2112 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2113 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2114 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2115 return failure(); 2116 } 2117 constraintType = valueRangeTy; 2118 } else if (const auto *cst = 2119 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2120 if (!allowNonCoreConstraints) { 2121 return emitError(ref.referenceLoc, 2122 "`Rewrite` arguments and results are only permitted to " 2123 "use core constraints, such as `Attr`, `Op`, `Type`, " 2124 "`TypeRange`, `Value`, `ValueRange`"); 2125 } 2126 2127 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2128 if (inputs.size() != 1) { 2129 return emitErrorAndNote(ref.referenceLoc, 2130 "`Constraint`s applied via a variable constraint " 2131 "list must take a single input, but got " + 2132 Twine(inputs.size()), 2133 cst->getLoc(), 2134 "see definition of constraint here"); 2135 } 2136 constraintType = inputs.front()->getType(); 2137 } else { 2138 llvm_unreachable("unknown constraint type"); 2139 } 2140 2141 // Check that the constraint type is compatible with the current inferred 2142 // type. 2143 if (!inferredType) { 2144 inferredType = constraintType; 2145 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2146 inferredType = mergedTy; 2147 } else { 2148 return emitError(ref.referenceLoc, 2149 llvm::formatv("constraint type `{0}` is incompatible " 2150 "with the previously inferred type `{1}`", 2151 constraintType, inferredType)); 2152 } 2153 return success(); 2154 } 2155 2156 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2157 ast::Type typeExprType = typeExpr->getType(); 2158 if (typeExprType != typeTy) { 2159 return emitError(typeExpr->getLoc(), 2160 "expected expression of `Type` in type constraint"); 2161 } 2162 return success(); 2163 } 2164 2165 LogicalResult 2166 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2167 ast::Type typeExprType = typeExpr->getType(); 2168 if (typeExprType != typeRangeTy) { 2169 return emitError(typeExpr->getLoc(), 2170 "expected expression of `TypeRange` in type constraint"); 2171 } 2172 return success(); 2173 } 2174 2175 //===----------------------------------------------------------------------===// 2176 // Exprs 2177 2178 FailureOr<ast::CallExpr *> 2179 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2180 MutableArrayRef<ast::Expr *> arguments) { 2181 ast::Type parentType = parentExpr->getType(); 2182 2183 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2184 if (!callableDecl) { 2185 return emitError(loc, 2186 llvm::formatv("expected a reference to a callable " 2187 "`Constraint` or `Rewrite`, but got: `{0}`", 2188 parentType)); 2189 } 2190 if (parserContext == ParserContext::Rewrite) { 2191 if (isa<ast::UserConstraintDecl>(callableDecl)) 2192 return emitError( 2193 loc, "unable to invoke `Constraint` within a rewrite section"); 2194 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2195 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2196 } 2197 2198 // Verify the arguments of the call. 2199 /// Handle size mismatch. 2200 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2201 if (callArgs.size() != arguments.size()) { 2202 return emitErrorAndNote( 2203 loc, 2204 llvm::formatv("invalid number of arguments for {0} call; expected " 2205 "{1}, but got {2}", 2206 callableDecl->getCallableType(), callArgs.size(), 2207 arguments.size()), 2208 callableDecl->getLoc(), 2209 llvm::formatv("see the definition of {0} here", 2210 callableDecl->getName()->getName())); 2211 } 2212 2213 /// Handle argument type mismatch. 2214 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2215 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2216 callableDecl->getName()->getName()), 2217 callableDecl->getLoc()); 2218 }; 2219 for (auto it : llvm::zip(callArgs, arguments)) { 2220 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2221 attachDiagFn))) 2222 return failure(); 2223 } 2224 2225 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2226 callableDecl->getResultType()); 2227 } 2228 2229 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2230 ast::Decl *decl) { 2231 // Check the type of decl being referenced. 2232 ast::Type declType; 2233 if (isa<ast::ConstraintDecl>(decl)) 2234 declType = ast::ConstraintType::get(ctx); 2235 else if (isa<ast::UserRewriteDecl>(decl)) 2236 declType = ast::RewriteType::get(ctx); 2237 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2238 declType = varDecl->getType(); 2239 else 2240 return emitError(loc, "invalid reference to `" + 2241 decl->getName()->getName() + "`"); 2242 2243 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2244 } 2245 2246 FailureOr<ast::DeclRefExpr *> 2247 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2248 ArrayRef<ast::ConstraintRef> constraints) { 2249 FailureOr<ast::VariableDecl *> decl = 2250 defineVariableDecl(name, loc, type, constraints); 2251 if (failed(decl)) 2252 return failure(); 2253 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2254 } 2255 2256 FailureOr<ast::MemberAccessExpr *> 2257 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2258 SMRange loc) { 2259 // Validate the member name for the given parent expression. 2260 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2261 if (failed(memberType)) 2262 return failure(); 2263 2264 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2265 } 2266 2267 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2268 StringRef name, SMRange loc) { 2269 ast::Type parentType = parentExpr->getType(); 2270 if (parentType.isa<ast::OperationType>()) { 2271 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2272 return valueRangeTy; 2273 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2274 // Handle indexed results. 2275 unsigned index = 0; 2276 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2277 index < tupleType.size()) { 2278 return tupleType.getElementTypes()[index]; 2279 } 2280 2281 // Handle named results. 2282 auto elementNames = tupleType.getElementNames(); 2283 const auto *it = llvm::find(elementNames, name); 2284 if (it != elementNames.end()) 2285 return tupleType.getElementTypes()[it - elementNames.begin()]; 2286 } 2287 return emitError( 2288 loc, 2289 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2290 name, parentType)); 2291 } 2292 2293 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2294 SMRange loc, const ast::OpNameDecl *name, 2295 MutableArrayRef<ast::Expr *> operands, 2296 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2297 MutableArrayRef<ast::Expr *> results) { 2298 Optional<StringRef> opNameRef = name->getName(); 2299 2300 // Verify the inputs operands. 2301 if (failed(validateOperationOperands(loc, opNameRef, operands))) 2302 return failure(); 2303 2304 // Verify the attribute list. 2305 for (ast::NamedAttributeDecl *attr : attributes) { 2306 // Check for an attribute type, or a type awaiting resolution. 2307 ast::Type attrType = attr->getValue()->getType(); 2308 if (!attrType.isa<ast::AttributeType>()) { 2309 return emitError( 2310 attr->getValue()->getLoc(), 2311 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2312 } 2313 } 2314 2315 // Verify the result types. 2316 if (failed(validateOperationResults(loc, opNameRef, results))) 2317 return failure(); 2318 2319 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2320 attributes); 2321 } 2322 2323 LogicalResult 2324 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2325 MutableArrayRef<ast::Expr *> operands) { 2326 return validateOperationOperandsOrResults(loc, name, operands, valueTy, 2327 valueRangeTy); 2328 } 2329 2330 LogicalResult 2331 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2332 MutableArrayRef<ast::Expr *> results) { 2333 return validateOperationOperandsOrResults(loc, name, results, typeTy, 2334 typeRangeTy); 2335 } 2336 2337 LogicalResult Parser::validateOperationOperandsOrResults( 2338 SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2339 ast::Type singleTy, ast::Type rangeTy) { 2340 // All operation types accept a single range parameter. 2341 if (values.size() == 1) { 2342 if (failed(convertExpressionTo(values[0], rangeTy))) 2343 return failure(); 2344 return success(); 2345 } 2346 2347 // Otherwise, accept the value groups as they have been defined and just 2348 // ensure they are one of the expected types. 2349 for (ast::Expr *&valueExpr : values) { 2350 ast::Type valueExprType = valueExpr->getType(); 2351 2352 // Check if this is one of the expected types. 2353 if (valueExprType == rangeTy || valueExprType == singleTy) 2354 continue; 2355 2356 // If the operand is an Operation, allow converting to a Value or 2357 // ValueRange. This situations arises quite often with nested operation 2358 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2359 if (singleTy == valueTy) { 2360 if (valueExprType.isa<ast::OperationType>()) { 2361 valueExpr = convertOpToValue(valueExpr); 2362 continue; 2363 } 2364 } 2365 2366 return emitError( 2367 valueExpr->getLoc(), 2368 llvm::formatv( 2369 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2370 singleTy, rangeTy, valueExprType)); 2371 } 2372 return success(); 2373 } 2374 2375 FailureOr<ast::TupleExpr *> 2376 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2377 ArrayRef<StringRef> elementNames) { 2378 for (const ast::Expr *element : elements) { 2379 ast::Type eleTy = element->getType(); 2380 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2381 return emitError( 2382 element->getLoc(), 2383 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2384 } 2385 } 2386 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2387 } 2388 2389 //===----------------------------------------------------------------------===// 2390 // Stmts 2391 2392 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2393 ast::Expr *rootOp) { 2394 // Check that root is an Operation. 2395 ast::Type rootType = rootOp->getType(); 2396 if (!rootType.isa<ast::OperationType>()) 2397 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2398 2399 return ast::EraseStmt::create(ctx, loc, rootOp); 2400 } 2401 2402 FailureOr<ast::ReplaceStmt *> 2403 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2404 MutableArrayRef<ast::Expr *> replValues) { 2405 // Check that root is an Operation. 2406 ast::Type rootType = rootOp->getType(); 2407 if (!rootType.isa<ast::OperationType>()) { 2408 return emitError( 2409 rootOp->getLoc(), 2410 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2411 } 2412 2413 // If there are multiple replacement values, we implicitly convert any Op 2414 // expressions to the value form. 2415 bool shouldConvertOpToValues = replValues.size() > 1; 2416 for (ast::Expr *&replExpr : replValues) { 2417 ast::Type replType = replExpr->getType(); 2418 2419 // Check that replExpr is an Operation, Value, or ValueRange. 2420 if (replType.isa<ast::OperationType>()) { 2421 if (shouldConvertOpToValues) 2422 replExpr = convertOpToValue(replExpr); 2423 continue; 2424 } 2425 2426 if (replType != valueTy && replType != valueRangeTy) { 2427 return emitError(replExpr->getLoc(), 2428 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2429 "expression, but got `{0}`", 2430 replType)); 2431 } 2432 } 2433 2434 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2435 } 2436 2437 FailureOr<ast::RewriteStmt *> 2438 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2439 ast::CompoundStmt *rewriteBody) { 2440 // Check that root is an Operation. 2441 ast::Type rootType = rootOp->getType(); 2442 if (!rootType.isa<ast::OperationType>()) { 2443 return emitError( 2444 rootOp->getLoc(), 2445 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2446 } 2447 2448 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2449 } 2450 2451 //===----------------------------------------------------------------------===// 2452 // Parser 2453 //===----------------------------------------------------------------------===// 2454 2455 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 2456 llvm::SourceMgr &sourceMgr) { 2457 Parser parser(ctx, sourceMgr); 2458 return parser.parseModule(); 2459 } 2460