1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/Tools/PDLL/AST/Context.h" 13 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 14 #include "mlir/Tools/PDLL/AST/Nodes.h" 15 #include "mlir/Tools/PDLL/AST/Types.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/SaveAndRestore.h" 20 #include "llvm/Support/ScopedPrinter.h" 21 #include <string> 22 23 using namespace mlir; 24 using namespace mlir::pdll; 25 26 //===----------------------------------------------------------------------===// 27 // Parser 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 class Parser { 32 public: 33 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 34 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 35 curToken(lexer.lexToken()), curDeclScope(nullptr), 36 valueTy(ast::ValueType::get(ctx)), 37 valueRangeTy(ast::ValueRangeType::get(ctx)), 38 typeTy(ast::TypeType::get(ctx)), 39 typeRangeTy(ast::TypeRangeType::get(ctx)) {} 40 41 /// Try to parse a new module. Returns nullptr in the case of failure. 42 FailureOr<ast::Module *> parseModule(); 43 44 private: 45 /// The current context of the parser. It allows for the parser to know a bit 46 /// about the construct it is nested within during parsing. This is used 47 /// specifically to provide additional verification during parsing, e.g. to 48 /// prevent using rewrites within a match context, matcher constraints within 49 /// a rewrite section, etc. 50 enum class ParserContext { 51 /// The parser is in the global context. 52 Global, 53 /// The parser is currently within a Constraint, which disallows all types 54 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 55 Constraint, 56 /// The parser is currently within the matcher portion of a Pattern, which 57 /// is allows a terminal operation rewrite statement but no other rewrite 58 /// transformations. 59 PatternMatch, 60 /// The parser is currently within a Rewrite, which disallows calls to 61 /// constraints, requires operation expressions to have names, etc. 62 Rewrite, 63 }; 64 65 //===--------------------------------------------------------------------===// 66 // Parsing 67 //===--------------------------------------------------------------------===// 68 69 /// Push a new decl scope onto the lexer. 70 ast::DeclScope *pushDeclScope() { 71 ast::DeclScope *newScope = 72 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 73 return (curDeclScope = newScope); 74 } 75 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 76 77 /// Pop the last decl scope from the lexer. 78 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 79 80 /// Parse the body of an AST module. 81 LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls); 82 83 /// Try to convert the given expression to `type`. Returns failure and emits 84 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 85 /// invoked to attach notes to the emitted error diagnostic. On success, 86 /// `expr` is updated to the expression used to convert to `type`. 87 LogicalResult convertExpressionTo( 88 ast::Expr *&expr, ast::Type type, 89 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 90 91 /// Given an operation expression, convert it to a Value or ValueRange 92 /// typed expression. 93 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 94 95 //===--------------------------------------------------------------------===// 96 // Directives 97 98 LogicalResult parseDirective(SmallVector<ast::Decl *> &decls); 99 LogicalResult parseInclude(SmallVector<ast::Decl *> &decls); 100 101 //===--------------------------------------------------------------------===// 102 // Decls 103 104 /// This structure contains the set of pattern metadata that may be parsed. 105 struct ParsedPatternMetadata { 106 Optional<uint16_t> benefit; 107 bool hasBoundedRecursion = false; 108 }; 109 110 FailureOr<ast::Decl *> parseTopLevelDecl(); 111 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 112 113 /// Parse an argument variable as part of the signature of a 114 /// UserConstraintDecl or UserRewriteDecl. 115 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 116 117 /// Parse a result variable as part of the signature of a UserConstraintDecl 118 /// or UserRewriteDecl. 119 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 120 121 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 122 /// defined in a non-global context. 123 FailureOr<ast::UserConstraintDecl *> 124 parseUserConstraintDecl(bool isInline = false); 125 126 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 127 /// non-global context, such as within a Pattern/Constraint/etc. 128 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 129 130 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 131 /// PDLL constructs. 132 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 133 const ast::Name &name, bool isInline, 134 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 135 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 136 137 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 138 /// defined in a non-global context. 139 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 140 141 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 142 /// non-global context, such as within a Pattern/Rewrite/etc. 143 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 144 145 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 146 /// PDLL constructs. 147 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 148 const ast::Name &name, bool isInline, 149 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 150 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 151 152 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 153 /// effectively the same syntax, and only differ on slight semantics (given 154 /// the different parsing contexts). 155 template <typename T, typename ParseUserPDLLDeclFnT> 156 FailureOr<T *> parseUserConstraintOrRewriteDecl( 157 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 158 StringRef anonymousNamePrefix, bool isInline); 159 160 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 161 /// These decls have effectively the same syntax. 162 template <typename T> 163 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 164 const ast::Name &name, bool isInline, 165 ArrayRef<ast::VariableDecl *> arguments, 166 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 167 168 /// Parse the functional signature (i.e. the arguments and results) of a 169 /// UserConstraintDecl or UserRewriteDecl. 170 LogicalResult parseUserConstraintOrRewriteSignature( 171 SmallVectorImpl<ast::VariableDecl *> &arguments, 172 SmallVectorImpl<ast::VariableDecl *> &results, 173 ast::DeclScope *&argumentScope, ast::Type &resultType); 174 175 /// Validate the return (which if present is specified by bodyIt) of a 176 /// UserConstraintDecl or UserRewriteDecl. 177 LogicalResult validateUserConstraintOrRewriteReturn( 178 StringRef declType, ast::CompoundStmt *body, 179 ArrayRef<ast::Stmt *>::iterator bodyIt, 180 ArrayRef<ast::Stmt *>::iterator bodyE, 181 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 182 183 FailureOr<ast::CompoundStmt *> 184 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 185 bool expectTerminalSemicolon = true); 186 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 187 FailureOr<ast::Decl *> parsePatternDecl(); 188 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 189 190 /// Check to see if a decl has already been defined with the given name, if 191 /// one has emit and error and return failure. Returns success otherwise. 192 LogicalResult checkDefineNamedDecl(const ast::Name &name); 193 194 /// Try to define a variable decl with the given components, returns the 195 /// variable on success. 196 FailureOr<ast::VariableDecl *> 197 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 198 ast::Expr *initExpr, 199 ArrayRef<ast::ConstraintRef> constraints); 200 FailureOr<ast::VariableDecl *> 201 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 202 ArrayRef<ast::ConstraintRef> constraints); 203 204 /// Parse the constraint reference list for a variable decl. 205 LogicalResult parseVariableDeclConstraintList( 206 SmallVectorImpl<ast::ConstraintRef> &constraints); 207 208 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 209 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 210 211 /// Try to parse a single reference to a constraint. `typeConstraint` is the 212 /// location of a previously parsed type constraint for the entity that will 213 /// be constrained by the parsed constraint. `existingConstraints` are any 214 /// existing constraints that have already been parsed for the same entity 215 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 216 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 217 FailureOr<ast::ConstraintRef> 218 parseConstraint(Optional<SMRange> &typeConstraint, 219 ArrayRef<ast::ConstraintRef> existingConstraints, 220 bool allowInlineTypeConstraints); 221 222 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 223 /// argument or result variable. The constraints for these variables do not 224 /// allow inline type constraints, and only permit a single constraint. 225 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 226 227 //===--------------------------------------------------------------------===// 228 // Exprs 229 230 FailureOr<ast::Expr *> parseExpr(); 231 232 /// Identifier expressions. 233 FailureOr<ast::Expr *> parseAttributeExpr(); 234 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 235 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 236 FailureOr<ast::Expr *> parseIdentifierExpr(); 237 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 238 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 239 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 240 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 241 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 242 FailureOr<ast::Expr *> parseOperationExpr(); 243 FailureOr<ast::Expr *> parseTupleExpr(); 244 FailureOr<ast::Expr *> parseTypeExpr(); 245 FailureOr<ast::Expr *> parseUnderscoreExpr(); 246 247 //===--------------------------------------------------------------------===// 248 // Stmts 249 250 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 251 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 252 FailureOr<ast::EraseStmt *> parseEraseStmt(); 253 FailureOr<ast::LetStmt *> parseLetStmt(); 254 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 255 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 256 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 257 258 //===--------------------------------------------------------------------===// 259 // Creation+Analysis 260 //===--------------------------------------------------------------------===// 261 262 //===--------------------------------------------------------------------===// 263 // Decls 264 265 /// Try to extract a callable from the given AST node. Returns nullptr on 266 /// failure. 267 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 268 269 /// Try to create a pattern decl with the given components, returning the 270 /// Pattern on success. 271 FailureOr<ast::PatternDecl *> 272 createPatternDecl(SMRange loc, const ast::Name *name, 273 const ParsedPatternMetadata &metadata, 274 ast::CompoundStmt *body); 275 276 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 277 /// of results, defined as part of the signature. 278 ast::Type 279 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 280 281 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 282 template <typename T> 283 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 284 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 285 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 286 ast::CompoundStmt *body); 287 288 /// Try to create a variable decl with the given components, returning the 289 /// Variable on success. 290 FailureOr<ast::VariableDecl *> 291 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 292 ArrayRef<ast::ConstraintRef> constraints); 293 294 /// Create a variable for an argument or result defined as part of the 295 /// signature of a UserConstraintDecl/UserRewriteDecl. 296 FailureOr<ast::VariableDecl *> 297 createArgOrResultVariableDecl(StringRef name, SMRange loc, 298 const ast::ConstraintRef &constraint); 299 300 /// Validate the constraints used to constraint a variable decl. 301 /// `inferredType` is the type of the variable inferred by the constraints 302 /// within the list, and is updated to the most refined type as determined by 303 /// the constraints. Returns success if the constraint list is valid, failure 304 /// otherwise. 305 LogicalResult 306 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 307 ast::Type &inferredType); 308 /// Validate a single reference to a constraint. `inferredType` contains the 309 /// currently inferred variabled type and is refined within the type defined 310 /// by the constraint. Returns success if the constraint is valid, failure 311 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 312 /// defined constraints) may be used with the variable. 313 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 314 ast::Type &inferredType, 315 bool allowNonCoreConstraints = true); 316 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 317 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 318 319 //===--------------------------------------------------------------------===// 320 // Exprs 321 322 FailureOr<ast::CallExpr *> 323 createCallExpr(SMRange loc, ast::Expr *parentExpr, 324 MutableArrayRef<ast::Expr *> arguments); 325 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 326 FailureOr<ast::DeclRefExpr *> 327 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 328 ArrayRef<ast::ConstraintRef> constraints); 329 FailureOr<ast::MemberAccessExpr *> 330 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 331 332 /// Validate the member access `name` into the given parent expression. On 333 /// success, this also returns the type of the member accessed. 334 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 335 StringRef name, SMRange loc); 336 FailureOr<ast::OperationExpr *> 337 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 338 MutableArrayRef<ast::Expr *> operands, 339 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 340 MutableArrayRef<ast::Expr *> results); 341 LogicalResult 342 validateOperationOperands(SMRange loc, Optional<StringRef> name, 343 MutableArrayRef<ast::Expr *> operands); 344 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 345 MutableArrayRef<ast::Expr *> results); 346 LogicalResult 347 validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name, 348 MutableArrayRef<ast::Expr *> values, 349 ast::Type singleTy, ast::Type rangeTy); 350 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 351 ArrayRef<ast::Expr *> elements, 352 ArrayRef<StringRef> elementNames); 353 354 //===--------------------------------------------------------------------===// 355 // Stmts 356 357 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 358 FailureOr<ast::ReplaceStmt *> 359 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 360 MutableArrayRef<ast::Expr *> replValues); 361 FailureOr<ast::RewriteStmt *> 362 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 363 ast::CompoundStmt *rewriteBody); 364 365 //===--------------------------------------------------------------------===// 366 // Lexer Utilities 367 //===--------------------------------------------------------------------===// 368 369 /// If the current token has the specified kind, consume it and return true. 370 /// If not, return false. 371 bool consumeIf(Token::Kind kind) { 372 if (curToken.isNot(kind)) 373 return false; 374 consumeToken(kind); 375 return true; 376 } 377 378 /// Advance the current lexer onto the next token. 379 void consumeToken() { 380 assert(curToken.isNot(Token::eof, Token::error) && 381 "shouldn't advance past EOF or errors"); 382 curToken = lexer.lexToken(); 383 } 384 385 /// Advance the current lexer onto the next token, asserting what the expected 386 /// current token is. This is preferred to the above method because it leads 387 /// to more self-documenting code with better checking. 388 void consumeToken(Token::Kind kind) { 389 assert(curToken.is(kind) && "consumed an unexpected token"); 390 consumeToken(); 391 } 392 393 /// Reset the lexer to the location at the given position. 394 void resetToken(SMRange tokLoc) { 395 lexer.resetPointer(tokLoc.Start.getPointer()); 396 curToken = lexer.lexToken(); 397 } 398 399 /// Consume the specified token if present and return success. On failure, 400 /// output a diagnostic and return failure. 401 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 402 if (curToken.getKind() != kind) 403 return emitError(curToken.getLoc(), msg); 404 consumeToken(); 405 return success(); 406 } 407 LogicalResult emitError(SMRange loc, const Twine &msg) { 408 lexer.emitError(loc, msg); 409 return failure(); 410 } 411 LogicalResult emitError(const Twine &msg) { 412 return emitError(curToken.getLoc(), msg); 413 } 414 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 415 const Twine ¬e) { 416 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 417 return failure(); 418 } 419 420 //===--------------------------------------------------------------------===// 421 // Fields 422 //===--------------------------------------------------------------------===// 423 424 /// The owning AST context. 425 ast::Context &ctx; 426 427 /// The lexer of this parser. 428 Lexer lexer; 429 430 /// The current token within the lexer. 431 Token curToken; 432 433 /// The most recently defined decl scope. 434 ast::DeclScope *curDeclScope; 435 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 436 437 /// The current context of the parser. 438 ParserContext parserContext = ParserContext::Global; 439 440 /// Cached types to simplify verification and expression creation. 441 ast::Type valueTy, valueRangeTy; 442 ast::Type typeTy, typeRangeTy; 443 444 /// A counter used when naming anonymous constraints and rewrites. 445 unsigned anonymousDeclNameCounter = 0; 446 }; 447 } // namespace 448 449 FailureOr<ast::Module *> Parser::parseModule() { 450 SMLoc moduleLoc = curToken.getStartLoc(); 451 pushDeclScope(); 452 453 // Parse the top-level decls of the module. 454 SmallVector<ast::Decl *> decls; 455 if (failed(parseModuleBody(decls))) 456 return popDeclScope(), failure(); 457 458 popDeclScope(); 459 return ast::Module::create(ctx, moduleLoc, decls); 460 } 461 462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) { 463 while (curToken.isNot(Token::eof)) { 464 if (curToken.is(Token::directive)) { 465 if (failed(parseDirective(decls))) 466 return failure(); 467 continue; 468 } 469 470 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 471 if (failed(decl)) 472 return failure(); 473 decls.push_back(*decl); 474 } 475 return success(); 476 } 477 478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 479 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 480 valueRangeTy); 481 } 482 483 LogicalResult Parser::convertExpressionTo( 484 ast::Expr *&expr, ast::Type type, 485 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 486 ast::Type exprType = expr->getType(); 487 if (exprType == type) 488 return success(); 489 490 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 491 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 492 expr->getLoc(), llvm::formatv("unable to convert expression of type " 493 "`{0}` to the expected type of " 494 "`{1}`", 495 exprType, type)); 496 if (noteAttachFn) 497 noteAttachFn(*diag); 498 return diag; 499 }; 500 501 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 502 // Two operation types are compatible if they have the same name, or if the 503 // expected type is more general. 504 if (auto opType = type.dyn_cast<ast::OperationType>()) { 505 if (opType.getName()) 506 return emitConvertError(); 507 return success(); 508 } 509 510 // An operation can always convert to a ValueRange. 511 if (type == valueRangeTy) { 512 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 513 valueRangeTy); 514 return success(); 515 } 516 517 // Allow conversion to a single value by constraining the result range. 518 if (type == valueTy) { 519 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 520 valueTy); 521 return success(); 522 } 523 return emitConvertError(); 524 } 525 526 // FIXME: Decide how to allow/support converting a single result to multiple, 527 // and multiple to a single result. For now, we just allow Single->Range, 528 // but this isn't something really supported in the PDL dialect. We should 529 // figure out some way to support both. 530 if ((exprType == valueTy || exprType == valueRangeTy) && 531 (type == valueTy || type == valueRangeTy)) 532 return success(); 533 if ((exprType == typeTy || exprType == typeRangeTy) && 534 (type == typeTy || type == typeRangeTy)) 535 return success(); 536 537 // Handle tuple types. 538 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 539 auto tupleType = type.dyn_cast<ast::TupleType>(); 540 if (!tupleType || tupleType.size() != exprTupleType.size()) 541 return emitConvertError(); 542 543 // Build a new tuple expression using each of the elements of the current 544 // tuple. 545 SmallVector<ast::Expr *> newExprs; 546 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 547 newExprs.push_back(ast::MemberAccessExpr::create( 548 ctx, expr->getLoc(), expr, llvm::to_string(i), 549 exprTupleType.getElementTypes()[i])); 550 551 auto diagFn = [&](ast::Diagnostic &diag) { 552 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 553 i, exprTupleType)); 554 if (noteAttachFn) 555 noteAttachFn(diag); 556 }; 557 if (failed(convertExpressionTo(newExprs.back(), 558 tupleType.getElementTypes()[i], diagFn))) 559 return failure(); 560 } 561 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 562 tupleType.getElementNames()); 563 return success(); 564 } 565 566 return emitConvertError(); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // Directives 571 572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) { 573 StringRef directive = curToken.getSpelling(); 574 if (directive == "#include") 575 return parseInclude(decls); 576 577 return emitError("unknown directive `" + directive + "`"); 578 } 579 580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) { 581 SMRange loc = curToken.getLoc(); 582 consumeToken(Token::directive); 583 584 // Parse the file being included. 585 if (!curToken.isString()) 586 return emitError(loc, 587 "expected string file name after `include` directive"); 588 SMRange fileLoc = curToken.getLoc(); 589 std::string filenameStr = curToken.getStringValue(); 590 StringRef filename = filenameStr; 591 consumeToken(); 592 593 // Check the type of include. If ending with `.pdll`, this is another pdl file 594 // to be parsed along with the current module. 595 if (filename.endswith(".pdll")) { 596 if (failed(lexer.pushInclude(filename))) 597 return emitError(fileLoc, 598 "unable to open include file `" + filename + "`"); 599 600 // If we added the include successfully, parse it into the current module. 601 // Make sure to save the current token so that we can restore it when we 602 // finish parsing the nested file. 603 Token oldToken = curToken; 604 curToken = lexer.lexToken(); 605 LogicalResult result = parseModuleBody(decls); 606 curToken = oldToken; 607 return result; 608 } 609 610 return emitError(fileLoc, "expected include filename to end with `.pdll`"); 611 } 612 613 //===----------------------------------------------------------------------===// 614 // Decls 615 616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 617 FailureOr<ast::Decl *> decl; 618 switch (curToken.getKind()) { 619 case Token::kw_Constraint: 620 decl = parseUserConstraintDecl(); 621 break; 622 case Token::kw_Pattern: 623 decl = parsePatternDecl(); 624 break; 625 case Token::kw_Rewrite: 626 decl = parseUserRewriteDecl(); 627 break; 628 default: 629 return emitError("expected top-level declaration, such as a `Pattern`"); 630 } 631 if (failed(decl)) 632 return failure(); 633 634 // If the decl has a name, add it to the current scope. 635 if (const ast::Name *name = (*decl)->getName()) { 636 if (failed(checkDefineNamedDecl(*name))) 637 return failure(); 638 curDeclScope->add(*decl); 639 } 640 return decl; 641 } 642 643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 644 std::string attrNameStr; 645 if (curToken.isString()) 646 attrNameStr = curToken.getStringValue(); 647 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 648 attrNameStr = curToken.getSpelling().str(); 649 else 650 return emitError("expected identifier or string attribute name"); 651 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 652 consumeToken(); 653 654 // Check for a value of the attribute. 655 ast::Expr *attrValue = nullptr; 656 if (consumeIf(Token::equal)) { 657 FailureOr<ast::Expr *> attrExpr = parseExpr(); 658 if (failed(attrExpr)) 659 return failure(); 660 attrValue = *attrExpr; 661 } else { 662 // If there isn't a concrete value, create an expression representing a 663 // UnitAttr. 664 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 665 } 666 667 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 668 } 669 670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 671 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 672 bool expectTerminalSemicolon) { 673 consumeToken(Token::equal_arrow); 674 675 // Parse the single statement of the lambda body. 676 SMLoc bodyStartLoc = curToken.getStartLoc(); 677 pushDeclScope(); 678 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 679 bool failedToParse = 680 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 681 popDeclScope(); 682 if (failedToParse) 683 return failure(); 684 685 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 686 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 687 } 688 689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 690 // Ensure that the argument is named. 691 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 692 return emitError("expected identifier argument name"); 693 694 // Parse the argument similarly to a normal variable. 695 StringRef name = curToken.getSpelling(); 696 SMRange nameLoc = curToken.getLoc(); 697 consumeToken(); 698 699 if (failed( 700 parseToken(Token::colon, "expected `:` before argument constraint"))) 701 return failure(); 702 703 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 704 if (failed(cst)) 705 return failure(); 706 707 return createArgOrResultVariableDecl(name, nameLoc, *cst); 708 } 709 710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 711 // Check to see if this result is named. 712 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 713 // Check to see if this name actually refers to a Constraint. 714 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 715 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 716 // If yes, and this is a Rewrite, give a nice error message as non-Core 717 // constraints are not supported on Rewrite results. 718 if (parserContext == ParserContext::Rewrite) { 719 return emitError( 720 "`Rewrite` results are only permitted to use core constraints, " 721 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 722 } 723 724 // Otherwise, parse this as an unnamed result variable. 725 } else { 726 // If it wasn't a constraint, parse the result similarly to a variable. If 727 // there is already an existing decl, we will emit an error when defining 728 // this variable later. 729 StringRef name = curToken.getSpelling(); 730 SMRange nameLoc = curToken.getLoc(); 731 consumeToken(); 732 733 if (failed(parseToken(Token::colon, 734 "expected `:` before result constraint"))) 735 return failure(); 736 737 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 738 if (failed(cst)) 739 return failure(); 740 741 return createArgOrResultVariableDecl(name, nameLoc, *cst); 742 } 743 } 744 745 // If it isn't named, we parse the constraint directly and create an unnamed 746 // result variable. 747 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 748 if (failed(cst)) 749 return failure(); 750 751 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 752 } 753 754 FailureOr<ast::UserConstraintDecl *> 755 Parser::parseUserConstraintDecl(bool isInline) { 756 // Constraints and rewrites have very similar formats, dispatch to a shared 757 // interface for parsing. 758 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 759 [&](auto &&...args) { return parseUserPDLLConstraintDecl(args...); }, 760 ParserContext::Constraint, "constraint", isInline); 761 } 762 763 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 764 FailureOr<ast::UserConstraintDecl *> decl = 765 parseUserConstraintDecl(/*isInline=*/true); 766 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 767 return failure(); 768 769 curDeclScope->add(*decl); 770 return decl; 771 } 772 773 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 774 const ast::Name &name, bool isInline, 775 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 776 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 777 // Push the argument scope back onto the list, so that the body can 778 // reference arguments. 779 pushDeclScope(argumentScope); 780 781 // Parse the body of the constraint. The body is either defined as a compound 782 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 783 ast::CompoundStmt *body; 784 if (curToken.is(Token::equal_arrow)) { 785 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 786 [&](ast::Stmt *&stmt) -> LogicalResult { 787 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 788 if (!stmtExpr) { 789 return emitError(stmt->getLoc(), 790 "expected `Constraint` lambda body to contain a " 791 "single expression"); 792 } 793 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 794 return success(); 795 }, 796 /*expectTerminalSemicolon=*/!isInline); 797 if (failed(bodyResult)) 798 return failure(); 799 body = *bodyResult; 800 } else { 801 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 802 if (failed(bodyResult)) 803 return failure(); 804 body = *bodyResult; 805 806 // Verify the structure of the body. 807 auto bodyIt = body->begin(), bodyE = body->end(); 808 for (; bodyIt != bodyE; ++bodyIt) 809 if (isa<ast::ReturnStmt>(*bodyIt)) 810 break; 811 if (failed(validateUserConstraintOrRewriteReturn( 812 "Constraint", body, bodyIt, bodyE, results, resultType))) 813 return failure(); 814 } 815 popDeclScope(); 816 817 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 818 name, arguments, results, resultType, body); 819 } 820 821 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 822 // Constraints and rewrites have very similar formats, dispatch to a shared 823 // interface for parsing. 824 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 825 [&](auto &&...args) { return parseUserPDLLRewriteDecl(args...); }, 826 ParserContext::Rewrite, "rewrite", isInline); 827 } 828 829 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 830 FailureOr<ast::UserRewriteDecl *> decl = 831 parseUserRewriteDecl(/*isInline=*/true); 832 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 833 return failure(); 834 835 curDeclScope->add(*decl); 836 return decl; 837 } 838 839 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 840 const ast::Name &name, bool isInline, 841 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 842 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 843 // Push the argument scope back onto the list, so that the body can 844 // reference arguments. 845 curDeclScope = argumentScope; 846 ast::CompoundStmt *body; 847 if (curToken.is(Token::equal_arrow)) { 848 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 849 [&](ast::Stmt *&statement) -> LogicalResult { 850 if (isa<ast::OpRewriteStmt>(statement)) 851 return success(); 852 853 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 854 if (!statementExpr) { 855 return emitError( 856 statement->getLoc(), 857 "expected `Rewrite` lambda body to contain a single expression " 858 "or an operation rewrite statement; such as `erase`, " 859 "`replace`, or `rewrite`"); 860 } 861 statement = 862 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 863 return success(); 864 }, 865 /*expectTerminalSemicolon=*/!isInline); 866 if (failed(bodyResult)) 867 return failure(); 868 body = *bodyResult; 869 } else { 870 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 871 if (failed(bodyResult)) 872 return failure(); 873 body = *bodyResult; 874 } 875 popDeclScope(); 876 877 // Verify the structure of the body. 878 auto bodyIt = body->begin(), bodyE = body->end(); 879 for (; bodyIt != bodyE; ++bodyIt) 880 if (isa<ast::ReturnStmt>(*bodyIt)) 881 break; 882 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 883 bodyE, results, resultType))) 884 return failure(); 885 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 886 name, arguments, results, resultType, body); 887 } 888 889 template <typename T, typename ParseUserPDLLDeclFnT> 890 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 891 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 892 StringRef anonymousNamePrefix, bool isInline) { 893 SMRange loc = curToken.getLoc(); 894 consumeToken(); 895 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 896 897 // Parse the name of the decl. 898 const ast::Name *name = nullptr; 899 if (curToken.isNot(Token::identifier)) { 900 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 901 // in C++, so being unnamed is fine. 902 if (!isInline) 903 return emitError("expected identifier name"); 904 905 // Create a unique anonymous name to use, as the name for this decl is not 906 // important. 907 std::string anonName = 908 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 909 anonymousDeclNameCounter++) 910 .str(); 911 name = &ast::Name::create(ctx, anonName, loc); 912 } else { 913 // If a name was provided, we can use it directly. 914 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 915 consumeToken(Token::identifier); 916 } 917 918 // Parse the functional signature of the decl. 919 SmallVector<ast::VariableDecl *> arguments, results; 920 ast::DeclScope *argumentScope; 921 ast::Type resultType; 922 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 923 argumentScope, resultType))) 924 return failure(); 925 926 // Check to see which type of constraint this is. If the constraint contains a 927 // compound body, this is a PDLL decl. 928 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 929 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 930 resultType); 931 932 // Otherwise, this is a native decl. 933 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 934 results, resultType); 935 } 936 937 template <typename T> 938 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 939 const ast::Name &name, bool isInline, 940 ArrayRef<ast::VariableDecl *> arguments, 941 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 942 // If followed by a string, the native code body has also been specified. 943 std::string codeStrStorage; 944 Optional<StringRef> optCodeStr; 945 if (curToken.isString()) { 946 codeStrStorage = curToken.getStringValue(); 947 optCodeStr = codeStrStorage; 948 consumeToken(); 949 } else if (isInline) { 950 return emitError(name.getLoc(), 951 "external declarations must be declared in global scope"); 952 } 953 if (failed(parseToken(Token::semicolon, 954 "expected `;` after native declaration"))) 955 return failure(); 956 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 957 } 958 959 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 960 SmallVectorImpl<ast::VariableDecl *> &arguments, 961 SmallVectorImpl<ast::VariableDecl *> &results, 962 ast::DeclScope *&argumentScope, ast::Type &resultType) { 963 // Parse the argument list of the decl. 964 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 965 return failure(); 966 967 argumentScope = pushDeclScope(); 968 if (curToken.isNot(Token::r_paren)) { 969 do { 970 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 971 if (failed(argument)) 972 return failure(); 973 arguments.emplace_back(*argument); 974 } while (consumeIf(Token::comma)); 975 } 976 popDeclScope(); 977 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 978 return failure(); 979 980 // Parse the results of the decl. 981 pushDeclScope(); 982 if (consumeIf(Token::arrow)) { 983 auto parseResultFn = [&]() -> LogicalResult { 984 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 985 if (failed(result)) 986 return failure(); 987 results.emplace_back(*result); 988 return success(); 989 }; 990 991 // Check for a list of results. 992 if (consumeIf(Token::l_paren)) { 993 do { 994 if (failed(parseResultFn())) 995 return failure(); 996 } while (consumeIf(Token::comma)); 997 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 998 return failure(); 999 1000 // Otherwise, there is only one result. 1001 } else if (failed(parseResultFn())) { 1002 return failure(); 1003 } 1004 } 1005 popDeclScope(); 1006 1007 // Compute the result type of the decl. 1008 resultType = createUserConstraintRewriteResultType(results); 1009 1010 // Verify that results are only named if there are more than one. 1011 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1012 return emitError( 1013 results.front()->getLoc(), 1014 "cannot create a single-element tuple with an element label"); 1015 } 1016 return success(); 1017 } 1018 1019 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1020 StringRef declType, ast::CompoundStmt *body, 1021 ArrayRef<ast::Stmt *>::iterator bodyIt, 1022 ArrayRef<ast::Stmt *>::iterator bodyE, 1023 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1024 // Handle if a `return` was provided. 1025 if (bodyIt != bodyE) { 1026 // Emit an error if we have trailing statements after the return. 1027 if (std::next(bodyIt) != bodyE) { 1028 return emitError( 1029 (*std::next(bodyIt))->getLoc(), 1030 llvm::formatv("`return` terminated the `{0}` body, but found " 1031 "trailing statements afterwards", 1032 declType)); 1033 } 1034 1035 // Otherwise if a return wasn't provided, check that no results are 1036 // expected. 1037 } else if (!results.empty()) { 1038 return emitError( 1039 {body->getLoc().End, body->getLoc().End}, 1040 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1041 declType, resultType)); 1042 } 1043 return success(); 1044 } 1045 1046 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1047 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1048 if (isa<ast::OpRewriteStmt>(statement)) 1049 return success(); 1050 return emitError( 1051 statement->getLoc(), 1052 "expected Pattern lambda body to contain a single operation " 1053 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1054 }); 1055 } 1056 1057 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1058 SMRange loc = curToken.getLoc(); 1059 consumeToken(Token::kw_Pattern); 1060 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1061 ParserContext::PatternMatch); 1062 1063 // Check for an optional identifier for the pattern name. 1064 const ast::Name *name = nullptr; 1065 if (curToken.is(Token::identifier)) { 1066 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1067 consumeToken(Token::identifier); 1068 } 1069 1070 // Parse any pattern metadata. 1071 ParsedPatternMetadata metadata; 1072 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1073 return failure(); 1074 1075 // Parse the pattern body. 1076 ast::CompoundStmt *body; 1077 1078 // Handle a lambda body. 1079 if (curToken.is(Token::equal_arrow)) { 1080 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1081 if (failed(bodyResult)) 1082 return failure(); 1083 body = *bodyResult; 1084 } else { 1085 if (curToken.isNot(Token::l_brace)) 1086 return emitError("expected `{` or `=>` to start pattern body"); 1087 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1088 if (failed(bodyResult)) 1089 return failure(); 1090 body = *bodyResult; 1091 1092 // Verify the body of the pattern. 1093 auto bodyIt = body->begin(), bodyE = body->end(); 1094 for (; bodyIt != bodyE; ++bodyIt) { 1095 if (isa<ast::ReturnStmt>(*bodyIt)) { 1096 return emitError((*bodyIt)->getLoc(), 1097 "`return` statements are only permitted within a " 1098 "`Constraint` or `Rewrite` body"); 1099 } 1100 // Break when we've found the rewrite statement. 1101 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1102 break; 1103 } 1104 if (bodyIt == bodyE) { 1105 return emitError(loc, 1106 "expected Pattern body to terminate with an operation " 1107 "rewrite statement, such as `erase`"); 1108 } 1109 if (std::next(bodyIt) != bodyE) { 1110 return emitError((*std::next(bodyIt))->getLoc(), 1111 "Pattern body was terminated by an operation " 1112 "rewrite statement, but found trailing statements"); 1113 } 1114 } 1115 1116 return createPatternDecl(loc, name, metadata, body); 1117 } 1118 1119 LogicalResult 1120 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1121 Optional<SMRange> benefitLoc; 1122 Optional<SMRange> hasBoundedRecursionLoc; 1123 1124 do { 1125 if (curToken.isNot(Token::identifier)) 1126 return emitError("expected pattern metadata identifier"); 1127 StringRef metadataStr = curToken.getSpelling(); 1128 SMRange metadataLoc = curToken.getLoc(); 1129 consumeToken(Token::identifier); 1130 1131 // Parse the benefit metadata: benefit(<integer-value>) 1132 if (metadataStr == "benefit") { 1133 if (benefitLoc) { 1134 return emitErrorAndNote(metadataLoc, 1135 "pattern benefit has already been specified", 1136 *benefitLoc, "see previous definition here"); 1137 } 1138 if (failed(parseToken(Token::l_paren, 1139 "expected `(` before pattern benefit"))) 1140 return failure(); 1141 1142 uint16_t benefitValue = 0; 1143 if (curToken.isNot(Token::integer)) 1144 return emitError("expected integral pattern benefit"); 1145 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1146 return emitError( 1147 "expected pattern benefit to fit within a 16-bit integer"); 1148 consumeToken(Token::integer); 1149 1150 metadata.benefit = benefitValue; 1151 benefitLoc = metadataLoc; 1152 1153 if (failed( 1154 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1155 return failure(); 1156 continue; 1157 } 1158 1159 // Parse the bounded recursion metadata: recursion 1160 if (metadataStr == "recursion") { 1161 if (hasBoundedRecursionLoc) { 1162 return emitErrorAndNote( 1163 metadataLoc, 1164 "pattern recursion metadata has already been specified", 1165 *hasBoundedRecursionLoc, "see previous definition here"); 1166 } 1167 metadata.hasBoundedRecursion = true; 1168 hasBoundedRecursionLoc = metadataLoc; 1169 continue; 1170 } 1171 1172 return emitError(metadataLoc, "unknown pattern metadata"); 1173 } while (consumeIf(Token::comma)); 1174 1175 return success(); 1176 } 1177 1178 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1179 consumeToken(Token::less); 1180 1181 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1182 if (failed(typeExpr) || 1183 failed(parseToken(Token::greater, 1184 "expected `>` after variable type constraint"))) 1185 return failure(); 1186 return typeExpr; 1187 } 1188 1189 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1190 assert(curDeclScope && "defining decl outside of a decl scope"); 1191 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1192 return emitErrorAndNote( 1193 name.getLoc(), "`" + name.getName() + "` has already been defined", 1194 lastDecl->getName()->getLoc(), "see previous definition here"); 1195 } 1196 return success(); 1197 } 1198 1199 FailureOr<ast::VariableDecl *> 1200 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1201 ast::Expr *initExpr, 1202 ArrayRef<ast::ConstraintRef> constraints) { 1203 assert(curDeclScope && "defining variable outside of decl scope"); 1204 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1205 1206 // If the name of the variable indicates a special variable, we don't add it 1207 // to the scope. This variable is local to the definition point. 1208 if (name.empty() || name == "_") { 1209 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1210 constraints); 1211 } 1212 if (failed(checkDefineNamedDecl(nameDecl))) 1213 return failure(); 1214 1215 auto *varDecl = 1216 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1217 curDeclScope->add(varDecl); 1218 return varDecl; 1219 } 1220 1221 FailureOr<ast::VariableDecl *> 1222 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1223 ArrayRef<ast::ConstraintRef> constraints) { 1224 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1225 constraints); 1226 } 1227 1228 LogicalResult Parser::parseVariableDeclConstraintList( 1229 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1230 Optional<SMRange> typeConstraint; 1231 auto parseSingleConstraint = [&] { 1232 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1233 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); 1234 if (failed(constraint)) 1235 return failure(); 1236 constraints.push_back(*constraint); 1237 return success(); 1238 }; 1239 1240 // Check to see if this is a single constraint, or a list. 1241 if (!consumeIf(Token::l_square)) 1242 return parseSingleConstraint(); 1243 1244 do { 1245 if (failed(parseSingleConstraint())) 1246 return failure(); 1247 } while (consumeIf(Token::comma)); 1248 return parseToken(Token::r_square, "expected `]` after constraint list"); 1249 } 1250 1251 FailureOr<ast::ConstraintRef> 1252 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1253 ArrayRef<ast::ConstraintRef> existingConstraints, 1254 bool allowInlineTypeConstraints) { 1255 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1256 if (!allowInlineTypeConstraints) { 1257 return emitError( 1258 curToken.getLoc(), 1259 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1260 "permitted on arguments or results"); 1261 } 1262 if (typeConstraint) 1263 return emitErrorAndNote( 1264 curToken.getLoc(), 1265 "the type of this variable has already been constrained", 1266 *typeConstraint, "see previous constraint location here"); 1267 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1268 if (failed(constraintExpr)) 1269 return failure(); 1270 typeExpr = *constraintExpr; 1271 typeConstraint = typeExpr->getLoc(); 1272 return success(); 1273 }; 1274 1275 SMRange loc = curToken.getLoc(); 1276 switch (curToken.getKind()) { 1277 case Token::kw_Attr: { 1278 consumeToken(Token::kw_Attr); 1279 1280 // Check for a type constraint. 1281 ast::Expr *typeExpr = nullptr; 1282 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1283 return failure(); 1284 return ast::ConstraintRef( 1285 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1286 } 1287 case Token::kw_Op: { 1288 consumeToken(Token::kw_Op); 1289 1290 // Parse an optional operation name. If the name isn't provided, this refers 1291 // to "any" operation. 1292 FailureOr<ast::OpNameDecl *> opName = 1293 parseWrappedOperationName(/*allowEmptyName=*/true); 1294 if (failed(opName)) 1295 return failure(); 1296 1297 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1298 loc); 1299 } 1300 case Token::kw_Type: 1301 consumeToken(Token::kw_Type); 1302 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1303 case Token::kw_TypeRange: 1304 consumeToken(Token::kw_TypeRange); 1305 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1306 loc); 1307 case Token::kw_Value: { 1308 consumeToken(Token::kw_Value); 1309 1310 // Check for a type constraint. 1311 ast::Expr *typeExpr = nullptr; 1312 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1313 return failure(); 1314 1315 return ast::ConstraintRef( 1316 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1317 } 1318 case Token::kw_ValueRange: { 1319 consumeToken(Token::kw_ValueRange); 1320 1321 // Check for a type constraint. 1322 ast::Expr *typeExpr = nullptr; 1323 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1324 return failure(); 1325 1326 return ast::ConstraintRef( 1327 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1328 } 1329 1330 case Token::kw_Constraint: { 1331 // Handle an inline constraint. 1332 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1333 if (failed(decl)) 1334 return failure(); 1335 return ast::ConstraintRef(*decl, loc); 1336 } 1337 case Token::identifier: { 1338 StringRef constraintName = curToken.getSpelling(); 1339 consumeToken(Token::identifier); 1340 1341 // Lookup the referenced constraint. 1342 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1343 if (!cstDecl) { 1344 return emitError(loc, "unknown reference to constraint `" + 1345 constraintName + "`"); 1346 } 1347 1348 // Handle a reference to a proper constraint. 1349 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1350 return ast::ConstraintRef(cst, loc); 1351 1352 return emitErrorAndNote( 1353 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1354 "see the definition of `" + constraintName + "` here"); 1355 } 1356 default: 1357 break; 1358 } 1359 return emitError(loc, "expected identifier constraint"); 1360 } 1361 1362 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1363 Optional<SMRange> typeConstraint; 1364 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1365 /*allowInlineTypeConstraints=*/false); 1366 } 1367 1368 //===----------------------------------------------------------------------===// 1369 // Exprs 1370 1371 FailureOr<ast::Expr *> Parser::parseExpr() { 1372 if (curToken.is(Token::underscore)) 1373 return parseUnderscoreExpr(); 1374 1375 // Parse the LHS expression. 1376 FailureOr<ast::Expr *> lhsExpr; 1377 switch (curToken.getKind()) { 1378 case Token::kw_attr: 1379 lhsExpr = parseAttributeExpr(); 1380 break; 1381 case Token::kw_Constraint: 1382 lhsExpr = parseInlineConstraintLambdaExpr(); 1383 break; 1384 case Token::identifier: 1385 lhsExpr = parseIdentifierExpr(); 1386 break; 1387 case Token::kw_op: 1388 lhsExpr = parseOperationExpr(); 1389 break; 1390 case Token::kw_Rewrite: 1391 lhsExpr = parseInlineRewriteLambdaExpr(); 1392 break; 1393 case Token::kw_type: 1394 lhsExpr = parseTypeExpr(); 1395 break; 1396 case Token::l_paren: 1397 lhsExpr = parseTupleExpr(); 1398 break; 1399 default: 1400 return emitError("expected expression"); 1401 } 1402 if (failed(lhsExpr)) 1403 return failure(); 1404 1405 // Check for an operator expression. 1406 while (true) { 1407 switch (curToken.getKind()) { 1408 case Token::dot: 1409 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1410 break; 1411 case Token::l_paren: 1412 lhsExpr = parseCallExpr(*lhsExpr); 1413 break; 1414 default: 1415 return lhsExpr; 1416 } 1417 if (failed(lhsExpr)) 1418 return failure(); 1419 } 1420 } 1421 1422 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1423 SMRange loc = curToken.getLoc(); 1424 consumeToken(Token::kw_attr); 1425 1426 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1427 // identifier. 1428 if (!consumeIf(Token::less)) { 1429 resetToken(loc); 1430 return parseIdentifierExpr(); 1431 } 1432 1433 if (!curToken.isString()) 1434 return emitError("expected string literal containing MLIR attribute"); 1435 std::string attrExpr = curToken.getStringValue(); 1436 consumeToken(); 1437 1438 if (failed( 1439 parseToken(Token::greater, "expected `>` after attribute literal"))) 1440 return failure(); 1441 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1442 } 1443 1444 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1445 SMRange loc = curToken.getLoc(); 1446 consumeToken(Token::l_paren); 1447 1448 // Parse the arguments of the call. 1449 SmallVector<ast::Expr *> arguments; 1450 if (curToken.isNot(Token::r_paren)) { 1451 do { 1452 FailureOr<ast::Expr *> argument = parseExpr(); 1453 if (failed(argument)) 1454 return failure(); 1455 arguments.push_back(*argument); 1456 } while (consumeIf(Token::comma)); 1457 } 1458 loc.End = curToken.getEndLoc(); 1459 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1460 return failure(); 1461 1462 return createCallExpr(loc, parentExpr, arguments); 1463 } 1464 1465 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1466 ast::Decl *decl = curDeclScope->lookup(name); 1467 if (!decl) 1468 return emitError(loc, "undefined reference to `" + name + "`"); 1469 1470 return createDeclRefExpr(loc, decl); 1471 } 1472 1473 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1474 StringRef name = curToken.getSpelling(); 1475 SMRange nameLoc = curToken.getLoc(); 1476 consumeToken(); 1477 1478 // Check to see if this is a decl ref expression that defines a variable 1479 // inline. 1480 if (consumeIf(Token::colon)) { 1481 SmallVector<ast::ConstraintRef> constraints; 1482 if (failed(parseVariableDeclConstraintList(constraints))) 1483 return failure(); 1484 ast::Type type; 1485 if (failed(validateVariableConstraints(constraints, type))) 1486 return failure(); 1487 return createInlineVariableExpr(type, name, nameLoc, constraints); 1488 } 1489 1490 return parseDeclRefExpr(name, nameLoc); 1491 } 1492 1493 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1494 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1495 if (failed(decl)) 1496 return failure(); 1497 1498 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1499 ast::ConstraintType::get(ctx)); 1500 } 1501 1502 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1503 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1504 if (failed(decl)) 1505 return failure(); 1506 1507 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1508 ast::RewriteType::get(ctx)); 1509 } 1510 1511 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1512 SMRange loc = curToken.getLoc(); 1513 consumeToken(Token::dot); 1514 1515 // Parse the member name. 1516 Token memberNameTok = curToken; 1517 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1518 !memberNameTok.isKeyword()) 1519 return emitError(loc, "expected identifier or numeric member name"); 1520 StringRef memberName = memberNameTok.getSpelling(); 1521 consumeToken(); 1522 1523 return createMemberAccessExpr(parentExpr, memberName, loc); 1524 } 1525 1526 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1527 SMRange loc = curToken.getLoc(); 1528 1529 // Handle the case of an no operation name. 1530 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1531 if (allowEmptyName) 1532 return ast::OpNameDecl::create(ctx, SMRange()); 1533 return emitError("expected dialect namespace"); 1534 } 1535 StringRef name = curToken.getSpelling(); 1536 consumeToken(); 1537 1538 // Otherwise, this is a literal operation name. 1539 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1540 return failure(); 1541 1542 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1543 return emitError("expected operation name after dialect namespace"); 1544 1545 name = StringRef(name.data(), name.size() + 1); 1546 do { 1547 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1548 loc.End = curToken.getEndLoc(); 1549 consumeToken(); 1550 } while (curToken.isAny(Token::identifier, Token::dot) || 1551 curToken.isKeyword()); 1552 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1553 } 1554 1555 FailureOr<ast::OpNameDecl *> 1556 Parser::parseWrappedOperationName(bool allowEmptyName) { 1557 if (!consumeIf(Token::less)) 1558 return ast::OpNameDecl::create(ctx, SMRange()); 1559 1560 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1561 if (failed(opNameDecl)) 1562 return failure(); 1563 1564 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1565 return failure(); 1566 return opNameDecl; 1567 } 1568 1569 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1570 SMRange loc = curToken.getLoc(); 1571 consumeToken(Token::kw_op); 1572 1573 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1574 // identifier. 1575 if (curToken.isNot(Token::less)) { 1576 resetToken(loc); 1577 return parseIdentifierExpr(); 1578 } 1579 1580 // Parse the operation name. The name may be elided, in which case the 1581 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1582 // `Operation*`). Operation names within a rewrite context must be named. 1583 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1584 FailureOr<ast::OpNameDecl *> opNameDecl = 1585 parseWrappedOperationName(allowEmptyName); 1586 if (failed(opNameDecl)) 1587 return failure(); 1588 1589 // Check for the optional list of operands. 1590 SmallVector<ast::Expr *> operands; 1591 if (consumeIf(Token::l_paren)) { 1592 do { 1593 FailureOr<ast::Expr *> operand = parseExpr(); 1594 if (failed(operand)) 1595 return failure(); 1596 operands.push_back(*operand); 1597 } while (consumeIf(Token::comma)); 1598 1599 if (failed(parseToken(Token::r_paren, 1600 "expected `)` after operation operand list"))) 1601 return failure(); 1602 } 1603 1604 // Check for the optional list of attributes. 1605 SmallVector<ast::NamedAttributeDecl *> attributes; 1606 if (consumeIf(Token::l_brace)) { 1607 do { 1608 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1609 if (failed(decl)) 1610 return failure(); 1611 attributes.emplace_back(*decl); 1612 } while (consumeIf(Token::comma)); 1613 1614 if (failed(parseToken(Token::r_brace, 1615 "expected `}` after operation attribute list"))) 1616 return failure(); 1617 } 1618 1619 // Check for the optional list of result types. 1620 SmallVector<ast::Expr *> resultTypes; 1621 if (consumeIf(Token::arrow)) { 1622 if (failed(parseToken(Token::l_paren, 1623 "expected `(` before operation result type list"))) 1624 return failure(); 1625 1626 do { 1627 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1628 if (failed(resultTypeExpr)) 1629 return failure(); 1630 resultTypes.push_back(*resultTypeExpr); 1631 } while (consumeIf(Token::comma)); 1632 1633 if (failed(parseToken(Token::r_paren, 1634 "expected `)` after operation result type list"))) 1635 return failure(); 1636 } 1637 1638 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1639 resultTypes); 1640 } 1641 1642 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1643 SMRange loc = curToken.getLoc(); 1644 consumeToken(Token::l_paren); 1645 1646 DenseMap<StringRef, SMRange> usedNames; 1647 SmallVector<StringRef> elementNames; 1648 SmallVector<ast::Expr *> elements; 1649 if (curToken.isNot(Token::r_paren)) { 1650 do { 1651 // Check for the optional element name assignment before the value. 1652 StringRef elementName; 1653 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1654 Token elementNameTok = curToken; 1655 consumeToken(); 1656 1657 // The element name is only present if followed by an `=`. 1658 if (consumeIf(Token::equal)) { 1659 elementName = elementNameTok.getSpelling(); 1660 1661 // Check to see if this name is already used. 1662 auto elementNameIt = 1663 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1664 if (!elementNameIt.second) { 1665 return emitErrorAndNote( 1666 elementNameTok.getLoc(), 1667 llvm::formatv("duplicate tuple element label `{0}`", 1668 elementName), 1669 elementNameIt.first->getSecond(), 1670 "see previous label use here"); 1671 } 1672 } else { 1673 // Otherwise, we treat this as part of an expression so reset the 1674 // lexer. 1675 resetToken(elementNameTok.getLoc()); 1676 } 1677 } 1678 elementNames.push_back(elementName); 1679 1680 // Parse the tuple element value. 1681 FailureOr<ast::Expr *> element = parseExpr(); 1682 if (failed(element)) 1683 return failure(); 1684 elements.push_back(*element); 1685 } while (consumeIf(Token::comma)); 1686 } 1687 loc.End = curToken.getEndLoc(); 1688 if (failed( 1689 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1690 return failure(); 1691 return createTupleExpr(loc, elements, elementNames); 1692 } 1693 1694 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1695 SMRange loc = curToken.getLoc(); 1696 consumeToken(Token::kw_type); 1697 1698 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1699 // identifier. 1700 if (!consumeIf(Token::less)) { 1701 resetToken(loc); 1702 return parseIdentifierExpr(); 1703 } 1704 1705 if (!curToken.isString()) 1706 return emitError("expected string literal containing MLIR type"); 1707 std::string attrExpr = curToken.getStringValue(); 1708 consumeToken(); 1709 1710 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1711 return failure(); 1712 return ast::TypeExpr::create(ctx, loc, attrExpr); 1713 } 1714 1715 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 1716 StringRef name = curToken.getSpelling(); 1717 SMRange nameLoc = curToken.getLoc(); 1718 consumeToken(Token::underscore); 1719 1720 // Underscore expressions require a constraint list. 1721 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 1722 return failure(); 1723 1724 // Parse the constraints for the expression. 1725 SmallVector<ast::ConstraintRef> constraints; 1726 if (failed(parseVariableDeclConstraintList(constraints))) 1727 return failure(); 1728 1729 ast::Type type; 1730 if (failed(validateVariableConstraints(constraints, type))) 1731 return failure(); 1732 return createInlineVariableExpr(type, name, nameLoc, constraints); 1733 } 1734 1735 //===----------------------------------------------------------------------===// 1736 // Stmts 1737 1738 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 1739 FailureOr<ast::Stmt *> stmt; 1740 switch (curToken.getKind()) { 1741 case Token::kw_erase: 1742 stmt = parseEraseStmt(); 1743 break; 1744 case Token::kw_let: 1745 stmt = parseLetStmt(); 1746 break; 1747 case Token::kw_replace: 1748 stmt = parseReplaceStmt(); 1749 break; 1750 case Token::kw_return: 1751 stmt = parseReturnStmt(); 1752 break; 1753 case Token::kw_rewrite: 1754 stmt = parseRewriteStmt(); 1755 break; 1756 default: 1757 stmt = parseExpr(); 1758 break; 1759 } 1760 if (failed(stmt) || 1761 (expectTerminalSemicolon && 1762 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 1763 return failure(); 1764 return stmt; 1765 } 1766 1767 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 1768 SMLoc startLoc = curToken.getStartLoc(); 1769 consumeToken(Token::l_brace); 1770 1771 // Push a new block scope and parse any nested statements. 1772 pushDeclScope(); 1773 SmallVector<ast::Stmt *> statements; 1774 while (curToken.isNot(Token::r_brace)) { 1775 FailureOr<ast::Stmt *> statement = parseStmt(); 1776 if (failed(statement)) 1777 return popDeclScope(), failure(); 1778 statements.push_back(*statement); 1779 } 1780 popDeclScope(); 1781 1782 // Consume the end brace. 1783 SMRange location(startLoc, curToken.getEndLoc()); 1784 consumeToken(Token::r_brace); 1785 1786 return ast::CompoundStmt::create(ctx, location, statements); 1787 } 1788 1789 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 1790 if (parserContext == ParserContext::Constraint) 1791 return emitError("`erase` cannot be used within a Constraint"); 1792 SMRange loc = curToken.getLoc(); 1793 consumeToken(Token::kw_erase); 1794 1795 // Parse the root operation expression. 1796 FailureOr<ast::Expr *> rootOp = parseExpr(); 1797 if (failed(rootOp)) 1798 return failure(); 1799 1800 return createEraseStmt(loc, *rootOp); 1801 } 1802 1803 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 1804 SMRange loc = curToken.getLoc(); 1805 consumeToken(Token::kw_let); 1806 1807 // Parse the name of the new variable. 1808 SMRange varLoc = curToken.getLoc(); 1809 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 1810 // `_` is a reserved variable name. 1811 if (curToken.is(Token::underscore)) { 1812 return emitError(varLoc, 1813 "`_` may only be used to define \"inline\" variables"); 1814 } 1815 return emitError(varLoc, 1816 "expected identifier after `let` to name a new variable"); 1817 } 1818 StringRef varName = curToken.getSpelling(); 1819 consumeToken(); 1820 1821 // Parse the optional set of constraints. 1822 SmallVector<ast::ConstraintRef> constraints; 1823 if (consumeIf(Token::colon) && 1824 failed(parseVariableDeclConstraintList(constraints))) 1825 return failure(); 1826 1827 // Parse the optional initializer expression. 1828 ast::Expr *initializer = nullptr; 1829 if (consumeIf(Token::equal)) { 1830 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 1831 if (failed(initOrFailure)) 1832 return failure(); 1833 initializer = *initOrFailure; 1834 1835 // Check that the constraints are compatible with having an initializer, 1836 // e.g. type constraints cannot be used with initializers. 1837 for (ast::ConstraintRef constraint : constraints) { 1838 LogicalResult result = 1839 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 1840 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 1841 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 1842 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 1843 return this->emitError( 1844 constraint.referenceLoc, 1845 "type constraints are not permitted on variables with " 1846 "initializers"); 1847 } 1848 return success(); 1849 }) 1850 .Default(success()); 1851 if (failed(result)) 1852 return failure(); 1853 } 1854 } 1855 1856 FailureOr<ast::VariableDecl *> varDecl = 1857 createVariableDecl(varName, varLoc, initializer, constraints); 1858 if (failed(varDecl)) 1859 return failure(); 1860 return ast::LetStmt::create(ctx, loc, *varDecl); 1861 } 1862 1863 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 1864 if (parserContext == ParserContext::Constraint) 1865 return emitError("`replace` cannot be used within a Constraint"); 1866 SMRange loc = curToken.getLoc(); 1867 consumeToken(Token::kw_replace); 1868 1869 // Parse the root operation expression. 1870 FailureOr<ast::Expr *> rootOp = parseExpr(); 1871 if (failed(rootOp)) 1872 return failure(); 1873 1874 if (failed( 1875 parseToken(Token::kw_with, "expected `with` after root operation"))) 1876 return failure(); 1877 1878 // The replacement portion of this statement is within a rewrite context. 1879 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1880 ParserContext::Rewrite); 1881 1882 // Parse the replacement values. 1883 SmallVector<ast::Expr *> replValues; 1884 if (consumeIf(Token::l_paren)) { 1885 if (consumeIf(Token::r_paren)) { 1886 return emitError( 1887 loc, "expected at least one replacement value, consider using " 1888 "`erase` if no replacement values are desired"); 1889 } 1890 1891 do { 1892 FailureOr<ast::Expr *> replExpr = parseExpr(); 1893 if (failed(replExpr)) 1894 return failure(); 1895 replValues.emplace_back(*replExpr); 1896 } while (consumeIf(Token::comma)); 1897 1898 if (failed(parseToken(Token::r_paren, 1899 "expected `)` after replacement values"))) 1900 return failure(); 1901 } else { 1902 FailureOr<ast::Expr *> replExpr = parseExpr(); 1903 if (failed(replExpr)) 1904 return failure(); 1905 replValues.emplace_back(*replExpr); 1906 } 1907 1908 return createReplaceStmt(loc, *rootOp, replValues); 1909 } 1910 1911 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 1912 SMRange loc = curToken.getLoc(); 1913 consumeToken(Token::kw_return); 1914 1915 // Parse the result value. 1916 FailureOr<ast::Expr *> resultExpr = parseExpr(); 1917 if (failed(resultExpr)) 1918 return failure(); 1919 1920 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 1921 } 1922 1923 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 1924 if (parserContext == ParserContext::Constraint) 1925 return emitError("`rewrite` cannot be used within a Constraint"); 1926 SMRange loc = curToken.getLoc(); 1927 consumeToken(Token::kw_rewrite); 1928 1929 // Parse the root operation. 1930 FailureOr<ast::Expr *> rootOp = parseExpr(); 1931 if (failed(rootOp)) 1932 return failure(); 1933 1934 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 1935 return failure(); 1936 1937 if (curToken.isNot(Token::l_brace)) 1938 return emitError("expected `{` to start rewrite body"); 1939 1940 // The rewrite body of this statement is within a rewrite context. 1941 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1942 ParserContext::Rewrite); 1943 1944 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 1945 if (failed(rewriteBody)) 1946 return failure(); 1947 1948 // Verify the rewrite body. 1949 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 1950 if (isa<ast::ReturnStmt>(stmt)) { 1951 return emitError(stmt->getLoc(), 1952 "`return` statements are only permitted within a " 1953 "`Constraint` or `Rewrite` body"); 1954 } 1955 } 1956 1957 return createRewriteStmt(loc, *rootOp, *rewriteBody); 1958 } 1959 1960 //===----------------------------------------------------------------------===// 1961 // Creation+Analysis 1962 //===----------------------------------------------------------------------===// 1963 1964 //===----------------------------------------------------------------------===// 1965 // Decls 1966 1967 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 1968 // Unwrap reference expressions. 1969 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 1970 node = init->getDecl(); 1971 return dyn_cast<ast::CallableDecl>(node); 1972 } 1973 1974 FailureOr<ast::PatternDecl *> 1975 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 1976 const ParsedPatternMetadata &metadata, 1977 ast::CompoundStmt *body) { 1978 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 1979 metadata.hasBoundedRecursion, body); 1980 } 1981 1982 ast::Type Parser::createUserConstraintRewriteResultType( 1983 ArrayRef<ast::VariableDecl *> results) { 1984 // Single result decls use the type of the single result. 1985 if (results.size() == 1) 1986 return results[0]->getType(); 1987 1988 // Multiple results use a tuple type, with the types and names grabbed from 1989 // the result variable decls. 1990 auto resultTypes = llvm::map_range( 1991 results, [&](const auto *result) { return result->getType(); }); 1992 auto resultNames = llvm::map_range( 1993 results, [&](const auto *result) { return result->getName().getName(); }); 1994 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 1995 llvm::to_vector(resultNames)); 1996 } 1997 1998 template <typename T> 1999 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2000 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2001 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2002 ast::CompoundStmt *body) { 2003 if (!body->getChildren().empty()) { 2004 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2005 ast::Expr *resultExpr = retStmt->getResultExpr(); 2006 2007 // Process the result of the decl. If no explicit signature results 2008 // were provided, check for return type inference. Otherwise, check that 2009 // the return expression can be converted to the expected type. 2010 if (results.empty()) 2011 resultType = resultExpr->getType(); 2012 else if (failed(convertExpressionTo(resultExpr, resultType))) 2013 return failure(); 2014 else 2015 retStmt->setResultExpr(resultExpr); 2016 } 2017 } 2018 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2019 } 2020 2021 FailureOr<ast::VariableDecl *> 2022 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2023 ArrayRef<ast::ConstraintRef> constraints) { 2024 // The type of the variable, which is expected to be inferred by either a 2025 // constraint or an initializer expression. 2026 ast::Type type; 2027 if (failed(validateVariableConstraints(constraints, type))) 2028 return failure(); 2029 2030 if (initializer) { 2031 // Update the variable type based on the initializer, or try to convert the 2032 // initializer to the existing type. 2033 if (!type) 2034 type = initializer->getType(); 2035 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2036 type = mergedType; 2037 else if (failed(convertExpressionTo(initializer, type))) 2038 return failure(); 2039 2040 // Otherwise, if there is no initializer check that the type has already 2041 // been resolved from the constraint list. 2042 } else if (!type) { 2043 return emitErrorAndNote( 2044 loc, "unable to infer type for variable `" + name + "`", loc, 2045 "the type of a variable must be inferable from the constraint " 2046 "list or the initializer"); 2047 } 2048 2049 // Constraint types cannot be used when defining variables. 2050 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2051 return emitError( 2052 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2053 } 2054 2055 // Try to define a variable with the given name. 2056 FailureOr<ast::VariableDecl *> varDecl = 2057 defineVariableDecl(name, loc, type, initializer, constraints); 2058 if (failed(varDecl)) 2059 return failure(); 2060 2061 return *varDecl; 2062 } 2063 2064 FailureOr<ast::VariableDecl *> 2065 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2066 const ast::ConstraintRef &constraint) { 2067 // Constraint arguments may apply more complex constraints via the arguments. 2068 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2069 ast::Type argType; 2070 if (failed(validateVariableConstraint(constraint, argType, 2071 allowNonCoreConstraints))) 2072 return failure(); 2073 return defineVariableDecl(name, loc, argType, constraint); 2074 } 2075 2076 LogicalResult 2077 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2078 ast::Type &inferredType) { 2079 for (const ast::ConstraintRef &ref : constraints) 2080 if (failed(validateVariableConstraint(ref, inferredType))) 2081 return failure(); 2082 return success(); 2083 } 2084 2085 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2086 ast::Type &inferredType, 2087 bool allowNonCoreConstraints) { 2088 ast::Type constraintType; 2089 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2090 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2091 if (failed(validateTypeConstraintExpr(typeExpr))) 2092 return failure(); 2093 } 2094 constraintType = ast::AttributeType::get(ctx); 2095 } else if (const auto *cst = 2096 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2097 constraintType = ast::OperationType::get(ctx, cst->getName()); 2098 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2099 constraintType = typeTy; 2100 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2101 constraintType = typeRangeTy; 2102 } else if (const auto *cst = 2103 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2104 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2105 if (failed(validateTypeConstraintExpr(typeExpr))) 2106 return failure(); 2107 } 2108 constraintType = valueTy; 2109 } else if (const auto *cst = 2110 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2111 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2112 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2113 return failure(); 2114 } 2115 constraintType = valueRangeTy; 2116 } else if (const auto *cst = 2117 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2118 if (!allowNonCoreConstraints) { 2119 return emitError(ref.referenceLoc, 2120 "`Rewrite` arguments and results are only permitted to " 2121 "use core constraints, such as `Attr`, `Op`, `Type`, " 2122 "`TypeRange`, `Value`, `ValueRange`"); 2123 } 2124 2125 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2126 if (inputs.size() != 1) { 2127 return emitErrorAndNote(ref.referenceLoc, 2128 "`Constraint`s applied via a variable constraint " 2129 "list must take a single input, but got " + 2130 Twine(inputs.size()), 2131 cst->getLoc(), 2132 "see definition of constraint here"); 2133 } 2134 constraintType = inputs.front()->getType(); 2135 } else { 2136 llvm_unreachable("unknown constraint type"); 2137 } 2138 2139 // Check that the constraint type is compatible with the current inferred 2140 // type. 2141 if (!inferredType) { 2142 inferredType = constraintType; 2143 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2144 inferredType = mergedTy; 2145 } else { 2146 return emitError(ref.referenceLoc, 2147 llvm::formatv("constraint type `{0}` is incompatible " 2148 "with the previously inferred type `{1}`", 2149 constraintType, inferredType)); 2150 } 2151 return success(); 2152 } 2153 2154 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2155 ast::Type typeExprType = typeExpr->getType(); 2156 if (typeExprType != typeTy) { 2157 return emitError(typeExpr->getLoc(), 2158 "expected expression of `Type` in type constraint"); 2159 } 2160 return success(); 2161 } 2162 2163 LogicalResult 2164 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2165 ast::Type typeExprType = typeExpr->getType(); 2166 if (typeExprType != typeRangeTy) { 2167 return emitError(typeExpr->getLoc(), 2168 "expected expression of `TypeRange` in type constraint"); 2169 } 2170 return success(); 2171 } 2172 2173 //===----------------------------------------------------------------------===// 2174 // Exprs 2175 2176 FailureOr<ast::CallExpr *> 2177 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2178 MutableArrayRef<ast::Expr *> arguments) { 2179 ast::Type parentType = parentExpr->getType(); 2180 2181 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2182 if (!callableDecl) { 2183 return emitError(loc, 2184 llvm::formatv("expected a reference to a callable " 2185 "`Constraint` or `Rewrite`, but got: `{0}`", 2186 parentType)); 2187 } 2188 if (parserContext == ParserContext::Rewrite) { 2189 if (isa<ast::UserConstraintDecl>(callableDecl)) 2190 return emitError( 2191 loc, "unable to invoke `Constraint` within a rewrite section"); 2192 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2193 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2194 } 2195 2196 // Verify the arguments of the call. 2197 /// Handle size mismatch. 2198 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2199 if (callArgs.size() != arguments.size()) { 2200 return emitErrorAndNote( 2201 loc, 2202 llvm::formatv("invalid number of arguments for {0} call; expected " 2203 "{1}, but got {2}", 2204 callableDecl->getCallableType(), callArgs.size(), 2205 arguments.size()), 2206 callableDecl->getLoc(), 2207 llvm::formatv("see the definition of {0} here", 2208 callableDecl->getName()->getName())); 2209 } 2210 2211 /// Handle argument type mismatch. 2212 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2213 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2214 callableDecl->getName()->getName()), 2215 callableDecl->getLoc()); 2216 }; 2217 for (auto it : llvm::zip(callArgs, arguments)) { 2218 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2219 attachDiagFn))) 2220 return failure(); 2221 } 2222 2223 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2224 callableDecl->getResultType()); 2225 } 2226 2227 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2228 ast::Decl *decl) { 2229 // Check the type of decl being referenced. 2230 ast::Type declType; 2231 if (isa<ast::ConstraintDecl>(decl)) 2232 declType = ast::ConstraintType::get(ctx); 2233 else if (isa<ast::UserRewriteDecl>(decl)) 2234 declType = ast::RewriteType::get(ctx); 2235 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2236 declType = varDecl->getType(); 2237 else 2238 return emitError(loc, "invalid reference to `" + 2239 decl->getName()->getName() + "`"); 2240 2241 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2242 } 2243 2244 FailureOr<ast::DeclRefExpr *> 2245 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2246 ArrayRef<ast::ConstraintRef> constraints) { 2247 FailureOr<ast::VariableDecl *> decl = 2248 defineVariableDecl(name, loc, type, constraints); 2249 if (failed(decl)) 2250 return failure(); 2251 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2252 } 2253 2254 FailureOr<ast::MemberAccessExpr *> 2255 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2256 SMRange loc) { 2257 // Validate the member name for the given parent expression. 2258 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2259 if (failed(memberType)) 2260 return failure(); 2261 2262 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2263 } 2264 2265 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2266 StringRef name, SMRange loc) { 2267 ast::Type parentType = parentExpr->getType(); 2268 if (parentType.isa<ast::OperationType>()) { 2269 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2270 return valueRangeTy; 2271 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2272 // Handle indexed results. 2273 unsigned index = 0; 2274 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2275 index < tupleType.size()) { 2276 return tupleType.getElementTypes()[index]; 2277 } 2278 2279 // Handle named results. 2280 auto elementNames = tupleType.getElementNames(); 2281 const auto *it = llvm::find(elementNames, name); 2282 if (it != elementNames.end()) 2283 return tupleType.getElementTypes()[it - elementNames.begin()]; 2284 } 2285 return emitError( 2286 loc, 2287 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2288 name, parentType)); 2289 } 2290 2291 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2292 SMRange loc, const ast::OpNameDecl *name, 2293 MutableArrayRef<ast::Expr *> operands, 2294 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2295 MutableArrayRef<ast::Expr *> results) { 2296 Optional<StringRef> opNameRef = name->getName(); 2297 2298 // Verify the inputs operands. 2299 if (failed(validateOperationOperands(loc, opNameRef, operands))) 2300 return failure(); 2301 2302 // Verify the attribute list. 2303 for (ast::NamedAttributeDecl *attr : attributes) { 2304 // Check for an attribute type, or a type awaiting resolution. 2305 ast::Type attrType = attr->getValue()->getType(); 2306 if (!attrType.isa<ast::AttributeType>()) { 2307 return emitError( 2308 attr->getValue()->getLoc(), 2309 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2310 } 2311 } 2312 2313 // Verify the result types. 2314 if (failed(validateOperationResults(loc, opNameRef, results))) 2315 return failure(); 2316 2317 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2318 attributes); 2319 } 2320 2321 LogicalResult 2322 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2323 MutableArrayRef<ast::Expr *> operands) { 2324 return validateOperationOperandsOrResults(loc, name, operands, valueTy, 2325 valueRangeTy); 2326 } 2327 2328 LogicalResult 2329 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2330 MutableArrayRef<ast::Expr *> results) { 2331 return validateOperationOperandsOrResults(loc, name, results, typeTy, 2332 typeRangeTy); 2333 } 2334 2335 LogicalResult Parser::validateOperationOperandsOrResults( 2336 SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2337 ast::Type singleTy, ast::Type rangeTy) { 2338 // All operation types accept a single range parameter. 2339 if (values.size() == 1) { 2340 if (failed(convertExpressionTo(values[0], rangeTy))) 2341 return failure(); 2342 return success(); 2343 } 2344 2345 // Otherwise, accept the value groups as they have been defined and just 2346 // ensure they are one of the expected types. 2347 for (ast::Expr *&valueExpr : values) { 2348 ast::Type valueExprType = valueExpr->getType(); 2349 2350 // Check if this is one of the expected types. 2351 if (valueExprType == rangeTy || valueExprType == singleTy) 2352 continue; 2353 2354 // If the operand is an Operation, allow converting to a Value or 2355 // ValueRange. This situations arises quite often with nested operation 2356 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2357 if (singleTy == valueTy) { 2358 if (valueExprType.isa<ast::OperationType>()) { 2359 valueExpr = convertOpToValue(valueExpr); 2360 continue; 2361 } 2362 } 2363 2364 return emitError( 2365 valueExpr->getLoc(), 2366 llvm::formatv( 2367 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2368 singleTy, rangeTy, valueExprType)); 2369 } 2370 return success(); 2371 } 2372 2373 FailureOr<ast::TupleExpr *> 2374 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2375 ArrayRef<StringRef> elementNames) { 2376 for (const ast::Expr *element : elements) { 2377 ast::Type eleTy = element->getType(); 2378 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2379 return emitError( 2380 element->getLoc(), 2381 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2382 } 2383 } 2384 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2385 } 2386 2387 //===----------------------------------------------------------------------===// 2388 // Stmts 2389 2390 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2391 ast::Expr *rootOp) { 2392 // Check that root is an Operation. 2393 ast::Type rootType = rootOp->getType(); 2394 if (!rootType.isa<ast::OperationType>()) 2395 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2396 2397 return ast::EraseStmt::create(ctx, loc, rootOp); 2398 } 2399 2400 FailureOr<ast::ReplaceStmt *> 2401 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2402 MutableArrayRef<ast::Expr *> replValues) { 2403 // Check that root is an Operation. 2404 ast::Type rootType = rootOp->getType(); 2405 if (!rootType.isa<ast::OperationType>()) { 2406 return emitError( 2407 rootOp->getLoc(), 2408 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2409 } 2410 2411 // If there are multiple replacement values, we implicitly convert any Op 2412 // expressions to the value form. 2413 bool shouldConvertOpToValues = replValues.size() > 1; 2414 for (ast::Expr *&replExpr : replValues) { 2415 ast::Type replType = replExpr->getType(); 2416 2417 // Check that replExpr is an Operation, Value, or ValueRange. 2418 if (replType.isa<ast::OperationType>()) { 2419 if (shouldConvertOpToValues) 2420 replExpr = convertOpToValue(replExpr); 2421 continue; 2422 } 2423 2424 if (replType != valueTy && replType != valueRangeTy) { 2425 return emitError(replExpr->getLoc(), 2426 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2427 "expression, but got `{0}`", 2428 replType)); 2429 } 2430 } 2431 2432 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2433 } 2434 2435 FailureOr<ast::RewriteStmt *> 2436 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2437 ast::CompoundStmt *rewriteBody) { 2438 // Check that root is an Operation. 2439 ast::Type rootType = rootOp->getType(); 2440 if (!rootType.isa<ast::OperationType>()) { 2441 return emitError( 2442 rootOp->getLoc(), 2443 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2444 } 2445 2446 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2447 } 2448 2449 //===----------------------------------------------------------------------===// 2450 // Parser 2451 //===----------------------------------------------------------------------===// 2452 2453 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 2454 llvm::SourceMgr &sourceMgr) { 2455 Parser parser(ctx, sourceMgr); 2456 return parser.parseModule(); 2457 } 2458