1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/Tools/PDLL/AST/Context.h" 13 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 14 #include "mlir/Tools/PDLL/AST/Nodes.h" 15 #include "mlir/Tools/PDLL/AST/Types.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/SaveAndRestore.h" 20 #include "llvm/Support/ScopedPrinter.h" 21 #include <string> 22 23 using namespace mlir; 24 using namespace mlir::pdll; 25 26 //===----------------------------------------------------------------------===// 27 // Parser 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 class Parser { 32 public: 33 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 34 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 35 curToken(lexer.lexToken()), curDeclScope(nullptr), 36 valueTy(ast::ValueType::get(ctx)), 37 valueRangeTy(ast::ValueRangeType::get(ctx)), 38 typeTy(ast::TypeType::get(ctx)), 39 typeRangeTy(ast::TypeRangeType::get(ctx)) {} 40 41 /// Try to parse a new module. Returns nullptr in the case of failure. 42 FailureOr<ast::Module *> parseModule(); 43 44 private: 45 /// The current context of the parser. It allows for the parser to know a bit 46 /// about the construct it is nested within during parsing. This is used 47 /// specifically to provide additional verification during parsing, e.g. to 48 /// prevent using rewrites within a match context, matcher constraints within 49 /// a rewrite section, etc. 50 enum class ParserContext { 51 /// The parser is in the global context. 52 Global, 53 /// The parser is currently within the matcher portion of a Pattern, which 54 /// is allows a terminal operation rewrite statement but no other rewrite 55 /// transformations. 56 PatternMatch, 57 /// The parser is currently within a Rewrite, which disallows calls to 58 /// constraints, requires operation expressions to have names, etc. 59 Rewrite, 60 }; 61 62 //===--------------------------------------------------------------------===// 63 // Parsing 64 //===--------------------------------------------------------------------===// 65 66 /// Push a new decl scope onto the lexer. 67 ast::DeclScope *pushDeclScope() { 68 ast::DeclScope *newScope = 69 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 70 return (curDeclScope = newScope); 71 } 72 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 73 74 /// Pop the last decl scope from the lexer. 75 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 76 77 /// Parse the body of an AST module. 78 LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls); 79 80 /// Try to convert the given expression to `type`. Returns failure and emits 81 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 82 /// invoked to attach notes to the emitted error diagnostic. On success, 83 /// `expr` is updated to the expression used to convert to `type`. 84 LogicalResult convertExpressionTo( 85 ast::Expr *&expr, ast::Type type, 86 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 87 88 /// Given an operation expression, convert it to a Value or ValueRange 89 /// typed expression. 90 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 91 92 //===--------------------------------------------------------------------===// 93 // Directives 94 95 LogicalResult parseDirective(SmallVector<ast::Decl *> &decls); 96 LogicalResult parseInclude(SmallVector<ast::Decl *> &decls); 97 98 //===--------------------------------------------------------------------===// 99 // Decls 100 101 /// This structure contains the set of pattern metadata that may be parsed. 102 struct ParsedPatternMetadata { 103 Optional<uint16_t> benefit; 104 bool hasBoundedRecursion = false; 105 }; 106 107 FailureOr<ast::Decl *> parseTopLevelDecl(); 108 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 109 FailureOr<ast::CompoundStmt *> 110 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 111 bool expectTerminalSemicolon = true); 112 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 113 FailureOr<ast::Decl *> parsePatternDecl(); 114 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 115 116 /// Check to see if a decl has already been defined with the given name, if 117 /// one has emit and error and return failure. Returns success otherwise. 118 LogicalResult checkDefineNamedDecl(const ast::Name &name); 119 120 /// Try to define a variable decl with the given components, returns the 121 /// variable on success. 122 FailureOr<ast::VariableDecl *> 123 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 124 ast::Expr *initExpr, 125 ArrayRef<ast::ConstraintRef> constraints); 126 FailureOr<ast::VariableDecl *> 127 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 128 ArrayRef<ast::ConstraintRef> constraints); 129 130 /// Parse the constraint reference list for a variable decl. 131 LogicalResult parseVariableDeclConstraintList( 132 SmallVectorImpl<ast::ConstraintRef> &constraints); 133 134 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 135 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 136 137 /// Try to parse a single reference to a constraint. `typeConstraint` is the 138 /// location of a previously parsed type constraint for the entity that will 139 /// be constrained by the parsed constraint. `existingConstraints` are any 140 /// existing constraints that have already been parsed for the same entity 141 /// that will be constrained by this constraint. 142 FailureOr<ast::ConstraintRef> 143 parseConstraint(Optional<SMRange> &typeConstraint, 144 ArrayRef<ast::ConstraintRef> existingConstraints); 145 146 //===--------------------------------------------------------------------===// 147 // Exprs 148 149 FailureOr<ast::Expr *> parseExpr(); 150 151 /// Identifier expressions. 152 FailureOr<ast::Expr *> parseAttributeExpr(); 153 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 154 FailureOr<ast::Expr *> parseIdentifierExpr(); 155 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 156 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 157 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 158 FailureOr<ast::Expr *> parseOperationExpr(); 159 FailureOr<ast::Expr *> parseTupleExpr(); 160 FailureOr<ast::Expr *> parseTypeExpr(); 161 FailureOr<ast::Expr *> parseUnderscoreExpr(); 162 163 //===--------------------------------------------------------------------===// 164 // Stmts 165 166 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 167 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 168 FailureOr<ast::EraseStmt *> parseEraseStmt(); 169 FailureOr<ast::LetStmt *> parseLetStmt(); 170 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 171 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 172 173 //===--------------------------------------------------------------------===// 174 // Creation+Analysis 175 //===--------------------------------------------------------------------===// 176 177 //===--------------------------------------------------------------------===// 178 // Decls 179 180 /// Try to create a pattern decl with the given components, returning the 181 /// Pattern on success. 182 FailureOr<ast::PatternDecl *> 183 createPatternDecl(SMRange loc, const ast::Name *name, 184 const ParsedPatternMetadata &metadata, 185 ast::CompoundStmt *body); 186 187 /// Try to create a variable decl with the given components, returning the 188 /// Variable on success. 189 FailureOr<ast::VariableDecl *> 190 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 191 ArrayRef<ast::ConstraintRef> constraints); 192 193 /// Validate the constraints used to constraint a variable decl. 194 /// `inferredType` is the type of the variable inferred by the constraints 195 /// within the list, and is updated to the most refined type as determined by 196 /// the constraints. Returns success if the constraint list is valid, failure 197 /// otherwise. 198 LogicalResult 199 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 200 ast::Type &inferredType); 201 /// Validate a single reference to a constraint. `inferredType` contains the 202 /// currently inferred variabled type and is refined within the type defined 203 /// by the constraint. Returns success if the constraint is valid, failure 204 /// otherwise. 205 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 206 ast::Type &inferredType); 207 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 208 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 209 210 //===--------------------------------------------------------------------===// 211 // Exprs 212 213 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, 214 ast::Decl *decl); 215 FailureOr<ast::DeclRefExpr *> 216 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 217 ArrayRef<ast::ConstraintRef> constraints); 218 FailureOr<ast::MemberAccessExpr *> 219 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 220 SMRange loc); 221 222 /// Validate the member access `name` into the given parent expression. On 223 /// success, this also returns the type of the member accessed. 224 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 225 StringRef name, SMRange loc); 226 FailureOr<ast::OperationExpr *> 227 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 228 MutableArrayRef<ast::Expr *> operands, 229 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 230 MutableArrayRef<ast::Expr *> results); 231 LogicalResult 232 validateOperationOperands(SMRange loc, Optional<StringRef> name, 233 MutableArrayRef<ast::Expr *> operands); 234 LogicalResult validateOperationResults(SMRange loc, 235 Optional<StringRef> name, 236 MutableArrayRef<ast::Expr *> results); 237 LogicalResult 238 validateOperationOperandsOrResults(SMRange loc, 239 Optional<StringRef> name, 240 MutableArrayRef<ast::Expr *> values, 241 ast::Type singleTy, ast::Type rangeTy); 242 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 243 ArrayRef<ast::Expr *> elements, 244 ArrayRef<StringRef> elementNames); 245 246 //===--------------------------------------------------------------------===// 247 // Stmts 248 249 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, 250 ast::Expr *rootOp); 251 FailureOr<ast::ReplaceStmt *> 252 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 253 MutableArrayRef<ast::Expr *> replValues); 254 FailureOr<ast::RewriteStmt *> 255 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 256 ast::CompoundStmt *rewriteBody); 257 258 //===--------------------------------------------------------------------===// 259 // Lexer Utilities 260 //===--------------------------------------------------------------------===// 261 262 /// If the current token has the specified kind, consume it and return true. 263 /// If not, return false. 264 bool consumeIf(Token::Kind kind) { 265 if (curToken.isNot(kind)) 266 return false; 267 consumeToken(kind); 268 return true; 269 } 270 271 /// Advance the current lexer onto the next token. 272 void consumeToken() { 273 assert(curToken.isNot(Token::eof, Token::error) && 274 "shouldn't advance past EOF or errors"); 275 curToken = lexer.lexToken(); 276 } 277 278 /// Advance the current lexer onto the next token, asserting what the expected 279 /// current token is. This is preferred to the above method because it leads 280 /// to more self-documenting code with better checking. 281 void consumeToken(Token::Kind kind) { 282 assert(curToken.is(kind) && "consumed an unexpected token"); 283 consumeToken(); 284 } 285 286 /// Reset the lexer to the location at the given position. 287 void resetToken(SMRange tokLoc) { 288 lexer.resetPointer(tokLoc.Start.getPointer()); 289 curToken = lexer.lexToken(); 290 } 291 292 /// Consume the specified token if present and return success. On failure, 293 /// output a diagnostic and return failure. 294 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 295 if (curToken.getKind() != kind) 296 return emitError(curToken.getLoc(), msg); 297 consumeToken(); 298 return success(); 299 } 300 LogicalResult emitError(SMRange loc, const Twine &msg) { 301 lexer.emitError(loc, msg); 302 return failure(); 303 } 304 LogicalResult emitError(const Twine &msg) { 305 return emitError(curToken.getLoc(), msg); 306 } 307 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, 308 SMRange noteLoc, const Twine ¬e) { 309 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 310 return failure(); 311 } 312 313 //===--------------------------------------------------------------------===// 314 // Fields 315 //===--------------------------------------------------------------------===// 316 317 /// The owning AST context. 318 ast::Context &ctx; 319 320 /// The lexer of this parser. 321 Lexer lexer; 322 323 /// The current token within the lexer. 324 Token curToken; 325 326 /// The most recently defined decl scope. 327 ast::DeclScope *curDeclScope; 328 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 329 330 /// The current context of the parser. 331 ParserContext parserContext = ParserContext::Global; 332 333 /// Cached types to simplify verification and expression creation. 334 ast::Type valueTy, valueRangeTy; 335 ast::Type typeTy, typeRangeTy; 336 }; 337 } // namespace 338 339 FailureOr<ast::Module *> Parser::parseModule() { 340 SMLoc moduleLoc = curToken.getStartLoc(); 341 pushDeclScope(); 342 343 // Parse the top-level decls of the module. 344 SmallVector<ast::Decl *> decls; 345 if (failed(parseModuleBody(decls))) 346 return popDeclScope(), failure(); 347 348 popDeclScope(); 349 return ast::Module::create(ctx, moduleLoc, decls); 350 } 351 352 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) { 353 while (curToken.isNot(Token::eof)) { 354 if (curToken.is(Token::directive)) { 355 if (failed(parseDirective(decls))) 356 return failure(); 357 continue; 358 } 359 360 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 361 if (failed(decl)) 362 return failure(); 363 decls.push_back(*decl); 364 } 365 return success(); 366 } 367 368 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 369 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 370 valueRangeTy); 371 } 372 373 LogicalResult Parser::convertExpressionTo( 374 ast::Expr *&expr, ast::Type type, 375 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 376 ast::Type exprType = expr->getType(); 377 if (exprType == type) 378 return success(); 379 380 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 381 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 382 expr->getLoc(), llvm::formatv("unable to convert expression of type " 383 "`{0}` to the expected type of " 384 "`{1}`", 385 exprType, type)); 386 if (noteAttachFn) 387 noteAttachFn(*diag); 388 return diag; 389 }; 390 391 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 392 // Two operation types are compatible if they have the same name, or if the 393 // expected type is more general. 394 if (auto opType = type.dyn_cast<ast::OperationType>()) { 395 if (opType.getName()) 396 return emitConvertError(); 397 return success(); 398 } 399 400 // An operation can always convert to a ValueRange. 401 if (type == valueRangeTy) { 402 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 403 valueRangeTy); 404 return success(); 405 } 406 407 // Allow conversion to a single value by constraining the result range. 408 if (type == valueTy) { 409 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 410 valueTy); 411 return success(); 412 } 413 return emitConvertError(); 414 } 415 416 // FIXME: Decide how to allow/support converting a single result to multiple, 417 // and multiple to a single result. For now, we just allow Single->Range, 418 // but this isn't something really supported in the PDL dialect. We should 419 // figure out some way to support both. 420 if ((exprType == valueTy || exprType == valueRangeTy) && 421 (type == valueTy || type == valueRangeTy)) 422 return success(); 423 if ((exprType == typeTy || exprType == typeRangeTy) && 424 (type == typeTy || type == typeRangeTy)) 425 return success(); 426 427 // Handle tuple types. 428 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 429 auto tupleType = type.dyn_cast<ast::TupleType>(); 430 if (!tupleType || tupleType.size() != exprTupleType.size()) 431 return emitConvertError(); 432 433 // Build a new tuple expression using each of the elements of the current 434 // tuple. 435 SmallVector<ast::Expr *> newExprs; 436 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 437 newExprs.push_back(ast::MemberAccessExpr::create( 438 ctx, expr->getLoc(), expr, llvm::to_string(i), 439 exprTupleType.getElementTypes()[i])); 440 441 auto diagFn = [&](ast::Diagnostic &diag) { 442 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 443 i, exprTupleType)); 444 if (noteAttachFn) 445 noteAttachFn(diag); 446 }; 447 if (failed(convertExpressionTo(newExprs.back(), 448 tupleType.getElementTypes()[i], diagFn))) 449 return failure(); 450 } 451 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 452 tupleType.getElementNames()); 453 return success(); 454 } 455 456 return emitConvertError(); 457 } 458 459 //===----------------------------------------------------------------------===// 460 // Directives 461 462 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) { 463 StringRef directive = curToken.getSpelling(); 464 if (directive == "#include") 465 return parseInclude(decls); 466 467 return emitError("unknown directive `" + directive + "`"); 468 } 469 470 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) { 471 SMRange loc = curToken.getLoc(); 472 consumeToken(Token::directive); 473 474 // Parse the file being included. 475 if (!curToken.isString()) 476 return emitError(loc, 477 "expected string file name after `include` directive"); 478 SMRange fileLoc = curToken.getLoc(); 479 std::string filenameStr = curToken.getStringValue(); 480 StringRef filename = filenameStr; 481 consumeToken(); 482 483 // Check the type of include. If ending with `.pdll`, this is another pdl file 484 // to be parsed along with the current module. 485 if (filename.endswith(".pdll")) { 486 if (failed(lexer.pushInclude(filename))) 487 return emitError(fileLoc, 488 "unable to open include file `" + filename + "`"); 489 490 // If we added the include successfully, parse it into the current module. 491 // Make sure to save the current token so that we can restore it when we 492 // finish parsing the nested file. 493 Token oldToken = curToken; 494 curToken = lexer.lexToken(); 495 LogicalResult result = parseModuleBody(decls); 496 curToken = oldToken; 497 return result; 498 } 499 500 return emitError(fileLoc, "expected include filename to end with `.pdll`"); 501 } 502 503 //===----------------------------------------------------------------------===// 504 // Decls 505 506 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 507 FailureOr<ast::Decl *> decl; 508 switch (curToken.getKind()) { 509 case Token::kw_Pattern: 510 decl = parsePatternDecl(); 511 break; 512 default: 513 return emitError("expected top-level declaration, such as a `Pattern`"); 514 } 515 if (failed(decl)) 516 return failure(); 517 518 // If the decl has a name, add it to the current scope. 519 if (const ast::Name *name = (*decl)->getName()) { 520 if (failed(checkDefineNamedDecl(*name))) 521 return failure(); 522 curDeclScope->add(*decl); 523 } 524 return decl; 525 } 526 527 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 528 std::string attrNameStr; 529 if (curToken.isString()) 530 attrNameStr = curToken.getStringValue(); 531 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 532 attrNameStr = curToken.getSpelling().str(); 533 else 534 return emitError("expected identifier or string attribute name"); 535 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 536 consumeToken(); 537 538 // Check for a value of the attribute. 539 ast::Expr *attrValue = nullptr; 540 if (consumeIf(Token::equal)) { 541 FailureOr<ast::Expr *> attrExpr = parseExpr(); 542 if (failed(attrExpr)) 543 return failure(); 544 attrValue = *attrExpr; 545 } else { 546 // If there isn't a concrete value, create an expression representing a 547 // UnitAttr. 548 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 549 } 550 551 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 552 } 553 554 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 555 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 556 bool expectTerminalSemicolon) { 557 consumeToken(Token::equal_arrow); 558 559 // Parse the single statement of the lambda body. 560 SMLoc bodyStartLoc = curToken.getStartLoc(); 561 pushDeclScope(); 562 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 563 bool failedToParse = 564 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 565 popDeclScope(); 566 if (failedToParse) 567 return failure(); 568 569 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 570 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 571 } 572 573 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 574 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 575 if (isa<ast::OpRewriteStmt>(statement)) 576 return success(); 577 return emitError( 578 statement->getLoc(), 579 "expected Pattern lambda body to contain a single operation " 580 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 581 }); 582 } 583 584 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 585 SMRange loc = curToken.getLoc(); 586 consumeToken(Token::kw_Pattern); 587 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 588 ParserContext::PatternMatch); 589 590 // Check for an optional identifier for the pattern name. 591 const ast::Name *name = nullptr; 592 if (curToken.is(Token::identifier)) { 593 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 594 consumeToken(Token::identifier); 595 } 596 597 // Parse any pattern metadata. 598 ParsedPatternMetadata metadata; 599 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 600 return failure(); 601 602 // Parse the pattern body. 603 ast::CompoundStmt *body; 604 605 // Handle a lambda body. 606 if (curToken.is(Token::equal_arrow)) { 607 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 608 if (failed(bodyResult)) 609 return failure(); 610 body = *bodyResult; 611 } else { 612 if (curToken.isNot(Token::l_brace)) 613 return emitError("expected `{` or `=>` to start pattern body"); 614 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 615 if (failed(bodyResult)) 616 return failure(); 617 body = *bodyResult; 618 619 // Verify the body of the pattern. 620 auto bodyIt = body->begin(), bodyE = body->end(); 621 for (; bodyIt != bodyE; ++bodyIt) { 622 // Break when we've found the rewrite statement. 623 if (isa<ast::OpRewriteStmt>(*bodyIt)) 624 break; 625 } 626 if (bodyIt == bodyE) { 627 return emitError(loc, 628 "expected Pattern body to terminate with an operation " 629 "rewrite statement, such as `erase`"); 630 } 631 if (std::next(bodyIt) != bodyE) { 632 return emitError((*std::next(bodyIt))->getLoc(), 633 "Pattern body was terminated by an operation " 634 "rewrite statement, but found trailing statements"); 635 } 636 } 637 638 return createPatternDecl(loc, name, metadata, body); 639 } 640 641 LogicalResult 642 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 643 Optional<SMRange> benefitLoc; 644 Optional<SMRange> hasBoundedRecursionLoc; 645 646 do { 647 if (curToken.isNot(Token::identifier)) 648 return emitError("expected pattern metadata identifier"); 649 StringRef metadataStr = curToken.getSpelling(); 650 SMRange metadataLoc = curToken.getLoc(); 651 consumeToken(Token::identifier); 652 653 // Parse the benefit metadata: benefit(<integer-value>) 654 if (metadataStr == "benefit") { 655 if (benefitLoc) { 656 return emitErrorAndNote(metadataLoc, 657 "pattern benefit has already been specified", 658 *benefitLoc, "see previous definition here"); 659 } 660 if (failed(parseToken(Token::l_paren, 661 "expected `(` before pattern benefit"))) 662 return failure(); 663 664 uint16_t benefitValue = 0; 665 if (curToken.isNot(Token::integer)) 666 return emitError("expected integral pattern benefit"); 667 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 668 return emitError( 669 "expected pattern benefit to fit within a 16-bit integer"); 670 consumeToken(Token::integer); 671 672 metadata.benefit = benefitValue; 673 benefitLoc = metadataLoc; 674 675 if (failed( 676 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 677 return failure(); 678 continue; 679 } 680 681 // Parse the bounded recursion metadata: recursion 682 if (metadataStr == "recursion") { 683 if (hasBoundedRecursionLoc) { 684 return emitErrorAndNote( 685 metadataLoc, 686 "pattern recursion metadata has already been specified", 687 *hasBoundedRecursionLoc, "see previous definition here"); 688 } 689 metadata.hasBoundedRecursion = true; 690 hasBoundedRecursionLoc = metadataLoc; 691 continue; 692 } 693 694 return emitError(metadataLoc, "unknown pattern metadata"); 695 } while (consumeIf(Token::comma)); 696 697 return success(); 698 } 699 700 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 701 consumeToken(Token::less); 702 703 FailureOr<ast::Expr *> typeExpr = parseExpr(); 704 if (failed(typeExpr) || 705 failed(parseToken(Token::greater, 706 "expected `>` after variable type constraint"))) 707 return failure(); 708 return typeExpr; 709 } 710 711 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 712 assert(curDeclScope && "defining decl outside of a decl scope"); 713 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 714 return emitErrorAndNote( 715 name.getLoc(), "`" + name.getName() + "` has already been defined", 716 lastDecl->getName()->getLoc(), "see previous definition here"); 717 } 718 return success(); 719 } 720 721 FailureOr<ast::VariableDecl *> 722 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, 723 ast::Type type, ast::Expr *initExpr, 724 ArrayRef<ast::ConstraintRef> constraints) { 725 assert(curDeclScope && "defining variable outside of decl scope"); 726 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 727 728 // If the name of the variable indicates a special variable, we don't add it 729 // to the scope. This variable is local to the definition point. 730 if (name.empty() || name == "_") { 731 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 732 constraints); 733 } 734 if (failed(checkDefineNamedDecl(nameDecl))) 735 return failure(); 736 737 auto *varDecl = 738 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 739 curDeclScope->add(varDecl); 740 return varDecl; 741 } 742 743 FailureOr<ast::VariableDecl *> 744 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, 745 ast::Type type, 746 ArrayRef<ast::ConstraintRef> constraints) { 747 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 748 constraints); 749 } 750 751 LogicalResult Parser::parseVariableDeclConstraintList( 752 SmallVectorImpl<ast::ConstraintRef> &constraints) { 753 Optional<SMRange> typeConstraint; 754 auto parseSingleConstraint = [&] { 755 FailureOr<ast::ConstraintRef> constraint = 756 parseConstraint(typeConstraint, constraints); 757 if (failed(constraint)) 758 return failure(); 759 constraints.push_back(*constraint); 760 return success(); 761 }; 762 763 // Check to see if this is a single constraint, or a list. 764 if (!consumeIf(Token::l_square)) 765 return parseSingleConstraint(); 766 767 do { 768 if (failed(parseSingleConstraint())) 769 return failure(); 770 } while (consumeIf(Token::comma)); 771 return parseToken(Token::r_square, "expected `]` after constraint list"); 772 } 773 774 FailureOr<ast::ConstraintRef> 775 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 776 ArrayRef<ast::ConstraintRef> existingConstraints) { 777 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 778 if (typeConstraint) 779 return emitErrorAndNote( 780 curToken.getLoc(), 781 "the type of this variable has already been constrained", 782 *typeConstraint, "see previous constraint location here"); 783 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 784 if (failed(constraintExpr)) 785 return failure(); 786 typeExpr = *constraintExpr; 787 typeConstraint = typeExpr->getLoc(); 788 return success(); 789 }; 790 791 SMRange loc = curToken.getLoc(); 792 switch (curToken.getKind()) { 793 case Token::kw_Attr: { 794 consumeToken(Token::kw_Attr); 795 796 // Check for a type constraint. 797 ast::Expr *typeExpr = nullptr; 798 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 799 return failure(); 800 return ast::ConstraintRef( 801 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 802 } 803 case Token::kw_Op: { 804 consumeToken(Token::kw_Op); 805 806 // Parse an optional operation name. If the name isn't provided, this refers 807 // to "any" operation. 808 FailureOr<ast::OpNameDecl *> opName = 809 parseWrappedOperationName(/*allowEmptyName=*/true); 810 if (failed(opName)) 811 return failure(); 812 813 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 814 loc); 815 } 816 case Token::kw_Type: 817 consumeToken(Token::kw_Type); 818 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 819 case Token::kw_TypeRange: 820 consumeToken(Token::kw_TypeRange); 821 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 822 loc); 823 case Token::kw_Value: { 824 consumeToken(Token::kw_Value); 825 826 // Check for a type constraint. 827 ast::Expr *typeExpr = nullptr; 828 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 829 return failure(); 830 831 return ast::ConstraintRef( 832 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 833 } 834 case Token::kw_ValueRange: { 835 consumeToken(Token::kw_ValueRange); 836 837 // Check for a type constraint. 838 ast::Expr *typeExpr = nullptr; 839 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 840 return failure(); 841 842 return ast::ConstraintRef( 843 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 844 } 845 case Token::identifier: { 846 StringRef constraintName = curToken.getSpelling(); 847 consumeToken(Token::identifier); 848 849 // Lookup the referenced constraint. 850 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 851 if (!cstDecl) { 852 return emitError(loc, "unknown reference to constraint `" + 853 constraintName + "`"); 854 } 855 856 // Handle a reference to a proper constraint. 857 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 858 return ast::ConstraintRef(cst, loc); 859 860 return emitErrorAndNote( 861 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 862 "see the definition of `" + constraintName + "` here"); 863 } 864 default: 865 break; 866 } 867 return emitError(loc, "expected identifier constraint"); 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Exprs 872 873 FailureOr<ast::Expr *> Parser::parseExpr() { 874 if (curToken.is(Token::underscore)) 875 return parseUnderscoreExpr(); 876 877 // Parse the LHS expression. 878 FailureOr<ast::Expr *> lhsExpr; 879 switch (curToken.getKind()) { 880 case Token::kw_attr: 881 lhsExpr = parseAttributeExpr(); 882 break; 883 case Token::identifier: 884 lhsExpr = parseIdentifierExpr(); 885 break; 886 case Token::kw_op: 887 lhsExpr = parseOperationExpr(); 888 break; 889 case Token::kw_type: 890 lhsExpr = parseTypeExpr(); 891 break; 892 case Token::l_paren: 893 lhsExpr = parseTupleExpr(); 894 break; 895 default: 896 return emitError("expected expression"); 897 } 898 if (failed(lhsExpr)) 899 return failure(); 900 901 // Check for an operator expression. 902 while (true) { 903 switch (curToken.getKind()) { 904 case Token::dot: 905 lhsExpr = parseMemberAccessExpr(*lhsExpr); 906 break; 907 default: 908 return lhsExpr; 909 } 910 if (failed(lhsExpr)) 911 return failure(); 912 } 913 } 914 915 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 916 SMRange loc = curToken.getLoc(); 917 consumeToken(Token::kw_attr); 918 919 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 920 // identifier. 921 if (!consumeIf(Token::less)) { 922 resetToken(loc); 923 return parseIdentifierExpr(); 924 } 925 926 if (!curToken.isString()) 927 return emitError("expected string literal containing MLIR attribute"); 928 std::string attrExpr = curToken.getStringValue(); 929 consumeToken(); 930 931 if (failed( 932 parseToken(Token::greater, "expected `>` after attribute literal"))) 933 return failure(); 934 return ast::AttributeExpr::create(ctx, loc, attrExpr); 935 } 936 937 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, 938 SMRange loc) { 939 ast::Decl *decl = curDeclScope->lookup(name); 940 if (!decl) 941 return emitError(loc, "undefined reference to `" + name + "`"); 942 943 return createDeclRefExpr(loc, decl); 944 } 945 946 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 947 StringRef name = curToken.getSpelling(); 948 SMRange nameLoc = curToken.getLoc(); 949 consumeToken(); 950 951 // Check to see if this is a decl ref expression that defines a variable 952 // inline. 953 if (consumeIf(Token::colon)) { 954 SmallVector<ast::ConstraintRef> constraints; 955 if (failed(parseVariableDeclConstraintList(constraints))) 956 return failure(); 957 ast::Type type; 958 if (failed(validateVariableConstraints(constraints, type))) 959 return failure(); 960 return createInlineVariableExpr(type, name, nameLoc, constraints); 961 } 962 963 return parseDeclRefExpr(name, nameLoc); 964 } 965 966 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 967 SMRange loc = curToken.getLoc(); 968 consumeToken(Token::dot); 969 970 // Parse the member name. 971 Token memberNameTok = curToken; 972 if (memberNameTok.isNot(Token::identifier, Token::integer) && 973 !memberNameTok.isKeyword()) 974 return emitError(loc, "expected identifier or numeric member name"); 975 StringRef memberName = memberNameTok.getSpelling(); 976 consumeToken(); 977 978 return createMemberAccessExpr(parentExpr, memberName, loc); 979 } 980 981 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 982 SMRange loc = curToken.getLoc(); 983 984 // Handle the case of an no operation name. 985 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 986 if (allowEmptyName) 987 return ast::OpNameDecl::create(ctx, SMRange()); 988 return emitError("expected dialect namespace"); 989 } 990 StringRef name = curToken.getSpelling(); 991 consumeToken(); 992 993 // Otherwise, this is a literal operation name. 994 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 995 return failure(); 996 997 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 998 return emitError("expected operation name after dialect namespace"); 999 1000 name = StringRef(name.data(), name.size() + 1); 1001 do { 1002 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1003 loc.End = curToken.getEndLoc(); 1004 consumeToken(); 1005 } while (curToken.isAny(Token::identifier, Token::dot) || 1006 curToken.isKeyword()); 1007 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1008 } 1009 1010 FailureOr<ast::OpNameDecl *> 1011 Parser::parseWrappedOperationName(bool allowEmptyName) { 1012 if (!consumeIf(Token::less)) 1013 return ast::OpNameDecl::create(ctx, SMRange()); 1014 1015 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1016 if (failed(opNameDecl)) 1017 return failure(); 1018 1019 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1020 return failure(); 1021 return opNameDecl; 1022 } 1023 1024 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1025 SMRange loc = curToken.getLoc(); 1026 consumeToken(Token::kw_op); 1027 1028 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1029 // identifier. 1030 if (curToken.isNot(Token::less)) { 1031 resetToken(loc); 1032 return parseIdentifierExpr(); 1033 } 1034 1035 // Parse the operation name. The name may be elided, in which case the 1036 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1037 // `Operation*`). Operation names within a rewrite context must be named. 1038 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1039 FailureOr<ast::OpNameDecl *> opNameDecl = 1040 parseWrappedOperationName(allowEmptyName); 1041 if (failed(opNameDecl)) 1042 return failure(); 1043 1044 // Check for the optional list of operands. 1045 SmallVector<ast::Expr *> operands; 1046 if (consumeIf(Token::l_paren)) { 1047 do { 1048 FailureOr<ast::Expr *> operand = parseExpr(); 1049 if (failed(operand)) 1050 return failure(); 1051 operands.push_back(*operand); 1052 } while (consumeIf(Token::comma)); 1053 1054 if (failed(parseToken(Token::r_paren, 1055 "expected `)` after operation operand list"))) 1056 return failure(); 1057 } 1058 1059 // Check for the optional list of attributes. 1060 SmallVector<ast::NamedAttributeDecl *> attributes; 1061 if (consumeIf(Token::l_brace)) { 1062 do { 1063 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1064 if (failed(decl)) 1065 return failure(); 1066 attributes.emplace_back(*decl); 1067 } while (consumeIf(Token::comma)); 1068 1069 if (failed(parseToken(Token::r_brace, 1070 "expected `}` after operation attribute list"))) 1071 return failure(); 1072 } 1073 1074 // Check for the optional list of result types. 1075 SmallVector<ast::Expr *> resultTypes; 1076 if (consumeIf(Token::arrow)) { 1077 if (failed(parseToken(Token::l_paren, 1078 "expected `(` before operation result type list"))) 1079 return failure(); 1080 1081 do { 1082 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1083 if (failed(resultTypeExpr)) 1084 return failure(); 1085 resultTypes.push_back(*resultTypeExpr); 1086 } while (consumeIf(Token::comma)); 1087 1088 if (failed(parseToken(Token::r_paren, 1089 "expected `)` after operation result type list"))) 1090 return failure(); 1091 } 1092 1093 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1094 resultTypes); 1095 } 1096 1097 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1098 SMRange loc = curToken.getLoc(); 1099 consumeToken(Token::l_paren); 1100 1101 DenseMap<StringRef, SMRange> usedNames; 1102 SmallVector<StringRef> elementNames; 1103 SmallVector<ast::Expr *> elements; 1104 if (curToken.isNot(Token::r_paren)) { 1105 do { 1106 // Check for the optional element name assignment before the value. 1107 StringRef elementName; 1108 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1109 Token elementNameTok = curToken; 1110 consumeToken(); 1111 1112 // The element name is only present if followed by an `=`. 1113 if (consumeIf(Token::equal)) { 1114 elementName = elementNameTok.getSpelling(); 1115 1116 // Check to see if this name is already used. 1117 auto elementNameIt = 1118 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1119 if (!elementNameIt.second) { 1120 return emitErrorAndNote( 1121 elementNameTok.getLoc(), 1122 llvm::formatv("duplicate tuple element label `{0}`", 1123 elementName), 1124 elementNameIt.first->getSecond(), 1125 "see previous label use here"); 1126 } 1127 } else { 1128 // Otherwise, we treat this as part of an expression so reset the 1129 // lexer. 1130 resetToken(elementNameTok.getLoc()); 1131 } 1132 } 1133 elementNames.push_back(elementName); 1134 1135 // Parse the tuple element value. 1136 FailureOr<ast::Expr *> element = parseExpr(); 1137 if (failed(element)) 1138 return failure(); 1139 elements.push_back(*element); 1140 } while (consumeIf(Token::comma)); 1141 } 1142 loc.End = curToken.getEndLoc(); 1143 if (failed( 1144 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1145 return failure(); 1146 return createTupleExpr(loc, elements, elementNames); 1147 } 1148 1149 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1150 SMRange loc = curToken.getLoc(); 1151 consumeToken(Token::kw_type); 1152 1153 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1154 // identifier. 1155 if (!consumeIf(Token::less)) { 1156 resetToken(loc); 1157 return parseIdentifierExpr(); 1158 } 1159 1160 if (!curToken.isString()) 1161 return emitError("expected string literal containing MLIR type"); 1162 std::string attrExpr = curToken.getStringValue(); 1163 consumeToken(); 1164 1165 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1166 return failure(); 1167 return ast::TypeExpr::create(ctx, loc, attrExpr); 1168 } 1169 1170 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 1171 StringRef name = curToken.getSpelling(); 1172 SMRange nameLoc = curToken.getLoc(); 1173 consumeToken(Token::underscore); 1174 1175 // Underscore expressions require a constraint list. 1176 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 1177 return failure(); 1178 1179 // Parse the constraints for the expression. 1180 SmallVector<ast::ConstraintRef> constraints; 1181 if (failed(parseVariableDeclConstraintList(constraints))) 1182 return failure(); 1183 1184 ast::Type type; 1185 if (failed(validateVariableConstraints(constraints, type))) 1186 return failure(); 1187 return createInlineVariableExpr(type, name, nameLoc, constraints); 1188 } 1189 1190 //===----------------------------------------------------------------------===// 1191 // Stmts 1192 1193 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 1194 FailureOr<ast::Stmt *> stmt; 1195 switch (curToken.getKind()) { 1196 case Token::kw_erase: 1197 stmt = parseEraseStmt(); 1198 break; 1199 case Token::kw_let: 1200 stmt = parseLetStmt(); 1201 break; 1202 case Token::kw_replace: 1203 stmt = parseReplaceStmt(); 1204 break; 1205 case Token::kw_rewrite: 1206 stmt = parseRewriteStmt(); 1207 break; 1208 default: 1209 stmt = parseExpr(); 1210 break; 1211 } 1212 if (failed(stmt) || 1213 (expectTerminalSemicolon && 1214 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 1215 return failure(); 1216 return stmt; 1217 } 1218 1219 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 1220 SMLoc startLoc = curToken.getStartLoc(); 1221 consumeToken(Token::l_brace); 1222 1223 // Push a new block scope and parse any nested statements. 1224 pushDeclScope(); 1225 SmallVector<ast::Stmt *> statements; 1226 while (curToken.isNot(Token::r_brace)) { 1227 FailureOr<ast::Stmt *> statement = parseStmt(); 1228 if (failed(statement)) 1229 return popDeclScope(), failure(); 1230 statements.push_back(*statement); 1231 } 1232 popDeclScope(); 1233 1234 // Consume the end brace. 1235 SMRange location(startLoc, curToken.getEndLoc()); 1236 consumeToken(Token::r_brace); 1237 1238 return ast::CompoundStmt::create(ctx, location, statements); 1239 } 1240 1241 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 1242 SMRange loc = curToken.getLoc(); 1243 consumeToken(Token::kw_erase); 1244 1245 // Parse the root operation expression. 1246 FailureOr<ast::Expr *> rootOp = parseExpr(); 1247 if (failed(rootOp)) 1248 return failure(); 1249 1250 return createEraseStmt(loc, *rootOp); 1251 } 1252 1253 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 1254 SMRange loc = curToken.getLoc(); 1255 consumeToken(Token::kw_let); 1256 1257 // Parse the name of the new variable. 1258 SMRange varLoc = curToken.getLoc(); 1259 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 1260 // `_` is a reserved variable name. 1261 if (curToken.is(Token::underscore)) { 1262 return emitError(varLoc, 1263 "`_` may only be used to define \"inline\" variables"); 1264 } 1265 return emitError(varLoc, 1266 "expected identifier after `let` to name a new variable"); 1267 } 1268 StringRef varName = curToken.getSpelling(); 1269 consumeToken(); 1270 1271 // Parse the optional set of constraints. 1272 SmallVector<ast::ConstraintRef> constraints; 1273 if (consumeIf(Token::colon) && 1274 failed(parseVariableDeclConstraintList(constraints))) 1275 return failure(); 1276 1277 // Parse the optional initializer expression. 1278 ast::Expr *initializer = nullptr; 1279 if (consumeIf(Token::equal)) { 1280 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 1281 if (failed(initOrFailure)) 1282 return failure(); 1283 initializer = *initOrFailure; 1284 1285 // Check that the constraints are compatible with having an initializer, 1286 // e.g. type constraints cannot be used with initializers. 1287 for (ast::ConstraintRef constraint : constraints) { 1288 LogicalResult result = 1289 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 1290 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 1291 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 1292 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 1293 return this->emitError( 1294 constraint.referenceLoc, 1295 "type constraints are not permitted on variables with " 1296 "initializers"); 1297 } 1298 return success(); 1299 }) 1300 .Default(success()); 1301 if (failed(result)) 1302 return failure(); 1303 } 1304 } 1305 1306 FailureOr<ast::VariableDecl *> varDecl = 1307 createVariableDecl(varName, varLoc, initializer, constraints); 1308 if (failed(varDecl)) 1309 return failure(); 1310 return ast::LetStmt::create(ctx, loc, *varDecl); 1311 } 1312 1313 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 1314 SMRange loc = curToken.getLoc(); 1315 consumeToken(Token::kw_replace); 1316 1317 // Parse the root operation expression. 1318 FailureOr<ast::Expr *> rootOp = parseExpr(); 1319 if (failed(rootOp)) 1320 return failure(); 1321 1322 if (failed( 1323 parseToken(Token::kw_with, "expected `with` after root operation"))) 1324 return failure(); 1325 1326 // The replacement portion of this statement is within a rewrite context. 1327 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1328 ParserContext::Rewrite); 1329 1330 // Parse the replacement values. 1331 SmallVector<ast::Expr *> replValues; 1332 if (consumeIf(Token::l_paren)) { 1333 if (consumeIf(Token::r_paren)) { 1334 return emitError( 1335 loc, "expected at least one replacement value, consider using " 1336 "`erase` if no replacement values are desired"); 1337 } 1338 1339 do { 1340 FailureOr<ast::Expr *> replExpr = parseExpr(); 1341 if (failed(replExpr)) 1342 return failure(); 1343 replValues.emplace_back(*replExpr); 1344 } while (consumeIf(Token::comma)); 1345 1346 if (failed(parseToken(Token::r_paren, 1347 "expected `)` after replacement values"))) 1348 return failure(); 1349 } else { 1350 FailureOr<ast::Expr *> replExpr = parseExpr(); 1351 if (failed(replExpr)) 1352 return failure(); 1353 replValues.emplace_back(*replExpr); 1354 } 1355 1356 return createReplaceStmt(loc, *rootOp, replValues); 1357 } 1358 1359 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 1360 SMRange loc = curToken.getLoc(); 1361 consumeToken(Token::kw_rewrite); 1362 1363 // Parse the root operation. 1364 FailureOr<ast::Expr *> rootOp = parseExpr(); 1365 if (failed(rootOp)) 1366 return failure(); 1367 1368 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 1369 return failure(); 1370 1371 if (curToken.isNot(Token::l_brace)) 1372 return emitError("expected `{` to start rewrite body"); 1373 1374 // The rewrite body of this statement is within a rewrite context. 1375 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1376 ParserContext::Rewrite); 1377 1378 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 1379 if (failed(rewriteBody)) 1380 return failure(); 1381 1382 return createRewriteStmt(loc, *rootOp, *rewriteBody); 1383 } 1384 1385 //===----------------------------------------------------------------------===// 1386 // Creation+Analysis 1387 //===----------------------------------------------------------------------===// 1388 1389 //===----------------------------------------------------------------------===// 1390 // Decls 1391 1392 FailureOr<ast::PatternDecl *> 1393 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 1394 const ParsedPatternMetadata &metadata, 1395 ast::CompoundStmt *body) { 1396 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 1397 metadata.hasBoundedRecursion, body); 1398 } 1399 1400 FailureOr<ast::VariableDecl *> 1401 Parser::createVariableDecl(StringRef name, SMRange loc, 1402 ast::Expr *initializer, 1403 ArrayRef<ast::ConstraintRef> constraints) { 1404 // The type of the variable, which is expected to be inferred by either a 1405 // constraint or an initializer expression. 1406 ast::Type type; 1407 if (failed(validateVariableConstraints(constraints, type))) 1408 return failure(); 1409 1410 if (initializer) { 1411 // Update the variable type based on the initializer, or try to convert the 1412 // initializer to the existing type. 1413 if (!type) 1414 type = initializer->getType(); 1415 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 1416 type = mergedType; 1417 else if (failed(convertExpressionTo(initializer, type))) 1418 return failure(); 1419 1420 // Otherwise, if there is no initializer check that the type has already 1421 // been resolved from the constraint list. 1422 } else if (!type) { 1423 return emitErrorAndNote( 1424 loc, "unable to infer type for variable `" + name + "`", loc, 1425 "the type of a variable must be inferable from the constraint " 1426 "list or the initializer"); 1427 } 1428 1429 // Try to define a variable with the given name. 1430 FailureOr<ast::VariableDecl *> varDecl = 1431 defineVariableDecl(name, loc, type, initializer, constraints); 1432 if (failed(varDecl)) 1433 return failure(); 1434 1435 return *varDecl; 1436 } 1437 1438 LogicalResult 1439 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 1440 ast::Type &inferredType) { 1441 for (const ast::ConstraintRef &ref : constraints) 1442 if (failed(validateVariableConstraint(ref, inferredType))) 1443 return failure(); 1444 return success(); 1445 } 1446 1447 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 1448 ast::Type &inferredType) { 1449 ast::Type constraintType; 1450 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 1451 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 1452 if (failed(validateTypeConstraintExpr(typeExpr))) 1453 return failure(); 1454 } 1455 constraintType = ast::AttributeType::get(ctx); 1456 } else if (const auto *cst = 1457 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 1458 constraintType = ast::OperationType::get(ctx, cst->getName()); 1459 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 1460 constraintType = typeTy; 1461 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 1462 constraintType = typeRangeTy; 1463 } else if (const auto *cst = 1464 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 1465 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 1466 if (failed(validateTypeConstraintExpr(typeExpr))) 1467 return failure(); 1468 } 1469 constraintType = valueTy; 1470 } else if (const auto *cst = 1471 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 1472 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 1473 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 1474 return failure(); 1475 } 1476 constraintType = valueRangeTy; 1477 } else { 1478 llvm_unreachable("unknown constraint type"); 1479 } 1480 1481 // Check that the constraint type is compatible with the current inferred 1482 // type. 1483 if (!inferredType) { 1484 inferredType = constraintType; 1485 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 1486 inferredType = mergedTy; 1487 } else { 1488 return emitError(ref.referenceLoc, 1489 llvm::formatv("constraint type `{0}` is incompatible " 1490 "with the previously inferred type `{1}`", 1491 constraintType, inferredType)); 1492 } 1493 return success(); 1494 } 1495 1496 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 1497 ast::Type typeExprType = typeExpr->getType(); 1498 if (typeExprType != typeTy) { 1499 return emitError(typeExpr->getLoc(), 1500 "expected expression of `Type` in type constraint"); 1501 } 1502 return success(); 1503 } 1504 1505 LogicalResult 1506 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 1507 ast::Type typeExprType = typeExpr->getType(); 1508 if (typeExprType != typeRangeTy) { 1509 return emitError(typeExpr->getLoc(), 1510 "expected expression of `TypeRange` in type constraint"); 1511 } 1512 return success(); 1513 } 1514 1515 //===----------------------------------------------------------------------===// 1516 // Exprs 1517 1518 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 1519 ast::Decl *decl) { 1520 // Check the type of decl being referenced. 1521 ast::Type declType; 1522 if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 1523 declType = varDecl->getType(); 1524 else 1525 return emitError(loc, "invalid reference to `" + 1526 decl->getName()->getName() + "`"); 1527 1528 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 1529 } 1530 1531 FailureOr<ast::DeclRefExpr *> 1532 Parser::createInlineVariableExpr(ast::Type type, StringRef name, 1533 SMRange loc, 1534 ArrayRef<ast::ConstraintRef> constraints) { 1535 FailureOr<ast::VariableDecl *> decl = 1536 defineVariableDecl(name, loc, type, constraints); 1537 if (failed(decl)) 1538 return failure(); 1539 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 1540 } 1541 1542 FailureOr<ast::MemberAccessExpr *> 1543 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 1544 SMRange loc) { 1545 // Validate the member name for the given parent expression. 1546 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 1547 if (failed(memberType)) 1548 return failure(); 1549 1550 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 1551 } 1552 1553 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 1554 StringRef name, 1555 SMRange loc) { 1556 ast::Type parentType = parentExpr->getType(); 1557 if (parentType.isa<ast::OperationType>()) { 1558 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 1559 return valueRangeTy; 1560 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 1561 // Handle indexed results. 1562 unsigned index = 0; 1563 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 1564 index < tupleType.size()) { 1565 return tupleType.getElementTypes()[index]; 1566 } 1567 1568 // Handle named results. 1569 auto elementNames = tupleType.getElementNames(); 1570 const auto *it = llvm::find(elementNames, name); 1571 if (it != elementNames.end()) 1572 return tupleType.getElementTypes()[it - elementNames.begin()]; 1573 } 1574 return emitError( 1575 loc, 1576 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 1577 name, parentType)); 1578 } 1579 1580 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 1581 SMRange loc, const ast::OpNameDecl *name, 1582 MutableArrayRef<ast::Expr *> operands, 1583 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 1584 MutableArrayRef<ast::Expr *> results) { 1585 Optional<StringRef> opNameRef = name->getName(); 1586 1587 // Verify the inputs operands. 1588 if (failed(validateOperationOperands(loc, opNameRef, operands))) 1589 return failure(); 1590 1591 // Verify the attribute list. 1592 for (ast::NamedAttributeDecl *attr : attributes) { 1593 // Check for an attribute type, or a type awaiting resolution. 1594 ast::Type attrType = attr->getValue()->getType(); 1595 if (!attrType.isa<ast::AttributeType>()) { 1596 return emitError( 1597 attr->getValue()->getLoc(), 1598 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 1599 } 1600 } 1601 1602 // Verify the result types. 1603 if (failed(validateOperationResults(loc, opNameRef, results))) 1604 return failure(); 1605 1606 return ast::OperationExpr::create(ctx, loc, name, operands, results, 1607 attributes); 1608 } 1609 1610 LogicalResult 1611 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 1612 MutableArrayRef<ast::Expr *> operands) { 1613 return validateOperationOperandsOrResults(loc, name, operands, valueTy, 1614 valueRangeTy); 1615 } 1616 1617 LogicalResult 1618 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 1619 MutableArrayRef<ast::Expr *> results) { 1620 return validateOperationOperandsOrResults(loc, name, results, typeTy, 1621 typeRangeTy); 1622 } 1623 1624 LogicalResult Parser::validateOperationOperandsOrResults( 1625 SMRange loc, Optional<StringRef> name, 1626 MutableArrayRef<ast::Expr *> values, ast::Type singleTy, 1627 ast::Type rangeTy) { 1628 // All operation types accept a single range parameter. 1629 if (values.size() == 1) { 1630 if (failed(convertExpressionTo(values[0], rangeTy))) 1631 return failure(); 1632 return success(); 1633 } 1634 1635 // Otherwise, accept the value groups as they have been defined and just 1636 // ensure they are one of the expected types. 1637 for (ast::Expr *&valueExpr : values) { 1638 ast::Type valueExprType = valueExpr->getType(); 1639 1640 // Check if this is one of the expected types. 1641 if (valueExprType == rangeTy || valueExprType == singleTy) 1642 continue; 1643 1644 // If the operand is an Operation, allow converting to a Value or 1645 // ValueRange. This situations arises quite often with nested operation 1646 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 1647 if (singleTy == valueTy) { 1648 if (valueExprType.isa<ast::OperationType>()) { 1649 valueExpr = convertOpToValue(valueExpr); 1650 continue; 1651 } 1652 } 1653 1654 return emitError( 1655 valueExpr->getLoc(), 1656 llvm::formatv( 1657 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 1658 singleTy, rangeTy, valueExprType)); 1659 } 1660 return success(); 1661 } 1662 1663 FailureOr<ast::TupleExpr *> 1664 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 1665 ArrayRef<StringRef> elementNames) { 1666 for (const ast::Expr *element : elements) { 1667 ast::Type eleTy = element->getType(); 1668 if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) { 1669 return emitError( 1670 element->getLoc(), 1671 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 1672 } 1673 } 1674 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 1675 } 1676 1677 //===----------------------------------------------------------------------===// 1678 // Stmts 1679 1680 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 1681 ast::Expr *rootOp) { 1682 // Check that root is an Operation. 1683 ast::Type rootType = rootOp->getType(); 1684 if (!rootType.isa<ast::OperationType>()) 1685 return emitError(rootOp->getLoc(), "expected `Op` expression"); 1686 1687 return ast::EraseStmt::create(ctx, loc, rootOp); 1688 } 1689 1690 FailureOr<ast::ReplaceStmt *> 1691 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 1692 MutableArrayRef<ast::Expr *> replValues) { 1693 // Check that root is an Operation. 1694 ast::Type rootType = rootOp->getType(); 1695 if (!rootType.isa<ast::OperationType>()) { 1696 return emitError( 1697 rootOp->getLoc(), 1698 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 1699 } 1700 1701 // If there are multiple replacement values, we implicitly convert any Op 1702 // expressions to the value form. 1703 bool shouldConvertOpToValues = replValues.size() > 1; 1704 for (ast::Expr *&replExpr : replValues) { 1705 ast::Type replType = replExpr->getType(); 1706 1707 // Check that replExpr is an Operation, Value, or ValueRange. 1708 if (replType.isa<ast::OperationType>()) { 1709 if (shouldConvertOpToValues) 1710 replExpr = convertOpToValue(replExpr); 1711 continue; 1712 } 1713 1714 if (replType != valueTy && replType != valueRangeTy) { 1715 return emitError(replExpr->getLoc(), 1716 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 1717 "expression, but got `{0}`", 1718 replType)); 1719 } 1720 } 1721 1722 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 1723 } 1724 1725 FailureOr<ast::RewriteStmt *> 1726 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 1727 ast::CompoundStmt *rewriteBody) { 1728 // Check that root is an Operation. 1729 ast::Type rootType = rootOp->getType(); 1730 if (!rootType.isa<ast::OperationType>()) { 1731 return emitError( 1732 rootOp->getLoc(), 1733 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 1734 } 1735 1736 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 1737 } 1738 1739 //===----------------------------------------------------------------------===// 1740 // Parser 1741 //===----------------------------------------------------------------------===// 1742 1743 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 1744 llvm::SourceMgr &sourceMgr) { 1745 Parser parser(ctx, sourceMgr); 1746 return parser.parseModule(); 1747 } 1748