1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/Tools/PDLL/AST/Context.h" 13 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 14 #include "mlir/Tools/PDLL/AST/Nodes.h" 15 #include "mlir/Tools/PDLL/AST/Types.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/SaveAndRestore.h" 20 #include "llvm/Support/ScopedPrinter.h" 21 #include <string> 22 23 using namespace mlir; 24 using namespace mlir::pdll; 25 26 //===----------------------------------------------------------------------===// 27 // Parser 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 class Parser { 32 public: 33 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 34 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 35 curToken(lexer.lexToken()), curDeclScope(nullptr), 36 valueTy(ast::ValueType::get(ctx)), 37 valueRangeTy(ast::ValueRangeType::get(ctx)), 38 typeTy(ast::TypeType::get(ctx)), 39 typeRangeTy(ast::TypeRangeType::get(ctx)) {} 40 41 /// Try to parse a new module. Returns nullptr in the case of failure. 42 FailureOr<ast::Module *> parseModule(); 43 44 private: 45 /// The current context of the parser. It allows for the parser to know a bit 46 /// about the construct it is nested within during parsing. This is used 47 /// specifically to provide additional verification during parsing, e.g. to 48 /// prevent using rewrites within a match context, matcher constraints within 49 /// a rewrite section, etc. 50 enum class ParserContext { 51 /// The parser is in the global context. 52 Global, 53 /// The parser is currently within a Constraint, which disallows all types 54 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 55 Constraint, 56 /// The parser is currently within the matcher portion of a Pattern, which 57 /// is allows a terminal operation rewrite statement but no other rewrite 58 /// transformations. 59 PatternMatch, 60 /// The parser is currently within a Rewrite, which disallows calls to 61 /// constraints, requires operation expressions to have names, etc. 62 Rewrite, 63 }; 64 65 //===--------------------------------------------------------------------===// 66 // Parsing 67 //===--------------------------------------------------------------------===// 68 69 /// Push a new decl scope onto the lexer. 70 ast::DeclScope *pushDeclScope() { 71 ast::DeclScope *newScope = 72 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 73 return (curDeclScope = newScope); 74 } 75 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 76 77 /// Pop the last decl scope from the lexer. 78 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 79 80 /// Parse the body of an AST module. 81 LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls); 82 83 /// Try to convert the given expression to `type`. Returns failure and emits 84 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 85 /// invoked to attach notes to the emitted error diagnostic. On success, 86 /// `expr` is updated to the expression used to convert to `type`. 87 LogicalResult convertExpressionTo( 88 ast::Expr *&expr, ast::Type type, 89 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 90 91 /// Given an operation expression, convert it to a Value or ValueRange 92 /// typed expression. 93 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 94 95 //===--------------------------------------------------------------------===// 96 // Directives 97 98 LogicalResult parseDirective(SmallVector<ast::Decl *> &decls); 99 LogicalResult parseInclude(SmallVector<ast::Decl *> &decls); 100 101 //===--------------------------------------------------------------------===// 102 // Decls 103 104 /// This structure contains the set of pattern metadata that may be parsed. 105 struct ParsedPatternMetadata { 106 Optional<uint16_t> benefit; 107 bool hasBoundedRecursion = false; 108 }; 109 110 FailureOr<ast::Decl *> parseTopLevelDecl(); 111 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 112 113 /// Parse an argument variable as part of the signature of a 114 /// UserConstraintDecl or UserRewriteDecl. 115 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 116 117 /// Parse a result variable as part of the signature of a UserConstraintDecl 118 /// or UserRewriteDecl. 119 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 120 121 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 122 /// defined in a non-global context. 123 FailureOr<ast::UserConstraintDecl *> 124 parseUserConstraintDecl(bool isInline = false); 125 126 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 127 /// non-global context, such as within a Pattern/Constraint/etc. 128 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 129 130 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 131 /// PDLL constructs. 132 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 133 const ast::Name &name, bool isInline, 134 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 135 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 136 137 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 138 /// defined in a non-global context. 139 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 140 141 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 142 /// non-global context, such as within a Pattern/Rewrite/etc. 143 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 144 145 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 146 /// PDLL constructs. 147 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 148 const ast::Name &name, bool isInline, 149 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 150 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 151 152 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 153 /// effectively the same syntax, and only differ on slight semantics (given 154 /// the different parsing contexts). 155 template <typename T, typename ParseUserPDLLDeclFnT> 156 FailureOr<T *> parseUserConstraintOrRewriteDecl( 157 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 158 StringRef anonymousNamePrefix, bool isInline); 159 160 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 161 /// These decls have effectively the same syntax. 162 template <typename T> 163 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 164 const ast::Name &name, bool isInline, 165 ArrayRef<ast::VariableDecl *> arguments, 166 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 167 168 /// Parse the functional signature (i.e. the arguments and results) of a 169 /// UserConstraintDecl or UserRewriteDecl. 170 LogicalResult parseUserConstraintOrRewriteSignature( 171 SmallVectorImpl<ast::VariableDecl *> &arguments, 172 SmallVectorImpl<ast::VariableDecl *> &results, 173 ast::DeclScope *&argumentScope, ast::Type &resultType); 174 175 /// Validate the return (which if present is specified by bodyIt) of a 176 /// UserConstraintDecl or UserRewriteDecl. 177 LogicalResult validateUserConstraintOrRewriteReturn( 178 StringRef declType, ast::CompoundStmt *body, 179 ArrayRef<ast::Stmt *>::iterator bodyIt, 180 ArrayRef<ast::Stmt *>::iterator bodyE, 181 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 182 183 FailureOr<ast::CompoundStmt *> 184 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 185 bool expectTerminalSemicolon = true); 186 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 187 FailureOr<ast::Decl *> parsePatternDecl(); 188 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 189 190 /// Check to see if a decl has already been defined with the given name, if 191 /// one has emit and error and return failure. Returns success otherwise. 192 LogicalResult checkDefineNamedDecl(const ast::Name &name); 193 194 /// Try to define a variable decl with the given components, returns the 195 /// variable on success. 196 FailureOr<ast::VariableDecl *> 197 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 198 ast::Expr *initExpr, 199 ArrayRef<ast::ConstraintRef> constraints); 200 FailureOr<ast::VariableDecl *> 201 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 202 ArrayRef<ast::ConstraintRef> constraints); 203 204 /// Parse the constraint reference list for a variable decl. 205 LogicalResult parseVariableDeclConstraintList( 206 SmallVectorImpl<ast::ConstraintRef> &constraints); 207 208 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 209 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 210 211 /// Try to parse a single reference to a constraint. `typeConstraint` is the 212 /// location of a previously parsed type constraint for the entity that will 213 /// be constrained by the parsed constraint. `existingConstraints` are any 214 /// existing constraints that have already been parsed for the same entity 215 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 216 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 217 FailureOr<ast::ConstraintRef> 218 parseConstraint(Optional<SMRange> &typeConstraint, 219 ArrayRef<ast::ConstraintRef> existingConstraints, 220 bool allowInlineTypeConstraints); 221 222 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 223 /// argument or result variable. The constraints for these variables do not 224 /// allow inline type constraints, and only permit a single constraint. 225 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 226 227 //===--------------------------------------------------------------------===// 228 // Exprs 229 230 FailureOr<ast::Expr *> parseExpr(); 231 232 /// Identifier expressions. 233 FailureOr<ast::Expr *> parseAttributeExpr(); 234 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 235 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 236 FailureOr<ast::Expr *> parseIdentifierExpr(); 237 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 238 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 239 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 240 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 241 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 242 FailureOr<ast::Expr *> parseOperationExpr(); 243 FailureOr<ast::Expr *> parseTupleExpr(); 244 FailureOr<ast::Expr *> parseTypeExpr(); 245 FailureOr<ast::Expr *> parseUnderscoreExpr(); 246 247 //===--------------------------------------------------------------------===// 248 // Stmts 249 250 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 251 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 252 FailureOr<ast::EraseStmt *> parseEraseStmt(); 253 FailureOr<ast::LetStmt *> parseLetStmt(); 254 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 255 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 256 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 257 258 //===--------------------------------------------------------------------===// 259 // Creation+Analysis 260 //===--------------------------------------------------------------------===// 261 262 //===--------------------------------------------------------------------===// 263 // Decls 264 265 /// Try to extract a callable from the given AST node. Returns nullptr on 266 /// failure. 267 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 268 269 /// Try to create a pattern decl with the given components, returning the 270 /// Pattern on success. 271 FailureOr<ast::PatternDecl *> 272 createPatternDecl(SMRange loc, const ast::Name *name, 273 const ParsedPatternMetadata &metadata, 274 ast::CompoundStmt *body); 275 276 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 277 /// of results, defined as part of the signature. 278 ast::Type 279 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 280 281 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 282 template <typename T> 283 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 284 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 285 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 286 ast::CompoundStmt *body); 287 288 /// Try to create a variable decl with the given components, returning the 289 /// Variable on success. 290 FailureOr<ast::VariableDecl *> 291 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 292 ArrayRef<ast::ConstraintRef> constraints); 293 294 /// Create a variable for an argument or result defined as part of the 295 /// signature of a UserConstraintDecl/UserRewriteDecl. 296 FailureOr<ast::VariableDecl *> 297 createArgOrResultVariableDecl(StringRef name, SMRange loc, 298 const ast::ConstraintRef &constraint); 299 300 /// Validate the constraints used to constraint a variable decl. 301 /// `inferredType` is the type of the variable inferred by the constraints 302 /// within the list, and is updated to the most refined type as determined by 303 /// the constraints. Returns success if the constraint list is valid, failure 304 /// otherwise. 305 LogicalResult 306 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 307 ast::Type &inferredType); 308 /// Validate a single reference to a constraint. `inferredType` contains the 309 /// currently inferred variabled type and is refined within the type defined 310 /// by the constraint. Returns success if the constraint is valid, failure 311 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 312 /// defined constraints) may be used with the variable. 313 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 314 ast::Type &inferredType, 315 bool allowNonCoreConstraints = true); 316 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 317 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 318 319 //===--------------------------------------------------------------------===// 320 // Exprs 321 322 FailureOr<ast::CallExpr *> 323 createCallExpr(SMRange loc, ast::Expr *parentExpr, 324 MutableArrayRef<ast::Expr *> arguments); 325 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 326 FailureOr<ast::DeclRefExpr *> 327 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 328 ArrayRef<ast::ConstraintRef> constraints); 329 FailureOr<ast::MemberAccessExpr *> 330 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 331 332 /// Validate the member access `name` into the given parent expression. On 333 /// success, this also returns the type of the member accessed. 334 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 335 StringRef name, SMRange loc); 336 FailureOr<ast::OperationExpr *> 337 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 338 MutableArrayRef<ast::Expr *> operands, 339 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 340 MutableArrayRef<ast::Expr *> results); 341 LogicalResult 342 validateOperationOperands(SMRange loc, Optional<StringRef> name, 343 MutableArrayRef<ast::Expr *> operands); 344 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 345 MutableArrayRef<ast::Expr *> results); 346 LogicalResult 347 validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name, 348 MutableArrayRef<ast::Expr *> values, 349 ast::Type singleTy, ast::Type rangeTy); 350 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 351 ArrayRef<ast::Expr *> elements, 352 ArrayRef<StringRef> elementNames); 353 354 //===--------------------------------------------------------------------===// 355 // Stmts 356 357 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 358 FailureOr<ast::ReplaceStmt *> 359 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 360 MutableArrayRef<ast::Expr *> replValues); 361 FailureOr<ast::RewriteStmt *> 362 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 363 ast::CompoundStmt *rewriteBody); 364 365 //===--------------------------------------------------------------------===// 366 // Lexer Utilities 367 //===--------------------------------------------------------------------===// 368 369 /// If the current token has the specified kind, consume it and return true. 370 /// If not, return false. 371 bool consumeIf(Token::Kind kind) { 372 if (curToken.isNot(kind)) 373 return false; 374 consumeToken(kind); 375 return true; 376 } 377 378 /// Advance the current lexer onto the next token. 379 void consumeToken() { 380 assert(curToken.isNot(Token::eof, Token::error) && 381 "shouldn't advance past EOF or errors"); 382 curToken = lexer.lexToken(); 383 } 384 385 /// Advance the current lexer onto the next token, asserting what the expected 386 /// current token is. This is preferred to the above method because it leads 387 /// to more self-documenting code with better checking. 388 void consumeToken(Token::Kind kind) { 389 assert(curToken.is(kind) && "consumed an unexpected token"); 390 consumeToken(); 391 } 392 393 /// Reset the lexer to the location at the given position. 394 void resetToken(SMRange tokLoc) { 395 lexer.resetPointer(tokLoc.Start.getPointer()); 396 curToken = lexer.lexToken(); 397 } 398 399 /// Consume the specified token if present and return success. On failure, 400 /// output a diagnostic and return failure. 401 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 402 if (curToken.getKind() != kind) 403 return emitError(curToken.getLoc(), msg); 404 consumeToken(); 405 return success(); 406 } 407 LogicalResult emitError(SMRange loc, const Twine &msg) { 408 lexer.emitError(loc, msg); 409 return failure(); 410 } 411 LogicalResult emitError(const Twine &msg) { 412 return emitError(curToken.getLoc(), msg); 413 } 414 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 415 const Twine ¬e) { 416 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 417 return failure(); 418 } 419 420 //===--------------------------------------------------------------------===// 421 // Fields 422 //===--------------------------------------------------------------------===// 423 424 /// The owning AST context. 425 ast::Context &ctx; 426 427 /// The lexer of this parser. 428 Lexer lexer; 429 430 /// The current token within the lexer. 431 Token curToken; 432 433 /// The most recently defined decl scope. 434 ast::DeclScope *curDeclScope; 435 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 436 437 /// The current context of the parser. 438 ParserContext parserContext = ParserContext::Global; 439 440 /// Cached types to simplify verification and expression creation. 441 ast::Type valueTy, valueRangeTy; 442 ast::Type typeTy, typeRangeTy; 443 444 /// A counter used when naming anonymous constraints and rewrites. 445 unsigned anonymousDeclNameCounter = 0; 446 }; 447 } // namespace 448 449 FailureOr<ast::Module *> Parser::parseModule() { 450 SMLoc moduleLoc = curToken.getStartLoc(); 451 pushDeclScope(); 452 453 // Parse the top-level decls of the module. 454 SmallVector<ast::Decl *> decls; 455 if (failed(parseModuleBody(decls))) 456 return popDeclScope(), failure(); 457 458 popDeclScope(); 459 return ast::Module::create(ctx, moduleLoc, decls); 460 } 461 462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) { 463 while (curToken.isNot(Token::eof)) { 464 if (curToken.is(Token::directive)) { 465 if (failed(parseDirective(decls))) 466 return failure(); 467 continue; 468 } 469 470 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 471 if (failed(decl)) 472 return failure(); 473 decls.push_back(*decl); 474 } 475 return success(); 476 } 477 478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 479 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 480 valueRangeTy); 481 } 482 483 LogicalResult Parser::convertExpressionTo( 484 ast::Expr *&expr, ast::Type type, 485 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 486 ast::Type exprType = expr->getType(); 487 if (exprType == type) 488 return success(); 489 490 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 491 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 492 expr->getLoc(), llvm::formatv("unable to convert expression of type " 493 "`{0}` to the expected type of " 494 "`{1}`", 495 exprType, type)); 496 if (noteAttachFn) 497 noteAttachFn(*diag); 498 return diag; 499 }; 500 501 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 502 // Two operation types are compatible if they have the same name, or if the 503 // expected type is more general. 504 if (auto opType = type.dyn_cast<ast::OperationType>()) { 505 if (opType.getName()) 506 return emitConvertError(); 507 return success(); 508 } 509 510 // An operation can always convert to a ValueRange. 511 if (type == valueRangeTy) { 512 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 513 valueRangeTy); 514 return success(); 515 } 516 517 // Allow conversion to a single value by constraining the result range. 518 if (type == valueTy) { 519 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 520 valueTy); 521 return success(); 522 } 523 return emitConvertError(); 524 } 525 526 // FIXME: Decide how to allow/support converting a single result to multiple, 527 // and multiple to a single result. For now, we just allow Single->Range, 528 // but this isn't something really supported in the PDL dialect. We should 529 // figure out some way to support both. 530 if ((exprType == valueTy || exprType == valueRangeTy) && 531 (type == valueTy || type == valueRangeTy)) 532 return success(); 533 if ((exprType == typeTy || exprType == typeRangeTy) && 534 (type == typeTy || type == typeRangeTy)) 535 return success(); 536 537 // Handle tuple types. 538 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 539 auto tupleType = type.dyn_cast<ast::TupleType>(); 540 if (!tupleType || tupleType.size() != exprTupleType.size()) 541 return emitConvertError(); 542 543 // Build a new tuple expression using each of the elements of the current 544 // tuple. 545 SmallVector<ast::Expr *> newExprs; 546 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 547 newExprs.push_back(ast::MemberAccessExpr::create( 548 ctx, expr->getLoc(), expr, llvm::to_string(i), 549 exprTupleType.getElementTypes()[i])); 550 551 auto diagFn = [&](ast::Diagnostic &diag) { 552 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 553 i, exprTupleType)); 554 if (noteAttachFn) 555 noteAttachFn(diag); 556 }; 557 if (failed(convertExpressionTo(newExprs.back(), 558 tupleType.getElementTypes()[i], diagFn))) 559 return failure(); 560 } 561 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 562 tupleType.getElementNames()); 563 return success(); 564 } 565 566 return emitConvertError(); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // Directives 571 572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) { 573 StringRef directive = curToken.getSpelling(); 574 if (directive == "#include") 575 return parseInclude(decls); 576 577 return emitError("unknown directive `" + directive + "`"); 578 } 579 580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) { 581 SMRange loc = curToken.getLoc(); 582 consumeToken(Token::directive); 583 584 // Parse the file being included. 585 if (!curToken.isString()) 586 return emitError(loc, 587 "expected string file name after `include` directive"); 588 SMRange fileLoc = curToken.getLoc(); 589 std::string filenameStr = curToken.getStringValue(); 590 StringRef filename = filenameStr; 591 consumeToken(); 592 593 // Check the type of include. If ending with `.pdll`, this is another pdl file 594 // to be parsed along with the current module. 595 if (filename.endswith(".pdll")) { 596 if (failed(lexer.pushInclude(filename))) 597 return emitError(fileLoc, 598 "unable to open include file `" + filename + "`"); 599 600 // If we added the include successfully, parse it into the current module. 601 // Make sure to save the current token so that we can restore it when we 602 // finish parsing the nested file. 603 Token oldToken = curToken; 604 curToken = lexer.lexToken(); 605 LogicalResult result = parseModuleBody(decls); 606 curToken = oldToken; 607 return result; 608 } 609 610 return emitError(fileLoc, "expected include filename to end with `.pdll`"); 611 } 612 613 //===----------------------------------------------------------------------===// 614 // Decls 615 616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 617 FailureOr<ast::Decl *> decl; 618 switch (curToken.getKind()) { 619 case Token::kw_Constraint: 620 decl = parseUserConstraintDecl(); 621 break; 622 case Token::kw_Pattern: 623 decl = parsePatternDecl(); 624 break; 625 case Token::kw_Rewrite: 626 decl = parseUserRewriteDecl(); 627 break; 628 default: 629 return emitError("expected top-level declaration, such as a `Pattern`"); 630 } 631 if (failed(decl)) 632 return failure(); 633 634 // If the decl has a name, add it to the current scope. 635 if (const ast::Name *name = (*decl)->getName()) { 636 if (failed(checkDefineNamedDecl(*name))) 637 return failure(); 638 curDeclScope->add(*decl); 639 } 640 return decl; 641 } 642 643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 644 std::string attrNameStr; 645 if (curToken.isString()) 646 attrNameStr = curToken.getStringValue(); 647 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 648 attrNameStr = curToken.getSpelling().str(); 649 else 650 return emitError("expected identifier or string attribute name"); 651 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 652 consumeToken(); 653 654 // Check for a value of the attribute. 655 ast::Expr *attrValue = nullptr; 656 if (consumeIf(Token::equal)) { 657 FailureOr<ast::Expr *> attrExpr = parseExpr(); 658 if (failed(attrExpr)) 659 return failure(); 660 attrValue = *attrExpr; 661 } else { 662 // If there isn't a concrete value, create an expression representing a 663 // UnitAttr. 664 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 665 } 666 667 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 668 } 669 670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 671 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 672 bool expectTerminalSemicolon) { 673 consumeToken(Token::equal_arrow); 674 675 // Parse the single statement of the lambda body. 676 SMLoc bodyStartLoc = curToken.getStartLoc(); 677 pushDeclScope(); 678 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 679 bool failedToParse = 680 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 681 popDeclScope(); 682 if (failedToParse) 683 return failure(); 684 685 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 686 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 687 } 688 689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 690 // Ensure that the argument is named. 691 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 692 return emitError("expected identifier argument name"); 693 694 // Parse the argument similarly to a normal variable. 695 StringRef name = curToken.getSpelling(); 696 SMRange nameLoc = curToken.getLoc(); 697 consumeToken(); 698 699 if (failed( 700 parseToken(Token::colon, "expected `:` before argument constraint"))) 701 return failure(); 702 703 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 704 if (failed(cst)) 705 return failure(); 706 707 return createArgOrResultVariableDecl(name, nameLoc, *cst); 708 } 709 710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 711 // Check to see if this result is named. 712 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 713 // Check to see if this name actually refers to a Constraint. 714 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 715 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 716 // If yes, and this is a Rewrite, give a nice error message as non-Core 717 // constraints are not supported on Rewrite results. 718 if (parserContext == ParserContext::Rewrite) { 719 return emitError( 720 "`Rewrite` results are only permitted to use core constraints, " 721 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 722 } 723 724 // Otherwise, parse this as an unnamed result variable. 725 } else { 726 // If it wasn't a constraint, parse the result similarly to a variable. If 727 // there is already an existing decl, we will emit an error when defining 728 // this variable later. 729 StringRef name = curToken.getSpelling(); 730 SMRange nameLoc = curToken.getLoc(); 731 consumeToken(); 732 733 if (failed(parseToken(Token::colon, 734 "expected `:` before result constraint"))) 735 return failure(); 736 737 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 738 if (failed(cst)) 739 return failure(); 740 741 return createArgOrResultVariableDecl(name, nameLoc, *cst); 742 } 743 } 744 745 // If it isn't named, we parse the constraint directly and create an unnamed 746 // result variable. 747 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 748 if (failed(cst)) 749 return failure(); 750 751 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 752 } 753 754 FailureOr<ast::UserConstraintDecl *> 755 Parser::parseUserConstraintDecl(bool isInline) { 756 // Constraints and rewrites have very similar formats, dispatch to a shared 757 // interface for parsing. 758 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 759 [&](auto &&...args) { 760 return this->parseUserPDLLConstraintDecl(args...); 761 }, 762 ParserContext::Constraint, "constraint", isInline); 763 } 764 765 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 766 FailureOr<ast::UserConstraintDecl *> decl = 767 parseUserConstraintDecl(/*isInline=*/true); 768 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 769 return failure(); 770 771 curDeclScope->add(*decl); 772 return decl; 773 } 774 775 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 776 const ast::Name &name, bool isInline, 777 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 778 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 779 // Push the argument scope back onto the list, so that the body can 780 // reference arguments. 781 pushDeclScope(argumentScope); 782 783 // Parse the body of the constraint. The body is either defined as a compound 784 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 785 ast::CompoundStmt *body; 786 if (curToken.is(Token::equal_arrow)) { 787 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 788 [&](ast::Stmt *&stmt) -> LogicalResult { 789 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 790 if (!stmtExpr) { 791 return emitError(stmt->getLoc(), 792 "expected `Constraint` lambda body to contain a " 793 "single expression"); 794 } 795 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 796 return success(); 797 }, 798 /*expectTerminalSemicolon=*/!isInline); 799 if (failed(bodyResult)) 800 return failure(); 801 body = *bodyResult; 802 } else { 803 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 804 if (failed(bodyResult)) 805 return failure(); 806 body = *bodyResult; 807 808 // Verify the structure of the body. 809 auto bodyIt = body->begin(), bodyE = body->end(); 810 for (; bodyIt != bodyE; ++bodyIt) 811 if (isa<ast::ReturnStmt>(*bodyIt)) 812 break; 813 if (failed(validateUserConstraintOrRewriteReturn( 814 "Constraint", body, bodyIt, bodyE, results, resultType))) 815 return failure(); 816 } 817 popDeclScope(); 818 819 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 820 name, arguments, results, resultType, body); 821 } 822 823 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 824 // Constraints and rewrites have very similar formats, dispatch to a shared 825 // interface for parsing. 826 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 827 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 828 ParserContext::Rewrite, "rewrite", isInline); 829 } 830 831 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 832 FailureOr<ast::UserRewriteDecl *> decl = 833 parseUserRewriteDecl(/*isInline=*/true); 834 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 835 return failure(); 836 837 curDeclScope->add(*decl); 838 return decl; 839 } 840 841 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 842 const ast::Name &name, bool isInline, 843 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 844 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 845 // Push the argument scope back onto the list, so that the body can 846 // reference arguments. 847 curDeclScope = argumentScope; 848 ast::CompoundStmt *body; 849 if (curToken.is(Token::equal_arrow)) { 850 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 851 [&](ast::Stmt *&statement) -> LogicalResult { 852 if (isa<ast::OpRewriteStmt>(statement)) 853 return success(); 854 855 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 856 if (!statementExpr) { 857 return emitError( 858 statement->getLoc(), 859 "expected `Rewrite` lambda body to contain a single expression " 860 "or an operation rewrite statement; such as `erase`, " 861 "`replace`, or `rewrite`"); 862 } 863 statement = 864 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 865 return success(); 866 }, 867 /*expectTerminalSemicolon=*/!isInline); 868 if (failed(bodyResult)) 869 return failure(); 870 body = *bodyResult; 871 } else { 872 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 873 if (failed(bodyResult)) 874 return failure(); 875 body = *bodyResult; 876 } 877 popDeclScope(); 878 879 // Verify the structure of the body. 880 auto bodyIt = body->begin(), bodyE = body->end(); 881 for (; bodyIt != bodyE; ++bodyIt) 882 if (isa<ast::ReturnStmt>(*bodyIt)) 883 break; 884 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 885 bodyE, results, resultType))) 886 return failure(); 887 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 888 name, arguments, results, resultType, body); 889 } 890 891 template <typename T, typename ParseUserPDLLDeclFnT> 892 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 893 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 894 StringRef anonymousNamePrefix, bool isInline) { 895 SMRange loc = curToken.getLoc(); 896 consumeToken(); 897 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 898 899 // Parse the name of the decl. 900 const ast::Name *name = nullptr; 901 if (curToken.isNot(Token::identifier)) { 902 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 903 // in C++, so being unnamed is fine. 904 if (!isInline) 905 return emitError("expected identifier name"); 906 907 // Create a unique anonymous name to use, as the name for this decl is not 908 // important. 909 std::string anonName = 910 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 911 anonymousDeclNameCounter++) 912 .str(); 913 name = &ast::Name::create(ctx, anonName, loc); 914 } else { 915 // If a name was provided, we can use it directly. 916 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 917 consumeToken(Token::identifier); 918 } 919 920 // Parse the functional signature of the decl. 921 SmallVector<ast::VariableDecl *> arguments, results; 922 ast::DeclScope *argumentScope; 923 ast::Type resultType; 924 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 925 argumentScope, resultType))) 926 return failure(); 927 928 // Check to see which type of constraint this is. If the constraint contains a 929 // compound body, this is a PDLL decl. 930 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 931 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 932 resultType); 933 934 // Otherwise, this is a native decl. 935 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 936 results, resultType); 937 } 938 939 template <typename T> 940 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 941 const ast::Name &name, bool isInline, 942 ArrayRef<ast::VariableDecl *> arguments, 943 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 944 // If followed by a string, the native code body has also been specified. 945 std::string codeStrStorage; 946 Optional<StringRef> optCodeStr; 947 if (curToken.isString()) { 948 codeStrStorage = curToken.getStringValue(); 949 optCodeStr = codeStrStorage; 950 consumeToken(); 951 } else if (isInline) { 952 return emitError(name.getLoc(), 953 "external declarations must be declared in global scope"); 954 } 955 if (failed(parseToken(Token::semicolon, 956 "expected `;` after native declaration"))) 957 return failure(); 958 // TODO: PDL should be able to support constraint results in certain 959 // situations, we should revise this. 960 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 961 return emitError( 962 "native Constraints currently do not support returning results"); 963 } 964 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 965 } 966 967 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 968 SmallVectorImpl<ast::VariableDecl *> &arguments, 969 SmallVectorImpl<ast::VariableDecl *> &results, 970 ast::DeclScope *&argumentScope, ast::Type &resultType) { 971 // Parse the argument list of the decl. 972 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 973 return failure(); 974 975 argumentScope = pushDeclScope(); 976 if (curToken.isNot(Token::r_paren)) { 977 do { 978 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 979 if (failed(argument)) 980 return failure(); 981 arguments.emplace_back(*argument); 982 } while (consumeIf(Token::comma)); 983 } 984 popDeclScope(); 985 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 986 return failure(); 987 988 // Parse the results of the decl. 989 pushDeclScope(); 990 if (consumeIf(Token::arrow)) { 991 auto parseResultFn = [&]() -> LogicalResult { 992 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 993 if (failed(result)) 994 return failure(); 995 results.emplace_back(*result); 996 return success(); 997 }; 998 999 // Check for a list of results. 1000 if (consumeIf(Token::l_paren)) { 1001 do { 1002 if (failed(parseResultFn())) 1003 return failure(); 1004 } while (consumeIf(Token::comma)); 1005 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1006 return failure(); 1007 1008 // Otherwise, there is only one result. 1009 } else if (failed(parseResultFn())) { 1010 return failure(); 1011 } 1012 } 1013 popDeclScope(); 1014 1015 // Compute the result type of the decl. 1016 resultType = createUserConstraintRewriteResultType(results); 1017 1018 // Verify that results are only named if there are more than one. 1019 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1020 return emitError( 1021 results.front()->getLoc(), 1022 "cannot create a single-element tuple with an element label"); 1023 } 1024 return success(); 1025 } 1026 1027 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1028 StringRef declType, ast::CompoundStmt *body, 1029 ArrayRef<ast::Stmt *>::iterator bodyIt, 1030 ArrayRef<ast::Stmt *>::iterator bodyE, 1031 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1032 // Handle if a `return` was provided. 1033 if (bodyIt != bodyE) { 1034 // Emit an error if we have trailing statements after the return. 1035 if (std::next(bodyIt) != bodyE) { 1036 return emitError( 1037 (*std::next(bodyIt))->getLoc(), 1038 llvm::formatv("`return` terminated the `{0}` body, but found " 1039 "trailing statements afterwards", 1040 declType)); 1041 } 1042 1043 // Otherwise if a return wasn't provided, check that no results are 1044 // expected. 1045 } else if (!results.empty()) { 1046 return emitError( 1047 {body->getLoc().End, body->getLoc().End}, 1048 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1049 declType, resultType)); 1050 } 1051 return success(); 1052 } 1053 1054 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1055 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1056 if (isa<ast::OpRewriteStmt>(statement)) 1057 return success(); 1058 return emitError( 1059 statement->getLoc(), 1060 "expected Pattern lambda body to contain a single operation " 1061 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1062 }); 1063 } 1064 1065 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1066 SMRange loc = curToken.getLoc(); 1067 consumeToken(Token::kw_Pattern); 1068 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1069 ParserContext::PatternMatch); 1070 1071 // Check for an optional identifier for the pattern name. 1072 const ast::Name *name = nullptr; 1073 if (curToken.is(Token::identifier)) { 1074 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1075 consumeToken(Token::identifier); 1076 } 1077 1078 // Parse any pattern metadata. 1079 ParsedPatternMetadata metadata; 1080 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1081 return failure(); 1082 1083 // Parse the pattern body. 1084 ast::CompoundStmt *body; 1085 1086 // Handle a lambda body. 1087 if (curToken.is(Token::equal_arrow)) { 1088 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1089 if (failed(bodyResult)) 1090 return failure(); 1091 body = *bodyResult; 1092 } else { 1093 if (curToken.isNot(Token::l_brace)) 1094 return emitError("expected `{` or `=>` to start pattern body"); 1095 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1096 if (failed(bodyResult)) 1097 return failure(); 1098 body = *bodyResult; 1099 1100 // Verify the body of the pattern. 1101 auto bodyIt = body->begin(), bodyE = body->end(); 1102 for (; bodyIt != bodyE; ++bodyIt) { 1103 if (isa<ast::ReturnStmt>(*bodyIt)) { 1104 return emitError((*bodyIt)->getLoc(), 1105 "`return` statements are only permitted within a " 1106 "`Constraint` or `Rewrite` body"); 1107 } 1108 // Break when we've found the rewrite statement. 1109 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1110 break; 1111 } 1112 if (bodyIt == bodyE) { 1113 return emitError(loc, 1114 "expected Pattern body to terminate with an operation " 1115 "rewrite statement, such as `erase`"); 1116 } 1117 if (std::next(bodyIt) != bodyE) { 1118 return emitError((*std::next(bodyIt))->getLoc(), 1119 "Pattern body was terminated by an operation " 1120 "rewrite statement, but found trailing statements"); 1121 } 1122 } 1123 1124 return createPatternDecl(loc, name, metadata, body); 1125 } 1126 1127 LogicalResult 1128 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1129 Optional<SMRange> benefitLoc; 1130 Optional<SMRange> hasBoundedRecursionLoc; 1131 1132 do { 1133 if (curToken.isNot(Token::identifier)) 1134 return emitError("expected pattern metadata identifier"); 1135 StringRef metadataStr = curToken.getSpelling(); 1136 SMRange metadataLoc = curToken.getLoc(); 1137 consumeToken(Token::identifier); 1138 1139 // Parse the benefit metadata: benefit(<integer-value>) 1140 if (metadataStr == "benefit") { 1141 if (benefitLoc) { 1142 return emitErrorAndNote(metadataLoc, 1143 "pattern benefit has already been specified", 1144 *benefitLoc, "see previous definition here"); 1145 } 1146 if (failed(parseToken(Token::l_paren, 1147 "expected `(` before pattern benefit"))) 1148 return failure(); 1149 1150 uint16_t benefitValue = 0; 1151 if (curToken.isNot(Token::integer)) 1152 return emitError("expected integral pattern benefit"); 1153 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1154 return emitError( 1155 "expected pattern benefit to fit within a 16-bit integer"); 1156 consumeToken(Token::integer); 1157 1158 metadata.benefit = benefitValue; 1159 benefitLoc = metadataLoc; 1160 1161 if (failed( 1162 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1163 return failure(); 1164 continue; 1165 } 1166 1167 // Parse the bounded recursion metadata: recursion 1168 if (metadataStr == "recursion") { 1169 if (hasBoundedRecursionLoc) { 1170 return emitErrorAndNote( 1171 metadataLoc, 1172 "pattern recursion metadata has already been specified", 1173 *hasBoundedRecursionLoc, "see previous definition here"); 1174 } 1175 metadata.hasBoundedRecursion = true; 1176 hasBoundedRecursionLoc = metadataLoc; 1177 continue; 1178 } 1179 1180 return emitError(metadataLoc, "unknown pattern metadata"); 1181 } while (consumeIf(Token::comma)); 1182 1183 return success(); 1184 } 1185 1186 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1187 consumeToken(Token::less); 1188 1189 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1190 if (failed(typeExpr) || 1191 failed(parseToken(Token::greater, 1192 "expected `>` after variable type constraint"))) 1193 return failure(); 1194 return typeExpr; 1195 } 1196 1197 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1198 assert(curDeclScope && "defining decl outside of a decl scope"); 1199 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1200 return emitErrorAndNote( 1201 name.getLoc(), "`" + name.getName() + "` has already been defined", 1202 lastDecl->getName()->getLoc(), "see previous definition here"); 1203 } 1204 return success(); 1205 } 1206 1207 FailureOr<ast::VariableDecl *> 1208 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1209 ast::Expr *initExpr, 1210 ArrayRef<ast::ConstraintRef> constraints) { 1211 assert(curDeclScope && "defining variable outside of decl scope"); 1212 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1213 1214 // If the name of the variable indicates a special variable, we don't add it 1215 // to the scope. This variable is local to the definition point. 1216 if (name.empty() || name == "_") { 1217 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1218 constraints); 1219 } 1220 if (failed(checkDefineNamedDecl(nameDecl))) 1221 return failure(); 1222 1223 auto *varDecl = 1224 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1225 curDeclScope->add(varDecl); 1226 return varDecl; 1227 } 1228 1229 FailureOr<ast::VariableDecl *> 1230 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1231 ArrayRef<ast::ConstraintRef> constraints) { 1232 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1233 constraints); 1234 } 1235 1236 LogicalResult Parser::parseVariableDeclConstraintList( 1237 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1238 Optional<SMRange> typeConstraint; 1239 auto parseSingleConstraint = [&] { 1240 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1241 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); 1242 if (failed(constraint)) 1243 return failure(); 1244 constraints.push_back(*constraint); 1245 return success(); 1246 }; 1247 1248 // Check to see if this is a single constraint, or a list. 1249 if (!consumeIf(Token::l_square)) 1250 return parseSingleConstraint(); 1251 1252 do { 1253 if (failed(parseSingleConstraint())) 1254 return failure(); 1255 } while (consumeIf(Token::comma)); 1256 return parseToken(Token::r_square, "expected `]` after constraint list"); 1257 } 1258 1259 FailureOr<ast::ConstraintRef> 1260 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1261 ArrayRef<ast::ConstraintRef> existingConstraints, 1262 bool allowInlineTypeConstraints) { 1263 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1264 if (!allowInlineTypeConstraints) { 1265 return emitError( 1266 curToken.getLoc(), 1267 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1268 "permitted on arguments or results"); 1269 } 1270 if (typeConstraint) 1271 return emitErrorAndNote( 1272 curToken.getLoc(), 1273 "the type of this variable has already been constrained", 1274 *typeConstraint, "see previous constraint location here"); 1275 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1276 if (failed(constraintExpr)) 1277 return failure(); 1278 typeExpr = *constraintExpr; 1279 typeConstraint = typeExpr->getLoc(); 1280 return success(); 1281 }; 1282 1283 SMRange loc = curToken.getLoc(); 1284 switch (curToken.getKind()) { 1285 case Token::kw_Attr: { 1286 consumeToken(Token::kw_Attr); 1287 1288 // Check for a type constraint. 1289 ast::Expr *typeExpr = nullptr; 1290 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1291 return failure(); 1292 return ast::ConstraintRef( 1293 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1294 } 1295 case Token::kw_Op: { 1296 consumeToken(Token::kw_Op); 1297 1298 // Parse an optional operation name. If the name isn't provided, this refers 1299 // to "any" operation. 1300 FailureOr<ast::OpNameDecl *> opName = 1301 parseWrappedOperationName(/*allowEmptyName=*/true); 1302 if (failed(opName)) 1303 return failure(); 1304 1305 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1306 loc); 1307 } 1308 case Token::kw_Type: 1309 consumeToken(Token::kw_Type); 1310 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1311 case Token::kw_TypeRange: 1312 consumeToken(Token::kw_TypeRange); 1313 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1314 loc); 1315 case Token::kw_Value: { 1316 consumeToken(Token::kw_Value); 1317 1318 // Check for a type constraint. 1319 ast::Expr *typeExpr = nullptr; 1320 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1321 return failure(); 1322 1323 return ast::ConstraintRef( 1324 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1325 } 1326 case Token::kw_ValueRange: { 1327 consumeToken(Token::kw_ValueRange); 1328 1329 // Check for a type constraint. 1330 ast::Expr *typeExpr = nullptr; 1331 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1332 return failure(); 1333 1334 return ast::ConstraintRef( 1335 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1336 } 1337 1338 case Token::kw_Constraint: { 1339 // Handle an inline constraint. 1340 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1341 if (failed(decl)) 1342 return failure(); 1343 return ast::ConstraintRef(*decl, loc); 1344 } 1345 case Token::identifier: { 1346 StringRef constraintName = curToken.getSpelling(); 1347 consumeToken(Token::identifier); 1348 1349 // Lookup the referenced constraint. 1350 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1351 if (!cstDecl) { 1352 return emitError(loc, "unknown reference to constraint `" + 1353 constraintName + "`"); 1354 } 1355 1356 // Handle a reference to a proper constraint. 1357 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1358 return ast::ConstraintRef(cst, loc); 1359 1360 return emitErrorAndNote( 1361 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1362 "see the definition of `" + constraintName + "` here"); 1363 } 1364 default: 1365 break; 1366 } 1367 return emitError(loc, "expected identifier constraint"); 1368 } 1369 1370 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1371 Optional<SMRange> typeConstraint; 1372 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1373 /*allowInlineTypeConstraints=*/false); 1374 } 1375 1376 //===----------------------------------------------------------------------===// 1377 // Exprs 1378 1379 FailureOr<ast::Expr *> Parser::parseExpr() { 1380 if (curToken.is(Token::underscore)) 1381 return parseUnderscoreExpr(); 1382 1383 // Parse the LHS expression. 1384 FailureOr<ast::Expr *> lhsExpr; 1385 switch (curToken.getKind()) { 1386 case Token::kw_attr: 1387 lhsExpr = parseAttributeExpr(); 1388 break; 1389 case Token::kw_Constraint: 1390 lhsExpr = parseInlineConstraintLambdaExpr(); 1391 break; 1392 case Token::identifier: 1393 lhsExpr = parseIdentifierExpr(); 1394 break; 1395 case Token::kw_op: 1396 lhsExpr = parseOperationExpr(); 1397 break; 1398 case Token::kw_Rewrite: 1399 lhsExpr = parseInlineRewriteLambdaExpr(); 1400 break; 1401 case Token::kw_type: 1402 lhsExpr = parseTypeExpr(); 1403 break; 1404 case Token::l_paren: 1405 lhsExpr = parseTupleExpr(); 1406 break; 1407 default: 1408 return emitError("expected expression"); 1409 } 1410 if (failed(lhsExpr)) 1411 return failure(); 1412 1413 // Check for an operator expression. 1414 while (true) { 1415 switch (curToken.getKind()) { 1416 case Token::dot: 1417 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1418 break; 1419 case Token::l_paren: 1420 lhsExpr = parseCallExpr(*lhsExpr); 1421 break; 1422 default: 1423 return lhsExpr; 1424 } 1425 if (failed(lhsExpr)) 1426 return failure(); 1427 } 1428 } 1429 1430 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1431 SMRange loc = curToken.getLoc(); 1432 consumeToken(Token::kw_attr); 1433 1434 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1435 // identifier. 1436 if (!consumeIf(Token::less)) { 1437 resetToken(loc); 1438 return parseIdentifierExpr(); 1439 } 1440 1441 if (!curToken.isString()) 1442 return emitError("expected string literal containing MLIR attribute"); 1443 std::string attrExpr = curToken.getStringValue(); 1444 consumeToken(); 1445 1446 if (failed( 1447 parseToken(Token::greater, "expected `>` after attribute literal"))) 1448 return failure(); 1449 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1450 } 1451 1452 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1453 SMRange loc = curToken.getLoc(); 1454 consumeToken(Token::l_paren); 1455 1456 // Parse the arguments of the call. 1457 SmallVector<ast::Expr *> arguments; 1458 if (curToken.isNot(Token::r_paren)) { 1459 do { 1460 FailureOr<ast::Expr *> argument = parseExpr(); 1461 if (failed(argument)) 1462 return failure(); 1463 arguments.push_back(*argument); 1464 } while (consumeIf(Token::comma)); 1465 } 1466 loc.End = curToken.getEndLoc(); 1467 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1468 return failure(); 1469 1470 return createCallExpr(loc, parentExpr, arguments); 1471 } 1472 1473 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1474 ast::Decl *decl = curDeclScope->lookup(name); 1475 if (!decl) 1476 return emitError(loc, "undefined reference to `" + name + "`"); 1477 1478 return createDeclRefExpr(loc, decl); 1479 } 1480 1481 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1482 StringRef name = curToken.getSpelling(); 1483 SMRange nameLoc = curToken.getLoc(); 1484 consumeToken(); 1485 1486 // Check to see if this is a decl ref expression that defines a variable 1487 // inline. 1488 if (consumeIf(Token::colon)) { 1489 SmallVector<ast::ConstraintRef> constraints; 1490 if (failed(parseVariableDeclConstraintList(constraints))) 1491 return failure(); 1492 ast::Type type; 1493 if (failed(validateVariableConstraints(constraints, type))) 1494 return failure(); 1495 return createInlineVariableExpr(type, name, nameLoc, constraints); 1496 } 1497 1498 return parseDeclRefExpr(name, nameLoc); 1499 } 1500 1501 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1502 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1503 if (failed(decl)) 1504 return failure(); 1505 1506 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1507 ast::ConstraintType::get(ctx)); 1508 } 1509 1510 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1511 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1512 if (failed(decl)) 1513 return failure(); 1514 1515 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1516 ast::RewriteType::get(ctx)); 1517 } 1518 1519 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1520 SMRange loc = curToken.getLoc(); 1521 consumeToken(Token::dot); 1522 1523 // Parse the member name. 1524 Token memberNameTok = curToken; 1525 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1526 !memberNameTok.isKeyword()) 1527 return emitError(loc, "expected identifier or numeric member name"); 1528 StringRef memberName = memberNameTok.getSpelling(); 1529 consumeToken(); 1530 1531 return createMemberAccessExpr(parentExpr, memberName, loc); 1532 } 1533 1534 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1535 SMRange loc = curToken.getLoc(); 1536 1537 // Handle the case of an no operation name. 1538 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1539 if (allowEmptyName) 1540 return ast::OpNameDecl::create(ctx, SMRange()); 1541 return emitError("expected dialect namespace"); 1542 } 1543 StringRef name = curToken.getSpelling(); 1544 consumeToken(); 1545 1546 // Otherwise, this is a literal operation name. 1547 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1548 return failure(); 1549 1550 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1551 return emitError("expected operation name after dialect namespace"); 1552 1553 name = StringRef(name.data(), name.size() + 1); 1554 do { 1555 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1556 loc.End = curToken.getEndLoc(); 1557 consumeToken(); 1558 } while (curToken.isAny(Token::identifier, Token::dot) || 1559 curToken.isKeyword()); 1560 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1561 } 1562 1563 FailureOr<ast::OpNameDecl *> 1564 Parser::parseWrappedOperationName(bool allowEmptyName) { 1565 if (!consumeIf(Token::less)) 1566 return ast::OpNameDecl::create(ctx, SMRange()); 1567 1568 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1569 if (failed(opNameDecl)) 1570 return failure(); 1571 1572 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1573 return failure(); 1574 return opNameDecl; 1575 } 1576 1577 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1578 SMRange loc = curToken.getLoc(); 1579 consumeToken(Token::kw_op); 1580 1581 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1582 // identifier. 1583 if (curToken.isNot(Token::less)) { 1584 resetToken(loc); 1585 return parseIdentifierExpr(); 1586 } 1587 1588 // Parse the operation name. The name may be elided, in which case the 1589 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1590 // `Operation*`). Operation names within a rewrite context must be named. 1591 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1592 FailureOr<ast::OpNameDecl *> opNameDecl = 1593 parseWrappedOperationName(allowEmptyName); 1594 if (failed(opNameDecl)) 1595 return failure(); 1596 1597 // Functor used to create an implicit range variable, used for implicit "all" 1598 // operand or results variables. 1599 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1600 FailureOr<ast::VariableDecl *> rangeVar = 1601 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1602 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1603 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1604 }; 1605 1606 // Check for the optional list of operands. 1607 SmallVector<ast::Expr *> operands; 1608 if (!consumeIf(Token::l_paren)) { 1609 // If the operand list isn't specified and we are in a match context, define 1610 // an inplace unconstrained operand range corresponding to all of the 1611 // operands of the operation. This avoids treating zero operands the same 1612 // way as "unconstrained operands". 1613 if (parserContext != ParserContext::Rewrite) { 1614 operands.push_back(createImplicitRangeVar( 1615 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1616 } 1617 } else if (!consumeIf(Token::r_paren)) { 1618 // If the operand list was specified and non-empty, parse the operands. 1619 do { 1620 FailureOr<ast::Expr *> operand = parseExpr(); 1621 if (failed(operand)) 1622 return failure(); 1623 operands.push_back(*operand); 1624 } while (consumeIf(Token::comma)); 1625 1626 if (failed(parseToken(Token::r_paren, 1627 "expected `)` after operation operand list"))) 1628 return failure(); 1629 } 1630 1631 // Check for the optional list of attributes. 1632 SmallVector<ast::NamedAttributeDecl *> attributes; 1633 if (consumeIf(Token::l_brace)) { 1634 do { 1635 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1636 if (failed(decl)) 1637 return failure(); 1638 attributes.emplace_back(*decl); 1639 } while (consumeIf(Token::comma)); 1640 1641 if (failed(parseToken(Token::r_brace, 1642 "expected `}` after operation attribute list"))) 1643 return failure(); 1644 } 1645 1646 // Check for the optional list of result types. 1647 SmallVector<ast::Expr *> resultTypes; 1648 if (consumeIf(Token::arrow)) { 1649 if (failed(parseToken(Token::l_paren, 1650 "expected `(` before operation result type list"))) 1651 return failure(); 1652 1653 // Handle the case of an empty result list. 1654 if (!consumeIf(Token::r_paren)) { 1655 do { 1656 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1657 if (failed(resultTypeExpr)) 1658 return failure(); 1659 resultTypes.push_back(*resultTypeExpr); 1660 } while (consumeIf(Token::comma)); 1661 1662 if (failed(parseToken(Token::r_paren, 1663 "expected `)` after operation result type list"))) 1664 return failure(); 1665 } 1666 } else if (parserContext != ParserContext::Rewrite) { 1667 // If the result list isn't specified and we are in a match context, define 1668 // an inplace unconstrained result range corresponding to all of the results 1669 // of the operation. This avoids treating zero results the same way as 1670 // "unconstrained results". 1671 resultTypes.push_back(createImplicitRangeVar( 1672 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 1673 } 1674 1675 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1676 resultTypes); 1677 } 1678 1679 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1680 SMRange loc = curToken.getLoc(); 1681 consumeToken(Token::l_paren); 1682 1683 DenseMap<StringRef, SMRange> usedNames; 1684 SmallVector<StringRef> elementNames; 1685 SmallVector<ast::Expr *> elements; 1686 if (curToken.isNot(Token::r_paren)) { 1687 do { 1688 // Check for the optional element name assignment before the value. 1689 StringRef elementName; 1690 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1691 Token elementNameTok = curToken; 1692 consumeToken(); 1693 1694 // The element name is only present if followed by an `=`. 1695 if (consumeIf(Token::equal)) { 1696 elementName = elementNameTok.getSpelling(); 1697 1698 // Check to see if this name is already used. 1699 auto elementNameIt = 1700 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1701 if (!elementNameIt.second) { 1702 return emitErrorAndNote( 1703 elementNameTok.getLoc(), 1704 llvm::formatv("duplicate tuple element label `{0}`", 1705 elementName), 1706 elementNameIt.first->getSecond(), 1707 "see previous label use here"); 1708 } 1709 } else { 1710 // Otherwise, we treat this as part of an expression so reset the 1711 // lexer. 1712 resetToken(elementNameTok.getLoc()); 1713 } 1714 } 1715 elementNames.push_back(elementName); 1716 1717 // Parse the tuple element value. 1718 FailureOr<ast::Expr *> element = parseExpr(); 1719 if (failed(element)) 1720 return failure(); 1721 elements.push_back(*element); 1722 } while (consumeIf(Token::comma)); 1723 } 1724 loc.End = curToken.getEndLoc(); 1725 if (failed( 1726 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1727 return failure(); 1728 return createTupleExpr(loc, elements, elementNames); 1729 } 1730 1731 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1732 SMRange loc = curToken.getLoc(); 1733 consumeToken(Token::kw_type); 1734 1735 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1736 // identifier. 1737 if (!consumeIf(Token::less)) { 1738 resetToken(loc); 1739 return parseIdentifierExpr(); 1740 } 1741 1742 if (!curToken.isString()) 1743 return emitError("expected string literal containing MLIR type"); 1744 std::string attrExpr = curToken.getStringValue(); 1745 consumeToken(); 1746 1747 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1748 return failure(); 1749 return ast::TypeExpr::create(ctx, loc, attrExpr); 1750 } 1751 1752 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 1753 StringRef name = curToken.getSpelling(); 1754 SMRange nameLoc = curToken.getLoc(); 1755 consumeToken(Token::underscore); 1756 1757 // Underscore expressions require a constraint list. 1758 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 1759 return failure(); 1760 1761 // Parse the constraints for the expression. 1762 SmallVector<ast::ConstraintRef> constraints; 1763 if (failed(parseVariableDeclConstraintList(constraints))) 1764 return failure(); 1765 1766 ast::Type type; 1767 if (failed(validateVariableConstraints(constraints, type))) 1768 return failure(); 1769 return createInlineVariableExpr(type, name, nameLoc, constraints); 1770 } 1771 1772 //===----------------------------------------------------------------------===// 1773 // Stmts 1774 1775 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 1776 FailureOr<ast::Stmt *> stmt; 1777 switch (curToken.getKind()) { 1778 case Token::kw_erase: 1779 stmt = parseEraseStmt(); 1780 break; 1781 case Token::kw_let: 1782 stmt = parseLetStmt(); 1783 break; 1784 case Token::kw_replace: 1785 stmt = parseReplaceStmt(); 1786 break; 1787 case Token::kw_return: 1788 stmt = parseReturnStmt(); 1789 break; 1790 case Token::kw_rewrite: 1791 stmt = parseRewriteStmt(); 1792 break; 1793 default: 1794 stmt = parseExpr(); 1795 break; 1796 } 1797 if (failed(stmt) || 1798 (expectTerminalSemicolon && 1799 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 1800 return failure(); 1801 return stmt; 1802 } 1803 1804 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 1805 SMLoc startLoc = curToken.getStartLoc(); 1806 consumeToken(Token::l_brace); 1807 1808 // Push a new block scope and parse any nested statements. 1809 pushDeclScope(); 1810 SmallVector<ast::Stmt *> statements; 1811 while (curToken.isNot(Token::r_brace)) { 1812 FailureOr<ast::Stmt *> statement = parseStmt(); 1813 if (failed(statement)) 1814 return popDeclScope(), failure(); 1815 statements.push_back(*statement); 1816 } 1817 popDeclScope(); 1818 1819 // Consume the end brace. 1820 SMRange location(startLoc, curToken.getEndLoc()); 1821 consumeToken(Token::r_brace); 1822 1823 return ast::CompoundStmt::create(ctx, location, statements); 1824 } 1825 1826 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 1827 if (parserContext == ParserContext::Constraint) 1828 return emitError("`erase` cannot be used within a Constraint"); 1829 SMRange loc = curToken.getLoc(); 1830 consumeToken(Token::kw_erase); 1831 1832 // Parse the root operation expression. 1833 FailureOr<ast::Expr *> rootOp = parseExpr(); 1834 if (failed(rootOp)) 1835 return failure(); 1836 1837 return createEraseStmt(loc, *rootOp); 1838 } 1839 1840 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 1841 SMRange loc = curToken.getLoc(); 1842 consumeToken(Token::kw_let); 1843 1844 // Parse the name of the new variable. 1845 SMRange varLoc = curToken.getLoc(); 1846 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 1847 // `_` is a reserved variable name. 1848 if (curToken.is(Token::underscore)) { 1849 return emitError(varLoc, 1850 "`_` may only be used to define \"inline\" variables"); 1851 } 1852 return emitError(varLoc, 1853 "expected identifier after `let` to name a new variable"); 1854 } 1855 StringRef varName = curToken.getSpelling(); 1856 consumeToken(); 1857 1858 // Parse the optional set of constraints. 1859 SmallVector<ast::ConstraintRef> constraints; 1860 if (consumeIf(Token::colon) && 1861 failed(parseVariableDeclConstraintList(constraints))) 1862 return failure(); 1863 1864 // Parse the optional initializer expression. 1865 ast::Expr *initializer = nullptr; 1866 if (consumeIf(Token::equal)) { 1867 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 1868 if (failed(initOrFailure)) 1869 return failure(); 1870 initializer = *initOrFailure; 1871 1872 // Check that the constraints are compatible with having an initializer, 1873 // e.g. type constraints cannot be used with initializers. 1874 for (ast::ConstraintRef constraint : constraints) { 1875 LogicalResult result = 1876 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 1877 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 1878 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 1879 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 1880 return this->emitError( 1881 constraint.referenceLoc, 1882 "type constraints are not permitted on variables with " 1883 "initializers"); 1884 } 1885 return success(); 1886 }) 1887 .Default(success()); 1888 if (failed(result)) 1889 return failure(); 1890 } 1891 } 1892 1893 FailureOr<ast::VariableDecl *> varDecl = 1894 createVariableDecl(varName, varLoc, initializer, constraints); 1895 if (failed(varDecl)) 1896 return failure(); 1897 return ast::LetStmt::create(ctx, loc, *varDecl); 1898 } 1899 1900 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 1901 if (parserContext == ParserContext::Constraint) 1902 return emitError("`replace` cannot be used within a Constraint"); 1903 SMRange loc = curToken.getLoc(); 1904 consumeToken(Token::kw_replace); 1905 1906 // Parse the root operation expression. 1907 FailureOr<ast::Expr *> rootOp = parseExpr(); 1908 if (failed(rootOp)) 1909 return failure(); 1910 1911 if (failed( 1912 parseToken(Token::kw_with, "expected `with` after root operation"))) 1913 return failure(); 1914 1915 // The replacement portion of this statement is within a rewrite context. 1916 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1917 ParserContext::Rewrite); 1918 1919 // Parse the replacement values. 1920 SmallVector<ast::Expr *> replValues; 1921 if (consumeIf(Token::l_paren)) { 1922 if (consumeIf(Token::r_paren)) { 1923 return emitError( 1924 loc, "expected at least one replacement value, consider using " 1925 "`erase` if no replacement values are desired"); 1926 } 1927 1928 do { 1929 FailureOr<ast::Expr *> replExpr = parseExpr(); 1930 if (failed(replExpr)) 1931 return failure(); 1932 replValues.emplace_back(*replExpr); 1933 } while (consumeIf(Token::comma)); 1934 1935 if (failed(parseToken(Token::r_paren, 1936 "expected `)` after replacement values"))) 1937 return failure(); 1938 } else { 1939 FailureOr<ast::Expr *> replExpr = parseExpr(); 1940 if (failed(replExpr)) 1941 return failure(); 1942 replValues.emplace_back(*replExpr); 1943 } 1944 1945 return createReplaceStmt(loc, *rootOp, replValues); 1946 } 1947 1948 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 1949 SMRange loc = curToken.getLoc(); 1950 consumeToken(Token::kw_return); 1951 1952 // Parse the result value. 1953 FailureOr<ast::Expr *> resultExpr = parseExpr(); 1954 if (failed(resultExpr)) 1955 return failure(); 1956 1957 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 1958 } 1959 1960 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 1961 if (parserContext == ParserContext::Constraint) 1962 return emitError("`rewrite` cannot be used within a Constraint"); 1963 SMRange loc = curToken.getLoc(); 1964 consumeToken(Token::kw_rewrite); 1965 1966 // Parse the root operation. 1967 FailureOr<ast::Expr *> rootOp = parseExpr(); 1968 if (failed(rootOp)) 1969 return failure(); 1970 1971 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 1972 return failure(); 1973 1974 if (curToken.isNot(Token::l_brace)) 1975 return emitError("expected `{` to start rewrite body"); 1976 1977 // The rewrite body of this statement is within a rewrite context. 1978 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1979 ParserContext::Rewrite); 1980 1981 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 1982 if (failed(rewriteBody)) 1983 return failure(); 1984 1985 // Verify the rewrite body. 1986 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 1987 if (isa<ast::ReturnStmt>(stmt)) { 1988 return emitError(stmt->getLoc(), 1989 "`return` statements are only permitted within a " 1990 "`Constraint` or `Rewrite` body"); 1991 } 1992 } 1993 1994 return createRewriteStmt(loc, *rootOp, *rewriteBody); 1995 } 1996 1997 //===----------------------------------------------------------------------===// 1998 // Creation+Analysis 1999 //===----------------------------------------------------------------------===// 2000 2001 //===----------------------------------------------------------------------===// 2002 // Decls 2003 2004 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2005 // Unwrap reference expressions. 2006 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2007 node = init->getDecl(); 2008 return dyn_cast<ast::CallableDecl>(node); 2009 } 2010 2011 FailureOr<ast::PatternDecl *> 2012 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2013 const ParsedPatternMetadata &metadata, 2014 ast::CompoundStmt *body) { 2015 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2016 metadata.hasBoundedRecursion, body); 2017 } 2018 2019 ast::Type Parser::createUserConstraintRewriteResultType( 2020 ArrayRef<ast::VariableDecl *> results) { 2021 // Single result decls use the type of the single result. 2022 if (results.size() == 1) 2023 return results[0]->getType(); 2024 2025 // Multiple results use a tuple type, with the types and names grabbed from 2026 // the result variable decls. 2027 auto resultTypes = llvm::map_range( 2028 results, [&](const auto *result) { return result->getType(); }); 2029 auto resultNames = llvm::map_range( 2030 results, [&](const auto *result) { return result->getName().getName(); }); 2031 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2032 llvm::to_vector(resultNames)); 2033 } 2034 2035 template <typename T> 2036 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2037 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2038 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2039 ast::CompoundStmt *body) { 2040 if (!body->getChildren().empty()) { 2041 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2042 ast::Expr *resultExpr = retStmt->getResultExpr(); 2043 2044 // Process the result of the decl. If no explicit signature results 2045 // were provided, check for return type inference. Otherwise, check that 2046 // the return expression can be converted to the expected type. 2047 if (results.empty()) 2048 resultType = resultExpr->getType(); 2049 else if (failed(convertExpressionTo(resultExpr, resultType))) 2050 return failure(); 2051 else 2052 retStmt->setResultExpr(resultExpr); 2053 } 2054 } 2055 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2056 } 2057 2058 FailureOr<ast::VariableDecl *> 2059 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2060 ArrayRef<ast::ConstraintRef> constraints) { 2061 // The type of the variable, which is expected to be inferred by either a 2062 // constraint or an initializer expression. 2063 ast::Type type; 2064 if (failed(validateVariableConstraints(constraints, type))) 2065 return failure(); 2066 2067 if (initializer) { 2068 // Update the variable type based on the initializer, or try to convert the 2069 // initializer to the existing type. 2070 if (!type) 2071 type = initializer->getType(); 2072 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2073 type = mergedType; 2074 else if (failed(convertExpressionTo(initializer, type))) 2075 return failure(); 2076 2077 // Otherwise, if there is no initializer check that the type has already 2078 // been resolved from the constraint list. 2079 } else if (!type) { 2080 return emitErrorAndNote( 2081 loc, "unable to infer type for variable `" + name + "`", loc, 2082 "the type of a variable must be inferable from the constraint " 2083 "list or the initializer"); 2084 } 2085 2086 // Constraint types cannot be used when defining variables. 2087 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2088 return emitError( 2089 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2090 } 2091 2092 // Try to define a variable with the given name. 2093 FailureOr<ast::VariableDecl *> varDecl = 2094 defineVariableDecl(name, loc, type, initializer, constraints); 2095 if (failed(varDecl)) 2096 return failure(); 2097 2098 return *varDecl; 2099 } 2100 2101 FailureOr<ast::VariableDecl *> 2102 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2103 const ast::ConstraintRef &constraint) { 2104 // Constraint arguments may apply more complex constraints via the arguments. 2105 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2106 ast::Type argType; 2107 if (failed(validateVariableConstraint(constraint, argType, 2108 allowNonCoreConstraints))) 2109 return failure(); 2110 return defineVariableDecl(name, loc, argType, constraint); 2111 } 2112 2113 LogicalResult 2114 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2115 ast::Type &inferredType) { 2116 for (const ast::ConstraintRef &ref : constraints) 2117 if (failed(validateVariableConstraint(ref, inferredType))) 2118 return failure(); 2119 return success(); 2120 } 2121 2122 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2123 ast::Type &inferredType, 2124 bool allowNonCoreConstraints) { 2125 ast::Type constraintType; 2126 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2127 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2128 if (failed(validateTypeConstraintExpr(typeExpr))) 2129 return failure(); 2130 } 2131 constraintType = ast::AttributeType::get(ctx); 2132 } else if (const auto *cst = 2133 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2134 constraintType = ast::OperationType::get(ctx, cst->getName()); 2135 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2136 constraintType = typeTy; 2137 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2138 constraintType = typeRangeTy; 2139 } else if (const auto *cst = 2140 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2141 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2142 if (failed(validateTypeConstraintExpr(typeExpr))) 2143 return failure(); 2144 } 2145 constraintType = valueTy; 2146 } else if (const auto *cst = 2147 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2148 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2149 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2150 return failure(); 2151 } 2152 constraintType = valueRangeTy; 2153 } else if (const auto *cst = 2154 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2155 if (!allowNonCoreConstraints) { 2156 return emitError(ref.referenceLoc, 2157 "`Rewrite` arguments and results are only permitted to " 2158 "use core constraints, such as `Attr`, `Op`, `Type`, " 2159 "`TypeRange`, `Value`, `ValueRange`"); 2160 } 2161 2162 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2163 if (inputs.size() != 1) { 2164 return emitErrorAndNote(ref.referenceLoc, 2165 "`Constraint`s applied via a variable constraint " 2166 "list must take a single input, but got " + 2167 Twine(inputs.size()), 2168 cst->getLoc(), 2169 "see definition of constraint here"); 2170 } 2171 constraintType = inputs.front()->getType(); 2172 } else { 2173 llvm_unreachable("unknown constraint type"); 2174 } 2175 2176 // Check that the constraint type is compatible with the current inferred 2177 // type. 2178 if (!inferredType) { 2179 inferredType = constraintType; 2180 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2181 inferredType = mergedTy; 2182 } else { 2183 return emitError(ref.referenceLoc, 2184 llvm::formatv("constraint type `{0}` is incompatible " 2185 "with the previously inferred type `{1}`", 2186 constraintType, inferredType)); 2187 } 2188 return success(); 2189 } 2190 2191 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2192 ast::Type typeExprType = typeExpr->getType(); 2193 if (typeExprType != typeTy) { 2194 return emitError(typeExpr->getLoc(), 2195 "expected expression of `Type` in type constraint"); 2196 } 2197 return success(); 2198 } 2199 2200 LogicalResult 2201 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2202 ast::Type typeExprType = typeExpr->getType(); 2203 if (typeExprType != typeRangeTy) { 2204 return emitError(typeExpr->getLoc(), 2205 "expected expression of `TypeRange` in type constraint"); 2206 } 2207 return success(); 2208 } 2209 2210 //===----------------------------------------------------------------------===// 2211 // Exprs 2212 2213 FailureOr<ast::CallExpr *> 2214 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2215 MutableArrayRef<ast::Expr *> arguments) { 2216 ast::Type parentType = parentExpr->getType(); 2217 2218 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2219 if (!callableDecl) { 2220 return emitError(loc, 2221 llvm::formatv("expected a reference to a callable " 2222 "`Constraint` or `Rewrite`, but got: `{0}`", 2223 parentType)); 2224 } 2225 if (parserContext == ParserContext::Rewrite) { 2226 if (isa<ast::UserConstraintDecl>(callableDecl)) 2227 return emitError( 2228 loc, "unable to invoke `Constraint` within a rewrite section"); 2229 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2230 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2231 } 2232 2233 // Verify the arguments of the call. 2234 /// Handle size mismatch. 2235 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2236 if (callArgs.size() != arguments.size()) { 2237 return emitErrorAndNote( 2238 loc, 2239 llvm::formatv("invalid number of arguments for {0} call; expected " 2240 "{1}, but got {2}", 2241 callableDecl->getCallableType(), callArgs.size(), 2242 arguments.size()), 2243 callableDecl->getLoc(), 2244 llvm::formatv("see the definition of {0} here", 2245 callableDecl->getName()->getName())); 2246 } 2247 2248 /// Handle argument type mismatch. 2249 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2250 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2251 callableDecl->getName()->getName()), 2252 callableDecl->getLoc()); 2253 }; 2254 for (auto it : llvm::zip(callArgs, arguments)) { 2255 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2256 attachDiagFn))) 2257 return failure(); 2258 } 2259 2260 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2261 callableDecl->getResultType()); 2262 } 2263 2264 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2265 ast::Decl *decl) { 2266 // Check the type of decl being referenced. 2267 ast::Type declType; 2268 if (isa<ast::ConstraintDecl>(decl)) 2269 declType = ast::ConstraintType::get(ctx); 2270 else if (isa<ast::UserRewriteDecl>(decl)) 2271 declType = ast::RewriteType::get(ctx); 2272 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2273 declType = varDecl->getType(); 2274 else 2275 return emitError(loc, "invalid reference to `" + 2276 decl->getName()->getName() + "`"); 2277 2278 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2279 } 2280 2281 FailureOr<ast::DeclRefExpr *> 2282 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2283 ArrayRef<ast::ConstraintRef> constraints) { 2284 FailureOr<ast::VariableDecl *> decl = 2285 defineVariableDecl(name, loc, type, constraints); 2286 if (failed(decl)) 2287 return failure(); 2288 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2289 } 2290 2291 FailureOr<ast::MemberAccessExpr *> 2292 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2293 SMRange loc) { 2294 // Validate the member name for the given parent expression. 2295 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2296 if (failed(memberType)) 2297 return failure(); 2298 2299 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2300 } 2301 2302 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2303 StringRef name, SMRange loc) { 2304 ast::Type parentType = parentExpr->getType(); 2305 if (parentType.isa<ast::OperationType>()) { 2306 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2307 return valueRangeTy; 2308 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2309 // Handle indexed results. 2310 unsigned index = 0; 2311 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2312 index < tupleType.size()) { 2313 return tupleType.getElementTypes()[index]; 2314 } 2315 2316 // Handle named results. 2317 auto elementNames = tupleType.getElementNames(); 2318 const auto *it = llvm::find(elementNames, name); 2319 if (it != elementNames.end()) 2320 return tupleType.getElementTypes()[it - elementNames.begin()]; 2321 } 2322 return emitError( 2323 loc, 2324 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2325 name, parentType)); 2326 } 2327 2328 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2329 SMRange loc, const ast::OpNameDecl *name, 2330 MutableArrayRef<ast::Expr *> operands, 2331 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2332 MutableArrayRef<ast::Expr *> results) { 2333 Optional<StringRef> opNameRef = name->getName(); 2334 2335 // Verify the inputs operands. 2336 if (failed(validateOperationOperands(loc, opNameRef, operands))) 2337 return failure(); 2338 2339 // Verify the attribute list. 2340 for (ast::NamedAttributeDecl *attr : attributes) { 2341 // Check for an attribute type, or a type awaiting resolution. 2342 ast::Type attrType = attr->getValue()->getType(); 2343 if (!attrType.isa<ast::AttributeType>()) { 2344 return emitError( 2345 attr->getValue()->getLoc(), 2346 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2347 } 2348 } 2349 2350 // Verify the result types. 2351 if (failed(validateOperationResults(loc, opNameRef, results))) 2352 return failure(); 2353 2354 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2355 attributes); 2356 } 2357 2358 LogicalResult 2359 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2360 MutableArrayRef<ast::Expr *> operands) { 2361 return validateOperationOperandsOrResults(loc, name, operands, valueTy, 2362 valueRangeTy); 2363 } 2364 2365 LogicalResult 2366 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2367 MutableArrayRef<ast::Expr *> results) { 2368 return validateOperationOperandsOrResults(loc, name, results, typeTy, 2369 typeRangeTy); 2370 } 2371 2372 LogicalResult Parser::validateOperationOperandsOrResults( 2373 SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2374 ast::Type singleTy, ast::Type rangeTy) { 2375 // All operation types accept a single range parameter. 2376 if (values.size() == 1) { 2377 if (failed(convertExpressionTo(values[0], rangeTy))) 2378 return failure(); 2379 return success(); 2380 } 2381 2382 // Otherwise, accept the value groups as they have been defined and just 2383 // ensure they are one of the expected types. 2384 for (ast::Expr *&valueExpr : values) { 2385 ast::Type valueExprType = valueExpr->getType(); 2386 2387 // Check if this is one of the expected types. 2388 if (valueExprType == rangeTy || valueExprType == singleTy) 2389 continue; 2390 2391 // If the operand is an Operation, allow converting to a Value or 2392 // ValueRange. This situations arises quite often with nested operation 2393 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2394 if (singleTy == valueTy) { 2395 if (valueExprType.isa<ast::OperationType>()) { 2396 valueExpr = convertOpToValue(valueExpr); 2397 continue; 2398 } 2399 } 2400 2401 return emitError( 2402 valueExpr->getLoc(), 2403 llvm::formatv( 2404 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2405 singleTy, rangeTy, valueExprType)); 2406 } 2407 return success(); 2408 } 2409 2410 FailureOr<ast::TupleExpr *> 2411 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2412 ArrayRef<StringRef> elementNames) { 2413 for (const ast::Expr *element : elements) { 2414 ast::Type eleTy = element->getType(); 2415 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2416 return emitError( 2417 element->getLoc(), 2418 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2419 } 2420 } 2421 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2422 } 2423 2424 //===----------------------------------------------------------------------===// 2425 // Stmts 2426 2427 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2428 ast::Expr *rootOp) { 2429 // Check that root is an Operation. 2430 ast::Type rootType = rootOp->getType(); 2431 if (!rootType.isa<ast::OperationType>()) 2432 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2433 2434 return ast::EraseStmt::create(ctx, loc, rootOp); 2435 } 2436 2437 FailureOr<ast::ReplaceStmt *> 2438 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2439 MutableArrayRef<ast::Expr *> replValues) { 2440 // Check that root is an Operation. 2441 ast::Type rootType = rootOp->getType(); 2442 if (!rootType.isa<ast::OperationType>()) { 2443 return emitError( 2444 rootOp->getLoc(), 2445 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2446 } 2447 2448 // If there are multiple replacement values, we implicitly convert any Op 2449 // expressions to the value form. 2450 bool shouldConvertOpToValues = replValues.size() > 1; 2451 for (ast::Expr *&replExpr : replValues) { 2452 ast::Type replType = replExpr->getType(); 2453 2454 // Check that replExpr is an Operation, Value, or ValueRange. 2455 if (replType.isa<ast::OperationType>()) { 2456 if (shouldConvertOpToValues) 2457 replExpr = convertOpToValue(replExpr); 2458 continue; 2459 } 2460 2461 if (replType != valueTy && replType != valueRangeTy) { 2462 return emitError(replExpr->getLoc(), 2463 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2464 "expression, but got `{0}`", 2465 replType)); 2466 } 2467 } 2468 2469 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2470 } 2471 2472 FailureOr<ast::RewriteStmt *> 2473 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2474 ast::CompoundStmt *rewriteBody) { 2475 // Check that root is an Operation. 2476 ast::Type rootType = rootOp->getType(); 2477 if (!rootType.isa<ast::OperationType>()) { 2478 return emitError( 2479 rootOp->getLoc(), 2480 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2481 } 2482 2483 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2484 } 2485 2486 //===----------------------------------------------------------------------===// 2487 // Parser 2488 //===----------------------------------------------------------------------===// 2489 2490 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 2491 llvm::SourceMgr &sourceMgr) { 2492 Parser parser(ctx, sourceMgr); 2493 return parser.parseModule(); 2494 } 2495