1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/Tools/PDLL/AST/Context.h" 13 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 14 #include "mlir/Tools/PDLL/AST/Nodes.h" 15 #include "mlir/Tools/PDLL/AST/Types.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/SaveAndRestore.h" 20 #include "llvm/Support/ScopedPrinter.h" 21 #include <string> 22 23 using namespace mlir; 24 using namespace mlir::pdll; 25 26 //===----------------------------------------------------------------------===// 27 // Parser 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 class Parser { 32 public: 33 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 34 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 35 curToken(lexer.lexToken()), curDeclScope(nullptr), 36 valueTy(ast::ValueType::get(ctx)), 37 valueRangeTy(ast::ValueRangeType::get(ctx)), 38 typeTy(ast::TypeType::get(ctx)), 39 typeRangeTy(ast::TypeRangeType::get(ctx)) {} 40 41 /// Try to parse a new module. Returns nullptr in the case of failure. 42 FailureOr<ast::Module *> parseModule(); 43 44 private: 45 /// The current context of the parser. It allows for the parser to know a bit 46 /// about the construct it is nested within during parsing. This is used 47 /// specifically to provide additional verification during parsing, e.g. to 48 /// prevent using rewrites within a match context, matcher constraints within 49 /// a rewrite section, etc. 50 enum class ParserContext { 51 /// The parser is in the global context. 52 Global, 53 /// The parser is currently within the matcher portion of a Pattern, which 54 /// is allows a terminal operation rewrite statement but no other rewrite 55 /// transformations. 56 PatternMatch, 57 /// The parser is currently within a Rewrite, which disallows calls to 58 /// constraints, requires operation expressions to have names, etc. 59 Rewrite, 60 }; 61 62 //===--------------------------------------------------------------------===// 63 // Parsing 64 //===--------------------------------------------------------------------===// 65 66 /// Push a new decl scope onto the lexer. 67 ast::DeclScope *pushDeclScope() { 68 ast::DeclScope *newScope = 69 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 70 return (curDeclScope = newScope); 71 } 72 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 73 74 /// Pop the last decl scope from the lexer. 75 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 76 77 /// Parse the body of an AST module. 78 LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls); 79 80 /// Try to convert the given expression to `type`. Returns failure and emits 81 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 82 /// invoked to attach notes to the emitted error diagnostic. On success, 83 /// `expr` is updated to the expression used to convert to `type`. 84 LogicalResult convertExpressionTo( 85 ast::Expr *&expr, ast::Type type, 86 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 87 88 /// Given an operation expression, convert it to a Value or ValueRange 89 /// typed expression. 90 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 91 92 //===--------------------------------------------------------------------===// 93 // Directives 94 95 LogicalResult parseDirective(SmallVector<ast::Decl *> &decls); 96 LogicalResult parseInclude(SmallVector<ast::Decl *> &decls); 97 98 //===--------------------------------------------------------------------===// 99 // Decls 100 101 /// This structure contains the set of pattern metadata that may be parsed. 102 struct ParsedPatternMetadata { 103 Optional<uint16_t> benefit; 104 bool hasBoundedRecursion = false; 105 }; 106 107 FailureOr<ast::Decl *> parseTopLevelDecl(); 108 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 109 FailureOr<ast::Decl *> parsePatternDecl(); 110 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 111 112 /// Check to see if a decl has already been defined with the given name, if 113 /// one has emit and error and return failure. Returns success otherwise. 114 LogicalResult checkDefineNamedDecl(const ast::Name &name); 115 116 /// Try to define a variable decl with the given components, returns the 117 /// variable on success. 118 FailureOr<ast::VariableDecl *> 119 defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, 120 ast::Expr *initExpr, 121 ArrayRef<ast::ConstraintRef> constraints); 122 FailureOr<ast::VariableDecl *> 123 defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, 124 ArrayRef<ast::ConstraintRef> constraints); 125 126 /// Parse the constraint reference list for a variable decl. 127 LogicalResult parseVariableDeclConstraintList( 128 SmallVectorImpl<ast::ConstraintRef> &constraints); 129 130 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 131 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 132 133 /// Try to parse a single reference to a constraint. `typeConstraint` is the 134 /// location of a previously parsed type constraint for the entity that will 135 /// be constrained by the parsed constraint. `existingConstraints` are any 136 /// existing constraints that have already been parsed for the same entity 137 /// that will be constrained by this constraint. 138 FailureOr<ast::ConstraintRef> 139 parseConstraint(Optional<llvm::SMRange> &typeConstraint, 140 ArrayRef<ast::ConstraintRef> existingConstraints); 141 142 //===--------------------------------------------------------------------===// 143 // Exprs 144 145 FailureOr<ast::Expr *> parseExpr(); 146 147 /// Identifier expressions. 148 FailureOr<ast::Expr *> parseAttributeExpr(); 149 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, llvm::SMRange loc); 150 FailureOr<ast::Expr *> parseIdentifierExpr(); 151 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 152 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 153 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 154 FailureOr<ast::Expr *> parseOperationExpr(); 155 FailureOr<ast::Expr *> parseTupleExpr(); 156 FailureOr<ast::Expr *> parseTypeExpr(); 157 FailureOr<ast::Expr *> parseUnderscoreExpr(); 158 159 //===--------------------------------------------------------------------===// 160 // Stmts 161 162 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 163 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 164 FailureOr<ast::EraseStmt *> parseEraseStmt(); 165 FailureOr<ast::LetStmt *> parseLetStmt(); 166 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 167 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 168 169 //===--------------------------------------------------------------------===// 170 // Creation+Analysis 171 //===--------------------------------------------------------------------===// 172 173 //===--------------------------------------------------------------------===// 174 // Decls 175 176 /// Try to create a pattern decl with the given components, returning the 177 /// Pattern on success. 178 FailureOr<ast::PatternDecl *> 179 createPatternDecl(llvm::SMRange loc, const ast::Name *name, 180 const ParsedPatternMetadata &metadata, 181 ast::CompoundStmt *body); 182 183 /// Try to create a variable decl with the given components, returning the 184 /// Variable on success. 185 FailureOr<ast::VariableDecl *> 186 createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer, 187 ArrayRef<ast::ConstraintRef> constraints); 188 189 /// Validate the constraints used to constraint a variable decl. 190 /// `inferredType` is the type of the variable inferred by the constraints 191 /// within the list, and is updated to the most refined type as determined by 192 /// the constraints. Returns success if the constraint list is valid, failure 193 /// otherwise. 194 LogicalResult 195 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 196 ast::Type &inferredType); 197 /// Validate a single reference to a constraint. `inferredType` contains the 198 /// currently inferred variabled type and is refined within the type defined 199 /// by the constraint. Returns success if the constraint is valid, failure 200 /// otherwise. 201 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 202 ast::Type &inferredType); 203 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 204 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 205 206 //===--------------------------------------------------------------------===// 207 // Exprs 208 209 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(llvm::SMRange loc, 210 ast::Decl *decl); 211 FailureOr<ast::DeclRefExpr *> 212 createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc, 213 ArrayRef<ast::ConstraintRef> constraints); 214 FailureOr<ast::MemberAccessExpr *> 215 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 216 llvm::SMRange loc); 217 218 /// Validate the member access `name` into the given parent expression. On 219 /// success, this also returns the type of the member accessed. 220 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 221 StringRef name, llvm::SMRange loc); 222 FailureOr<ast::OperationExpr *> 223 createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name, 224 MutableArrayRef<ast::Expr *> operands, 225 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 226 MutableArrayRef<ast::Expr *> results); 227 LogicalResult 228 validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name, 229 MutableArrayRef<ast::Expr *> operands); 230 LogicalResult validateOperationResults(llvm::SMRange loc, 231 Optional<StringRef> name, 232 MutableArrayRef<ast::Expr *> results); 233 LogicalResult 234 validateOperationOperandsOrResults(llvm::SMRange loc, 235 Optional<StringRef> name, 236 MutableArrayRef<ast::Expr *> values, 237 ast::Type singleTy, ast::Type rangeTy); 238 FailureOr<ast::TupleExpr *> createTupleExpr(llvm::SMRange loc, 239 ArrayRef<ast::Expr *> elements, 240 ArrayRef<StringRef> elementNames); 241 242 //===--------------------------------------------------------------------===// 243 // Stmts 244 245 FailureOr<ast::EraseStmt *> createEraseStmt(llvm::SMRange loc, 246 ast::Expr *rootOp); 247 FailureOr<ast::ReplaceStmt *> 248 createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, 249 MutableArrayRef<ast::Expr *> replValues); 250 FailureOr<ast::RewriteStmt *> 251 createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp, 252 ast::CompoundStmt *rewriteBody); 253 254 //===--------------------------------------------------------------------===// 255 // Lexer Utilities 256 //===--------------------------------------------------------------------===// 257 258 /// If the current token has the specified kind, consume it and return true. 259 /// If not, return false. 260 bool consumeIf(Token::Kind kind) { 261 if (curToken.isNot(kind)) 262 return false; 263 consumeToken(kind); 264 return true; 265 } 266 267 /// Advance the current lexer onto the next token. 268 void consumeToken() { 269 assert(curToken.isNot(Token::eof, Token::error) && 270 "shouldn't advance past EOF or errors"); 271 curToken = lexer.lexToken(); 272 } 273 274 /// Advance the current lexer onto the next token, asserting what the expected 275 /// current token is. This is preferred to the above method because it leads 276 /// to more self-documenting code with better checking. 277 void consumeToken(Token::Kind kind) { 278 assert(curToken.is(kind) && "consumed an unexpected token"); 279 consumeToken(); 280 } 281 282 /// Reset the lexer to the location at the given position. 283 void resetToken(llvm::SMRange tokLoc) { 284 lexer.resetPointer(tokLoc.Start.getPointer()); 285 curToken = lexer.lexToken(); 286 } 287 288 /// Consume the specified token if present and return success. On failure, 289 /// output a diagnostic and return failure. 290 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 291 if (curToken.getKind() != kind) 292 return emitError(curToken.getLoc(), msg); 293 consumeToken(); 294 return success(); 295 } 296 LogicalResult emitError(llvm::SMRange loc, const Twine &msg) { 297 lexer.emitError(loc, msg); 298 return failure(); 299 } 300 LogicalResult emitError(const Twine &msg) { 301 return emitError(curToken.getLoc(), msg); 302 } 303 LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg, 304 llvm::SMRange noteLoc, const Twine ¬e) { 305 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 306 return failure(); 307 } 308 309 //===--------------------------------------------------------------------===// 310 // Fields 311 //===--------------------------------------------------------------------===// 312 313 /// The owning AST context. 314 ast::Context &ctx; 315 316 /// The lexer of this parser. 317 Lexer lexer; 318 319 /// The current token within the lexer. 320 Token curToken; 321 322 /// The most recently defined decl scope. 323 ast::DeclScope *curDeclScope; 324 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 325 326 /// The current context of the parser. 327 ParserContext parserContext = ParserContext::Global; 328 329 /// Cached types to simplify verification and expression creation. 330 ast::Type valueTy, valueRangeTy; 331 ast::Type typeTy, typeRangeTy; 332 }; 333 } // namespace 334 335 FailureOr<ast::Module *> Parser::parseModule() { 336 llvm::SMLoc moduleLoc = curToken.getStartLoc(); 337 pushDeclScope(); 338 339 // Parse the top-level decls of the module. 340 SmallVector<ast::Decl *> decls; 341 if (failed(parseModuleBody(decls))) 342 return popDeclScope(), failure(); 343 344 popDeclScope(); 345 return ast::Module::create(ctx, moduleLoc, decls); 346 } 347 348 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) { 349 while (curToken.isNot(Token::eof)) { 350 if (curToken.is(Token::directive)) { 351 if (failed(parseDirective(decls))) 352 return failure(); 353 continue; 354 } 355 356 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 357 if (failed(decl)) 358 return failure(); 359 decls.push_back(*decl); 360 } 361 return success(); 362 } 363 364 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 365 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 366 valueRangeTy); 367 } 368 369 LogicalResult Parser::convertExpressionTo( 370 ast::Expr *&expr, ast::Type type, 371 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 372 ast::Type exprType = expr->getType(); 373 if (exprType == type) 374 return success(); 375 376 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 377 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 378 expr->getLoc(), llvm::formatv("unable to convert expression of type " 379 "`{0}` to the expected type of " 380 "`{1}`", 381 exprType, type)); 382 if (noteAttachFn) 383 noteAttachFn(*diag); 384 return diag; 385 }; 386 387 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 388 // Two operation types are compatible if they have the same name, or if the 389 // expected type is more general. 390 if (auto opType = type.dyn_cast<ast::OperationType>()) { 391 if (opType.getName()) 392 return emitConvertError(); 393 return success(); 394 } 395 396 // An operation can always convert to a ValueRange. 397 if (type == valueRangeTy) { 398 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 399 valueRangeTy); 400 return success(); 401 } 402 403 // Allow conversion to a single value by constraining the result range. 404 if (type == valueTy) { 405 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 406 valueTy); 407 return success(); 408 } 409 return emitConvertError(); 410 } 411 412 // FIXME: Decide how to allow/support converting a single result to multiple, 413 // and multiple to a single result. For now, we just allow Single->Range, 414 // but this isn't something really supported in the PDL dialect. We should 415 // figure out some way to support both. 416 if ((exprType == valueTy || exprType == valueRangeTy) && 417 (type == valueTy || type == valueRangeTy)) 418 return success(); 419 if ((exprType == typeTy || exprType == typeRangeTy) && 420 (type == typeTy || type == typeRangeTy)) 421 return success(); 422 423 // Handle tuple types. 424 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 425 auto tupleType = type.dyn_cast<ast::TupleType>(); 426 if (!tupleType || tupleType.size() != exprTupleType.size()) 427 return emitConvertError(); 428 429 // Build a new tuple expression using each of the elements of the current 430 // tuple. 431 SmallVector<ast::Expr *> newExprs; 432 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 433 newExprs.push_back(ast::MemberAccessExpr::create( 434 ctx, expr->getLoc(), expr, llvm::to_string(i), 435 exprTupleType.getElementTypes()[i])); 436 437 auto diagFn = [&](ast::Diagnostic &diag) { 438 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 439 i, exprTupleType)); 440 if (noteAttachFn) 441 noteAttachFn(diag); 442 }; 443 if (failed(convertExpressionTo(newExprs.back(), 444 tupleType.getElementTypes()[i], diagFn))) 445 return failure(); 446 } 447 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 448 tupleType.getElementNames()); 449 return success(); 450 } 451 452 return emitConvertError(); 453 } 454 455 //===----------------------------------------------------------------------===// 456 // Directives 457 458 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) { 459 StringRef directive = curToken.getSpelling(); 460 if (directive == "#include") 461 return parseInclude(decls); 462 463 return emitError("unknown directive `" + directive + "`"); 464 } 465 466 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) { 467 llvm::SMRange loc = curToken.getLoc(); 468 consumeToken(Token::directive); 469 470 // Parse the file being included. 471 if (!curToken.isString()) 472 return emitError(loc, 473 "expected string file name after `include` directive"); 474 llvm::SMRange fileLoc = curToken.getLoc(); 475 std::string filenameStr = curToken.getStringValue(); 476 StringRef filename = filenameStr; 477 consumeToken(); 478 479 // Check the type of include. If ending with `.pdll`, this is another pdl file 480 // to be parsed along with the current module. 481 if (filename.endswith(".pdll")) { 482 if (failed(lexer.pushInclude(filename))) 483 return emitError(fileLoc, 484 "unable to open include file `" + filename + "`"); 485 486 // If we added the include successfully, parse it into the current module. 487 // Make sure to save the current token so that we can restore it when we 488 // finish parsing the nested file. 489 Token oldToken = curToken; 490 curToken = lexer.lexToken(); 491 LogicalResult result = parseModuleBody(decls); 492 curToken = oldToken; 493 return result; 494 } 495 496 return emitError(fileLoc, "expected include filename to end with `.pdll`"); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // Decls 501 502 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 503 FailureOr<ast::Decl *> decl; 504 switch (curToken.getKind()) { 505 case Token::kw_Pattern: 506 decl = parsePatternDecl(); 507 break; 508 default: 509 return emitError("expected top-level declaration, such as a `Pattern`"); 510 } 511 if (failed(decl)) 512 return failure(); 513 514 // If the decl has a name, add it to the current scope. 515 if (const ast::Name *name = (*decl)->getName()) { 516 if (failed(checkDefineNamedDecl(*name))) 517 return failure(); 518 curDeclScope->add(*decl); 519 } 520 return decl; 521 } 522 523 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 524 std::string attrNameStr; 525 if (curToken.isString()) 526 attrNameStr = curToken.getStringValue(); 527 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 528 attrNameStr = curToken.getSpelling().str(); 529 else 530 return emitError("expected identifier or string attribute name"); 531 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 532 consumeToken(); 533 534 // Check for a value of the attribute. 535 ast::Expr *attrValue = nullptr; 536 if (consumeIf(Token::equal)) { 537 FailureOr<ast::Expr *> attrExpr = parseExpr(); 538 if (failed(attrExpr)) 539 return failure(); 540 attrValue = *attrExpr; 541 } else { 542 // If there isn't a concrete value, create an expression representing a 543 // UnitAttr. 544 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 545 } 546 547 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 548 } 549 550 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 551 llvm::SMRange loc = curToken.getLoc(); 552 consumeToken(Token::kw_Pattern); 553 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 554 ParserContext::PatternMatch); 555 556 // Check for an optional identifier for the pattern name. 557 const ast::Name *name = nullptr; 558 if (curToken.is(Token::identifier)) { 559 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 560 consumeToken(Token::identifier); 561 } 562 563 // Parse any pattern metadata. 564 ParsedPatternMetadata metadata; 565 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 566 return failure(); 567 568 // Parse the pattern body. 569 ast::CompoundStmt *body; 570 571 if (curToken.isNot(Token::l_brace)) 572 return emitError("expected `{` to start pattern body"); 573 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 574 if (failed(bodyResult)) 575 return failure(); 576 body = *bodyResult; 577 578 // Verify the body of the pattern. 579 auto bodyIt = body->begin(), bodyE = body->end(); 580 for (; bodyIt != bodyE; ++bodyIt) { 581 // Break when we've found the rewrite statement. 582 if (isa<ast::OpRewriteStmt>(*bodyIt)) 583 break; 584 } 585 if (bodyIt == bodyE) { 586 return emitError(loc, 587 "expected Pattern body to terminate with an operation " 588 "rewrite statement, such as `erase`"); 589 } 590 if (std::next(bodyIt) != bodyE) { 591 return emitError((*std::next(bodyIt))->getLoc(), 592 "Pattern body was terminated by an operation " 593 "rewrite statement, but found trailing statements"); 594 } 595 596 return createPatternDecl(loc, name, metadata, body); 597 } 598 599 LogicalResult 600 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 601 Optional<llvm::SMRange> benefitLoc; 602 Optional<llvm::SMRange> hasBoundedRecursionLoc; 603 604 do { 605 if (curToken.isNot(Token::identifier)) 606 return emitError("expected pattern metadata identifier"); 607 StringRef metadataStr = curToken.getSpelling(); 608 llvm::SMRange metadataLoc = curToken.getLoc(); 609 consumeToken(Token::identifier); 610 611 // Parse the benefit metadata: benefit(<integer-value>) 612 if (metadataStr == "benefit") { 613 if (benefitLoc) { 614 return emitErrorAndNote(metadataLoc, 615 "pattern benefit has already been specified", 616 *benefitLoc, "see previous definition here"); 617 } 618 if (failed(parseToken(Token::l_paren, 619 "expected `(` before pattern benefit"))) 620 return failure(); 621 622 uint16_t benefitValue = 0; 623 if (curToken.isNot(Token::integer)) 624 return emitError("expected integral pattern benefit"); 625 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 626 return emitError( 627 "expected pattern benefit to fit within a 16-bit integer"); 628 consumeToken(Token::integer); 629 630 metadata.benefit = benefitValue; 631 benefitLoc = metadataLoc; 632 633 if (failed( 634 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 635 return failure(); 636 continue; 637 } 638 639 // Parse the bounded recursion metadata: recursion 640 if (metadataStr == "recursion") { 641 if (hasBoundedRecursionLoc) { 642 return emitErrorAndNote( 643 metadataLoc, 644 "pattern recursion metadata has already been specified", 645 *hasBoundedRecursionLoc, "see previous definition here"); 646 } 647 metadata.hasBoundedRecursion = true; 648 hasBoundedRecursionLoc = metadataLoc; 649 continue; 650 } 651 652 return emitError(metadataLoc, "unknown pattern metadata"); 653 } while (consumeIf(Token::comma)); 654 655 return success(); 656 } 657 658 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 659 consumeToken(Token::less); 660 661 FailureOr<ast::Expr *> typeExpr = parseExpr(); 662 if (failed(typeExpr) || 663 failed(parseToken(Token::greater, 664 "expected `>` after variable type constraint"))) 665 return failure(); 666 return typeExpr; 667 } 668 669 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 670 assert(curDeclScope && "defining decl outside of a decl scope"); 671 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 672 return emitErrorAndNote( 673 name.getLoc(), "`" + name.getName() + "` has already been defined", 674 lastDecl->getName()->getLoc(), "see previous definition here"); 675 } 676 return success(); 677 } 678 679 FailureOr<ast::VariableDecl *> 680 Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc, 681 ast::Type type, ast::Expr *initExpr, 682 ArrayRef<ast::ConstraintRef> constraints) { 683 assert(curDeclScope && "defining variable outside of decl scope"); 684 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 685 686 // If the name of the variable indicates a special variable, we don't add it 687 // to the scope. This variable is local to the definition point. 688 if (name.empty() || name == "_") { 689 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 690 constraints); 691 } 692 if (failed(checkDefineNamedDecl(nameDecl))) 693 return failure(); 694 695 auto *varDecl = 696 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 697 curDeclScope->add(varDecl); 698 return varDecl; 699 } 700 701 FailureOr<ast::VariableDecl *> 702 Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc, 703 ast::Type type, 704 ArrayRef<ast::ConstraintRef> constraints) { 705 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 706 constraints); 707 } 708 709 LogicalResult Parser::parseVariableDeclConstraintList( 710 SmallVectorImpl<ast::ConstraintRef> &constraints) { 711 Optional<llvm::SMRange> typeConstraint; 712 auto parseSingleConstraint = [&] { 713 FailureOr<ast::ConstraintRef> constraint = 714 parseConstraint(typeConstraint, constraints); 715 if (failed(constraint)) 716 return failure(); 717 constraints.push_back(*constraint); 718 return success(); 719 }; 720 721 // Check to see if this is a single constraint, or a list. 722 if (!consumeIf(Token::l_square)) 723 return parseSingleConstraint(); 724 725 do { 726 if (failed(parseSingleConstraint())) 727 return failure(); 728 } while (consumeIf(Token::comma)); 729 return parseToken(Token::r_square, "expected `]` after constraint list"); 730 } 731 732 FailureOr<ast::ConstraintRef> 733 Parser::parseConstraint(Optional<llvm::SMRange> &typeConstraint, 734 ArrayRef<ast::ConstraintRef> existingConstraints) { 735 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 736 if (typeConstraint) 737 return emitErrorAndNote( 738 curToken.getLoc(), 739 "the type of this variable has already been constrained", 740 *typeConstraint, "see previous constraint location here"); 741 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 742 if (failed(constraintExpr)) 743 return failure(); 744 typeExpr = *constraintExpr; 745 typeConstraint = typeExpr->getLoc(); 746 return success(); 747 }; 748 749 llvm::SMRange loc = curToken.getLoc(); 750 switch (curToken.getKind()) { 751 case Token::kw_Attr: { 752 consumeToken(Token::kw_Attr); 753 754 // Check for a type constraint. 755 ast::Expr *typeExpr = nullptr; 756 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 757 return failure(); 758 return ast::ConstraintRef( 759 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 760 } 761 case Token::kw_Op: { 762 consumeToken(Token::kw_Op); 763 764 // Parse an optional operation name. If the name isn't provided, this refers 765 // to "any" operation. 766 FailureOr<ast::OpNameDecl *> opName = 767 parseWrappedOperationName(/*allowEmptyName=*/true); 768 if (failed(opName)) 769 return failure(); 770 771 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 772 loc); 773 } 774 case Token::kw_Type: 775 consumeToken(Token::kw_Type); 776 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 777 case Token::kw_TypeRange: 778 consumeToken(Token::kw_TypeRange); 779 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 780 loc); 781 case Token::kw_Value: { 782 consumeToken(Token::kw_Value); 783 784 // Check for a type constraint. 785 ast::Expr *typeExpr = nullptr; 786 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 787 return failure(); 788 789 return ast::ConstraintRef( 790 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 791 } 792 case Token::kw_ValueRange: { 793 consumeToken(Token::kw_ValueRange); 794 795 // Check for a type constraint. 796 ast::Expr *typeExpr = nullptr; 797 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 798 return failure(); 799 800 return ast::ConstraintRef( 801 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 802 } 803 case Token::identifier: { 804 StringRef constraintName = curToken.getSpelling(); 805 consumeToken(Token::identifier); 806 807 // Lookup the referenced constraint. 808 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 809 if (!cstDecl) { 810 return emitError(loc, "unknown reference to constraint `" + 811 constraintName + "`"); 812 } 813 814 // Handle a reference to a proper constraint. 815 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 816 return ast::ConstraintRef(cst, loc); 817 818 return emitErrorAndNote( 819 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 820 "see the definition of `" + constraintName + "` here"); 821 } 822 default: 823 break; 824 } 825 return emitError(loc, "expected identifier constraint"); 826 } 827 828 //===----------------------------------------------------------------------===// 829 // Exprs 830 831 FailureOr<ast::Expr *> Parser::parseExpr() { 832 if (curToken.is(Token::underscore)) 833 return parseUnderscoreExpr(); 834 835 // Parse the LHS expression. 836 FailureOr<ast::Expr *> lhsExpr; 837 switch (curToken.getKind()) { 838 case Token::kw_attr: 839 lhsExpr = parseAttributeExpr(); 840 break; 841 case Token::identifier: 842 lhsExpr = parseIdentifierExpr(); 843 break; 844 case Token::kw_op: 845 lhsExpr = parseOperationExpr(); 846 break; 847 case Token::kw_type: 848 lhsExpr = parseTypeExpr(); 849 break; 850 case Token::l_paren: 851 lhsExpr = parseTupleExpr(); 852 break; 853 default: 854 return emitError("expected expression"); 855 } 856 if (failed(lhsExpr)) 857 return failure(); 858 859 // Check for an operator expression. 860 while (true) { 861 switch (curToken.getKind()) { 862 case Token::dot: 863 lhsExpr = parseMemberAccessExpr(*lhsExpr); 864 break; 865 default: 866 return lhsExpr; 867 } 868 if (failed(lhsExpr)) 869 return failure(); 870 } 871 } 872 873 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 874 llvm::SMRange loc = curToken.getLoc(); 875 consumeToken(Token::kw_attr); 876 877 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 878 // identifier. 879 if (!consumeIf(Token::less)) { 880 resetToken(loc); 881 return parseIdentifierExpr(); 882 } 883 884 if (!curToken.isString()) 885 return emitError("expected string literal containing MLIR attribute"); 886 std::string attrExpr = curToken.getStringValue(); 887 consumeToken(); 888 889 if (failed( 890 parseToken(Token::greater, "expected `>` after attribute literal"))) 891 return failure(); 892 return ast::AttributeExpr::create(ctx, loc, attrExpr); 893 } 894 895 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, 896 llvm::SMRange loc) { 897 ast::Decl *decl = curDeclScope->lookup(name); 898 if (!decl) 899 return emitError(loc, "undefined reference to `" + name + "`"); 900 901 return createDeclRefExpr(loc, decl); 902 } 903 904 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 905 StringRef name = curToken.getSpelling(); 906 llvm::SMRange nameLoc = curToken.getLoc(); 907 consumeToken(); 908 909 // Check to see if this is a decl ref expression that defines a variable 910 // inline. 911 if (consumeIf(Token::colon)) { 912 SmallVector<ast::ConstraintRef> constraints; 913 if (failed(parseVariableDeclConstraintList(constraints))) 914 return failure(); 915 ast::Type type; 916 if (failed(validateVariableConstraints(constraints, type))) 917 return failure(); 918 return createInlineVariableExpr(type, name, nameLoc, constraints); 919 } 920 921 return parseDeclRefExpr(name, nameLoc); 922 } 923 924 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 925 llvm::SMRange loc = curToken.getLoc(); 926 consumeToken(Token::dot); 927 928 // Parse the member name. 929 Token memberNameTok = curToken; 930 if (memberNameTok.isNot(Token::identifier, Token::integer) && 931 !memberNameTok.isKeyword()) 932 return emitError(loc, "expected identifier or numeric member name"); 933 StringRef memberName = memberNameTok.getSpelling(); 934 consumeToken(); 935 936 return createMemberAccessExpr(parentExpr, memberName, loc); 937 } 938 939 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 940 llvm::SMRange loc = curToken.getLoc(); 941 942 // Handle the case of an no operation name. 943 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 944 if (allowEmptyName) 945 return ast::OpNameDecl::create(ctx, llvm::SMRange()); 946 return emitError("expected dialect namespace"); 947 } 948 StringRef name = curToken.getSpelling(); 949 consumeToken(); 950 951 // Otherwise, this is a literal operation name. 952 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 953 return failure(); 954 955 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 956 return emitError("expected operation name after dialect namespace"); 957 958 name = StringRef(name.data(), name.size() + 1); 959 do { 960 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 961 loc.End = curToken.getEndLoc(); 962 consumeToken(); 963 } while (curToken.isAny(Token::identifier, Token::dot) || 964 curToken.isKeyword()); 965 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 966 } 967 968 FailureOr<ast::OpNameDecl *> 969 Parser::parseWrappedOperationName(bool allowEmptyName) { 970 if (!consumeIf(Token::less)) 971 return ast::OpNameDecl::create(ctx, llvm::SMRange()); 972 973 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 974 if (failed(opNameDecl)) 975 return failure(); 976 977 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 978 return failure(); 979 return opNameDecl; 980 } 981 982 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 983 llvm::SMRange loc = curToken.getLoc(); 984 consumeToken(Token::kw_op); 985 986 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 987 // identifier. 988 if (curToken.isNot(Token::less)) { 989 resetToken(loc); 990 return parseIdentifierExpr(); 991 } 992 993 // Parse the operation name. The name may be elided, in which case the 994 // operation refers to "any" operation(i.e. a difference between `MyOp` and 995 // `Operation*`). Operation names within a rewrite context must be named. 996 bool allowEmptyName = parserContext != ParserContext::Rewrite; 997 FailureOr<ast::OpNameDecl *> opNameDecl = 998 parseWrappedOperationName(allowEmptyName); 999 if (failed(opNameDecl)) 1000 return failure(); 1001 1002 // Check for the optional list of operands. 1003 SmallVector<ast::Expr *> operands; 1004 if (consumeIf(Token::l_paren)) { 1005 do { 1006 FailureOr<ast::Expr *> operand = parseExpr(); 1007 if (failed(operand)) 1008 return failure(); 1009 operands.push_back(*operand); 1010 } while (consumeIf(Token::comma)); 1011 1012 if (failed(parseToken(Token::r_paren, 1013 "expected `)` after operation operand list"))) 1014 return failure(); 1015 } 1016 1017 // Check for the optional list of attributes. 1018 SmallVector<ast::NamedAttributeDecl *> attributes; 1019 if (consumeIf(Token::l_brace)) { 1020 do { 1021 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1022 if (failed(decl)) 1023 return failure(); 1024 attributes.emplace_back(*decl); 1025 } while (consumeIf(Token::comma)); 1026 1027 if (failed(parseToken(Token::r_brace, 1028 "expected `}` after operation attribute list"))) 1029 return failure(); 1030 } 1031 1032 // Check for the optional list of result types. 1033 SmallVector<ast::Expr *> resultTypes; 1034 if (consumeIf(Token::arrow)) { 1035 if (failed(parseToken(Token::l_paren, 1036 "expected `(` before operation result type list"))) 1037 return failure(); 1038 1039 do { 1040 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1041 if (failed(resultTypeExpr)) 1042 return failure(); 1043 resultTypes.push_back(*resultTypeExpr); 1044 } while (consumeIf(Token::comma)); 1045 1046 if (failed(parseToken(Token::r_paren, 1047 "expected `)` after operation result type list"))) 1048 return failure(); 1049 } 1050 1051 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1052 resultTypes); 1053 } 1054 1055 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1056 llvm::SMRange loc = curToken.getLoc(); 1057 consumeToken(Token::l_paren); 1058 1059 DenseMap<StringRef, llvm::SMRange> usedNames; 1060 SmallVector<StringRef> elementNames; 1061 SmallVector<ast::Expr *> elements; 1062 if (curToken.isNot(Token::r_paren)) { 1063 do { 1064 // Check for the optional element name assignment before the value. 1065 StringRef elementName; 1066 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1067 Token elementNameTok = curToken; 1068 consumeToken(); 1069 1070 // The element name is only present if followed by an `=`. 1071 if (consumeIf(Token::equal)) { 1072 elementName = elementNameTok.getSpelling(); 1073 1074 // Check to see if this name is already used. 1075 auto elementNameIt = 1076 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1077 if (!elementNameIt.second) { 1078 return emitErrorAndNote( 1079 elementNameTok.getLoc(), 1080 llvm::formatv("duplicate tuple element label `{0}`", 1081 elementName), 1082 elementNameIt.first->getSecond(), 1083 "see previous label use here"); 1084 } 1085 } else { 1086 // Otherwise, we treat this as part of an expression so reset the 1087 // lexer. 1088 resetToken(elementNameTok.getLoc()); 1089 } 1090 } 1091 elementNames.push_back(elementName); 1092 1093 // Parse the tuple element value. 1094 FailureOr<ast::Expr *> element = parseExpr(); 1095 if (failed(element)) 1096 return failure(); 1097 elements.push_back(*element); 1098 } while (consumeIf(Token::comma)); 1099 } 1100 loc.End = curToken.getEndLoc(); 1101 if (failed( 1102 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1103 return failure(); 1104 return createTupleExpr(loc, elements, elementNames); 1105 } 1106 1107 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1108 llvm::SMRange loc = curToken.getLoc(); 1109 consumeToken(Token::kw_type); 1110 1111 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1112 // identifier. 1113 if (!consumeIf(Token::less)) { 1114 resetToken(loc); 1115 return parseIdentifierExpr(); 1116 } 1117 1118 if (!curToken.isString()) 1119 return emitError("expected string literal containing MLIR type"); 1120 std::string attrExpr = curToken.getStringValue(); 1121 consumeToken(); 1122 1123 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1124 return failure(); 1125 return ast::TypeExpr::create(ctx, loc, attrExpr); 1126 } 1127 1128 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 1129 StringRef name = curToken.getSpelling(); 1130 llvm::SMRange nameLoc = curToken.getLoc(); 1131 consumeToken(Token::underscore); 1132 1133 // Underscore expressions require a constraint list. 1134 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 1135 return failure(); 1136 1137 // Parse the constraints for the expression. 1138 SmallVector<ast::ConstraintRef> constraints; 1139 if (failed(parseVariableDeclConstraintList(constraints))) 1140 return failure(); 1141 1142 ast::Type type; 1143 if (failed(validateVariableConstraints(constraints, type))) 1144 return failure(); 1145 return createInlineVariableExpr(type, name, nameLoc, constraints); 1146 } 1147 1148 //===----------------------------------------------------------------------===// 1149 // Stmts 1150 1151 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 1152 FailureOr<ast::Stmt *> stmt; 1153 switch (curToken.getKind()) { 1154 case Token::kw_erase: 1155 stmt = parseEraseStmt(); 1156 break; 1157 case Token::kw_let: 1158 stmt = parseLetStmt(); 1159 break; 1160 case Token::kw_replace: 1161 stmt = parseReplaceStmt(); 1162 break; 1163 case Token::kw_rewrite: 1164 stmt = parseRewriteStmt(); 1165 break; 1166 default: 1167 stmt = parseExpr(); 1168 break; 1169 } 1170 if (failed(stmt) || 1171 (expectTerminalSemicolon && 1172 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 1173 return failure(); 1174 return stmt; 1175 } 1176 1177 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 1178 llvm::SMLoc startLoc = curToken.getStartLoc(); 1179 consumeToken(Token::l_brace); 1180 1181 // Push a new block scope and parse any nested statements. 1182 pushDeclScope(); 1183 SmallVector<ast::Stmt *> statements; 1184 while (curToken.isNot(Token::r_brace)) { 1185 FailureOr<ast::Stmt *> statement = parseStmt(); 1186 if (failed(statement)) 1187 return popDeclScope(), failure(); 1188 statements.push_back(*statement); 1189 } 1190 popDeclScope(); 1191 1192 // Consume the end brace. 1193 llvm::SMRange location(startLoc, curToken.getEndLoc()); 1194 consumeToken(Token::r_brace); 1195 1196 return ast::CompoundStmt::create(ctx, location, statements); 1197 } 1198 1199 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 1200 llvm::SMRange loc = curToken.getLoc(); 1201 consumeToken(Token::kw_erase); 1202 1203 // Parse the root operation expression. 1204 FailureOr<ast::Expr *> rootOp = parseExpr(); 1205 if (failed(rootOp)) 1206 return failure(); 1207 1208 return createEraseStmt(loc, *rootOp); 1209 } 1210 1211 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 1212 llvm::SMRange loc = curToken.getLoc(); 1213 consumeToken(Token::kw_let); 1214 1215 // Parse the name of the new variable. 1216 llvm::SMRange varLoc = curToken.getLoc(); 1217 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 1218 // `_` is a reserved variable name. 1219 if (curToken.is(Token::underscore)) { 1220 return emitError(varLoc, 1221 "`_` may only be used to define \"inline\" variables"); 1222 } 1223 return emitError(varLoc, 1224 "expected identifier after `let` to name a new variable"); 1225 } 1226 StringRef varName = curToken.getSpelling(); 1227 consumeToken(); 1228 1229 // Parse the optional set of constraints. 1230 SmallVector<ast::ConstraintRef> constraints; 1231 if (consumeIf(Token::colon) && 1232 failed(parseVariableDeclConstraintList(constraints))) 1233 return failure(); 1234 1235 // Parse the optional initializer expression. 1236 ast::Expr *initializer = nullptr; 1237 if (consumeIf(Token::equal)) { 1238 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 1239 if (failed(initOrFailure)) 1240 return failure(); 1241 initializer = *initOrFailure; 1242 1243 // Check that the constraints are compatible with having an initializer, 1244 // e.g. type constraints cannot be used with initializers. 1245 for (ast::ConstraintRef constraint : constraints) { 1246 LogicalResult result = 1247 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 1248 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 1249 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 1250 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 1251 return this->emitError( 1252 constraint.referenceLoc, 1253 "type constraints are not permitted on variables with " 1254 "initializers"); 1255 } 1256 return success(); 1257 }) 1258 .Default(success()); 1259 if (failed(result)) 1260 return failure(); 1261 } 1262 } 1263 1264 FailureOr<ast::VariableDecl *> varDecl = 1265 createVariableDecl(varName, varLoc, initializer, constraints); 1266 if (failed(varDecl)) 1267 return failure(); 1268 return ast::LetStmt::create(ctx, loc, *varDecl); 1269 } 1270 1271 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 1272 llvm::SMRange loc = curToken.getLoc(); 1273 consumeToken(Token::kw_replace); 1274 1275 // Parse the root operation expression. 1276 FailureOr<ast::Expr *> rootOp = parseExpr(); 1277 if (failed(rootOp)) 1278 return failure(); 1279 1280 if (failed( 1281 parseToken(Token::kw_with, "expected `with` after root operation"))) 1282 return failure(); 1283 1284 // The replacement portion of this statement is within a rewrite context. 1285 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1286 ParserContext::Rewrite); 1287 1288 // Parse the replacement values. 1289 SmallVector<ast::Expr *> replValues; 1290 if (consumeIf(Token::l_paren)) { 1291 if (consumeIf(Token::r_paren)) { 1292 return emitError( 1293 loc, "expected at least one replacement value, consider using " 1294 "`erase` if no replacement values are desired"); 1295 } 1296 1297 do { 1298 FailureOr<ast::Expr *> replExpr = parseExpr(); 1299 if (failed(replExpr)) 1300 return failure(); 1301 replValues.emplace_back(*replExpr); 1302 } while (consumeIf(Token::comma)); 1303 1304 if (failed(parseToken(Token::r_paren, 1305 "expected `)` after replacement values"))) 1306 return failure(); 1307 } else { 1308 FailureOr<ast::Expr *> replExpr = parseExpr(); 1309 if (failed(replExpr)) 1310 return failure(); 1311 replValues.emplace_back(*replExpr); 1312 } 1313 1314 return createReplaceStmt(loc, *rootOp, replValues); 1315 } 1316 1317 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 1318 llvm::SMRange loc = curToken.getLoc(); 1319 consumeToken(Token::kw_rewrite); 1320 1321 // Parse the root operation. 1322 FailureOr<ast::Expr *> rootOp = parseExpr(); 1323 if (failed(rootOp)) 1324 return failure(); 1325 1326 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 1327 return failure(); 1328 1329 if (curToken.isNot(Token::l_brace)) 1330 return emitError("expected `{` to start rewrite body"); 1331 1332 // The rewrite body of this statement is within a rewrite context. 1333 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1334 ParserContext::Rewrite); 1335 1336 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 1337 if (failed(rewriteBody)) 1338 return failure(); 1339 1340 return createRewriteStmt(loc, *rootOp, *rewriteBody); 1341 } 1342 1343 //===----------------------------------------------------------------------===// 1344 // Creation+Analysis 1345 //===----------------------------------------------------------------------===// 1346 1347 //===----------------------------------------------------------------------===// 1348 // Decls 1349 1350 FailureOr<ast::PatternDecl *> 1351 Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name, 1352 const ParsedPatternMetadata &metadata, 1353 ast::CompoundStmt *body) { 1354 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 1355 metadata.hasBoundedRecursion, body); 1356 } 1357 1358 FailureOr<ast::VariableDecl *> 1359 Parser::createVariableDecl(StringRef name, llvm::SMRange loc, 1360 ast::Expr *initializer, 1361 ArrayRef<ast::ConstraintRef> constraints) { 1362 // The type of the variable, which is expected to be inferred by either a 1363 // constraint or an initializer expression. 1364 ast::Type type; 1365 if (failed(validateVariableConstraints(constraints, type))) 1366 return failure(); 1367 1368 if (initializer) { 1369 // Update the variable type based on the initializer, or try to convert the 1370 // initializer to the existing type. 1371 if (!type) 1372 type = initializer->getType(); 1373 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 1374 type = mergedType; 1375 else if (failed(convertExpressionTo(initializer, type))) 1376 return failure(); 1377 1378 // Otherwise, if there is no initializer check that the type has already 1379 // been resolved from the constraint list. 1380 } else if (!type) { 1381 return emitErrorAndNote( 1382 loc, "unable to infer type for variable `" + name + "`", loc, 1383 "the type of a variable must be inferable from the constraint " 1384 "list or the initializer"); 1385 } 1386 1387 // Try to define a variable with the given name. 1388 FailureOr<ast::VariableDecl *> varDecl = 1389 defineVariableDecl(name, loc, type, initializer, constraints); 1390 if (failed(varDecl)) 1391 return failure(); 1392 1393 return *varDecl; 1394 } 1395 1396 LogicalResult 1397 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 1398 ast::Type &inferredType) { 1399 for (const ast::ConstraintRef &ref : constraints) 1400 if (failed(validateVariableConstraint(ref, inferredType))) 1401 return failure(); 1402 return success(); 1403 } 1404 1405 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 1406 ast::Type &inferredType) { 1407 ast::Type constraintType; 1408 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 1409 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 1410 if (failed(validateTypeConstraintExpr(typeExpr))) 1411 return failure(); 1412 } 1413 constraintType = ast::AttributeType::get(ctx); 1414 } else if (const auto *cst = 1415 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 1416 constraintType = ast::OperationType::get(ctx, cst->getName()); 1417 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 1418 constraintType = typeTy; 1419 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 1420 constraintType = typeRangeTy; 1421 } else if (const auto *cst = 1422 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 1423 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 1424 if (failed(validateTypeConstraintExpr(typeExpr))) 1425 return failure(); 1426 } 1427 constraintType = valueTy; 1428 } else if (const auto *cst = 1429 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 1430 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 1431 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 1432 return failure(); 1433 } 1434 constraintType = valueRangeTy; 1435 } else { 1436 llvm_unreachable("unknown constraint type"); 1437 } 1438 1439 // Check that the constraint type is compatible with the current inferred 1440 // type. 1441 if (!inferredType) { 1442 inferredType = constraintType; 1443 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 1444 inferredType = mergedTy; 1445 } else { 1446 return emitError(ref.referenceLoc, 1447 llvm::formatv("constraint type `{0}` is incompatible " 1448 "with the previously inferred type `{1}`", 1449 constraintType, inferredType)); 1450 } 1451 return success(); 1452 } 1453 1454 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 1455 ast::Type typeExprType = typeExpr->getType(); 1456 if (typeExprType != typeTy) { 1457 return emitError(typeExpr->getLoc(), 1458 "expected expression of `Type` in type constraint"); 1459 } 1460 return success(); 1461 } 1462 1463 LogicalResult 1464 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 1465 ast::Type typeExprType = typeExpr->getType(); 1466 if (typeExprType != typeRangeTy) { 1467 return emitError(typeExpr->getLoc(), 1468 "expected expression of `TypeRange` in type constraint"); 1469 } 1470 return success(); 1471 } 1472 1473 //===----------------------------------------------------------------------===// 1474 // Exprs 1475 1476 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(llvm::SMRange loc, 1477 ast::Decl *decl) { 1478 // Check the type of decl being referenced. 1479 ast::Type declType; 1480 if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 1481 declType = varDecl->getType(); 1482 else 1483 return emitError(loc, "invalid reference to `" + 1484 decl->getName()->getName() + "`"); 1485 1486 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 1487 } 1488 1489 FailureOr<ast::DeclRefExpr *> 1490 Parser::createInlineVariableExpr(ast::Type type, StringRef name, 1491 llvm::SMRange loc, 1492 ArrayRef<ast::ConstraintRef> constraints) { 1493 FailureOr<ast::VariableDecl *> decl = 1494 defineVariableDecl(name, loc, type, constraints); 1495 if (failed(decl)) 1496 return failure(); 1497 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 1498 } 1499 1500 FailureOr<ast::MemberAccessExpr *> 1501 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 1502 llvm::SMRange loc) { 1503 // Validate the member name for the given parent expression. 1504 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 1505 if (failed(memberType)) 1506 return failure(); 1507 1508 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 1509 } 1510 1511 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 1512 StringRef name, 1513 llvm::SMRange loc) { 1514 ast::Type parentType = parentExpr->getType(); 1515 if (parentType.isa<ast::OperationType>()) { 1516 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 1517 return valueRangeTy; 1518 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 1519 // Handle indexed results. 1520 unsigned index = 0; 1521 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 1522 index < tupleType.size()) { 1523 return tupleType.getElementTypes()[index]; 1524 } 1525 1526 // Handle named results. 1527 auto elementNames = tupleType.getElementNames(); 1528 const auto *it = llvm::find(elementNames, name); 1529 if (it != elementNames.end()) 1530 return tupleType.getElementTypes()[it - elementNames.begin()]; 1531 } 1532 return emitError( 1533 loc, 1534 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 1535 name, parentType)); 1536 } 1537 1538 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 1539 llvm::SMRange loc, const ast::OpNameDecl *name, 1540 MutableArrayRef<ast::Expr *> operands, 1541 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 1542 MutableArrayRef<ast::Expr *> results) { 1543 Optional<StringRef> opNameRef = name->getName(); 1544 1545 // Verify the inputs operands. 1546 if (failed(validateOperationOperands(loc, opNameRef, operands))) 1547 return failure(); 1548 1549 // Verify the attribute list. 1550 for (ast::NamedAttributeDecl *attr : attributes) { 1551 // Check for an attribute type, or a type awaiting resolution. 1552 ast::Type attrType = attr->getValue()->getType(); 1553 if (!attrType.isa<ast::AttributeType>()) { 1554 return emitError( 1555 attr->getValue()->getLoc(), 1556 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 1557 } 1558 } 1559 1560 // Verify the result types. 1561 if (failed(validateOperationResults(loc, opNameRef, results))) 1562 return failure(); 1563 1564 return ast::OperationExpr::create(ctx, loc, name, operands, results, 1565 attributes); 1566 } 1567 1568 LogicalResult 1569 Parser::validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name, 1570 MutableArrayRef<ast::Expr *> operands) { 1571 return validateOperationOperandsOrResults(loc, name, operands, valueTy, 1572 valueRangeTy); 1573 } 1574 1575 LogicalResult 1576 Parser::validateOperationResults(llvm::SMRange loc, Optional<StringRef> name, 1577 MutableArrayRef<ast::Expr *> results) { 1578 return validateOperationOperandsOrResults(loc, name, results, typeTy, 1579 typeRangeTy); 1580 } 1581 1582 LogicalResult Parser::validateOperationOperandsOrResults( 1583 llvm::SMRange loc, Optional<StringRef> name, 1584 MutableArrayRef<ast::Expr *> values, ast::Type singleTy, 1585 ast::Type rangeTy) { 1586 // All operation types accept a single range parameter. 1587 if (values.size() == 1) { 1588 if (failed(convertExpressionTo(values[0], rangeTy))) 1589 return failure(); 1590 return success(); 1591 } 1592 1593 // Otherwise, accept the value groups as they have been defined and just 1594 // ensure they are one of the expected types. 1595 for (ast::Expr *&valueExpr : values) { 1596 ast::Type valueExprType = valueExpr->getType(); 1597 1598 // Check if this is one of the expected types. 1599 if (valueExprType == rangeTy || valueExprType == singleTy) 1600 continue; 1601 1602 // If the operand is an Operation, allow converting to a Value or 1603 // ValueRange. This situations arises quite often with nested operation 1604 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 1605 if (singleTy == valueTy) { 1606 if (valueExprType.isa<ast::OperationType>()) { 1607 valueExpr = convertOpToValue(valueExpr); 1608 continue; 1609 } 1610 } 1611 1612 return emitError( 1613 valueExpr->getLoc(), 1614 llvm::formatv( 1615 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 1616 singleTy, rangeTy, valueExprType)); 1617 } 1618 return success(); 1619 } 1620 1621 FailureOr<ast::TupleExpr *> 1622 Parser::createTupleExpr(llvm::SMRange loc, ArrayRef<ast::Expr *> elements, 1623 ArrayRef<StringRef> elementNames) { 1624 for (const ast::Expr *element : elements) { 1625 ast::Type eleTy = element->getType(); 1626 if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) { 1627 return emitError( 1628 element->getLoc(), 1629 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 1630 } 1631 } 1632 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 1633 } 1634 1635 //===----------------------------------------------------------------------===// 1636 // Stmts 1637 1638 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(llvm::SMRange loc, 1639 ast::Expr *rootOp) { 1640 // Check that root is an Operation. 1641 ast::Type rootType = rootOp->getType(); 1642 if (!rootType.isa<ast::OperationType>()) 1643 return emitError(rootOp->getLoc(), "expected `Op` expression"); 1644 1645 return ast::EraseStmt::create(ctx, loc, rootOp); 1646 } 1647 1648 FailureOr<ast::ReplaceStmt *> 1649 Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, 1650 MutableArrayRef<ast::Expr *> replValues) { 1651 // Check that root is an Operation. 1652 ast::Type rootType = rootOp->getType(); 1653 if (!rootType.isa<ast::OperationType>()) { 1654 return emitError( 1655 rootOp->getLoc(), 1656 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 1657 } 1658 1659 // If there are multiple replacement values, we implicitly convert any Op 1660 // expressions to the value form. 1661 bool shouldConvertOpToValues = replValues.size() > 1; 1662 for (ast::Expr *&replExpr : replValues) { 1663 ast::Type replType = replExpr->getType(); 1664 1665 // Check that replExpr is an Operation, Value, or ValueRange. 1666 if (replType.isa<ast::OperationType>()) { 1667 if (shouldConvertOpToValues) 1668 replExpr = convertOpToValue(replExpr); 1669 continue; 1670 } 1671 1672 if (replType != valueTy && replType != valueRangeTy) { 1673 return emitError(replExpr->getLoc(), 1674 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 1675 "expression, but got `{0}`", 1676 replType)); 1677 } 1678 } 1679 1680 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 1681 } 1682 1683 FailureOr<ast::RewriteStmt *> 1684 Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp, 1685 ast::CompoundStmt *rewriteBody) { 1686 // Check that root is an Operation. 1687 ast::Type rootType = rootOp->getType(); 1688 if (!rootType.isa<ast::OperationType>()) { 1689 return emitError( 1690 rootOp->getLoc(), 1691 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 1692 } 1693 1694 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 1695 } 1696 1697 //===----------------------------------------------------------------------===// 1698 // Parser 1699 //===----------------------------------------------------------------------===// 1700 1701 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 1702 llvm::SourceMgr &sourceMgr) { 1703 Parser parser(ctx, sourceMgr); 1704 return parser.parseModule(); 1705 } 1706