1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/ManagedStatic.h" 28 #include "llvm/Support/SaveAndRestore.h" 29 #include "llvm/Support/ScopedPrinter.h" 30 #include "llvm/TableGen/Error.h" 31 #include "llvm/TableGen/Parser.h" 32 #include <string> 33 34 using namespace mlir; 35 using namespace mlir::pdll; 36 37 //===----------------------------------------------------------------------===// 38 // Parser 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 class Parser { 43 public: 44 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 45 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 46 curToken(lexer.lexToken()), curDeclScope(nullptr), 47 valueTy(ast::ValueType::get(ctx)), 48 valueRangeTy(ast::ValueRangeType::get(ctx)), 49 typeTy(ast::TypeType::get(ctx)), 50 typeRangeTy(ast::TypeRangeType::get(ctx)), 51 attrTy(ast::AttributeType::get(ctx)) {} 52 53 /// Try to parse a new module. Returns nullptr in the case of failure. 54 FailureOr<ast::Module *> parseModule(); 55 56 private: 57 /// The current context of the parser. It allows for the parser to know a bit 58 /// about the construct it is nested within during parsing. This is used 59 /// specifically to provide additional verification during parsing, e.g. to 60 /// prevent using rewrites within a match context, matcher constraints within 61 /// a rewrite section, etc. 62 enum class ParserContext { 63 /// The parser is in the global context. 64 Global, 65 /// The parser is currently within a Constraint, which disallows all types 66 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 67 Constraint, 68 /// The parser is currently within the matcher portion of a Pattern, which 69 /// is allows a terminal operation rewrite statement but no other rewrite 70 /// transformations. 71 PatternMatch, 72 /// The parser is currently within a Rewrite, which disallows calls to 73 /// constraints, requires operation expressions to have names, etc. 74 Rewrite, 75 }; 76 77 //===--------------------------------------------------------------------===// 78 // Parsing 79 //===--------------------------------------------------------------------===// 80 81 /// Push a new decl scope onto the lexer. 82 ast::DeclScope *pushDeclScope() { 83 ast::DeclScope *newScope = 84 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 85 return (curDeclScope = newScope); 86 } 87 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 88 89 /// Pop the last decl scope from the lexer. 90 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 91 92 /// Parse the body of an AST module. 93 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 94 95 /// Try to convert the given expression to `type`. Returns failure and emits 96 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 97 /// invoked to attach notes to the emitted error diagnostic. On success, 98 /// `expr` is updated to the expression used to convert to `type`. 99 LogicalResult convertExpressionTo( 100 ast::Expr *&expr, ast::Type type, 101 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 102 103 /// Given an operation expression, convert it to a Value or ValueRange 104 /// typed expression. 105 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 106 107 /// Lookup ODS information for the given operation, returns nullptr if no 108 /// information is found. 109 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 110 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 111 } 112 113 //===--------------------------------------------------------------------===// 114 // Directives 115 116 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 117 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 118 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 119 SmallVectorImpl<ast::Decl *> &decls); 120 121 /// Process the records of a parsed tablegen include file. 122 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 123 SmallVectorImpl<ast::Decl *> &decls); 124 125 /// Create a user defined native constraint for a constraint imported from 126 /// ODS. 127 template <typename ConstraintT> 128 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 129 StringRef codeBlock, SMRange loc, 130 ast::Type type); 131 template <typename ConstraintT> 132 ast::Decl * 133 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 134 SMRange loc, ast::Type type); 135 136 //===--------------------------------------------------------------------===// 137 // Decls 138 139 /// This structure contains the set of pattern metadata that may be parsed. 140 struct ParsedPatternMetadata { 141 Optional<uint16_t> benefit; 142 bool hasBoundedRecursion = false; 143 }; 144 145 FailureOr<ast::Decl *> parseTopLevelDecl(); 146 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 147 148 /// Parse an argument variable as part of the signature of a 149 /// UserConstraintDecl or UserRewriteDecl. 150 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 151 152 /// Parse a result variable as part of the signature of a UserConstraintDecl 153 /// or UserRewriteDecl. 154 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 155 156 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 157 /// defined in a non-global context. 158 FailureOr<ast::UserConstraintDecl *> 159 parseUserConstraintDecl(bool isInline = false); 160 161 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 162 /// non-global context, such as within a Pattern/Constraint/etc. 163 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 164 165 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 166 /// PDLL constructs. 167 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 168 const ast::Name &name, bool isInline, 169 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 170 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 171 172 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 173 /// defined in a non-global context. 174 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 175 176 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 177 /// non-global context, such as within a Pattern/Rewrite/etc. 178 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 179 180 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 181 /// PDLL constructs. 182 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 183 const ast::Name &name, bool isInline, 184 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 185 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 186 187 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 188 /// effectively the same syntax, and only differ on slight semantics (given 189 /// the different parsing contexts). 190 template <typename T, typename ParseUserPDLLDeclFnT> 191 FailureOr<T *> parseUserConstraintOrRewriteDecl( 192 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 193 StringRef anonymousNamePrefix, bool isInline); 194 195 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 196 /// These decls have effectively the same syntax. 197 template <typename T> 198 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 199 const ast::Name &name, bool isInline, 200 ArrayRef<ast::VariableDecl *> arguments, 201 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 202 203 /// Parse the functional signature (i.e. the arguments and results) of a 204 /// UserConstraintDecl or UserRewriteDecl. 205 LogicalResult parseUserConstraintOrRewriteSignature( 206 SmallVectorImpl<ast::VariableDecl *> &arguments, 207 SmallVectorImpl<ast::VariableDecl *> &results, 208 ast::DeclScope *&argumentScope, ast::Type &resultType); 209 210 /// Validate the return (which if present is specified by bodyIt) of a 211 /// UserConstraintDecl or UserRewriteDecl. 212 LogicalResult validateUserConstraintOrRewriteReturn( 213 StringRef declType, ast::CompoundStmt *body, 214 ArrayRef<ast::Stmt *>::iterator bodyIt, 215 ArrayRef<ast::Stmt *>::iterator bodyE, 216 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 217 218 FailureOr<ast::CompoundStmt *> 219 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 220 bool expectTerminalSemicolon = true); 221 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 222 FailureOr<ast::Decl *> parsePatternDecl(); 223 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 224 225 /// Check to see if a decl has already been defined with the given name, if 226 /// one has emit and error and return failure. Returns success otherwise. 227 LogicalResult checkDefineNamedDecl(const ast::Name &name); 228 229 /// Try to define a variable decl with the given components, returns the 230 /// variable on success. 231 FailureOr<ast::VariableDecl *> 232 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 233 ast::Expr *initExpr, 234 ArrayRef<ast::ConstraintRef> constraints); 235 FailureOr<ast::VariableDecl *> 236 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 237 ArrayRef<ast::ConstraintRef> constraints); 238 239 /// Parse the constraint reference list for a variable decl. 240 LogicalResult parseVariableDeclConstraintList( 241 SmallVectorImpl<ast::ConstraintRef> &constraints); 242 243 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 244 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 245 246 /// Try to parse a single reference to a constraint. `typeConstraint` is the 247 /// location of a previously parsed type constraint for the entity that will 248 /// be constrained by the parsed constraint. `existingConstraints` are any 249 /// existing constraints that have already been parsed for the same entity 250 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 251 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 252 FailureOr<ast::ConstraintRef> 253 parseConstraint(Optional<SMRange> &typeConstraint, 254 ArrayRef<ast::ConstraintRef> existingConstraints, 255 bool allowInlineTypeConstraints); 256 257 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 258 /// argument or result variable. The constraints for these variables do not 259 /// allow inline type constraints, and only permit a single constraint. 260 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 261 262 //===--------------------------------------------------------------------===// 263 // Exprs 264 265 FailureOr<ast::Expr *> parseExpr(); 266 267 /// Identifier expressions. 268 FailureOr<ast::Expr *> parseAttributeExpr(); 269 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 270 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 271 FailureOr<ast::Expr *> parseIdentifierExpr(); 272 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 273 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 274 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 275 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 276 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 277 FailureOr<ast::Expr *> parseOperationExpr(); 278 FailureOr<ast::Expr *> parseTupleExpr(); 279 FailureOr<ast::Expr *> parseTypeExpr(); 280 FailureOr<ast::Expr *> parseUnderscoreExpr(); 281 282 //===--------------------------------------------------------------------===// 283 // Stmts 284 285 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 286 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 287 FailureOr<ast::EraseStmt *> parseEraseStmt(); 288 FailureOr<ast::LetStmt *> parseLetStmt(); 289 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 290 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 291 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 292 293 //===--------------------------------------------------------------------===// 294 // Creation+Analysis 295 //===--------------------------------------------------------------------===// 296 297 //===--------------------------------------------------------------------===// 298 // Decls 299 300 /// Try to extract a callable from the given AST node. Returns nullptr on 301 /// failure. 302 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 303 304 /// Try to create a pattern decl with the given components, returning the 305 /// Pattern on success. 306 FailureOr<ast::PatternDecl *> 307 createPatternDecl(SMRange loc, const ast::Name *name, 308 const ParsedPatternMetadata &metadata, 309 ast::CompoundStmt *body); 310 311 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 312 /// of results, defined as part of the signature. 313 ast::Type 314 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 315 316 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 317 template <typename T> 318 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 319 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 320 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 321 ast::CompoundStmt *body); 322 323 /// Try to create a variable decl with the given components, returning the 324 /// Variable on success. 325 FailureOr<ast::VariableDecl *> 326 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 327 ArrayRef<ast::ConstraintRef> constraints); 328 329 /// Create a variable for an argument or result defined as part of the 330 /// signature of a UserConstraintDecl/UserRewriteDecl. 331 FailureOr<ast::VariableDecl *> 332 createArgOrResultVariableDecl(StringRef name, SMRange loc, 333 const ast::ConstraintRef &constraint); 334 335 /// Validate the constraints used to constraint a variable decl. 336 /// `inferredType` is the type of the variable inferred by the constraints 337 /// within the list, and is updated to the most refined type as determined by 338 /// the constraints. Returns success if the constraint list is valid, failure 339 /// otherwise. 340 LogicalResult 341 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 342 ast::Type &inferredType); 343 /// Validate a single reference to a constraint. `inferredType` contains the 344 /// currently inferred variabled type and is refined within the type defined 345 /// by the constraint. Returns success if the constraint is valid, failure 346 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 347 /// defined constraints) may be used with the variable. 348 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 352 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 353 354 //===--------------------------------------------------------------------===// 355 // Exprs 356 357 FailureOr<ast::CallExpr *> 358 createCallExpr(SMRange loc, ast::Expr *parentExpr, 359 MutableArrayRef<ast::Expr *> arguments); 360 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 361 FailureOr<ast::DeclRefExpr *> 362 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 363 ArrayRef<ast::ConstraintRef> constraints); 364 FailureOr<ast::MemberAccessExpr *> 365 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 366 367 /// Validate the member access `name` into the given parent expression. On 368 /// success, this also returns the type of the member accessed. 369 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 370 StringRef name, SMRange loc); 371 FailureOr<ast::OperationExpr *> 372 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 373 MutableArrayRef<ast::Expr *> operands, 374 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 375 MutableArrayRef<ast::Expr *> results); 376 LogicalResult 377 validateOperationOperands(SMRange loc, Optional<StringRef> name, 378 const ods::Operation *odsOp, 379 MutableArrayRef<ast::Expr *> operands); 380 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 381 const ods::Operation *odsOp, 382 MutableArrayRef<ast::Expr *> results); 383 LogicalResult validateOperationOperandsOrResults( 384 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 385 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 386 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 387 ast::Type rangeTy); 388 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 389 ArrayRef<ast::Expr *> elements, 390 ArrayRef<StringRef> elementNames); 391 392 //===--------------------------------------------------------------------===// 393 // Stmts 394 395 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 396 FailureOr<ast::ReplaceStmt *> 397 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 398 MutableArrayRef<ast::Expr *> replValues); 399 FailureOr<ast::RewriteStmt *> 400 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 401 ast::CompoundStmt *rewriteBody); 402 403 //===--------------------------------------------------------------------===// 404 // Lexer Utilities 405 //===--------------------------------------------------------------------===// 406 407 /// If the current token has the specified kind, consume it and return true. 408 /// If not, return false. 409 bool consumeIf(Token::Kind kind) { 410 if (curToken.isNot(kind)) 411 return false; 412 consumeToken(kind); 413 return true; 414 } 415 416 /// Advance the current lexer onto the next token. 417 void consumeToken() { 418 assert(curToken.isNot(Token::eof, Token::error) && 419 "shouldn't advance past EOF or errors"); 420 curToken = lexer.lexToken(); 421 } 422 423 /// Advance the current lexer onto the next token, asserting what the expected 424 /// current token is. This is preferred to the above method because it leads 425 /// to more self-documenting code with better checking. 426 void consumeToken(Token::Kind kind) { 427 assert(curToken.is(kind) && "consumed an unexpected token"); 428 consumeToken(); 429 } 430 431 /// Reset the lexer to the location at the given position. 432 void resetToken(SMRange tokLoc) { 433 lexer.resetPointer(tokLoc.Start.getPointer()); 434 curToken = lexer.lexToken(); 435 } 436 437 /// Consume the specified token if present and return success. On failure, 438 /// output a diagnostic and return failure. 439 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 440 if (curToken.getKind() != kind) 441 return emitError(curToken.getLoc(), msg); 442 consumeToken(); 443 return success(); 444 } 445 LogicalResult emitError(SMRange loc, const Twine &msg) { 446 lexer.emitError(loc, msg); 447 return failure(); 448 } 449 LogicalResult emitError(const Twine &msg) { 450 return emitError(curToken.getLoc(), msg); 451 } 452 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 453 const Twine ¬e) { 454 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 455 return failure(); 456 } 457 458 //===--------------------------------------------------------------------===// 459 // Fields 460 //===--------------------------------------------------------------------===// 461 462 /// The owning AST context. 463 ast::Context &ctx; 464 465 /// The lexer of this parser. 466 Lexer lexer; 467 468 /// The current token within the lexer. 469 Token curToken; 470 471 /// The most recently defined decl scope. 472 ast::DeclScope *curDeclScope; 473 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 474 475 /// The current context of the parser. 476 ParserContext parserContext = ParserContext::Global; 477 478 /// Cached types to simplify verification and expression creation. 479 ast::Type valueTy, valueRangeTy; 480 ast::Type typeTy, typeRangeTy; 481 ast::Type attrTy; 482 483 /// A counter used when naming anonymous constraints and rewrites. 484 unsigned anonymousDeclNameCounter = 0; 485 }; 486 } // namespace 487 488 FailureOr<ast::Module *> Parser::parseModule() { 489 SMLoc moduleLoc = curToken.getStartLoc(); 490 pushDeclScope(); 491 492 // Parse the top-level decls of the module. 493 SmallVector<ast::Decl *> decls; 494 if (failed(parseModuleBody(decls))) 495 return popDeclScope(), failure(); 496 497 popDeclScope(); 498 return ast::Module::create(ctx, moduleLoc, decls); 499 } 500 501 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 502 while (curToken.isNot(Token::eof)) { 503 if (curToken.is(Token::directive)) { 504 if (failed(parseDirective(decls))) 505 return failure(); 506 continue; 507 } 508 509 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 510 if (failed(decl)) 511 return failure(); 512 decls.push_back(*decl); 513 } 514 return success(); 515 } 516 517 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 518 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 519 valueRangeTy); 520 } 521 522 LogicalResult Parser::convertExpressionTo( 523 ast::Expr *&expr, ast::Type type, 524 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 525 ast::Type exprType = expr->getType(); 526 if (exprType == type) 527 return success(); 528 529 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 530 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 531 expr->getLoc(), llvm::formatv("unable to convert expression of type " 532 "`{0}` to the expected type of " 533 "`{1}`", 534 exprType, type)); 535 if (noteAttachFn) 536 noteAttachFn(*diag); 537 return diag; 538 }; 539 540 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 541 // Two operation types are compatible if they have the same name, or if the 542 // expected type is more general. 543 if (auto opType = type.dyn_cast<ast::OperationType>()) { 544 if (opType.getName()) 545 return emitConvertError(); 546 return success(); 547 } 548 549 // An operation can always convert to a ValueRange. 550 if (type == valueRangeTy) { 551 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 552 valueRangeTy); 553 return success(); 554 } 555 556 // Allow conversion to a single value by constraining the result range. 557 if (type == valueTy) { 558 // If the operation is registered, we can verify if it can ever have a 559 // single result. 560 Optional<StringRef> opName = exprOpType.getName(); 561 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 562 if (odsOp->getResults().empty()) { 563 return emitConvertError()->attachNote( 564 llvm::formatv("see the definition of `{0}`, which was defined " 565 "with zero results", 566 odsOp->getName()), 567 odsOp->getLoc()); 568 } 569 570 unsigned numSingleResults = llvm::count_if( 571 odsOp->getResults(), [](const ods::OperandOrResult &result) { 572 return result.getVariableLengthKind() == 573 ods::VariableLengthKind::Single; 574 }); 575 if (numSingleResults > 1) { 576 return emitConvertError()->attachNote( 577 llvm::formatv("see the definition of `{0}`, which was defined " 578 "with at least {1} results", 579 odsOp->getName(), numSingleResults), 580 odsOp->getLoc()); 581 } 582 } 583 584 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 585 valueTy); 586 return success(); 587 } 588 return emitConvertError(); 589 } 590 591 // FIXME: Decide how to allow/support converting a single result to multiple, 592 // and multiple to a single result. For now, we just allow Single->Range, 593 // but this isn't something really supported in the PDL dialect. We should 594 // figure out some way to support both. 595 if ((exprType == valueTy || exprType == valueRangeTy) && 596 (type == valueTy || type == valueRangeTy)) 597 return success(); 598 if ((exprType == typeTy || exprType == typeRangeTy) && 599 (type == typeTy || type == typeRangeTy)) 600 return success(); 601 602 // Handle tuple types. 603 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 604 auto tupleType = type.dyn_cast<ast::TupleType>(); 605 if (!tupleType || tupleType.size() != exprTupleType.size()) 606 return emitConvertError(); 607 608 // Build a new tuple expression using each of the elements of the current 609 // tuple. 610 SmallVector<ast::Expr *> newExprs; 611 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 612 newExprs.push_back(ast::MemberAccessExpr::create( 613 ctx, expr->getLoc(), expr, llvm::to_string(i), 614 exprTupleType.getElementTypes()[i])); 615 616 auto diagFn = [&](ast::Diagnostic &diag) { 617 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 618 i, exprTupleType)); 619 if (noteAttachFn) 620 noteAttachFn(diag); 621 }; 622 if (failed(convertExpressionTo(newExprs.back(), 623 tupleType.getElementTypes()[i], diagFn))) 624 return failure(); 625 } 626 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 627 tupleType.getElementNames()); 628 return success(); 629 } 630 631 return emitConvertError(); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // Directives 636 637 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 638 StringRef directive = curToken.getSpelling(); 639 if (directive == "#include") 640 return parseInclude(decls); 641 642 return emitError("unknown directive `" + directive + "`"); 643 } 644 645 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 646 SMRange loc = curToken.getLoc(); 647 consumeToken(Token::directive); 648 649 // Parse the file being included. 650 if (!curToken.isString()) 651 return emitError(loc, 652 "expected string file name after `include` directive"); 653 SMRange fileLoc = curToken.getLoc(); 654 std::string filenameStr = curToken.getStringValue(); 655 StringRef filename = filenameStr; 656 consumeToken(); 657 658 // Check the type of include. If ending with `.pdll`, this is another pdl file 659 // to be parsed along with the current module. 660 if (filename.endswith(".pdll")) { 661 if (failed(lexer.pushInclude(filename))) 662 return emitError(fileLoc, 663 "unable to open include file `" + filename + "`"); 664 665 // If we added the include successfully, parse it into the current module. 666 // Make sure to save the current token so that we can restore it when we 667 // finish parsing the nested file. 668 Token oldToken = curToken; 669 curToken = lexer.lexToken(); 670 LogicalResult result = parseModuleBody(decls); 671 curToken = oldToken; 672 return result; 673 } 674 675 // Otherwise, this must be a `.td` include. 676 if (filename.endswith(".td")) 677 return parseTdInclude(filename, fileLoc, decls); 678 679 return emitError(fileLoc, 680 "expected include filename to end with `.pdll` or `.td`"); 681 } 682 683 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 684 SmallVectorImpl<ast::Decl *> &decls) { 685 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 686 687 // This class provides a context argument for the llvm::SourceMgr diagnostic 688 // handler. 689 struct DiagHandlerContext { 690 Parser &parser; 691 StringRef filename; 692 llvm::SMRange loc; 693 } handlerContext{*this, filename, fileLoc}; 694 695 // Set the diagnostic handler for the tablegen source manager. 696 llvm::SrcMgr.setDiagHandler( 697 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 698 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 699 (void)ctx->parser.emitError( 700 ctx->loc, 701 llvm::formatv("error while processing include file `{0}`: {1}", 702 ctx->filename, diag.getMessage())); 703 }, 704 &handlerContext); 705 706 // Use the source manager to open the file, but don't yet add it. 707 std::string includedFile; 708 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 709 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 710 if (!includeBuffer) 711 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 712 713 auto processFn = [&](llvm::RecordKeeper &records) { 714 processTdIncludeRecords(records, decls); 715 716 // After we are done processing, move all of the tablegen source buffers to 717 // the main parser source mgr. This allows for directly using source 718 // locations from the .td files without needing to remap them. 719 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr); 720 return false; 721 }; 722 if (llvm::TableGenParseFile(std::move(*includeBuffer), 723 parserSrcMgr.getIncludeDirs(), processFn)) 724 return failure(); 725 726 return success(); 727 } 728 729 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 730 SmallVectorImpl<ast::Decl *> &decls) { 731 // Return the length kind of the given value. 732 auto getLengthKind = [](const auto &value) { 733 if (value.isOptional()) 734 return ods::VariableLengthKind::Optional; 735 return value.isVariadic() ? ods::VariableLengthKind::Variadic 736 : ods::VariableLengthKind::Single; 737 }; 738 739 // Insert a type constraint into the ODS context. 740 ods::Context &odsContext = ctx.getODSContext(); 741 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 742 -> const ods::TypeConstraint & { 743 return odsContext.insertTypeConstraint(cst.constraint.getDefName(), 744 cst.constraint.getSummary(), 745 cst.constraint.getCPPClassName()); 746 }; 747 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 748 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 749 }; 750 751 // Process the parsed tablegen records to build ODS information. 752 /// Operations. 753 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 754 tblgen::Operator op(def); 755 756 bool inserted = false; 757 ods::Operation *odsOp = nullptr; 758 std::tie(odsOp, inserted) = 759 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 760 op.getDescription(), op.getLoc().front()); 761 762 // Ignore operations that have already been added. 763 if (!inserted) 764 continue; 765 766 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 767 odsOp->appendAttribute( 768 attr.name, attr.attr.isOptional(), 769 odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(), 770 attr.attr.getSummary(), 771 attr.attr.getStorageType())); 772 } 773 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 774 odsOp->appendOperand(operand.name, getLengthKind(operand), 775 addTypeConstraint(operand)); 776 } 777 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 778 odsOp->appendResult(result.name, getLengthKind(result), 779 addTypeConstraint(result)); 780 } 781 } 782 /// Attr constraints. 783 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 784 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 785 decls.push_back( 786 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 787 tblgen::AttrConstraint(def), 788 convertLocToRange(def->getLoc().front()), attrTy)); 789 } 790 } 791 /// Type constraints. 792 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 793 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 794 decls.push_back( 795 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 796 tblgen::TypeConstraint(def), 797 convertLocToRange(def->getLoc().front()), typeTy)); 798 } 799 } 800 /// Interfaces. 801 ast::Type opTy = ast::OperationType::get(ctx); 802 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 803 StringRef name = def->getName(); 804 if (def->isAnonymous() || curDeclScope->lookup(name) || 805 def->isSubClassOf("DeclareInterfaceMethods")) 806 continue; 807 SMRange loc = convertLocToRange(def->getLoc().front()); 808 809 StringRef className = def->getValueAsString("cppClassName"); 810 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 811 std::string codeBlock = 812 llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) 813 .str(); 814 815 if (def->isSubClassOf("OpInterface")) { 816 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 817 name, codeBlock, loc, opTy)); 818 } else if (def->isSubClassOf("AttrInterface")) { 819 decls.push_back( 820 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 821 name, codeBlock, loc, attrTy)); 822 } else if (def->isSubClassOf("TypeInterface")) { 823 decls.push_back( 824 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 825 name, codeBlock, loc, typeTy)); 826 } 827 } 828 } 829 830 template <typename ConstraintT> 831 ast::Decl * 832 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 833 SMRange loc, ast::Type type) { 834 // Build the single input parameter. 835 ast::DeclScope *argScope = pushDeclScope(); 836 auto *paramVar = ast::VariableDecl::create( 837 ctx, ast::Name::create(ctx, "self", loc), type, 838 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 839 argScope->add(paramVar); 840 popDeclScope(); 841 842 // Build the native constraint. 843 auto *constraintDecl = ast::UserConstraintDecl::createNative( 844 ctx, ast::Name::create(ctx, name, loc), paramVar, 845 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 846 curDeclScope->add(constraintDecl); 847 return constraintDecl; 848 } 849 850 template <typename ConstraintT> 851 ast::Decl * 852 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 853 SMRange loc, ast::Type type) { 854 // Format the condition template. 855 tblgen::FmtContext fmtContext; 856 fmtContext.withSelf("self"); 857 std::string codeBlock = 858 tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); 859 860 return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(), 861 codeBlock, loc, type); 862 } 863 864 //===----------------------------------------------------------------------===// 865 // Decls 866 867 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 868 FailureOr<ast::Decl *> decl; 869 switch (curToken.getKind()) { 870 case Token::kw_Constraint: 871 decl = parseUserConstraintDecl(); 872 break; 873 case Token::kw_Pattern: 874 decl = parsePatternDecl(); 875 break; 876 case Token::kw_Rewrite: 877 decl = parseUserRewriteDecl(); 878 break; 879 default: 880 return emitError("expected top-level declaration, such as a `Pattern`"); 881 } 882 if (failed(decl)) 883 return failure(); 884 885 // If the decl has a name, add it to the current scope. 886 if (const ast::Name *name = (*decl)->getName()) { 887 if (failed(checkDefineNamedDecl(*name))) 888 return failure(); 889 curDeclScope->add(*decl); 890 } 891 return decl; 892 } 893 894 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 895 std::string attrNameStr; 896 if (curToken.isString()) 897 attrNameStr = curToken.getStringValue(); 898 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 899 attrNameStr = curToken.getSpelling().str(); 900 else 901 return emitError("expected identifier or string attribute name"); 902 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 903 consumeToken(); 904 905 // Check for a value of the attribute. 906 ast::Expr *attrValue = nullptr; 907 if (consumeIf(Token::equal)) { 908 FailureOr<ast::Expr *> attrExpr = parseExpr(); 909 if (failed(attrExpr)) 910 return failure(); 911 attrValue = *attrExpr; 912 } else { 913 // If there isn't a concrete value, create an expression representing a 914 // UnitAttr. 915 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 916 } 917 918 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 919 } 920 921 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 922 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 923 bool expectTerminalSemicolon) { 924 consumeToken(Token::equal_arrow); 925 926 // Parse the single statement of the lambda body. 927 SMLoc bodyStartLoc = curToken.getStartLoc(); 928 pushDeclScope(); 929 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 930 bool failedToParse = 931 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 932 popDeclScope(); 933 if (failedToParse) 934 return failure(); 935 936 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 937 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 938 } 939 940 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 941 // Ensure that the argument is named. 942 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 943 return emitError("expected identifier argument name"); 944 945 // Parse the argument similarly to a normal variable. 946 StringRef name = curToken.getSpelling(); 947 SMRange nameLoc = curToken.getLoc(); 948 consumeToken(); 949 950 if (failed( 951 parseToken(Token::colon, "expected `:` before argument constraint"))) 952 return failure(); 953 954 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 955 if (failed(cst)) 956 return failure(); 957 958 return createArgOrResultVariableDecl(name, nameLoc, *cst); 959 } 960 961 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 962 // Check to see if this result is named. 963 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 964 // Check to see if this name actually refers to a Constraint. 965 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 966 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 967 // If yes, and this is a Rewrite, give a nice error message as non-Core 968 // constraints are not supported on Rewrite results. 969 if (parserContext == ParserContext::Rewrite) { 970 return emitError( 971 "`Rewrite` results are only permitted to use core constraints, " 972 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 973 } 974 975 // Otherwise, parse this as an unnamed result variable. 976 } else { 977 // If it wasn't a constraint, parse the result similarly to a variable. If 978 // there is already an existing decl, we will emit an error when defining 979 // this variable later. 980 StringRef name = curToken.getSpelling(); 981 SMRange nameLoc = curToken.getLoc(); 982 consumeToken(); 983 984 if (failed(parseToken(Token::colon, 985 "expected `:` before result constraint"))) 986 return failure(); 987 988 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 989 if (failed(cst)) 990 return failure(); 991 992 return createArgOrResultVariableDecl(name, nameLoc, *cst); 993 } 994 } 995 996 // If it isn't named, we parse the constraint directly and create an unnamed 997 // result variable. 998 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 999 if (failed(cst)) 1000 return failure(); 1001 1002 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1003 } 1004 1005 FailureOr<ast::UserConstraintDecl *> 1006 Parser::parseUserConstraintDecl(bool isInline) { 1007 // Constraints and rewrites have very similar formats, dispatch to a shared 1008 // interface for parsing. 1009 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1010 [&](auto &&...args) { 1011 return this->parseUserPDLLConstraintDecl(args...); 1012 }, 1013 ParserContext::Constraint, "constraint", isInline); 1014 } 1015 1016 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1017 FailureOr<ast::UserConstraintDecl *> decl = 1018 parseUserConstraintDecl(/*isInline=*/true); 1019 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1020 return failure(); 1021 1022 curDeclScope->add(*decl); 1023 return decl; 1024 } 1025 1026 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1027 const ast::Name &name, bool isInline, 1028 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1029 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1030 // Push the argument scope back onto the list, so that the body can 1031 // reference arguments. 1032 pushDeclScope(argumentScope); 1033 1034 // Parse the body of the constraint. The body is either defined as a compound 1035 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1036 ast::CompoundStmt *body; 1037 if (curToken.is(Token::equal_arrow)) { 1038 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1039 [&](ast::Stmt *&stmt) -> LogicalResult { 1040 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1041 if (!stmtExpr) { 1042 return emitError(stmt->getLoc(), 1043 "expected `Constraint` lambda body to contain a " 1044 "single expression"); 1045 } 1046 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1047 return success(); 1048 }, 1049 /*expectTerminalSemicolon=*/!isInline); 1050 if (failed(bodyResult)) 1051 return failure(); 1052 body = *bodyResult; 1053 } else { 1054 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1055 if (failed(bodyResult)) 1056 return failure(); 1057 body = *bodyResult; 1058 1059 // Verify the structure of the body. 1060 auto bodyIt = body->begin(), bodyE = body->end(); 1061 for (; bodyIt != bodyE; ++bodyIt) 1062 if (isa<ast::ReturnStmt>(*bodyIt)) 1063 break; 1064 if (failed(validateUserConstraintOrRewriteReturn( 1065 "Constraint", body, bodyIt, bodyE, results, resultType))) 1066 return failure(); 1067 } 1068 popDeclScope(); 1069 1070 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1071 name, arguments, results, resultType, body); 1072 } 1073 1074 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1075 // Constraints and rewrites have very similar formats, dispatch to a shared 1076 // interface for parsing. 1077 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1078 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1079 ParserContext::Rewrite, "rewrite", isInline); 1080 } 1081 1082 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1083 FailureOr<ast::UserRewriteDecl *> decl = 1084 parseUserRewriteDecl(/*isInline=*/true); 1085 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1086 return failure(); 1087 1088 curDeclScope->add(*decl); 1089 return decl; 1090 } 1091 1092 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1093 const ast::Name &name, bool isInline, 1094 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1095 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1096 // Push the argument scope back onto the list, so that the body can 1097 // reference arguments. 1098 curDeclScope = argumentScope; 1099 ast::CompoundStmt *body; 1100 if (curToken.is(Token::equal_arrow)) { 1101 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1102 [&](ast::Stmt *&statement) -> LogicalResult { 1103 if (isa<ast::OpRewriteStmt>(statement)) 1104 return success(); 1105 1106 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1107 if (!statementExpr) { 1108 return emitError( 1109 statement->getLoc(), 1110 "expected `Rewrite` lambda body to contain a single expression " 1111 "or an operation rewrite statement; such as `erase`, " 1112 "`replace`, or `rewrite`"); 1113 } 1114 statement = 1115 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1116 return success(); 1117 }, 1118 /*expectTerminalSemicolon=*/!isInline); 1119 if (failed(bodyResult)) 1120 return failure(); 1121 body = *bodyResult; 1122 } else { 1123 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1124 if (failed(bodyResult)) 1125 return failure(); 1126 body = *bodyResult; 1127 } 1128 popDeclScope(); 1129 1130 // Verify the structure of the body. 1131 auto bodyIt = body->begin(), bodyE = body->end(); 1132 for (; bodyIt != bodyE; ++bodyIt) 1133 if (isa<ast::ReturnStmt>(*bodyIt)) 1134 break; 1135 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1136 bodyE, results, resultType))) 1137 return failure(); 1138 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1139 name, arguments, results, resultType, body); 1140 } 1141 1142 template <typename T, typename ParseUserPDLLDeclFnT> 1143 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1144 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1145 StringRef anonymousNamePrefix, bool isInline) { 1146 SMRange loc = curToken.getLoc(); 1147 consumeToken(); 1148 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1149 1150 // Parse the name of the decl. 1151 const ast::Name *name = nullptr; 1152 if (curToken.isNot(Token::identifier)) { 1153 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1154 // in C++, so being unnamed is fine. 1155 if (!isInline) 1156 return emitError("expected identifier name"); 1157 1158 // Create a unique anonymous name to use, as the name for this decl is not 1159 // important. 1160 std::string anonName = 1161 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1162 anonymousDeclNameCounter++) 1163 .str(); 1164 name = &ast::Name::create(ctx, anonName, loc); 1165 } else { 1166 // If a name was provided, we can use it directly. 1167 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1168 consumeToken(Token::identifier); 1169 } 1170 1171 // Parse the functional signature of the decl. 1172 SmallVector<ast::VariableDecl *> arguments, results; 1173 ast::DeclScope *argumentScope; 1174 ast::Type resultType; 1175 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1176 argumentScope, resultType))) 1177 return failure(); 1178 1179 // Check to see which type of constraint this is. If the constraint contains a 1180 // compound body, this is a PDLL decl. 1181 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1182 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1183 resultType); 1184 1185 // Otherwise, this is a native decl. 1186 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1187 results, resultType); 1188 } 1189 1190 template <typename T> 1191 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1192 const ast::Name &name, bool isInline, 1193 ArrayRef<ast::VariableDecl *> arguments, 1194 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1195 // If followed by a string, the native code body has also been specified. 1196 std::string codeStrStorage; 1197 Optional<StringRef> optCodeStr; 1198 if (curToken.isString()) { 1199 codeStrStorage = curToken.getStringValue(); 1200 optCodeStr = codeStrStorage; 1201 consumeToken(); 1202 } else if (isInline) { 1203 return emitError(name.getLoc(), 1204 "external declarations must be declared in global scope"); 1205 } 1206 if (failed(parseToken(Token::semicolon, 1207 "expected `;` after native declaration"))) 1208 return failure(); 1209 // TODO: PDL should be able to support constraint results in certain 1210 // situations, we should revise this. 1211 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1212 return emitError( 1213 "native Constraints currently do not support returning results"); 1214 } 1215 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1216 } 1217 1218 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1219 SmallVectorImpl<ast::VariableDecl *> &arguments, 1220 SmallVectorImpl<ast::VariableDecl *> &results, 1221 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1222 // Parse the argument list of the decl. 1223 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1224 return failure(); 1225 1226 argumentScope = pushDeclScope(); 1227 if (curToken.isNot(Token::r_paren)) { 1228 do { 1229 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1230 if (failed(argument)) 1231 return failure(); 1232 arguments.emplace_back(*argument); 1233 } while (consumeIf(Token::comma)); 1234 } 1235 popDeclScope(); 1236 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1237 return failure(); 1238 1239 // Parse the results of the decl. 1240 pushDeclScope(); 1241 if (consumeIf(Token::arrow)) { 1242 auto parseResultFn = [&]() -> LogicalResult { 1243 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1244 if (failed(result)) 1245 return failure(); 1246 results.emplace_back(*result); 1247 return success(); 1248 }; 1249 1250 // Check for a list of results. 1251 if (consumeIf(Token::l_paren)) { 1252 do { 1253 if (failed(parseResultFn())) 1254 return failure(); 1255 } while (consumeIf(Token::comma)); 1256 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1257 return failure(); 1258 1259 // Otherwise, there is only one result. 1260 } else if (failed(parseResultFn())) { 1261 return failure(); 1262 } 1263 } 1264 popDeclScope(); 1265 1266 // Compute the result type of the decl. 1267 resultType = createUserConstraintRewriteResultType(results); 1268 1269 // Verify that results are only named if there are more than one. 1270 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1271 return emitError( 1272 results.front()->getLoc(), 1273 "cannot create a single-element tuple with an element label"); 1274 } 1275 return success(); 1276 } 1277 1278 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1279 StringRef declType, ast::CompoundStmt *body, 1280 ArrayRef<ast::Stmt *>::iterator bodyIt, 1281 ArrayRef<ast::Stmt *>::iterator bodyE, 1282 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1283 // Handle if a `return` was provided. 1284 if (bodyIt != bodyE) { 1285 // Emit an error if we have trailing statements after the return. 1286 if (std::next(bodyIt) != bodyE) { 1287 return emitError( 1288 (*std::next(bodyIt))->getLoc(), 1289 llvm::formatv("`return` terminated the `{0}` body, but found " 1290 "trailing statements afterwards", 1291 declType)); 1292 } 1293 1294 // Otherwise if a return wasn't provided, check that no results are 1295 // expected. 1296 } else if (!results.empty()) { 1297 return emitError( 1298 {body->getLoc().End, body->getLoc().End}, 1299 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1300 declType, resultType)); 1301 } 1302 return success(); 1303 } 1304 1305 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1306 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1307 if (isa<ast::OpRewriteStmt>(statement)) 1308 return success(); 1309 return emitError( 1310 statement->getLoc(), 1311 "expected Pattern lambda body to contain a single operation " 1312 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1313 }); 1314 } 1315 1316 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1317 SMRange loc = curToken.getLoc(); 1318 consumeToken(Token::kw_Pattern); 1319 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1320 ParserContext::PatternMatch); 1321 1322 // Check for an optional identifier for the pattern name. 1323 const ast::Name *name = nullptr; 1324 if (curToken.is(Token::identifier)) { 1325 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1326 consumeToken(Token::identifier); 1327 } 1328 1329 // Parse any pattern metadata. 1330 ParsedPatternMetadata metadata; 1331 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1332 return failure(); 1333 1334 // Parse the pattern body. 1335 ast::CompoundStmt *body; 1336 1337 // Handle a lambda body. 1338 if (curToken.is(Token::equal_arrow)) { 1339 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1340 if (failed(bodyResult)) 1341 return failure(); 1342 body = *bodyResult; 1343 } else { 1344 if (curToken.isNot(Token::l_brace)) 1345 return emitError("expected `{` or `=>` to start pattern body"); 1346 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1347 if (failed(bodyResult)) 1348 return failure(); 1349 body = *bodyResult; 1350 1351 // Verify the body of the pattern. 1352 auto bodyIt = body->begin(), bodyE = body->end(); 1353 for (; bodyIt != bodyE; ++bodyIt) { 1354 if (isa<ast::ReturnStmt>(*bodyIt)) { 1355 return emitError((*bodyIt)->getLoc(), 1356 "`return` statements are only permitted within a " 1357 "`Constraint` or `Rewrite` body"); 1358 } 1359 // Break when we've found the rewrite statement. 1360 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1361 break; 1362 } 1363 if (bodyIt == bodyE) { 1364 return emitError(loc, 1365 "expected Pattern body to terminate with an operation " 1366 "rewrite statement, such as `erase`"); 1367 } 1368 if (std::next(bodyIt) != bodyE) { 1369 return emitError((*std::next(bodyIt))->getLoc(), 1370 "Pattern body was terminated by an operation " 1371 "rewrite statement, but found trailing statements"); 1372 } 1373 } 1374 1375 return createPatternDecl(loc, name, metadata, body); 1376 } 1377 1378 LogicalResult 1379 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1380 Optional<SMRange> benefitLoc; 1381 Optional<SMRange> hasBoundedRecursionLoc; 1382 1383 do { 1384 if (curToken.isNot(Token::identifier)) 1385 return emitError("expected pattern metadata identifier"); 1386 StringRef metadataStr = curToken.getSpelling(); 1387 SMRange metadataLoc = curToken.getLoc(); 1388 consumeToken(Token::identifier); 1389 1390 // Parse the benefit metadata: benefit(<integer-value>) 1391 if (metadataStr == "benefit") { 1392 if (benefitLoc) { 1393 return emitErrorAndNote(metadataLoc, 1394 "pattern benefit has already been specified", 1395 *benefitLoc, "see previous definition here"); 1396 } 1397 if (failed(parseToken(Token::l_paren, 1398 "expected `(` before pattern benefit"))) 1399 return failure(); 1400 1401 uint16_t benefitValue = 0; 1402 if (curToken.isNot(Token::integer)) 1403 return emitError("expected integral pattern benefit"); 1404 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1405 return emitError( 1406 "expected pattern benefit to fit within a 16-bit integer"); 1407 consumeToken(Token::integer); 1408 1409 metadata.benefit = benefitValue; 1410 benefitLoc = metadataLoc; 1411 1412 if (failed( 1413 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1414 return failure(); 1415 continue; 1416 } 1417 1418 // Parse the bounded recursion metadata: recursion 1419 if (metadataStr == "recursion") { 1420 if (hasBoundedRecursionLoc) { 1421 return emitErrorAndNote( 1422 metadataLoc, 1423 "pattern recursion metadata has already been specified", 1424 *hasBoundedRecursionLoc, "see previous definition here"); 1425 } 1426 metadata.hasBoundedRecursion = true; 1427 hasBoundedRecursionLoc = metadataLoc; 1428 continue; 1429 } 1430 1431 return emitError(metadataLoc, "unknown pattern metadata"); 1432 } while (consumeIf(Token::comma)); 1433 1434 return success(); 1435 } 1436 1437 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1438 consumeToken(Token::less); 1439 1440 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1441 if (failed(typeExpr) || 1442 failed(parseToken(Token::greater, 1443 "expected `>` after variable type constraint"))) 1444 return failure(); 1445 return typeExpr; 1446 } 1447 1448 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1449 assert(curDeclScope && "defining decl outside of a decl scope"); 1450 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1451 return emitErrorAndNote( 1452 name.getLoc(), "`" + name.getName() + "` has already been defined", 1453 lastDecl->getName()->getLoc(), "see previous definition here"); 1454 } 1455 return success(); 1456 } 1457 1458 FailureOr<ast::VariableDecl *> 1459 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1460 ast::Expr *initExpr, 1461 ArrayRef<ast::ConstraintRef> constraints) { 1462 assert(curDeclScope && "defining variable outside of decl scope"); 1463 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1464 1465 // If the name of the variable indicates a special variable, we don't add it 1466 // to the scope. This variable is local to the definition point. 1467 if (name.empty() || name == "_") { 1468 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1469 constraints); 1470 } 1471 if (failed(checkDefineNamedDecl(nameDecl))) 1472 return failure(); 1473 1474 auto *varDecl = 1475 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1476 curDeclScope->add(varDecl); 1477 return varDecl; 1478 } 1479 1480 FailureOr<ast::VariableDecl *> 1481 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1482 ArrayRef<ast::ConstraintRef> constraints) { 1483 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1484 constraints); 1485 } 1486 1487 LogicalResult Parser::parseVariableDeclConstraintList( 1488 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1489 Optional<SMRange> typeConstraint; 1490 auto parseSingleConstraint = [&] { 1491 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1492 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); 1493 if (failed(constraint)) 1494 return failure(); 1495 constraints.push_back(*constraint); 1496 return success(); 1497 }; 1498 1499 // Check to see if this is a single constraint, or a list. 1500 if (!consumeIf(Token::l_square)) 1501 return parseSingleConstraint(); 1502 1503 do { 1504 if (failed(parseSingleConstraint())) 1505 return failure(); 1506 } while (consumeIf(Token::comma)); 1507 return parseToken(Token::r_square, "expected `]` after constraint list"); 1508 } 1509 1510 FailureOr<ast::ConstraintRef> 1511 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1512 ArrayRef<ast::ConstraintRef> existingConstraints, 1513 bool allowInlineTypeConstraints) { 1514 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1515 if (!allowInlineTypeConstraints) { 1516 return emitError( 1517 curToken.getLoc(), 1518 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1519 "permitted on arguments or results"); 1520 } 1521 if (typeConstraint) 1522 return emitErrorAndNote( 1523 curToken.getLoc(), 1524 "the type of this variable has already been constrained", 1525 *typeConstraint, "see previous constraint location here"); 1526 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1527 if (failed(constraintExpr)) 1528 return failure(); 1529 typeExpr = *constraintExpr; 1530 typeConstraint = typeExpr->getLoc(); 1531 return success(); 1532 }; 1533 1534 SMRange loc = curToken.getLoc(); 1535 switch (curToken.getKind()) { 1536 case Token::kw_Attr: { 1537 consumeToken(Token::kw_Attr); 1538 1539 // Check for a type constraint. 1540 ast::Expr *typeExpr = nullptr; 1541 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1542 return failure(); 1543 return ast::ConstraintRef( 1544 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1545 } 1546 case Token::kw_Op: { 1547 consumeToken(Token::kw_Op); 1548 1549 // Parse an optional operation name. If the name isn't provided, this refers 1550 // to "any" operation. 1551 FailureOr<ast::OpNameDecl *> opName = 1552 parseWrappedOperationName(/*allowEmptyName=*/true); 1553 if (failed(opName)) 1554 return failure(); 1555 1556 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1557 loc); 1558 } 1559 case Token::kw_Type: 1560 consumeToken(Token::kw_Type); 1561 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1562 case Token::kw_TypeRange: 1563 consumeToken(Token::kw_TypeRange); 1564 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1565 loc); 1566 case Token::kw_Value: { 1567 consumeToken(Token::kw_Value); 1568 1569 // Check for a type constraint. 1570 ast::Expr *typeExpr = nullptr; 1571 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1572 return failure(); 1573 1574 return ast::ConstraintRef( 1575 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1576 } 1577 case Token::kw_ValueRange: { 1578 consumeToken(Token::kw_ValueRange); 1579 1580 // Check for a type constraint. 1581 ast::Expr *typeExpr = nullptr; 1582 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1583 return failure(); 1584 1585 return ast::ConstraintRef( 1586 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1587 } 1588 1589 case Token::kw_Constraint: { 1590 // Handle an inline constraint. 1591 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1592 if (failed(decl)) 1593 return failure(); 1594 return ast::ConstraintRef(*decl, loc); 1595 } 1596 case Token::identifier: { 1597 StringRef constraintName = curToken.getSpelling(); 1598 consumeToken(Token::identifier); 1599 1600 // Lookup the referenced constraint. 1601 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1602 if (!cstDecl) { 1603 return emitError(loc, "unknown reference to constraint `" + 1604 constraintName + "`"); 1605 } 1606 1607 // Handle a reference to a proper constraint. 1608 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1609 return ast::ConstraintRef(cst, loc); 1610 1611 return emitErrorAndNote( 1612 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1613 "see the definition of `" + constraintName + "` here"); 1614 } 1615 default: 1616 break; 1617 } 1618 return emitError(loc, "expected identifier constraint"); 1619 } 1620 1621 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1622 Optional<SMRange> typeConstraint; 1623 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1624 /*allowInlineTypeConstraints=*/false); 1625 } 1626 1627 //===----------------------------------------------------------------------===// 1628 // Exprs 1629 1630 FailureOr<ast::Expr *> Parser::parseExpr() { 1631 if (curToken.is(Token::underscore)) 1632 return parseUnderscoreExpr(); 1633 1634 // Parse the LHS expression. 1635 FailureOr<ast::Expr *> lhsExpr; 1636 switch (curToken.getKind()) { 1637 case Token::kw_attr: 1638 lhsExpr = parseAttributeExpr(); 1639 break; 1640 case Token::kw_Constraint: 1641 lhsExpr = parseInlineConstraintLambdaExpr(); 1642 break; 1643 case Token::identifier: 1644 lhsExpr = parseIdentifierExpr(); 1645 break; 1646 case Token::kw_op: 1647 lhsExpr = parseOperationExpr(); 1648 break; 1649 case Token::kw_Rewrite: 1650 lhsExpr = parseInlineRewriteLambdaExpr(); 1651 break; 1652 case Token::kw_type: 1653 lhsExpr = parseTypeExpr(); 1654 break; 1655 case Token::l_paren: 1656 lhsExpr = parseTupleExpr(); 1657 break; 1658 default: 1659 return emitError("expected expression"); 1660 } 1661 if (failed(lhsExpr)) 1662 return failure(); 1663 1664 // Check for an operator expression. 1665 while (true) { 1666 switch (curToken.getKind()) { 1667 case Token::dot: 1668 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1669 break; 1670 case Token::l_paren: 1671 lhsExpr = parseCallExpr(*lhsExpr); 1672 break; 1673 default: 1674 return lhsExpr; 1675 } 1676 if (failed(lhsExpr)) 1677 return failure(); 1678 } 1679 } 1680 1681 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1682 SMRange loc = curToken.getLoc(); 1683 consumeToken(Token::kw_attr); 1684 1685 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1686 // identifier. 1687 if (!consumeIf(Token::less)) { 1688 resetToken(loc); 1689 return parseIdentifierExpr(); 1690 } 1691 1692 if (!curToken.isString()) 1693 return emitError("expected string literal containing MLIR attribute"); 1694 std::string attrExpr = curToken.getStringValue(); 1695 consumeToken(); 1696 1697 if (failed( 1698 parseToken(Token::greater, "expected `>` after attribute literal"))) 1699 return failure(); 1700 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1701 } 1702 1703 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1704 SMRange loc = curToken.getLoc(); 1705 consumeToken(Token::l_paren); 1706 1707 // Parse the arguments of the call. 1708 SmallVector<ast::Expr *> arguments; 1709 if (curToken.isNot(Token::r_paren)) { 1710 do { 1711 FailureOr<ast::Expr *> argument = parseExpr(); 1712 if (failed(argument)) 1713 return failure(); 1714 arguments.push_back(*argument); 1715 } while (consumeIf(Token::comma)); 1716 } 1717 loc.End = curToken.getEndLoc(); 1718 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1719 return failure(); 1720 1721 return createCallExpr(loc, parentExpr, arguments); 1722 } 1723 1724 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1725 ast::Decl *decl = curDeclScope->lookup(name); 1726 if (!decl) 1727 return emitError(loc, "undefined reference to `" + name + "`"); 1728 1729 return createDeclRefExpr(loc, decl); 1730 } 1731 1732 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1733 StringRef name = curToken.getSpelling(); 1734 SMRange nameLoc = curToken.getLoc(); 1735 consumeToken(); 1736 1737 // Check to see if this is a decl ref expression that defines a variable 1738 // inline. 1739 if (consumeIf(Token::colon)) { 1740 SmallVector<ast::ConstraintRef> constraints; 1741 if (failed(parseVariableDeclConstraintList(constraints))) 1742 return failure(); 1743 ast::Type type; 1744 if (failed(validateVariableConstraints(constraints, type))) 1745 return failure(); 1746 return createInlineVariableExpr(type, name, nameLoc, constraints); 1747 } 1748 1749 return parseDeclRefExpr(name, nameLoc); 1750 } 1751 1752 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1753 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1754 if (failed(decl)) 1755 return failure(); 1756 1757 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1758 ast::ConstraintType::get(ctx)); 1759 } 1760 1761 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1762 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1763 if (failed(decl)) 1764 return failure(); 1765 1766 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1767 ast::RewriteType::get(ctx)); 1768 } 1769 1770 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1771 SMRange loc = curToken.getLoc(); 1772 consumeToken(Token::dot); 1773 1774 // Parse the member name. 1775 Token memberNameTok = curToken; 1776 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1777 !memberNameTok.isKeyword()) 1778 return emitError(loc, "expected identifier or numeric member name"); 1779 StringRef memberName = memberNameTok.getSpelling(); 1780 consumeToken(); 1781 1782 return createMemberAccessExpr(parentExpr, memberName, loc); 1783 } 1784 1785 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1786 SMRange loc = curToken.getLoc(); 1787 1788 // Handle the case of an no operation name. 1789 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1790 if (allowEmptyName) 1791 return ast::OpNameDecl::create(ctx, SMRange()); 1792 return emitError("expected dialect namespace"); 1793 } 1794 StringRef name = curToken.getSpelling(); 1795 consumeToken(); 1796 1797 // Otherwise, this is a literal operation name. 1798 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1799 return failure(); 1800 1801 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1802 return emitError("expected operation name after dialect namespace"); 1803 1804 name = StringRef(name.data(), name.size() + 1); 1805 do { 1806 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1807 loc.End = curToken.getEndLoc(); 1808 consumeToken(); 1809 } while (curToken.isAny(Token::identifier, Token::dot) || 1810 curToken.isKeyword()); 1811 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1812 } 1813 1814 FailureOr<ast::OpNameDecl *> 1815 Parser::parseWrappedOperationName(bool allowEmptyName) { 1816 if (!consumeIf(Token::less)) 1817 return ast::OpNameDecl::create(ctx, SMRange()); 1818 1819 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1820 if (failed(opNameDecl)) 1821 return failure(); 1822 1823 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1824 return failure(); 1825 return opNameDecl; 1826 } 1827 1828 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1829 SMRange loc = curToken.getLoc(); 1830 consumeToken(Token::kw_op); 1831 1832 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1833 // identifier. 1834 if (curToken.isNot(Token::less)) { 1835 resetToken(loc); 1836 return parseIdentifierExpr(); 1837 } 1838 1839 // Parse the operation name. The name may be elided, in which case the 1840 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1841 // `Operation*`). Operation names within a rewrite context must be named. 1842 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1843 FailureOr<ast::OpNameDecl *> opNameDecl = 1844 parseWrappedOperationName(allowEmptyName); 1845 if (failed(opNameDecl)) 1846 return failure(); 1847 1848 // Functor used to create an implicit range variable, used for implicit "all" 1849 // operand or results variables. 1850 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1851 FailureOr<ast::VariableDecl *> rangeVar = 1852 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1853 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1854 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1855 }; 1856 1857 // Check for the optional list of operands. 1858 SmallVector<ast::Expr *> operands; 1859 if (!consumeIf(Token::l_paren)) { 1860 // If the operand list isn't specified and we are in a match context, define 1861 // an inplace unconstrained operand range corresponding to all of the 1862 // operands of the operation. This avoids treating zero operands the same 1863 // way as "unconstrained operands". 1864 if (parserContext != ParserContext::Rewrite) { 1865 operands.push_back(createImplicitRangeVar( 1866 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1867 } 1868 } else if (!consumeIf(Token::r_paren)) { 1869 // If the operand list was specified and non-empty, parse the operands. 1870 do { 1871 FailureOr<ast::Expr *> operand = parseExpr(); 1872 if (failed(operand)) 1873 return failure(); 1874 operands.push_back(*operand); 1875 } while (consumeIf(Token::comma)); 1876 1877 if (failed(parseToken(Token::r_paren, 1878 "expected `)` after operation operand list"))) 1879 return failure(); 1880 } 1881 1882 // Check for the optional list of attributes. 1883 SmallVector<ast::NamedAttributeDecl *> attributes; 1884 if (consumeIf(Token::l_brace)) { 1885 do { 1886 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1887 if (failed(decl)) 1888 return failure(); 1889 attributes.emplace_back(*decl); 1890 } while (consumeIf(Token::comma)); 1891 1892 if (failed(parseToken(Token::r_brace, 1893 "expected `}` after operation attribute list"))) 1894 return failure(); 1895 } 1896 1897 // Check for the optional list of result types. 1898 SmallVector<ast::Expr *> resultTypes; 1899 if (consumeIf(Token::arrow)) { 1900 if (failed(parseToken(Token::l_paren, 1901 "expected `(` before operation result type list"))) 1902 return failure(); 1903 1904 // Handle the case of an empty result list. 1905 if (!consumeIf(Token::r_paren)) { 1906 do { 1907 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1908 if (failed(resultTypeExpr)) 1909 return failure(); 1910 resultTypes.push_back(*resultTypeExpr); 1911 } while (consumeIf(Token::comma)); 1912 1913 if (failed(parseToken(Token::r_paren, 1914 "expected `)` after operation result type list"))) 1915 return failure(); 1916 } 1917 } else if (parserContext != ParserContext::Rewrite) { 1918 // If the result list isn't specified and we are in a match context, define 1919 // an inplace unconstrained result range corresponding to all of the results 1920 // of the operation. This avoids treating zero results the same way as 1921 // "unconstrained results". 1922 resultTypes.push_back(createImplicitRangeVar( 1923 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 1924 } 1925 1926 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1927 resultTypes); 1928 } 1929 1930 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1931 SMRange loc = curToken.getLoc(); 1932 consumeToken(Token::l_paren); 1933 1934 DenseMap<StringRef, SMRange> usedNames; 1935 SmallVector<StringRef> elementNames; 1936 SmallVector<ast::Expr *> elements; 1937 if (curToken.isNot(Token::r_paren)) { 1938 do { 1939 // Check for the optional element name assignment before the value. 1940 StringRef elementName; 1941 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1942 Token elementNameTok = curToken; 1943 consumeToken(); 1944 1945 // The element name is only present if followed by an `=`. 1946 if (consumeIf(Token::equal)) { 1947 elementName = elementNameTok.getSpelling(); 1948 1949 // Check to see if this name is already used. 1950 auto elementNameIt = 1951 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1952 if (!elementNameIt.second) { 1953 return emitErrorAndNote( 1954 elementNameTok.getLoc(), 1955 llvm::formatv("duplicate tuple element label `{0}`", 1956 elementName), 1957 elementNameIt.first->getSecond(), 1958 "see previous label use here"); 1959 } 1960 } else { 1961 // Otherwise, we treat this as part of an expression so reset the 1962 // lexer. 1963 resetToken(elementNameTok.getLoc()); 1964 } 1965 } 1966 elementNames.push_back(elementName); 1967 1968 // Parse the tuple element value. 1969 FailureOr<ast::Expr *> element = parseExpr(); 1970 if (failed(element)) 1971 return failure(); 1972 elements.push_back(*element); 1973 } while (consumeIf(Token::comma)); 1974 } 1975 loc.End = curToken.getEndLoc(); 1976 if (failed( 1977 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1978 return failure(); 1979 return createTupleExpr(loc, elements, elementNames); 1980 } 1981 1982 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1983 SMRange loc = curToken.getLoc(); 1984 consumeToken(Token::kw_type); 1985 1986 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1987 // identifier. 1988 if (!consumeIf(Token::less)) { 1989 resetToken(loc); 1990 return parseIdentifierExpr(); 1991 } 1992 1993 if (!curToken.isString()) 1994 return emitError("expected string literal containing MLIR type"); 1995 std::string attrExpr = curToken.getStringValue(); 1996 consumeToken(); 1997 1998 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1999 return failure(); 2000 return ast::TypeExpr::create(ctx, loc, attrExpr); 2001 } 2002 2003 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2004 StringRef name = curToken.getSpelling(); 2005 SMRange nameLoc = curToken.getLoc(); 2006 consumeToken(Token::underscore); 2007 2008 // Underscore expressions require a constraint list. 2009 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2010 return failure(); 2011 2012 // Parse the constraints for the expression. 2013 SmallVector<ast::ConstraintRef> constraints; 2014 if (failed(parseVariableDeclConstraintList(constraints))) 2015 return failure(); 2016 2017 ast::Type type; 2018 if (failed(validateVariableConstraints(constraints, type))) 2019 return failure(); 2020 return createInlineVariableExpr(type, name, nameLoc, constraints); 2021 } 2022 2023 //===----------------------------------------------------------------------===// 2024 // Stmts 2025 2026 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2027 FailureOr<ast::Stmt *> stmt; 2028 switch (curToken.getKind()) { 2029 case Token::kw_erase: 2030 stmt = parseEraseStmt(); 2031 break; 2032 case Token::kw_let: 2033 stmt = parseLetStmt(); 2034 break; 2035 case Token::kw_replace: 2036 stmt = parseReplaceStmt(); 2037 break; 2038 case Token::kw_return: 2039 stmt = parseReturnStmt(); 2040 break; 2041 case Token::kw_rewrite: 2042 stmt = parseRewriteStmt(); 2043 break; 2044 default: 2045 stmt = parseExpr(); 2046 break; 2047 } 2048 if (failed(stmt) || 2049 (expectTerminalSemicolon && 2050 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2051 return failure(); 2052 return stmt; 2053 } 2054 2055 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2056 SMLoc startLoc = curToken.getStartLoc(); 2057 consumeToken(Token::l_brace); 2058 2059 // Push a new block scope and parse any nested statements. 2060 pushDeclScope(); 2061 SmallVector<ast::Stmt *> statements; 2062 while (curToken.isNot(Token::r_brace)) { 2063 FailureOr<ast::Stmt *> statement = parseStmt(); 2064 if (failed(statement)) 2065 return popDeclScope(), failure(); 2066 statements.push_back(*statement); 2067 } 2068 popDeclScope(); 2069 2070 // Consume the end brace. 2071 SMRange location(startLoc, curToken.getEndLoc()); 2072 consumeToken(Token::r_brace); 2073 2074 return ast::CompoundStmt::create(ctx, location, statements); 2075 } 2076 2077 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2078 if (parserContext == ParserContext::Constraint) 2079 return emitError("`erase` cannot be used within a Constraint"); 2080 SMRange loc = curToken.getLoc(); 2081 consumeToken(Token::kw_erase); 2082 2083 // Parse the root operation expression. 2084 FailureOr<ast::Expr *> rootOp = parseExpr(); 2085 if (failed(rootOp)) 2086 return failure(); 2087 2088 return createEraseStmt(loc, *rootOp); 2089 } 2090 2091 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2092 SMRange loc = curToken.getLoc(); 2093 consumeToken(Token::kw_let); 2094 2095 // Parse the name of the new variable. 2096 SMRange varLoc = curToken.getLoc(); 2097 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2098 // `_` is a reserved variable name. 2099 if (curToken.is(Token::underscore)) { 2100 return emitError(varLoc, 2101 "`_` may only be used to define \"inline\" variables"); 2102 } 2103 return emitError(varLoc, 2104 "expected identifier after `let` to name a new variable"); 2105 } 2106 StringRef varName = curToken.getSpelling(); 2107 consumeToken(); 2108 2109 // Parse the optional set of constraints. 2110 SmallVector<ast::ConstraintRef> constraints; 2111 if (consumeIf(Token::colon) && 2112 failed(parseVariableDeclConstraintList(constraints))) 2113 return failure(); 2114 2115 // Parse the optional initializer expression. 2116 ast::Expr *initializer = nullptr; 2117 if (consumeIf(Token::equal)) { 2118 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2119 if (failed(initOrFailure)) 2120 return failure(); 2121 initializer = *initOrFailure; 2122 2123 // Check that the constraints are compatible with having an initializer, 2124 // e.g. type constraints cannot be used with initializers. 2125 for (ast::ConstraintRef constraint : constraints) { 2126 LogicalResult result = 2127 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2128 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2129 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2130 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2131 return this->emitError( 2132 constraint.referenceLoc, 2133 "type constraints are not permitted on variables with " 2134 "initializers"); 2135 } 2136 return success(); 2137 }) 2138 .Default(success()); 2139 if (failed(result)) 2140 return failure(); 2141 } 2142 } 2143 2144 FailureOr<ast::VariableDecl *> varDecl = 2145 createVariableDecl(varName, varLoc, initializer, constraints); 2146 if (failed(varDecl)) 2147 return failure(); 2148 return ast::LetStmt::create(ctx, loc, *varDecl); 2149 } 2150 2151 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2152 if (parserContext == ParserContext::Constraint) 2153 return emitError("`replace` cannot be used within a Constraint"); 2154 SMRange loc = curToken.getLoc(); 2155 consumeToken(Token::kw_replace); 2156 2157 // Parse the root operation expression. 2158 FailureOr<ast::Expr *> rootOp = parseExpr(); 2159 if (failed(rootOp)) 2160 return failure(); 2161 2162 if (failed( 2163 parseToken(Token::kw_with, "expected `with` after root operation"))) 2164 return failure(); 2165 2166 // The replacement portion of this statement is within a rewrite context. 2167 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2168 ParserContext::Rewrite); 2169 2170 // Parse the replacement values. 2171 SmallVector<ast::Expr *> replValues; 2172 if (consumeIf(Token::l_paren)) { 2173 if (consumeIf(Token::r_paren)) { 2174 return emitError( 2175 loc, "expected at least one replacement value, consider using " 2176 "`erase` if no replacement values are desired"); 2177 } 2178 2179 do { 2180 FailureOr<ast::Expr *> replExpr = parseExpr(); 2181 if (failed(replExpr)) 2182 return failure(); 2183 replValues.emplace_back(*replExpr); 2184 } while (consumeIf(Token::comma)); 2185 2186 if (failed(parseToken(Token::r_paren, 2187 "expected `)` after replacement values"))) 2188 return failure(); 2189 } else { 2190 FailureOr<ast::Expr *> replExpr = parseExpr(); 2191 if (failed(replExpr)) 2192 return failure(); 2193 replValues.emplace_back(*replExpr); 2194 } 2195 2196 return createReplaceStmt(loc, *rootOp, replValues); 2197 } 2198 2199 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2200 SMRange loc = curToken.getLoc(); 2201 consumeToken(Token::kw_return); 2202 2203 // Parse the result value. 2204 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2205 if (failed(resultExpr)) 2206 return failure(); 2207 2208 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2209 } 2210 2211 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2212 if (parserContext == ParserContext::Constraint) 2213 return emitError("`rewrite` cannot be used within a Constraint"); 2214 SMRange loc = curToken.getLoc(); 2215 consumeToken(Token::kw_rewrite); 2216 2217 // Parse the root operation. 2218 FailureOr<ast::Expr *> rootOp = parseExpr(); 2219 if (failed(rootOp)) 2220 return failure(); 2221 2222 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2223 return failure(); 2224 2225 if (curToken.isNot(Token::l_brace)) 2226 return emitError("expected `{` to start rewrite body"); 2227 2228 // The rewrite body of this statement is within a rewrite context. 2229 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2230 ParserContext::Rewrite); 2231 2232 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2233 if (failed(rewriteBody)) 2234 return failure(); 2235 2236 // Verify the rewrite body. 2237 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2238 if (isa<ast::ReturnStmt>(stmt)) { 2239 return emitError(stmt->getLoc(), 2240 "`return` statements are only permitted within a " 2241 "`Constraint` or `Rewrite` body"); 2242 } 2243 } 2244 2245 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2246 } 2247 2248 //===----------------------------------------------------------------------===// 2249 // Creation+Analysis 2250 //===----------------------------------------------------------------------===// 2251 2252 //===----------------------------------------------------------------------===// 2253 // Decls 2254 2255 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2256 // Unwrap reference expressions. 2257 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2258 node = init->getDecl(); 2259 return dyn_cast<ast::CallableDecl>(node); 2260 } 2261 2262 FailureOr<ast::PatternDecl *> 2263 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2264 const ParsedPatternMetadata &metadata, 2265 ast::CompoundStmt *body) { 2266 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2267 metadata.hasBoundedRecursion, body); 2268 } 2269 2270 ast::Type Parser::createUserConstraintRewriteResultType( 2271 ArrayRef<ast::VariableDecl *> results) { 2272 // Single result decls use the type of the single result. 2273 if (results.size() == 1) 2274 return results[0]->getType(); 2275 2276 // Multiple results use a tuple type, with the types and names grabbed from 2277 // the result variable decls. 2278 auto resultTypes = llvm::map_range( 2279 results, [&](const auto *result) { return result->getType(); }); 2280 auto resultNames = llvm::map_range( 2281 results, [&](const auto *result) { return result->getName().getName(); }); 2282 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2283 llvm::to_vector(resultNames)); 2284 } 2285 2286 template <typename T> 2287 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2288 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2289 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2290 ast::CompoundStmt *body) { 2291 if (!body->getChildren().empty()) { 2292 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2293 ast::Expr *resultExpr = retStmt->getResultExpr(); 2294 2295 // Process the result of the decl. If no explicit signature results 2296 // were provided, check for return type inference. Otherwise, check that 2297 // the return expression can be converted to the expected type. 2298 if (results.empty()) 2299 resultType = resultExpr->getType(); 2300 else if (failed(convertExpressionTo(resultExpr, resultType))) 2301 return failure(); 2302 else 2303 retStmt->setResultExpr(resultExpr); 2304 } 2305 } 2306 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2307 } 2308 2309 FailureOr<ast::VariableDecl *> 2310 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2311 ArrayRef<ast::ConstraintRef> constraints) { 2312 // The type of the variable, which is expected to be inferred by either a 2313 // constraint or an initializer expression. 2314 ast::Type type; 2315 if (failed(validateVariableConstraints(constraints, type))) 2316 return failure(); 2317 2318 if (initializer) { 2319 // Update the variable type based on the initializer, or try to convert the 2320 // initializer to the existing type. 2321 if (!type) 2322 type = initializer->getType(); 2323 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2324 type = mergedType; 2325 else if (failed(convertExpressionTo(initializer, type))) 2326 return failure(); 2327 2328 // Otherwise, if there is no initializer check that the type has already 2329 // been resolved from the constraint list. 2330 } else if (!type) { 2331 return emitErrorAndNote( 2332 loc, "unable to infer type for variable `" + name + "`", loc, 2333 "the type of a variable must be inferable from the constraint " 2334 "list or the initializer"); 2335 } 2336 2337 // Constraint types cannot be used when defining variables. 2338 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2339 return emitError( 2340 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2341 } 2342 2343 // Try to define a variable with the given name. 2344 FailureOr<ast::VariableDecl *> varDecl = 2345 defineVariableDecl(name, loc, type, initializer, constraints); 2346 if (failed(varDecl)) 2347 return failure(); 2348 2349 return *varDecl; 2350 } 2351 2352 FailureOr<ast::VariableDecl *> 2353 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2354 const ast::ConstraintRef &constraint) { 2355 // Constraint arguments may apply more complex constraints via the arguments. 2356 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2357 ast::Type argType; 2358 if (failed(validateVariableConstraint(constraint, argType, 2359 allowNonCoreConstraints))) 2360 return failure(); 2361 return defineVariableDecl(name, loc, argType, constraint); 2362 } 2363 2364 LogicalResult 2365 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2366 ast::Type &inferredType) { 2367 for (const ast::ConstraintRef &ref : constraints) 2368 if (failed(validateVariableConstraint(ref, inferredType))) 2369 return failure(); 2370 return success(); 2371 } 2372 2373 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2374 ast::Type &inferredType, 2375 bool allowNonCoreConstraints) { 2376 ast::Type constraintType; 2377 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2378 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2379 if (failed(validateTypeConstraintExpr(typeExpr))) 2380 return failure(); 2381 } 2382 constraintType = ast::AttributeType::get(ctx); 2383 } else if (const auto *cst = 2384 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2385 constraintType = ast::OperationType::get(ctx, cst->getName()); 2386 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2387 constraintType = typeTy; 2388 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2389 constraintType = typeRangeTy; 2390 } else if (const auto *cst = 2391 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2392 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2393 if (failed(validateTypeConstraintExpr(typeExpr))) 2394 return failure(); 2395 } 2396 constraintType = valueTy; 2397 } else if (const auto *cst = 2398 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2399 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2400 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2401 return failure(); 2402 } 2403 constraintType = valueRangeTy; 2404 } else if (const auto *cst = 2405 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2406 if (!allowNonCoreConstraints) { 2407 return emitError(ref.referenceLoc, 2408 "`Rewrite` arguments and results are only permitted to " 2409 "use core constraints, such as `Attr`, `Op`, `Type`, " 2410 "`TypeRange`, `Value`, `ValueRange`"); 2411 } 2412 2413 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2414 if (inputs.size() != 1) { 2415 return emitErrorAndNote(ref.referenceLoc, 2416 "`Constraint`s applied via a variable constraint " 2417 "list must take a single input, but got " + 2418 Twine(inputs.size()), 2419 cst->getLoc(), 2420 "see definition of constraint here"); 2421 } 2422 constraintType = inputs.front()->getType(); 2423 } else { 2424 llvm_unreachable("unknown constraint type"); 2425 } 2426 2427 // Check that the constraint type is compatible with the current inferred 2428 // type. 2429 if (!inferredType) { 2430 inferredType = constraintType; 2431 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2432 inferredType = mergedTy; 2433 } else { 2434 return emitError(ref.referenceLoc, 2435 llvm::formatv("constraint type `{0}` is incompatible " 2436 "with the previously inferred type `{1}`", 2437 constraintType, inferredType)); 2438 } 2439 return success(); 2440 } 2441 2442 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2443 ast::Type typeExprType = typeExpr->getType(); 2444 if (typeExprType != typeTy) { 2445 return emitError(typeExpr->getLoc(), 2446 "expected expression of `Type` in type constraint"); 2447 } 2448 return success(); 2449 } 2450 2451 LogicalResult 2452 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2453 ast::Type typeExprType = typeExpr->getType(); 2454 if (typeExprType != typeRangeTy) { 2455 return emitError(typeExpr->getLoc(), 2456 "expected expression of `TypeRange` in type constraint"); 2457 } 2458 return success(); 2459 } 2460 2461 //===----------------------------------------------------------------------===// 2462 // Exprs 2463 2464 FailureOr<ast::CallExpr *> 2465 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2466 MutableArrayRef<ast::Expr *> arguments) { 2467 ast::Type parentType = parentExpr->getType(); 2468 2469 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2470 if (!callableDecl) { 2471 return emitError(loc, 2472 llvm::formatv("expected a reference to a callable " 2473 "`Constraint` or `Rewrite`, but got: `{0}`", 2474 parentType)); 2475 } 2476 if (parserContext == ParserContext::Rewrite) { 2477 if (isa<ast::UserConstraintDecl>(callableDecl)) 2478 return emitError( 2479 loc, "unable to invoke `Constraint` within a rewrite section"); 2480 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2481 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2482 } 2483 2484 // Verify the arguments of the call. 2485 /// Handle size mismatch. 2486 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2487 if (callArgs.size() != arguments.size()) { 2488 return emitErrorAndNote( 2489 loc, 2490 llvm::formatv("invalid number of arguments for {0} call; expected " 2491 "{1}, but got {2}", 2492 callableDecl->getCallableType(), callArgs.size(), 2493 arguments.size()), 2494 callableDecl->getLoc(), 2495 llvm::formatv("see the definition of {0} here", 2496 callableDecl->getName()->getName())); 2497 } 2498 2499 /// Handle argument type mismatch. 2500 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2501 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2502 callableDecl->getName()->getName()), 2503 callableDecl->getLoc()); 2504 }; 2505 for (auto it : llvm::zip(callArgs, arguments)) { 2506 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2507 attachDiagFn))) 2508 return failure(); 2509 } 2510 2511 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2512 callableDecl->getResultType()); 2513 } 2514 2515 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2516 ast::Decl *decl) { 2517 // Check the type of decl being referenced. 2518 ast::Type declType; 2519 if (isa<ast::ConstraintDecl>(decl)) 2520 declType = ast::ConstraintType::get(ctx); 2521 else if (isa<ast::UserRewriteDecl>(decl)) 2522 declType = ast::RewriteType::get(ctx); 2523 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2524 declType = varDecl->getType(); 2525 else 2526 return emitError(loc, "invalid reference to `" + 2527 decl->getName()->getName() + "`"); 2528 2529 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2530 } 2531 2532 FailureOr<ast::DeclRefExpr *> 2533 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2534 ArrayRef<ast::ConstraintRef> constraints) { 2535 FailureOr<ast::VariableDecl *> decl = 2536 defineVariableDecl(name, loc, type, constraints); 2537 if (failed(decl)) 2538 return failure(); 2539 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2540 } 2541 2542 FailureOr<ast::MemberAccessExpr *> 2543 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2544 SMRange loc) { 2545 // Validate the member name for the given parent expression. 2546 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2547 if (failed(memberType)) 2548 return failure(); 2549 2550 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2551 } 2552 2553 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2554 StringRef name, SMRange loc) { 2555 ast::Type parentType = parentExpr->getType(); 2556 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2557 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2558 return valueRangeTy; 2559 2560 // Verify member access based on the operation type. 2561 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2562 auto results = odsOp->getResults(); 2563 2564 // Handle indexed results. 2565 unsigned index = 0; 2566 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2567 index < results.size()) { 2568 return results[index].isVariadic() ? valueRangeTy : valueTy; 2569 } 2570 2571 // Handle named results. 2572 const auto *it = llvm::find_if(results, [&](const auto &result) { 2573 return result.getName() == name; 2574 }); 2575 if (it != results.end()) 2576 return it->isVariadic() ? valueRangeTy : valueTy; 2577 } 2578 2579 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2580 // Handle indexed results. 2581 unsigned index = 0; 2582 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2583 index < tupleType.size()) { 2584 return tupleType.getElementTypes()[index]; 2585 } 2586 2587 // Handle named results. 2588 auto elementNames = tupleType.getElementNames(); 2589 const auto *it = llvm::find(elementNames, name); 2590 if (it != elementNames.end()) 2591 return tupleType.getElementTypes()[it - elementNames.begin()]; 2592 } 2593 return emitError( 2594 loc, 2595 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2596 name, parentType)); 2597 } 2598 2599 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2600 SMRange loc, const ast::OpNameDecl *name, 2601 MutableArrayRef<ast::Expr *> operands, 2602 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2603 MutableArrayRef<ast::Expr *> results) { 2604 Optional<StringRef> opNameRef = name->getName(); 2605 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2606 2607 // Verify the inputs operands. 2608 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2609 return failure(); 2610 2611 // Verify the attribute list. 2612 for (ast::NamedAttributeDecl *attr : attributes) { 2613 // Check for an attribute type, or a type awaiting resolution. 2614 ast::Type attrType = attr->getValue()->getType(); 2615 if (!attrType.isa<ast::AttributeType>()) { 2616 return emitError( 2617 attr->getValue()->getLoc(), 2618 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2619 } 2620 } 2621 2622 // Verify the result types. 2623 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2624 return failure(); 2625 2626 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2627 attributes); 2628 } 2629 2630 LogicalResult 2631 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2632 const ods::Operation *odsOp, 2633 MutableArrayRef<ast::Expr *> operands) { 2634 return validateOperationOperandsOrResults( 2635 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2636 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2637 valueRangeTy); 2638 } 2639 2640 LogicalResult 2641 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2642 const ods::Operation *odsOp, 2643 MutableArrayRef<ast::Expr *> results) { 2644 return validateOperationOperandsOrResults( 2645 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2646 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2647 } 2648 2649 LogicalResult Parser::validateOperationOperandsOrResults( 2650 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2651 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2652 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2653 ast::Type rangeTy) { 2654 // All operation types accept a single range parameter. 2655 if (values.size() == 1) { 2656 if (failed(convertExpressionTo(values[0], rangeTy))) 2657 return failure(); 2658 return success(); 2659 } 2660 2661 /// If the operation has ODS information, we can more accurately verify the 2662 /// values. 2663 if (odsOpLoc) { 2664 if (odsValues.size() != values.size()) { 2665 return emitErrorAndNote( 2666 loc, 2667 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2668 "{2}, but got {3}", 2669 groupName, *name, odsValues.size(), values.size()), 2670 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2671 } 2672 auto diagFn = [&](ast::Diagnostic &diag) { 2673 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2674 *odsOpLoc); 2675 }; 2676 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2677 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2678 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2679 return failure(); 2680 } 2681 return success(); 2682 } 2683 2684 // Otherwise, accept the value groups as they have been defined and just 2685 // ensure they are one of the expected types. 2686 for (ast::Expr *&valueExpr : values) { 2687 ast::Type valueExprType = valueExpr->getType(); 2688 2689 // Check if this is one of the expected types. 2690 if (valueExprType == rangeTy || valueExprType == singleTy) 2691 continue; 2692 2693 // If the operand is an Operation, allow converting to a Value or 2694 // ValueRange. This situations arises quite often with nested operation 2695 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2696 if (singleTy == valueTy) { 2697 if (valueExprType.isa<ast::OperationType>()) { 2698 valueExpr = convertOpToValue(valueExpr); 2699 continue; 2700 } 2701 } 2702 2703 return emitError( 2704 valueExpr->getLoc(), 2705 llvm::formatv( 2706 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2707 singleTy, rangeTy, valueExprType)); 2708 } 2709 return success(); 2710 } 2711 2712 FailureOr<ast::TupleExpr *> 2713 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2714 ArrayRef<StringRef> elementNames) { 2715 for (const ast::Expr *element : elements) { 2716 ast::Type eleTy = element->getType(); 2717 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2718 return emitError( 2719 element->getLoc(), 2720 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2721 } 2722 } 2723 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2724 } 2725 2726 //===----------------------------------------------------------------------===// 2727 // Stmts 2728 2729 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2730 ast::Expr *rootOp) { 2731 // Check that root is an Operation. 2732 ast::Type rootType = rootOp->getType(); 2733 if (!rootType.isa<ast::OperationType>()) 2734 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2735 2736 return ast::EraseStmt::create(ctx, loc, rootOp); 2737 } 2738 2739 FailureOr<ast::ReplaceStmt *> 2740 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2741 MutableArrayRef<ast::Expr *> replValues) { 2742 // Check that root is an Operation. 2743 ast::Type rootType = rootOp->getType(); 2744 if (!rootType.isa<ast::OperationType>()) { 2745 return emitError( 2746 rootOp->getLoc(), 2747 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2748 } 2749 2750 // If there are multiple replacement values, we implicitly convert any Op 2751 // expressions to the value form. 2752 bool shouldConvertOpToValues = replValues.size() > 1; 2753 for (ast::Expr *&replExpr : replValues) { 2754 ast::Type replType = replExpr->getType(); 2755 2756 // Check that replExpr is an Operation, Value, or ValueRange. 2757 if (replType.isa<ast::OperationType>()) { 2758 if (shouldConvertOpToValues) 2759 replExpr = convertOpToValue(replExpr); 2760 continue; 2761 } 2762 2763 if (replType != valueTy && replType != valueRangeTy) { 2764 return emitError(replExpr->getLoc(), 2765 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2766 "expression, but got `{0}`", 2767 replType)); 2768 } 2769 } 2770 2771 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2772 } 2773 2774 FailureOr<ast::RewriteStmt *> 2775 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2776 ast::CompoundStmt *rewriteBody) { 2777 // Check that root is an Operation. 2778 ast::Type rootType = rootOp->getType(); 2779 if (!rootType.isa<ast::OperationType>()) { 2780 return emitError( 2781 rootOp->getLoc(), 2782 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2783 } 2784 2785 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2786 } 2787 2788 //===----------------------------------------------------------------------===// 2789 // Parser 2790 //===----------------------------------------------------------------------===// 2791 2792 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 2793 llvm::SourceMgr &sourceMgr) { 2794 Parser parser(ctx, sourceMgr); 2795 return parser.parseModule(); 2796 } 2797