1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 /// The current specification context of an operations result type. This 80 /// indicates how the result types of an operation may be inferred. 81 enum class OpResultTypeContext { 82 /// The result types of the operation are not known to be inferred. 83 Explicit, 84 /// The result types of the operation are inferred from the root input of a 85 /// `replace` statement. 86 Replacement, 87 /// The result types of the operation are inferred by using the 88 /// `InferTypeOpInterface` interface provided by the operation. 89 Interface, 90 }; 91 92 //===--------------------------------------------------------------------===// 93 // Parsing 94 //===--------------------------------------------------------------------===// 95 96 /// Push a new decl scope onto the lexer. 97 ast::DeclScope *pushDeclScope() { 98 ast::DeclScope *newScope = 99 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 100 return (curDeclScope = newScope); 101 } 102 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 103 104 /// Pop the last decl scope from the lexer. 105 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 106 107 /// Parse the body of an AST module. 108 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 109 110 /// Try to convert the given expression to `type`. Returns failure and emits 111 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 112 /// invoked to attach notes to the emitted error diagnostic. On success, 113 /// `expr` is updated to the expression used to convert to `type`. 114 LogicalResult convertExpressionTo( 115 ast::Expr *&expr, ast::Type type, 116 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 117 118 /// Given an operation expression, convert it to a Value or ValueRange 119 /// typed expression. 120 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 121 122 /// Lookup ODS information for the given operation, returns nullptr if no 123 /// information is found. 124 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 125 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 126 } 127 128 //===--------------------------------------------------------------------===// 129 // Directives 130 131 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 132 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 133 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 134 SmallVectorImpl<ast::Decl *> &decls); 135 136 /// Process the records of a parsed tablegen include file. 137 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 138 SmallVectorImpl<ast::Decl *> &decls); 139 140 /// Create a user defined native constraint for a constraint imported from 141 /// ODS. 142 template <typename ConstraintT> 143 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 144 StringRef codeBlock, SMRange loc, 145 ast::Type type, 146 StringRef nativeType); 147 template <typename ConstraintT> 148 ast::Decl * 149 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 150 SMRange loc, ast::Type type, 151 StringRef nativeType); 152 153 //===--------------------------------------------------------------------===// 154 // Decls 155 156 /// This structure contains the set of pattern metadata that may be parsed. 157 struct ParsedPatternMetadata { 158 Optional<uint16_t> benefit; 159 bool hasBoundedRecursion = false; 160 }; 161 162 FailureOr<ast::Decl *> parseTopLevelDecl(); 163 FailureOr<ast::NamedAttributeDecl *> 164 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 165 166 /// Parse an argument variable as part of the signature of a 167 /// UserConstraintDecl or UserRewriteDecl. 168 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 169 170 /// Parse a result variable as part of the signature of a UserConstraintDecl 171 /// or UserRewriteDecl. 172 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 173 174 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 175 /// defined in a non-global context. 176 FailureOr<ast::UserConstraintDecl *> 177 parseUserConstraintDecl(bool isInline = false); 178 179 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Constraint/etc. 181 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 191 /// defined in a non-global context. 192 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 193 194 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 195 /// non-global context, such as within a Pattern/Rewrite/etc. 196 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 197 198 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 199 /// PDLL constructs. 200 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 201 const ast::Name &name, bool isInline, 202 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 203 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 204 205 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 206 /// effectively the same syntax, and only differ on slight semantics (given 207 /// the different parsing contexts). 208 template <typename T, typename ParseUserPDLLDeclFnT> 209 FailureOr<T *> parseUserConstraintOrRewriteDecl( 210 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 211 StringRef anonymousNamePrefix, bool isInline); 212 213 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 214 /// These decls have effectively the same syntax. 215 template <typename T> 216 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 217 const ast::Name &name, bool isInline, 218 ArrayRef<ast::VariableDecl *> arguments, 219 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 220 221 /// Parse the functional signature (i.e. the arguments and results) of a 222 /// UserConstraintDecl or UserRewriteDecl. 223 LogicalResult parseUserConstraintOrRewriteSignature( 224 SmallVectorImpl<ast::VariableDecl *> &arguments, 225 SmallVectorImpl<ast::VariableDecl *> &results, 226 ast::DeclScope *&argumentScope, ast::Type &resultType); 227 228 /// Validate the return (which if present is specified by bodyIt) of a 229 /// UserConstraintDecl or UserRewriteDecl. 230 LogicalResult validateUserConstraintOrRewriteReturn( 231 StringRef declType, ast::CompoundStmt *body, 232 ArrayRef<ast::Stmt *>::iterator bodyIt, 233 ArrayRef<ast::Stmt *>::iterator bodyE, 234 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 235 236 FailureOr<ast::CompoundStmt *> 237 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 238 bool expectTerminalSemicolon = true); 239 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 240 FailureOr<ast::Decl *> parsePatternDecl(); 241 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 242 243 /// Check to see if a decl has already been defined with the given name, if 244 /// one has emit and error and return failure. Returns success otherwise. 245 LogicalResult checkDefineNamedDecl(const ast::Name &name); 246 247 /// Try to define a variable decl with the given components, returns the 248 /// variable on success. 249 FailureOr<ast::VariableDecl *> 250 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 251 ast::Expr *initExpr, 252 ArrayRef<ast::ConstraintRef> constraints); 253 FailureOr<ast::VariableDecl *> 254 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 255 ArrayRef<ast::ConstraintRef> constraints); 256 257 /// Parse the constraint reference list for a variable decl. 258 LogicalResult parseVariableDeclConstraintList( 259 SmallVectorImpl<ast::ConstraintRef> &constraints); 260 261 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 262 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 263 264 /// Try to parse a single reference to a constraint. `typeConstraint` is the 265 /// location of a previously parsed type constraint for the entity that will 266 /// be constrained by the parsed constraint. `existingConstraints` are any 267 /// existing constraints that have already been parsed for the same entity 268 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 269 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 270 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 271 /// constraints) may be used with the variable. 272 FailureOr<ast::ConstraintRef> 273 parseConstraint(Optional<SMRange> &typeConstraint, 274 ArrayRef<ast::ConstraintRef> existingConstraints, 275 bool allowInlineTypeConstraints, 276 bool allowNonCoreConstraints); 277 278 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 279 /// argument or result variable. The constraints for these variables do not 280 /// allow inline type constraints, and only permit a single constraint. 281 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 282 283 //===--------------------------------------------------------------------===// 284 // Exprs 285 286 FailureOr<ast::Expr *> parseExpr(); 287 288 /// Identifier expressions. 289 FailureOr<ast::Expr *> parseAttributeExpr(); 290 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 291 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 292 FailureOr<ast::Expr *> parseIdentifierExpr(); 293 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 294 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 295 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 296 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 297 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 298 FailureOr<ast::Expr *> 299 parseOperationExpr(OpResultTypeContext inputResultTypeContext = 300 OpResultTypeContext::Explicit); 301 FailureOr<ast::Expr *> parseTupleExpr(); 302 FailureOr<ast::Expr *> parseTypeExpr(); 303 FailureOr<ast::Expr *> parseUnderscoreExpr(); 304 305 //===--------------------------------------------------------------------===// 306 // Stmts 307 308 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 309 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 310 FailureOr<ast::EraseStmt *> parseEraseStmt(); 311 FailureOr<ast::LetStmt *> parseLetStmt(); 312 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 313 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 314 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 315 316 //===--------------------------------------------------------------------===// 317 // Creation+Analysis 318 //===--------------------------------------------------------------------===// 319 320 //===--------------------------------------------------------------------===// 321 // Decls 322 323 /// Try to extract a callable from the given AST node. Returns nullptr on 324 /// failure. 325 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 326 327 /// Try to create a pattern decl with the given components, returning the 328 /// Pattern on success. 329 FailureOr<ast::PatternDecl *> 330 createPatternDecl(SMRange loc, const ast::Name *name, 331 const ParsedPatternMetadata &metadata, 332 ast::CompoundStmt *body); 333 334 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 335 /// of results, defined as part of the signature. 336 ast::Type 337 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 338 339 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 340 template <typename T> 341 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 342 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 343 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 344 ast::CompoundStmt *body); 345 346 /// Try to create a variable decl with the given components, returning the 347 /// Variable on success. 348 FailureOr<ast::VariableDecl *> 349 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 350 ArrayRef<ast::ConstraintRef> constraints); 351 352 /// Create a variable for an argument or result defined as part of the 353 /// signature of a UserConstraintDecl/UserRewriteDecl. 354 FailureOr<ast::VariableDecl *> 355 createArgOrResultVariableDecl(StringRef name, SMRange loc, 356 const ast::ConstraintRef &constraint); 357 358 /// Validate the constraints used to constraint a variable decl. 359 /// `inferredType` is the type of the variable inferred by the constraints 360 /// within the list, and is updated to the most refined type as determined by 361 /// the constraints. Returns success if the constraint list is valid, failure 362 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 363 /// defined constraints) may be used with the variable. 364 LogicalResult 365 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 366 ast::Type &inferredType, 367 bool allowNonCoreConstraints = true); 368 /// Validate a single reference to a constraint. `inferredType` contains the 369 /// currently inferred variabled type and is refined within the type defined 370 /// by the constraint. Returns success if the constraint is valid, failure 371 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 372 /// defined constraints) may be used with the variable. 373 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 374 ast::Type &inferredType, 375 bool allowNonCoreConstraints = true); 376 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 377 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 378 379 //===--------------------------------------------------------------------===// 380 // Exprs 381 382 FailureOr<ast::CallExpr *> 383 createCallExpr(SMRange loc, ast::Expr *parentExpr, 384 MutableArrayRef<ast::Expr *> arguments); 385 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 386 FailureOr<ast::DeclRefExpr *> 387 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 388 ArrayRef<ast::ConstraintRef> constraints); 389 FailureOr<ast::MemberAccessExpr *> 390 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 391 392 /// Validate the member access `name` into the given parent expression. On 393 /// success, this also returns the type of the member accessed. 394 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 395 StringRef name, SMRange loc); 396 FailureOr<ast::OperationExpr *> 397 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 398 OpResultTypeContext resultTypeContext, 399 MutableArrayRef<ast::Expr *> operands, 400 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 401 MutableArrayRef<ast::Expr *> results); 402 LogicalResult 403 validateOperationOperands(SMRange loc, Optional<StringRef> name, 404 const ods::Operation *odsOp, 405 MutableArrayRef<ast::Expr *> operands); 406 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 407 const ods::Operation *odsOp, 408 MutableArrayRef<ast::Expr *> results); 409 void checkOperationResultTypeInferrence(SMRange loc, StringRef name, 410 const ods::Operation *odsOp); 411 LogicalResult validateOperationOperandsOrResults( 412 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 413 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 414 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 415 ast::Type rangeTy); 416 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 417 ArrayRef<ast::Expr *> elements, 418 ArrayRef<StringRef> elementNames); 419 420 //===--------------------------------------------------------------------===// 421 // Stmts 422 423 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 424 FailureOr<ast::ReplaceStmt *> 425 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 426 MutableArrayRef<ast::Expr *> replValues); 427 FailureOr<ast::RewriteStmt *> 428 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 429 ast::CompoundStmt *rewriteBody); 430 431 //===--------------------------------------------------------------------===// 432 // Code Completion 433 //===--------------------------------------------------------------------===// 434 435 /// The set of various code completion methods. Every completion method 436 /// returns `failure` to stop the parsing process after providing completion 437 /// results. 438 439 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 440 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 441 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 442 bool allowNonCoreConstraints, 443 bool allowInlineTypeConstraints); 444 LogicalResult codeCompleteDialectName(); 445 LogicalResult codeCompleteOperationName(StringRef dialectName); 446 LogicalResult codeCompletePatternMetadata(); 447 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 448 449 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 450 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 451 unsigned currentNumOperands); 452 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 453 unsigned currentNumResults); 454 455 //===--------------------------------------------------------------------===// 456 // Lexer Utilities 457 //===--------------------------------------------------------------------===// 458 459 /// If the current token has the specified kind, consume it and return true. 460 /// If not, return false. 461 bool consumeIf(Token::Kind kind) { 462 if (curToken.isNot(kind)) 463 return false; 464 consumeToken(kind); 465 return true; 466 } 467 468 /// Advance the current lexer onto the next token. 469 void consumeToken() { 470 assert(curToken.isNot(Token::eof, Token::error) && 471 "shouldn't advance past EOF or errors"); 472 curToken = lexer.lexToken(); 473 } 474 475 /// Advance the current lexer onto the next token, asserting what the expected 476 /// current token is. This is preferred to the above method because it leads 477 /// to more self-documenting code with better checking. 478 void consumeToken(Token::Kind kind) { 479 assert(curToken.is(kind) && "consumed an unexpected token"); 480 consumeToken(); 481 } 482 483 /// Reset the lexer to the location at the given position. 484 void resetToken(SMRange tokLoc) { 485 lexer.resetPointer(tokLoc.Start.getPointer()); 486 curToken = lexer.lexToken(); 487 } 488 489 /// Consume the specified token if present and return success. On failure, 490 /// output a diagnostic and return failure. 491 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 492 if (curToken.getKind() != kind) 493 return emitError(curToken.getLoc(), msg); 494 consumeToken(); 495 return success(); 496 } 497 LogicalResult emitError(SMRange loc, const Twine &msg) { 498 lexer.emitError(loc, msg); 499 return failure(); 500 } 501 LogicalResult emitError(const Twine &msg) { 502 return emitError(curToken.getLoc(), msg); 503 } 504 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 505 const Twine ¬e) { 506 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 507 return failure(); 508 } 509 510 //===--------------------------------------------------------------------===// 511 // Fields 512 //===--------------------------------------------------------------------===// 513 514 /// The owning AST context. 515 ast::Context &ctx; 516 517 /// The lexer of this parser. 518 Lexer lexer; 519 520 /// The current token within the lexer. 521 Token curToken; 522 523 /// The most recently defined decl scope. 524 ast::DeclScope *curDeclScope = nullptr; 525 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 526 527 /// The current context of the parser. 528 ParserContext parserContext = ParserContext::Global; 529 530 /// Cached types to simplify verification and expression creation. 531 ast::Type valueTy, valueRangeTy; 532 ast::Type typeTy, typeRangeTy; 533 ast::Type attrTy; 534 535 /// A counter used when naming anonymous constraints and rewrites. 536 unsigned anonymousDeclNameCounter = 0; 537 538 /// The optional code completion context. 539 CodeCompleteContext *codeCompleteContext; 540 }; 541 } // namespace 542 543 FailureOr<ast::Module *> Parser::parseModule() { 544 SMLoc moduleLoc = curToken.getStartLoc(); 545 pushDeclScope(); 546 547 // Parse the top-level decls of the module. 548 SmallVector<ast::Decl *> decls; 549 if (failed(parseModuleBody(decls))) 550 return popDeclScope(), failure(); 551 552 popDeclScope(); 553 return ast::Module::create(ctx, moduleLoc, decls); 554 } 555 556 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 557 while (curToken.isNot(Token::eof)) { 558 if (curToken.is(Token::directive)) { 559 if (failed(parseDirective(decls))) 560 return failure(); 561 continue; 562 } 563 564 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 565 if (failed(decl)) 566 return failure(); 567 decls.push_back(*decl); 568 } 569 return success(); 570 } 571 572 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 573 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 574 valueRangeTy); 575 } 576 577 LogicalResult Parser::convertExpressionTo( 578 ast::Expr *&expr, ast::Type type, 579 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 580 ast::Type exprType = expr->getType(); 581 if (exprType == type) 582 return success(); 583 584 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 585 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 586 expr->getLoc(), llvm::formatv("unable to convert expression of type " 587 "`{0}` to the expected type of " 588 "`{1}`", 589 exprType, type)); 590 if (noteAttachFn) 591 noteAttachFn(*diag); 592 return diag; 593 }; 594 595 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 596 // Two operation types are compatible if they have the same name, or if the 597 // expected type is more general. 598 if (auto opType = type.dyn_cast<ast::OperationType>()) { 599 if (opType.getName()) 600 return emitConvertError(); 601 return success(); 602 } 603 604 // An operation can always convert to a ValueRange. 605 if (type == valueRangeTy) { 606 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 607 valueRangeTy); 608 return success(); 609 } 610 611 // Allow conversion to a single value by constraining the result range. 612 if (type == valueTy) { 613 // If the operation is registered, we can verify if it can ever have a 614 // single result. 615 if (const ods::Operation *odsOp = exprOpType.getODSOperation()) { 616 if (odsOp->getResults().empty()) { 617 return emitConvertError()->attachNote( 618 llvm::formatv("see the definition of `{0}`, which was defined " 619 "with zero results", 620 odsOp->getName()), 621 odsOp->getLoc()); 622 } 623 624 unsigned numSingleResults = llvm::count_if( 625 odsOp->getResults(), [](const ods::OperandOrResult &result) { 626 return result.getVariableLengthKind() == 627 ods::VariableLengthKind::Single; 628 }); 629 if (numSingleResults > 1) { 630 return emitConvertError()->attachNote( 631 llvm::formatv("see the definition of `{0}`, which was defined " 632 "with at least {1} results", 633 odsOp->getName(), numSingleResults), 634 odsOp->getLoc()); 635 } 636 } 637 638 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 639 valueTy); 640 return success(); 641 } 642 return emitConvertError(); 643 } 644 645 // FIXME: Decide how to allow/support converting a single result to multiple, 646 // and multiple to a single result. For now, we just allow Single->Range, 647 // but this isn't something really supported in the PDL dialect. We should 648 // figure out some way to support both. 649 if ((exprType == valueTy || exprType == valueRangeTy) && 650 (type == valueTy || type == valueRangeTy)) 651 return success(); 652 if ((exprType == typeTy || exprType == typeRangeTy) && 653 (type == typeTy || type == typeRangeTy)) 654 return success(); 655 656 // Handle tuple types. 657 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 658 auto tupleType = type.dyn_cast<ast::TupleType>(); 659 if (!tupleType || tupleType.size() != exprTupleType.size()) 660 return emitConvertError(); 661 662 // Build a new tuple expression using each of the elements of the current 663 // tuple. 664 SmallVector<ast::Expr *> newExprs; 665 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 666 newExprs.push_back(ast::MemberAccessExpr::create( 667 ctx, expr->getLoc(), expr, llvm::to_string(i), 668 exprTupleType.getElementTypes()[i])); 669 670 auto diagFn = [&](ast::Diagnostic &diag) { 671 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 672 i, exprTupleType)); 673 if (noteAttachFn) 674 noteAttachFn(diag); 675 }; 676 if (failed(convertExpressionTo(newExprs.back(), 677 tupleType.getElementTypes()[i], diagFn))) 678 return failure(); 679 } 680 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 681 tupleType.getElementNames()); 682 return success(); 683 } 684 685 return emitConvertError(); 686 } 687 688 //===----------------------------------------------------------------------===// 689 // Directives 690 691 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 692 StringRef directive = curToken.getSpelling(); 693 if (directive == "#include") 694 return parseInclude(decls); 695 696 return emitError("unknown directive `" + directive + "`"); 697 } 698 699 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 700 SMRange loc = curToken.getLoc(); 701 consumeToken(Token::directive); 702 703 // Handle code completion of the include file path. 704 if (curToken.is(Token::code_complete_string)) 705 return codeCompleteIncludeFilename(curToken.getStringValue()); 706 707 // Parse the file being included. 708 if (!curToken.isString()) 709 return emitError(loc, 710 "expected string file name after `include` directive"); 711 SMRange fileLoc = curToken.getLoc(); 712 std::string filenameStr = curToken.getStringValue(); 713 StringRef filename = filenameStr; 714 consumeToken(); 715 716 // Check the type of include. If ending with `.pdll`, this is another pdl file 717 // to be parsed along with the current module. 718 if (filename.endswith(".pdll")) { 719 if (failed(lexer.pushInclude(filename, fileLoc))) 720 return emitError(fileLoc, 721 "unable to open include file `" + filename + "`"); 722 723 // If we added the include successfully, parse it into the current module. 724 // Make sure to update to the next token after we finish parsing the nested 725 // file. 726 curToken = lexer.lexToken(); 727 LogicalResult result = parseModuleBody(decls); 728 curToken = lexer.lexToken(); 729 return result; 730 } 731 732 // Otherwise, this must be a `.td` include. 733 if (filename.endswith(".td")) 734 return parseTdInclude(filename, fileLoc, decls); 735 736 return emitError(fileLoc, 737 "expected include filename to end with `.pdll` or `.td`"); 738 } 739 740 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 741 SmallVectorImpl<ast::Decl *> &decls) { 742 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 743 744 // Use the source manager to open the file, but don't yet add it. 745 std::string includedFile; 746 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 747 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 748 if (!includeBuffer) 749 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 750 751 // Setup the source manager for parsing the tablegen file. 752 llvm::SourceMgr tdSrcMgr; 753 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 754 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 755 756 // This class provides a context argument for the llvm::SourceMgr diagnostic 757 // handler. 758 struct DiagHandlerContext { 759 Parser &parser; 760 StringRef filename; 761 llvm::SMRange loc; 762 } handlerContext{*this, filename, fileLoc}; 763 764 // Set the diagnostic handler for the tablegen source manager. 765 tdSrcMgr.setDiagHandler( 766 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 767 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 768 (void)ctx->parser.emitError( 769 ctx->loc, 770 llvm::formatv("error while processing include file `{0}`: {1}", 771 ctx->filename, diag.getMessage())); 772 }, 773 &handlerContext); 774 775 // Parse the tablegen file. 776 llvm::RecordKeeper tdRecords; 777 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 778 return failure(); 779 780 // Process the parsed records. 781 processTdIncludeRecords(tdRecords, decls); 782 783 // After we are done processing, move all of the tablegen source buffers to 784 // the main parser source mgr. This allows for directly using source locations 785 // from the .td files without needing to remap them. 786 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 787 return success(); 788 } 789 790 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 791 SmallVectorImpl<ast::Decl *> &decls) { 792 // Return the length kind of the given value. 793 auto getLengthKind = [](const auto &value) { 794 if (value.isOptional()) 795 return ods::VariableLengthKind::Optional; 796 return value.isVariadic() ? ods::VariableLengthKind::Variadic 797 : ods::VariableLengthKind::Single; 798 }; 799 800 // Insert a type constraint into the ODS context. 801 ods::Context &odsContext = ctx.getODSContext(); 802 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 803 -> const ods::TypeConstraint & { 804 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 805 cst.constraint.getSummary(), 806 cst.constraint.getCPPClassName()); 807 }; 808 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 809 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 810 }; 811 812 // Process the parsed tablegen records to build ODS information. 813 /// Operations. 814 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 815 tblgen::Operator op(def); 816 817 // Check to see if this operation is known to support type inferrence. 818 bool supportsResultTypeInferrence = 819 op.getTrait("::mlir::InferTypeOpInterface::Trait"); 820 821 bool inserted = false; 822 ods::Operation *odsOp = nullptr; 823 std::tie(odsOp, inserted) = odsContext.insertOperation( 824 op.getOperationName(), op.getSummary(), op.getDescription(), 825 op.getQualCppClassName(), supportsResultTypeInferrence, 826 op.getLoc().front()); 827 828 // Ignore operations that have already been added. 829 if (!inserted) 830 continue; 831 832 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 833 odsOp->appendAttribute( 834 attr.name, attr.attr.isOptional(), 835 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 836 attr.attr.getSummary(), 837 attr.attr.getStorageType())); 838 } 839 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 840 odsOp->appendOperand(operand.name, getLengthKind(operand), 841 addTypeConstraint(operand)); 842 } 843 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 844 odsOp->appendResult(result.name, getLengthKind(result), 845 addTypeConstraint(result)); 846 } 847 } 848 /// Attr constraints. 849 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 850 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 851 tblgen::Attribute constraint(def); 852 decls.push_back( 853 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 854 constraint, convertLocToRange(def->getLoc().front()), attrTy, 855 constraint.getStorageType())); 856 } 857 } 858 /// Type constraints. 859 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 860 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 861 tblgen::TypeConstraint constraint(def); 862 decls.push_back( 863 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 864 constraint, convertLocToRange(def->getLoc().front()), typeTy, 865 constraint.getCPPClassName())); 866 } 867 } 868 /// Interfaces. 869 ast::Type opTy = ast::OperationType::get(ctx); 870 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 871 StringRef name = def->getName(); 872 if (def->isAnonymous() || curDeclScope->lookup(name) || 873 def->isSubClassOf("DeclareInterfaceMethods")) 874 continue; 875 SMRange loc = convertLocToRange(def->getLoc().front()); 876 877 std::string cppClassName = 878 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), 879 def->getValueAsString("cppClassName")) 880 .str(); 881 std::string codeBlock = 882 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", 883 cppClassName) 884 .str(); 885 886 if (def->isSubClassOf("OpInterface")) { 887 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 888 name, codeBlock, loc, opTy, cppClassName)); 889 } else if (def->isSubClassOf("AttrInterface")) { 890 decls.push_back( 891 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 892 name, codeBlock, loc, attrTy, cppClassName)); 893 } else if (def->isSubClassOf("TypeInterface")) { 894 decls.push_back( 895 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 896 name, codeBlock, loc, typeTy, cppClassName)); 897 } 898 } 899 } 900 901 template <typename ConstraintT> 902 ast::Decl * 903 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 904 SMRange loc, ast::Type type, 905 StringRef nativeType) { 906 // Build the single input parameter. 907 ast::DeclScope *argScope = pushDeclScope(); 908 auto *paramVar = ast::VariableDecl::create( 909 ctx, ast::Name::create(ctx, "self", loc), type, 910 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 911 argScope->add(paramVar); 912 popDeclScope(); 913 914 // Build the native constraint. 915 auto *constraintDecl = ast::UserConstraintDecl::createNative( 916 ctx, ast::Name::create(ctx, name, loc), paramVar, 917 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType); 918 curDeclScope->add(constraintDecl); 919 return constraintDecl; 920 } 921 922 template <typename ConstraintT> 923 ast::Decl * 924 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 925 SMRange loc, ast::Type type, 926 StringRef nativeType) { 927 // Format the condition template. 928 tblgen::FmtContext fmtContext; 929 fmtContext.withSelf("self"); 930 std::string codeBlock = tblgen::tgfmt( 931 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 932 &fmtContext); 933 934 return createODSNativePDLLConstraintDecl<ConstraintT>( 935 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType); 936 } 937 938 //===----------------------------------------------------------------------===// 939 // Decls 940 941 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 942 FailureOr<ast::Decl *> decl; 943 switch (curToken.getKind()) { 944 case Token::kw_Constraint: 945 decl = parseUserConstraintDecl(); 946 break; 947 case Token::kw_Pattern: 948 decl = parsePatternDecl(); 949 break; 950 case Token::kw_Rewrite: 951 decl = parseUserRewriteDecl(); 952 break; 953 default: 954 return emitError("expected top-level declaration, such as a `Pattern`"); 955 } 956 if (failed(decl)) 957 return failure(); 958 959 // If the decl has a name, add it to the current scope. 960 if (const ast::Name *name = (*decl)->getName()) { 961 if (failed(checkDefineNamedDecl(*name))) 962 return failure(); 963 curDeclScope->add(*decl); 964 } 965 return decl; 966 } 967 968 FailureOr<ast::NamedAttributeDecl *> 969 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 970 // Check for name code completion. 971 if (curToken.is(Token::code_complete)) 972 return codeCompleteAttributeName(parentOpName); 973 974 std::string attrNameStr; 975 if (curToken.isString()) 976 attrNameStr = curToken.getStringValue(); 977 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 978 attrNameStr = curToken.getSpelling().str(); 979 else 980 return emitError("expected identifier or string attribute name"); 981 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 982 consumeToken(); 983 984 // Check for a value of the attribute. 985 ast::Expr *attrValue = nullptr; 986 if (consumeIf(Token::equal)) { 987 FailureOr<ast::Expr *> attrExpr = parseExpr(); 988 if (failed(attrExpr)) 989 return failure(); 990 attrValue = *attrExpr; 991 } else { 992 // If there isn't a concrete value, create an expression representing a 993 // UnitAttr. 994 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 995 } 996 997 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 998 } 999 1000 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 1001 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 1002 bool expectTerminalSemicolon) { 1003 consumeToken(Token::equal_arrow); 1004 1005 // Parse the single statement of the lambda body. 1006 SMLoc bodyStartLoc = curToken.getStartLoc(); 1007 pushDeclScope(); 1008 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 1009 bool failedToParse = 1010 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 1011 popDeclScope(); 1012 if (failedToParse) 1013 return failure(); 1014 1015 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 1016 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 1017 } 1018 1019 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 1020 // Ensure that the argument is named. 1021 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 1022 return emitError("expected identifier argument name"); 1023 1024 // Parse the argument similarly to a normal variable. 1025 StringRef name = curToken.getSpelling(); 1026 SMRange nameLoc = curToken.getLoc(); 1027 consumeToken(); 1028 1029 if (failed( 1030 parseToken(Token::colon, "expected `:` before argument constraint"))) 1031 return failure(); 1032 1033 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1034 if (failed(cst)) 1035 return failure(); 1036 1037 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1038 } 1039 1040 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1041 // Check to see if this result is named. 1042 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1043 // Check to see if this name actually refers to a Constraint. 1044 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1045 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1046 // If yes, and this is a Rewrite, give a nice error message as non-Core 1047 // constraints are not supported on Rewrite results. 1048 if (parserContext == ParserContext::Rewrite) { 1049 return emitError( 1050 "`Rewrite` results are only permitted to use core constraints, " 1051 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1052 } 1053 1054 // Otherwise, parse this as an unnamed result variable. 1055 } else { 1056 // If it wasn't a constraint, parse the result similarly to a variable. If 1057 // there is already an existing decl, we will emit an error when defining 1058 // this variable later. 1059 StringRef name = curToken.getSpelling(); 1060 SMRange nameLoc = curToken.getLoc(); 1061 consumeToken(); 1062 1063 if (failed(parseToken(Token::colon, 1064 "expected `:` before result constraint"))) 1065 return failure(); 1066 1067 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1068 if (failed(cst)) 1069 return failure(); 1070 1071 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1072 } 1073 } 1074 1075 // If it isn't named, we parse the constraint directly and create an unnamed 1076 // result variable. 1077 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1078 if (failed(cst)) 1079 return failure(); 1080 1081 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1082 } 1083 1084 FailureOr<ast::UserConstraintDecl *> 1085 Parser::parseUserConstraintDecl(bool isInline) { 1086 // Constraints and rewrites have very similar formats, dispatch to a shared 1087 // interface for parsing. 1088 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1089 [&](auto &&...args) { 1090 return this->parseUserPDLLConstraintDecl(args...); 1091 }, 1092 ParserContext::Constraint, "constraint", isInline); 1093 } 1094 1095 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1096 FailureOr<ast::UserConstraintDecl *> decl = 1097 parseUserConstraintDecl(/*isInline=*/true); 1098 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1099 return failure(); 1100 1101 curDeclScope->add(*decl); 1102 return decl; 1103 } 1104 1105 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1106 const ast::Name &name, bool isInline, 1107 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1108 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1109 // Push the argument scope back onto the list, so that the body can 1110 // reference arguments. 1111 pushDeclScope(argumentScope); 1112 1113 // Parse the body of the constraint. The body is either defined as a compound 1114 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1115 ast::CompoundStmt *body; 1116 if (curToken.is(Token::equal_arrow)) { 1117 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1118 [&](ast::Stmt *&stmt) -> LogicalResult { 1119 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1120 if (!stmtExpr) { 1121 return emitError(stmt->getLoc(), 1122 "expected `Constraint` lambda body to contain a " 1123 "single expression"); 1124 } 1125 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1126 return success(); 1127 }, 1128 /*expectTerminalSemicolon=*/!isInline); 1129 if (failed(bodyResult)) 1130 return failure(); 1131 body = *bodyResult; 1132 } else { 1133 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1134 if (failed(bodyResult)) 1135 return failure(); 1136 body = *bodyResult; 1137 1138 // Verify the structure of the body. 1139 auto bodyIt = body->begin(), bodyE = body->end(); 1140 for (; bodyIt != bodyE; ++bodyIt) 1141 if (isa<ast::ReturnStmt>(*bodyIt)) 1142 break; 1143 if (failed(validateUserConstraintOrRewriteReturn( 1144 "Constraint", body, bodyIt, bodyE, results, resultType))) 1145 return failure(); 1146 } 1147 popDeclScope(); 1148 1149 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1150 name, arguments, results, resultType, body); 1151 } 1152 1153 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1154 // Constraints and rewrites have very similar formats, dispatch to a shared 1155 // interface for parsing. 1156 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1157 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1158 ParserContext::Rewrite, "rewrite", isInline); 1159 } 1160 1161 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1162 FailureOr<ast::UserRewriteDecl *> decl = 1163 parseUserRewriteDecl(/*isInline=*/true); 1164 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1165 return failure(); 1166 1167 curDeclScope->add(*decl); 1168 return decl; 1169 } 1170 1171 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1172 const ast::Name &name, bool isInline, 1173 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1174 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1175 // Push the argument scope back onto the list, so that the body can 1176 // reference arguments. 1177 curDeclScope = argumentScope; 1178 ast::CompoundStmt *body; 1179 if (curToken.is(Token::equal_arrow)) { 1180 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1181 [&](ast::Stmt *&statement) -> LogicalResult { 1182 if (isa<ast::OpRewriteStmt>(statement)) 1183 return success(); 1184 1185 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1186 if (!statementExpr) { 1187 return emitError( 1188 statement->getLoc(), 1189 "expected `Rewrite` lambda body to contain a single expression " 1190 "or an operation rewrite statement; such as `erase`, " 1191 "`replace`, or `rewrite`"); 1192 } 1193 statement = 1194 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1195 return success(); 1196 }, 1197 /*expectTerminalSemicolon=*/!isInline); 1198 if (failed(bodyResult)) 1199 return failure(); 1200 body = *bodyResult; 1201 } else { 1202 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1203 if (failed(bodyResult)) 1204 return failure(); 1205 body = *bodyResult; 1206 } 1207 popDeclScope(); 1208 1209 // Verify the structure of the body. 1210 auto bodyIt = body->begin(), bodyE = body->end(); 1211 for (; bodyIt != bodyE; ++bodyIt) 1212 if (isa<ast::ReturnStmt>(*bodyIt)) 1213 break; 1214 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1215 bodyE, results, resultType))) 1216 return failure(); 1217 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1218 name, arguments, results, resultType, body); 1219 } 1220 1221 template <typename T, typename ParseUserPDLLDeclFnT> 1222 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1223 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1224 StringRef anonymousNamePrefix, bool isInline) { 1225 SMRange loc = curToken.getLoc(); 1226 consumeToken(); 1227 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1228 1229 // Parse the name of the decl. 1230 const ast::Name *name = nullptr; 1231 if (curToken.isNot(Token::identifier)) { 1232 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1233 // in C++, so being unnamed is fine. 1234 if (!isInline) 1235 return emitError("expected identifier name"); 1236 1237 // Create a unique anonymous name to use, as the name for this decl is not 1238 // important. 1239 std::string anonName = 1240 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1241 anonymousDeclNameCounter++) 1242 .str(); 1243 name = &ast::Name::create(ctx, anonName, loc); 1244 } else { 1245 // If a name was provided, we can use it directly. 1246 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1247 consumeToken(Token::identifier); 1248 } 1249 1250 // Parse the functional signature of the decl. 1251 SmallVector<ast::VariableDecl *> arguments, results; 1252 ast::DeclScope *argumentScope; 1253 ast::Type resultType; 1254 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1255 argumentScope, resultType))) 1256 return failure(); 1257 1258 // Check to see which type of constraint this is. If the constraint contains a 1259 // compound body, this is a PDLL decl. 1260 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1261 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1262 resultType); 1263 1264 // Otherwise, this is a native decl. 1265 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1266 results, resultType); 1267 } 1268 1269 template <typename T> 1270 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1271 const ast::Name &name, bool isInline, 1272 ArrayRef<ast::VariableDecl *> arguments, 1273 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1274 // If followed by a string, the native code body has also been specified. 1275 std::string codeStrStorage; 1276 Optional<StringRef> optCodeStr; 1277 if (curToken.isString()) { 1278 codeStrStorage = curToken.getStringValue(); 1279 optCodeStr = codeStrStorage; 1280 consumeToken(); 1281 } else if (isInline) { 1282 return emitError(name.getLoc(), 1283 "external declarations must be declared in global scope"); 1284 } else if (curToken.is(Token::error)) { 1285 return failure(); 1286 } 1287 if (failed(parseToken(Token::semicolon, 1288 "expected `;` after native declaration"))) 1289 return failure(); 1290 // TODO: PDL should be able to support constraint results in certain 1291 // situations, we should revise this. 1292 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1293 return emitError( 1294 "native Constraints currently do not support returning results"); 1295 } 1296 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1297 } 1298 1299 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1300 SmallVectorImpl<ast::VariableDecl *> &arguments, 1301 SmallVectorImpl<ast::VariableDecl *> &results, 1302 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1303 // Parse the argument list of the decl. 1304 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1305 return failure(); 1306 1307 argumentScope = pushDeclScope(); 1308 if (curToken.isNot(Token::r_paren)) { 1309 do { 1310 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1311 if (failed(argument)) 1312 return failure(); 1313 arguments.emplace_back(*argument); 1314 } while (consumeIf(Token::comma)); 1315 } 1316 popDeclScope(); 1317 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1318 return failure(); 1319 1320 // Parse the results of the decl. 1321 pushDeclScope(); 1322 if (consumeIf(Token::arrow)) { 1323 auto parseResultFn = [&]() -> LogicalResult { 1324 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1325 if (failed(result)) 1326 return failure(); 1327 results.emplace_back(*result); 1328 return success(); 1329 }; 1330 1331 // Check for a list of results. 1332 if (consumeIf(Token::l_paren)) { 1333 do { 1334 if (failed(parseResultFn())) 1335 return failure(); 1336 } while (consumeIf(Token::comma)); 1337 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1338 return failure(); 1339 1340 // Otherwise, there is only one result. 1341 } else if (failed(parseResultFn())) { 1342 return failure(); 1343 } 1344 } 1345 popDeclScope(); 1346 1347 // Compute the result type of the decl. 1348 resultType = createUserConstraintRewriteResultType(results); 1349 1350 // Verify that results are only named if there are more than one. 1351 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1352 return emitError( 1353 results.front()->getLoc(), 1354 "cannot create a single-element tuple with an element label"); 1355 } 1356 return success(); 1357 } 1358 1359 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1360 StringRef declType, ast::CompoundStmt *body, 1361 ArrayRef<ast::Stmt *>::iterator bodyIt, 1362 ArrayRef<ast::Stmt *>::iterator bodyE, 1363 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1364 // Handle if a `return` was provided. 1365 if (bodyIt != bodyE) { 1366 // Emit an error if we have trailing statements after the return. 1367 if (std::next(bodyIt) != bodyE) { 1368 return emitError( 1369 (*std::next(bodyIt))->getLoc(), 1370 llvm::formatv("`return` terminated the `{0}` body, but found " 1371 "trailing statements afterwards", 1372 declType)); 1373 } 1374 1375 // Otherwise if a return wasn't provided, check that no results are 1376 // expected. 1377 } else if (!results.empty()) { 1378 return emitError( 1379 {body->getLoc().End, body->getLoc().End}, 1380 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1381 declType, resultType)); 1382 } 1383 return success(); 1384 } 1385 1386 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1387 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1388 if (isa<ast::OpRewriteStmt>(statement)) 1389 return success(); 1390 return emitError( 1391 statement->getLoc(), 1392 "expected Pattern lambda body to contain a single operation " 1393 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1394 }); 1395 } 1396 1397 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1398 SMRange loc = curToken.getLoc(); 1399 consumeToken(Token::kw_Pattern); 1400 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1401 ParserContext::PatternMatch); 1402 1403 // Check for an optional identifier for the pattern name. 1404 const ast::Name *name = nullptr; 1405 if (curToken.is(Token::identifier)) { 1406 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1407 consumeToken(Token::identifier); 1408 } 1409 1410 // Parse any pattern metadata. 1411 ParsedPatternMetadata metadata; 1412 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1413 return failure(); 1414 1415 // Parse the pattern body. 1416 ast::CompoundStmt *body; 1417 1418 // Handle a lambda body. 1419 if (curToken.is(Token::equal_arrow)) { 1420 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1421 if (failed(bodyResult)) 1422 return failure(); 1423 body = *bodyResult; 1424 } else { 1425 if (curToken.isNot(Token::l_brace)) 1426 return emitError("expected `{` or `=>` to start pattern body"); 1427 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1428 if (failed(bodyResult)) 1429 return failure(); 1430 body = *bodyResult; 1431 1432 // Verify the body of the pattern. 1433 auto bodyIt = body->begin(), bodyE = body->end(); 1434 for (; bodyIt != bodyE; ++bodyIt) { 1435 if (isa<ast::ReturnStmt>(*bodyIt)) { 1436 return emitError((*bodyIt)->getLoc(), 1437 "`return` statements are only permitted within a " 1438 "`Constraint` or `Rewrite` body"); 1439 } 1440 // Break when we've found the rewrite statement. 1441 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1442 break; 1443 } 1444 if (bodyIt == bodyE) { 1445 return emitError(loc, 1446 "expected Pattern body to terminate with an operation " 1447 "rewrite statement, such as `erase`"); 1448 } 1449 if (std::next(bodyIt) != bodyE) { 1450 return emitError((*std::next(bodyIt))->getLoc(), 1451 "Pattern body was terminated by an operation " 1452 "rewrite statement, but found trailing statements"); 1453 } 1454 } 1455 1456 return createPatternDecl(loc, name, metadata, body); 1457 } 1458 1459 LogicalResult 1460 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1461 Optional<SMRange> benefitLoc; 1462 Optional<SMRange> hasBoundedRecursionLoc; 1463 1464 do { 1465 // Handle metadata code completion. 1466 if (curToken.is(Token::code_complete)) 1467 return codeCompletePatternMetadata(); 1468 1469 if (curToken.isNot(Token::identifier)) 1470 return emitError("expected pattern metadata identifier"); 1471 StringRef metadataStr = curToken.getSpelling(); 1472 SMRange metadataLoc = curToken.getLoc(); 1473 consumeToken(Token::identifier); 1474 1475 // Parse the benefit metadata: benefit(<integer-value>) 1476 if (metadataStr == "benefit") { 1477 if (benefitLoc) { 1478 return emitErrorAndNote(metadataLoc, 1479 "pattern benefit has already been specified", 1480 *benefitLoc, "see previous definition here"); 1481 } 1482 if (failed(parseToken(Token::l_paren, 1483 "expected `(` before pattern benefit"))) 1484 return failure(); 1485 1486 uint16_t benefitValue = 0; 1487 if (curToken.isNot(Token::integer)) 1488 return emitError("expected integral pattern benefit"); 1489 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1490 return emitError( 1491 "expected pattern benefit to fit within a 16-bit integer"); 1492 consumeToken(Token::integer); 1493 1494 metadata.benefit = benefitValue; 1495 benefitLoc = metadataLoc; 1496 1497 if (failed( 1498 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1499 return failure(); 1500 continue; 1501 } 1502 1503 // Parse the bounded recursion metadata: recursion 1504 if (metadataStr == "recursion") { 1505 if (hasBoundedRecursionLoc) { 1506 return emitErrorAndNote( 1507 metadataLoc, 1508 "pattern recursion metadata has already been specified", 1509 *hasBoundedRecursionLoc, "see previous definition here"); 1510 } 1511 metadata.hasBoundedRecursion = true; 1512 hasBoundedRecursionLoc = metadataLoc; 1513 continue; 1514 } 1515 1516 return emitError(metadataLoc, "unknown pattern metadata"); 1517 } while (consumeIf(Token::comma)); 1518 1519 return success(); 1520 } 1521 1522 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1523 consumeToken(Token::less); 1524 1525 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1526 if (failed(typeExpr) || 1527 failed(parseToken(Token::greater, 1528 "expected `>` after variable type constraint"))) 1529 return failure(); 1530 return typeExpr; 1531 } 1532 1533 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1534 assert(curDeclScope && "defining decl outside of a decl scope"); 1535 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1536 return emitErrorAndNote( 1537 name.getLoc(), "`" + name.getName() + "` has already been defined", 1538 lastDecl->getName()->getLoc(), "see previous definition here"); 1539 } 1540 return success(); 1541 } 1542 1543 FailureOr<ast::VariableDecl *> 1544 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1545 ast::Expr *initExpr, 1546 ArrayRef<ast::ConstraintRef> constraints) { 1547 assert(curDeclScope && "defining variable outside of decl scope"); 1548 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1549 1550 // If the name of the variable indicates a special variable, we don't add it 1551 // to the scope. This variable is local to the definition point. 1552 if (name.empty() || name == "_") { 1553 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1554 constraints); 1555 } 1556 if (failed(checkDefineNamedDecl(nameDecl))) 1557 return failure(); 1558 1559 auto *varDecl = 1560 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1561 curDeclScope->add(varDecl); 1562 return varDecl; 1563 } 1564 1565 FailureOr<ast::VariableDecl *> 1566 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1567 ArrayRef<ast::ConstraintRef> constraints) { 1568 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1569 constraints); 1570 } 1571 1572 LogicalResult Parser::parseVariableDeclConstraintList( 1573 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1574 Optional<SMRange> typeConstraint; 1575 auto parseSingleConstraint = [&] { 1576 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1577 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1578 /*allowNonCoreConstraints=*/true); 1579 if (failed(constraint)) 1580 return failure(); 1581 constraints.push_back(*constraint); 1582 return success(); 1583 }; 1584 1585 // Check to see if this is a single constraint, or a list. 1586 if (!consumeIf(Token::l_square)) 1587 return parseSingleConstraint(); 1588 1589 do { 1590 if (failed(parseSingleConstraint())) 1591 return failure(); 1592 } while (consumeIf(Token::comma)); 1593 return parseToken(Token::r_square, "expected `]` after constraint list"); 1594 } 1595 1596 FailureOr<ast::ConstraintRef> 1597 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1598 ArrayRef<ast::ConstraintRef> existingConstraints, 1599 bool allowInlineTypeConstraints, 1600 bool allowNonCoreConstraints) { 1601 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1602 if (!allowInlineTypeConstraints) { 1603 return emitError( 1604 curToken.getLoc(), 1605 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1606 "permitted on arguments or results"); 1607 } 1608 if (typeConstraint) 1609 return emitErrorAndNote( 1610 curToken.getLoc(), 1611 "the type of this variable has already been constrained", 1612 *typeConstraint, "see previous constraint location here"); 1613 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1614 if (failed(constraintExpr)) 1615 return failure(); 1616 typeExpr = *constraintExpr; 1617 typeConstraint = typeExpr->getLoc(); 1618 return success(); 1619 }; 1620 1621 SMRange loc = curToken.getLoc(); 1622 switch (curToken.getKind()) { 1623 case Token::kw_Attr: { 1624 consumeToken(Token::kw_Attr); 1625 1626 // Check for a type constraint. 1627 ast::Expr *typeExpr = nullptr; 1628 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1629 return failure(); 1630 return ast::ConstraintRef( 1631 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1632 } 1633 case Token::kw_Op: { 1634 consumeToken(Token::kw_Op); 1635 1636 // Parse an optional operation name. If the name isn't provided, this refers 1637 // to "any" operation. 1638 FailureOr<ast::OpNameDecl *> opName = 1639 parseWrappedOperationName(/*allowEmptyName=*/true); 1640 if (failed(opName)) 1641 return failure(); 1642 1643 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1644 loc); 1645 } 1646 case Token::kw_Type: 1647 consumeToken(Token::kw_Type); 1648 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1649 case Token::kw_TypeRange: 1650 consumeToken(Token::kw_TypeRange); 1651 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1652 loc); 1653 case Token::kw_Value: { 1654 consumeToken(Token::kw_Value); 1655 1656 // Check for a type constraint. 1657 ast::Expr *typeExpr = nullptr; 1658 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1659 return failure(); 1660 1661 return ast::ConstraintRef( 1662 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1663 } 1664 case Token::kw_ValueRange: { 1665 consumeToken(Token::kw_ValueRange); 1666 1667 // Check for a type constraint. 1668 ast::Expr *typeExpr = nullptr; 1669 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1670 return failure(); 1671 1672 return ast::ConstraintRef( 1673 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1674 } 1675 1676 case Token::kw_Constraint: { 1677 // Handle an inline constraint. 1678 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1679 if (failed(decl)) 1680 return failure(); 1681 return ast::ConstraintRef(*decl, loc); 1682 } 1683 case Token::identifier: { 1684 StringRef constraintName = curToken.getSpelling(); 1685 consumeToken(Token::identifier); 1686 1687 // Lookup the referenced constraint. 1688 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1689 if (!cstDecl) { 1690 return emitError(loc, "unknown reference to constraint `" + 1691 constraintName + "`"); 1692 } 1693 1694 // Handle a reference to a proper constraint. 1695 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1696 return ast::ConstraintRef(cst, loc); 1697 1698 return emitErrorAndNote( 1699 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1700 "see the definition of `" + constraintName + "` here"); 1701 } 1702 // Handle single entity constraint code completion. 1703 case Token::code_complete: { 1704 // Try to infer the current type for use by code completion. 1705 ast::Type inferredType; 1706 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1707 allowNonCoreConstraints))) 1708 return failure(); 1709 1710 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1711 allowInlineTypeConstraints); 1712 } 1713 default: 1714 break; 1715 } 1716 return emitError(loc, "expected identifier constraint"); 1717 } 1718 1719 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1720 // Constraint arguments may apply more complex constraints via the arguments. 1721 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1722 1723 Optional<SMRange> typeConstraint; 1724 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1725 /*allowInlineTypeConstraints=*/false, 1726 allowNonCoreConstraints); 1727 } 1728 1729 //===----------------------------------------------------------------------===// 1730 // Exprs 1731 1732 FailureOr<ast::Expr *> Parser::parseExpr() { 1733 if (curToken.is(Token::underscore)) 1734 return parseUnderscoreExpr(); 1735 1736 // Parse the LHS expression. 1737 FailureOr<ast::Expr *> lhsExpr; 1738 switch (curToken.getKind()) { 1739 case Token::kw_attr: 1740 lhsExpr = parseAttributeExpr(); 1741 break; 1742 case Token::kw_Constraint: 1743 lhsExpr = parseInlineConstraintLambdaExpr(); 1744 break; 1745 case Token::identifier: 1746 lhsExpr = parseIdentifierExpr(); 1747 break; 1748 case Token::kw_op: 1749 lhsExpr = parseOperationExpr(); 1750 break; 1751 case Token::kw_Rewrite: 1752 lhsExpr = parseInlineRewriteLambdaExpr(); 1753 break; 1754 case Token::kw_type: 1755 lhsExpr = parseTypeExpr(); 1756 break; 1757 case Token::l_paren: 1758 lhsExpr = parseTupleExpr(); 1759 break; 1760 default: 1761 return emitError("expected expression"); 1762 } 1763 if (failed(lhsExpr)) 1764 return failure(); 1765 1766 // Check for an operator expression. 1767 while (true) { 1768 switch (curToken.getKind()) { 1769 case Token::dot: 1770 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1771 break; 1772 case Token::l_paren: 1773 lhsExpr = parseCallExpr(*lhsExpr); 1774 break; 1775 default: 1776 return lhsExpr; 1777 } 1778 if (failed(lhsExpr)) 1779 return failure(); 1780 } 1781 } 1782 1783 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1784 SMRange loc = curToken.getLoc(); 1785 consumeToken(Token::kw_attr); 1786 1787 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1788 // identifier. 1789 if (!consumeIf(Token::less)) { 1790 resetToken(loc); 1791 return parseIdentifierExpr(); 1792 } 1793 1794 if (!curToken.isString()) 1795 return emitError("expected string literal containing MLIR attribute"); 1796 std::string attrExpr = curToken.getStringValue(); 1797 consumeToken(); 1798 1799 loc.End = curToken.getEndLoc(); 1800 if (failed( 1801 parseToken(Token::greater, "expected `>` after attribute literal"))) 1802 return failure(); 1803 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1804 } 1805 1806 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1807 consumeToken(Token::l_paren); 1808 1809 // Parse the arguments of the call. 1810 SmallVector<ast::Expr *> arguments; 1811 if (curToken.isNot(Token::r_paren)) { 1812 do { 1813 // Handle code completion for the call arguments. 1814 if (curToken.is(Token::code_complete)) { 1815 codeCompleteCallSignature(parentExpr, arguments.size()); 1816 return failure(); 1817 } 1818 1819 FailureOr<ast::Expr *> argument = parseExpr(); 1820 if (failed(argument)) 1821 return failure(); 1822 arguments.push_back(*argument); 1823 } while (consumeIf(Token::comma)); 1824 } 1825 1826 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1827 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1828 return failure(); 1829 1830 return createCallExpr(loc, parentExpr, arguments); 1831 } 1832 1833 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1834 ast::Decl *decl = curDeclScope->lookup(name); 1835 if (!decl) 1836 return emitError(loc, "undefined reference to `" + name + "`"); 1837 1838 return createDeclRefExpr(loc, decl); 1839 } 1840 1841 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1842 StringRef name = curToken.getSpelling(); 1843 SMRange nameLoc = curToken.getLoc(); 1844 consumeToken(); 1845 1846 // Check to see if this is a decl ref expression that defines a variable 1847 // inline. 1848 if (consumeIf(Token::colon)) { 1849 SmallVector<ast::ConstraintRef> constraints; 1850 if (failed(parseVariableDeclConstraintList(constraints))) 1851 return failure(); 1852 ast::Type type; 1853 if (failed(validateVariableConstraints(constraints, type))) 1854 return failure(); 1855 return createInlineVariableExpr(type, name, nameLoc, constraints); 1856 } 1857 1858 return parseDeclRefExpr(name, nameLoc); 1859 } 1860 1861 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1862 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1863 if (failed(decl)) 1864 return failure(); 1865 1866 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1867 ast::ConstraintType::get(ctx)); 1868 } 1869 1870 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1871 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1872 if (failed(decl)) 1873 return failure(); 1874 1875 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1876 ast::RewriteType::get(ctx)); 1877 } 1878 1879 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1880 SMRange dotLoc = curToken.getLoc(); 1881 consumeToken(Token::dot); 1882 1883 // Check for code completion of the member name. 1884 if (curToken.is(Token::code_complete)) 1885 return codeCompleteMemberAccess(parentExpr); 1886 1887 // Parse the member name. 1888 Token memberNameTok = curToken; 1889 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1890 !memberNameTok.isKeyword()) 1891 return emitError(dotLoc, "expected identifier or numeric member name"); 1892 StringRef memberName = memberNameTok.getSpelling(); 1893 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1894 consumeToken(); 1895 1896 return createMemberAccessExpr(parentExpr, memberName, loc); 1897 } 1898 1899 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1900 SMRange loc = curToken.getLoc(); 1901 1902 // Check for code completion for the dialect name. 1903 if (curToken.is(Token::code_complete)) 1904 return codeCompleteDialectName(); 1905 1906 // Handle the case of an no operation name. 1907 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1908 if (allowEmptyName) 1909 return ast::OpNameDecl::create(ctx, SMRange()); 1910 return emitError("expected dialect namespace"); 1911 } 1912 StringRef name = curToken.getSpelling(); 1913 consumeToken(); 1914 1915 // Otherwise, this is a literal operation name. 1916 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1917 return failure(); 1918 1919 // Check for code completion for the operation name. 1920 if (curToken.is(Token::code_complete)) 1921 return codeCompleteOperationName(name); 1922 1923 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1924 return emitError("expected operation name after dialect namespace"); 1925 1926 name = StringRef(name.data(), name.size() + 1); 1927 do { 1928 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1929 loc.End = curToken.getEndLoc(); 1930 consumeToken(); 1931 } while (curToken.isAny(Token::identifier, Token::dot) || 1932 curToken.isKeyword()); 1933 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1934 } 1935 1936 FailureOr<ast::OpNameDecl *> 1937 Parser::parseWrappedOperationName(bool allowEmptyName) { 1938 if (!consumeIf(Token::less)) 1939 return ast::OpNameDecl::create(ctx, SMRange()); 1940 1941 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1942 if (failed(opNameDecl)) 1943 return failure(); 1944 1945 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1946 return failure(); 1947 return opNameDecl; 1948 } 1949 1950 FailureOr<ast::Expr *> 1951 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { 1952 SMRange loc = curToken.getLoc(); 1953 consumeToken(Token::kw_op); 1954 1955 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1956 // identifier. 1957 if (curToken.isNot(Token::less)) { 1958 resetToken(loc); 1959 return parseIdentifierExpr(); 1960 } 1961 1962 // Parse the operation name. The name may be elided, in which case the 1963 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1964 // `Operation*`). Operation names within a rewrite context must be named. 1965 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1966 FailureOr<ast::OpNameDecl *> opNameDecl = 1967 parseWrappedOperationName(allowEmptyName); 1968 if (failed(opNameDecl)) 1969 return failure(); 1970 Optional<StringRef> opName = (*opNameDecl)->getName(); 1971 1972 // Functor used to create an implicit range variable, used for implicit "all" 1973 // operand or results variables. 1974 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1975 FailureOr<ast::VariableDecl *> rangeVar = 1976 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1977 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1978 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1979 }; 1980 1981 // Check for the optional list of operands. 1982 SmallVector<ast::Expr *> operands; 1983 if (!consumeIf(Token::l_paren)) { 1984 // If the operand list isn't specified and we are in a match context, define 1985 // an inplace unconstrained operand range corresponding to all of the 1986 // operands of the operation. This avoids treating zero operands the same 1987 // way as "unconstrained operands". 1988 if (parserContext != ParserContext::Rewrite) { 1989 operands.push_back(createImplicitRangeVar( 1990 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1991 } 1992 } else if (!consumeIf(Token::r_paren)) { 1993 // If the operand list was specified and non-empty, parse the operands. 1994 do { 1995 // Check for operand signature code completion. 1996 if (curToken.is(Token::code_complete)) { 1997 codeCompleteOperationOperandsSignature(opName, operands.size()); 1998 return failure(); 1999 } 2000 2001 FailureOr<ast::Expr *> operand = parseExpr(); 2002 if (failed(operand)) 2003 return failure(); 2004 operands.push_back(*operand); 2005 } while (consumeIf(Token::comma)); 2006 2007 if (failed(parseToken(Token::r_paren, 2008 "expected `)` after operation operand list"))) 2009 return failure(); 2010 } 2011 2012 // Check for the optional list of attributes. 2013 SmallVector<ast::NamedAttributeDecl *> attributes; 2014 if (consumeIf(Token::l_brace)) { 2015 do { 2016 FailureOr<ast::NamedAttributeDecl *> decl = 2017 parseNamedAttributeDecl(opName); 2018 if (failed(decl)) 2019 return failure(); 2020 attributes.emplace_back(*decl); 2021 } while (consumeIf(Token::comma)); 2022 2023 if (failed(parseToken(Token::r_brace, 2024 "expected `}` after operation attribute list"))) 2025 return failure(); 2026 } 2027 2028 // Handle the result types of the operation. 2029 SmallVector<ast::Expr *> resultTypes; 2030 OpResultTypeContext resultTypeContext = inputResultTypeContext; 2031 2032 // Check for an explicit list of result types. 2033 if (consumeIf(Token::arrow)) { 2034 if (failed(parseToken(Token::l_paren, 2035 "expected `(` before operation result type list"))) 2036 return failure(); 2037 2038 // If result types are provided, initially assume that the operation does 2039 // not rely on type inferrence. We don't assert that it isn't, because we 2040 // may be inferring the value of some type/type range variables, but given 2041 // that these variables may be defined in calls we can't always discern when 2042 // this is the case. 2043 resultTypeContext = OpResultTypeContext::Explicit; 2044 2045 // Handle the case of an empty result list. 2046 if (!consumeIf(Token::r_paren)) { 2047 do { 2048 // Check for result signature code completion. 2049 if (curToken.is(Token::code_complete)) { 2050 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2051 return failure(); 2052 } 2053 2054 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2055 if (failed(resultTypeExpr)) 2056 return failure(); 2057 resultTypes.push_back(*resultTypeExpr); 2058 } while (consumeIf(Token::comma)); 2059 2060 if (failed(parseToken(Token::r_paren, 2061 "expected `)` after operation result type list"))) 2062 return failure(); 2063 } 2064 } else if (parserContext != ParserContext::Rewrite) { 2065 // If the result list isn't specified and we are in a match context, define 2066 // an inplace unconstrained result range corresponding to all of the results 2067 // of the operation. This avoids treating zero results the same way as 2068 // "unconstrained results". 2069 resultTypes.push_back(createImplicitRangeVar( 2070 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2071 } else if (resultTypeContext == OpResultTypeContext::Explicit) { 2072 // If the result list isn't specified and we are in a rewrite, try to infer 2073 // them at runtime instead. 2074 resultTypeContext = OpResultTypeContext::Interface; 2075 } 2076 2077 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands, 2078 attributes, resultTypes); 2079 } 2080 2081 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2082 SMRange loc = curToken.getLoc(); 2083 consumeToken(Token::l_paren); 2084 2085 DenseMap<StringRef, SMRange> usedNames; 2086 SmallVector<StringRef> elementNames; 2087 SmallVector<ast::Expr *> elements; 2088 if (curToken.isNot(Token::r_paren)) { 2089 do { 2090 // Check for the optional element name assignment before the value. 2091 StringRef elementName; 2092 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2093 Token elementNameTok = curToken; 2094 consumeToken(); 2095 2096 // The element name is only present if followed by an `=`. 2097 if (consumeIf(Token::equal)) { 2098 elementName = elementNameTok.getSpelling(); 2099 2100 // Check to see if this name is already used. 2101 auto elementNameIt = 2102 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2103 if (!elementNameIt.second) { 2104 return emitErrorAndNote( 2105 elementNameTok.getLoc(), 2106 llvm::formatv("duplicate tuple element label `{0}`", 2107 elementName), 2108 elementNameIt.first->getSecond(), 2109 "see previous label use here"); 2110 } 2111 } else { 2112 // Otherwise, we treat this as part of an expression so reset the 2113 // lexer. 2114 resetToken(elementNameTok.getLoc()); 2115 } 2116 } 2117 elementNames.push_back(elementName); 2118 2119 // Parse the tuple element value. 2120 FailureOr<ast::Expr *> element = parseExpr(); 2121 if (failed(element)) 2122 return failure(); 2123 elements.push_back(*element); 2124 } while (consumeIf(Token::comma)); 2125 } 2126 loc.End = curToken.getEndLoc(); 2127 if (failed( 2128 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2129 return failure(); 2130 return createTupleExpr(loc, elements, elementNames); 2131 } 2132 2133 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2134 SMRange loc = curToken.getLoc(); 2135 consumeToken(Token::kw_type); 2136 2137 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2138 // identifier. 2139 if (!consumeIf(Token::less)) { 2140 resetToken(loc); 2141 return parseIdentifierExpr(); 2142 } 2143 2144 if (!curToken.isString()) 2145 return emitError("expected string literal containing MLIR type"); 2146 std::string attrExpr = curToken.getStringValue(); 2147 consumeToken(); 2148 2149 loc.End = curToken.getEndLoc(); 2150 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2151 return failure(); 2152 return ast::TypeExpr::create(ctx, loc, attrExpr); 2153 } 2154 2155 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2156 StringRef name = curToken.getSpelling(); 2157 SMRange nameLoc = curToken.getLoc(); 2158 consumeToken(Token::underscore); 2159 2160 // Underscore expressions require a constraint list. 2161 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2162 return failure(); 2163 2164 // Parse the constraints for the expression. 2165 SmallVector<ast::ConstraintRef> constraints; 2166 if (failed(parseVariableDeclConstraintList(constraints))) 2167 return failure(); 2168 2169 ast::Type type; 2170 if (failed(validateVariableConstraints(constraints, type))) 2171 return failure(); 2172 return createInlineVariableExpr(type, name, nameLoc, constraints); 2173 } 2174 2175 //===----------------------------------------------------------------------===// 2176 // Stmts 2177 2178 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2179 FailureOr<ast::Stmt *> stmt; 2180 switch (curToken.getKind()) { 2181 case Token::kw_erase: 2182 stmt = parseEraseStmt(); 2183 break; 2184 case Token::kw_let: 2185 stmt = parseLetStmt(); 2186 break; 2187 case Token::kw_replace: 2188 stmt = parseReplaceStmt(); 2189 break; 2190 case Token::kw_return: 2191 stmt = parseReturnStmt(); 2192 break; 2193 case Token::kw_rewrite: 2194 stmt = parseRewriteStmt(); 2195 break; 2196 default: 2197 stmt = parseExpr(); 2198 break; 2199 } 2200 if (failed(stmt) || 2201 (expectTerminalSemicolon && 2202 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2203 return failure(); 2204 return stmt; 2205 } 2206 2207 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2208 SMLoc startLoc = curToken.getStartLoc(); 2209 consumeToken(Token::l_brace); 2210 2211 // Push a new block scope and parse any nested statements. 2212 pushDeclScope(); 2213 SmallVector<ast::Stmt *> statements; 2214 while (curToken.isNot(Token::r_brace)) { 2215 FailureOr<ast::Stmt *> statement = parseStmt(); 2216 if (failed(statement)) 2217 return popDeclScope(), failure(); 2218 statements.push_back(*statement); 2219 } 2220 popDeclScope(); 2221 2222 // Consume the end brace. 2223 SMRange location(startLoc, curToken.getEndLoc()); 2224 consumeToken(Token::r_brace); 2225 2226 return ast::CompoundStmt::create(ctx, location, statements); 2227 } 2228 2229 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2230 if (parserContext == ParserContext::Constraint) 2231 return emitError("`erase` cannot be used within a Constraint"); 2232 SMRange loc = curToken.getLoc(); 2233 consumeToken(Token::kw_erase); 2234 2235 // Parse the root operation expression. 2236 FailureOr<ast::Expr *> rootOp = parseExpr(); 2237 if (failed(rootOp)) 2238 return failure(); 2239 2240 return createEraseStmt(loc, *rootOp); 2241 } 2242 2243 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2244 SMRange loc = curToken.getLoc(); 2245 consumeToken(Token::kw_let); 2246 2247 // Parse the name of the new variable. 2248 SMRange varLoc = curToken.getLoc(); 2249 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2250 // `_` is a reserved variable name. 2251 if (curToken.is(Token::underscore)) { 2252 return emitError(varLoc, 2253 "`_` may only be used to define \"inline\" variables"); 2254 } 2255 return emitError(varLoc, 2256 "expected identifier after `let` to name a new variable"); 2257 } 2258 StringRef varName = curToken.getSpelling(); 2259 consumeToken(); 2260 2261 // Parse the optional set of constraints. 2262 SmallVector<ast::ConstraintRef> constraints; 2263 if (consumeIf(Token::colon) && 2264 failed(parseVariableDeclConstraintList(constraints))) 2265 return failure(); 2266 2267 // Parse the optional initializer expression. 2268 ast::Expr *initializer = nullptr; 2269 if (consumeIf(Token::equal)) { 2270 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2271 if (failed(initOrFailure)) 2272 return failure(); 2273 initializer = *initOrFailure; 2274 2275 // Check that the constraints are compatible with having an initializer, 2276 // e.g. type constraints cannot be used with initializers. 2277 for (ast::ConstraintRef constraint : constraints) { 2278 LogicalResult result = 2279 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2280 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2281 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2282 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2283 return this->emitError( 2284 constraint.referenceLoc, 2285 "type constraints are not permitted on variables with " 2286 "initializers"); 2287 } 2288 return success(); 2289 }) 2290 .Default(success()); 2291 if (failed(result)) 2292 return failure(); 2293 } 2294 } 2295 2296 FailureOr<ast::VariableDecl *> varDecl = 2297 createVariableDecl(varName, varLoc, initializer, constraints); 2298 if (failed(varDecl)) 2299 return failure(); 2300 return ast::LetStmt::create(ctx, loc, *varDecl); 2301 } 2302 2303 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2304 if (parserContext == ParserContext::Constraint) 2305 return emitError("`replace` cannot be used within a Constraint"); 2306 SMRange loc = curToken.getLoc(); 2307 consumeToken(Token::kw_replace); 2308 2309 // Parse the root operation expression. 2310 FailureOr<ast::Expr *> rootOp = parseExpr(); 2311 if (failed(rootOp)) 2312 return failure(); 2313 2314 if (failed( 2315 parseToken(Token::kw_with, "expected `with` after root operation"))) 2316 return failure(); 2317 2318 // The replacement portion of this statement is within a rewrite context. 2319 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2320 ParserContext::Rewrite); 2321 2322 // Parse the replacement values. 2323 SmallVector<ast::Expr *> replValues; 2324 if (consumeIf(Token::l_paren)) { 2325 if (consumeIf(Token::r_paren)) { 2326 return emitError( 2327 loc, "expected at least one replacement value, consider using " 2328 "`erase` if no replacement values are desired"); 2329 } 2330 2331 do { 2332 FailureOr<ast::Expr *> replExpr = parseExpr(); 2333 if (failed(replExpr)) 2334 return failure(); 2335 replValues.emplace_back(*replExpr); 2336 } while (consumeIf(Token::comma)); 2337 2338 if (failed(parseToken(Token::r_paren, 2339 "expected `)` after replacement values"))) 2340 return failure(); 2341 } else { 2342 // Handle replacement with an operation uniquely, as the replacement 2343 // operation supports type inferrence from the root operation. 2344 FailureOr<ast::Expr *> replExpr; 2345 if (curToken.is(Token::kw_op)) 2346 replExpr = parseOperationExpr(OpResultTypeContext::Replacement); 2347 else 2348 replExpr = parseExpr(); 2349 if (failed(replExpr)) 2350 return failure(); 2351 replValues.emplace_back(*replExpr); 2352 } 2353 2354 return createReplaceStmt(loc, *rootOp, replValues); 2355 } 2356 2357 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2358 SMRange loc = curToken.getLoc(); 2359 consumeToken(Token::kw_return); 2360 2361 // Parse the result value. 2362 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2363 if (failed(resultExpr)) 2364 return failure(); 2365 2366 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2367 } 2368 2369 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2370 if (parserContext == ParserContext::Constraint) 2371 return emitError("`rewrite` cannot be used within a Constraint"); 2372 SMRange loc = curToken.getLoc(); 2373 consumeToken(Token::kw_rewrite); 2374 2375 // Parse the root operation. 2376 FailureOr<ast::Expr *> rootOp = parseExpr(); 2377 if (failed(rootOp)) 2378 return failure(); 2379 2380 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2381 return failure(); 2382 2383 if (curToken.isNot(Token::l_brace)) 2384 return emitError("expected `{` to start rewrite body"); 2385 2386 // The rewrite body of this statement is within a rewrite context. 2387 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2388 ParserContext::Rewrite); 2389 2390 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2391 if (failed(rewriteBody)) 2392 return failure(); 2393 2394 // Verify the rewrite body. 2395 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2396 if (isa<ast::ReturnStmt>(stmt)) { 2397 return emitError(stmt->getLoc(), 2398 "`return` statements are only permitted within a " 2399 "`Constraint` or `Rewrite` body"); 2400 } 2401 } 2402 2403 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2404 } 2405 2406 //===----------------------------------------------------------------------===// 2407 // Creation+Analysis 2408 //===----------------------------------------------------------------------===// 2409 2410 //===----------------------------------------------------------------------===// 2411 // Decls 2412 2413 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2414 // Unwrap reference expressions. 2415 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2416 node = init->getDecl(); 2417 return dyn_cast<ast::CallableDecl>(node); 2418 } 2419 2420 FailureOr<ast::PatternDecl *> 2421 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2422 const ParsedPatternMetadata &metadata, 2423 ast::CompoundStmt *body) { 2424 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2425 metadata.hasBoundedRecursion, body); 2426 } 2427 2428 ast::Type Parser::createUserConstraintRewriteResultType( 2429 ArrayRef<ast::VariableDecl *> results) { 2430 // Single result decls use the type of the single result. 2431 if (results.size() == 1) 2432 return results[0]->getType(); 2433 2434 // Multiple results use a tuple type, with the types and names grabbed from 2435 // the result variable decls. 2436 auto resultTypes = llvm::map_range( 2437 results, [&](const auto *result) { return result->getType(); }); 2438 auto resultNames = llvm::map_range( 2439 results, [&](const auto *result) { return result->getName().getName(); }); 2440 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2441 llvm::to_vector(resultNames)); 2442 } 2443 2444 template <typename T> 2445 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2446 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2447 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2448 ast::CompoundStmt *body) { 2449 if (!body->getChildren().empty()) { 2450 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2451 ast::Expr *resultExpr = retStmt->getResultExpr(); 2452 2453 // Process the result of the decl. If no explicit signature results 2454 // were provided, check for return type inference. Otherwise, check that 2455 // the return expression can be converted to the expected type. 2456 if (results.empty()) 2457 resultType = resultExpr->getType(); 2458 else if (failed(convertExpressionTo(resultExpr, resultType))) 2459 return failure(); 2460 else 2461 retStmt->setResultExpr(resultExpr); 2462 } 2463 } 2464 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2465 } 2466 2467 FailureOr<ast::VariableDecl *> 2468 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2469 ArrayRef<ast::ConstraintRef> constraints) { 2470 // The type of the variable, which is expected to be inferred by either a 2471 // constraint or an initializer expression. 2472 ast::Type type; 2473 if (failed(validateVariableConstraints(constraints, type))) 2474 return failure(); 2475 2476 if (initializer) { 2477 // Update the variable type based on the initializer, or try to convert the 2478 // initializer to the existing type. 2479 if (!type) 2480 type = initializer->getType(); 2481 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2482 type = mergedType; 2483 else if (failed(convertExpressionTo(initializer, type))) 2484 return failure(); 2485 2486 // Otherwise, if there is no initializer check that the type has already 2487 // been resolved from the constraint list. 2488 } else if (!type) { 2489 return emitErrorAndNote( 2490 loc, "unable to infer type for variable `" + name + "`", loc, 2491 "the type of a variable must be inferable from the constraint " 2492 "list or the initializer"); 2493 } 2494 2495 // Constraint types cannot be used when defining variables. 2496 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2497 return emitError( 2498 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2499 } 2500 2501 // Try to define a variable with the given name. 2502 FailureOr<ast::VariableDecl *> varDecl = 2503 defineVariableDecl(name, loc, type, initializer, constraints); 2504 if (failed(varDecl)) 2505 return failure(); 2506 2507 return *varDecl; 2508 } 2509 2510 FailureOr<ast::VariableDecl *> 2511 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2512 const ast::ConstraintRef &constraint) { 2513 // Constraint arguments may apply more complex constraints via the arguments. 2514 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2515 ast::Type argType; 2516 if (failed(validateVariableConstraint(constraint, argType, 2517 allowNonCoreConstraints))) 2518 return failure(); 2519 return defineVariableDecl(name, loc, argType, constraint); 2520 } 2521 2522 LogicalResult 2523 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2524 ast::Type &inferredType, 2525 bool allowNonCoreConstraints) { 2526 for (const ast::ConstraintRef &ref : constraints) 2527 if (failed(validateVariableConstraint(ref, inferredType, 2528 allowNonCoreConstraints))) 2529 return failure(); 2530 return success(); 2531 } 2532 2533 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2534 ast::Type &inferredType, 2535 bool allowNonCoreConstraints) { 2536 ast::Type constraintType; 2537 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2538 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2539 if (failed(validateTypeConstraintExpr(typeExpr))) 2540 return failure(); 2541 } 2542 constraintType = ast::AttributeType::get(ctx); 2543 } else if (const auto *cst = 2544 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2545 constraintType = ast::OperationType::get( 2546 ctx, cst->getName(), lookupODSOperation(cst->getName())); 2547 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2548 constraintType = typeTy; 2549 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2550 constraintType = typeRangeTy; 2551 } else if (const auto *cst = 2552 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2553 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2554 if (failed(validateTypeConstraintExpr(typeExpr))) 2555 return failure(); 2556 } 2557 constraintType = valueTy; 2558 } else if (const auto *cst = 2559 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2560 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2561 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2562 return failure(); 2563 } 2564 constraintType = valueRangeTy; 2565 } else if (const auto *cst = 2566 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2567 if (!allowNonCoreConstraints) { 2568 return emitError(ref.referenceLoc, 2569 "`Rewrite` arguments and results are only permitted to " 2570 "use core constraints, such as `Attr`, `Op`, `Type`, " 2571 "`TypeRange`, `Value`, `ValueRange`"); 2572 } 2573 2574 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2575 if (inputs.size() != 1) { 2576 return emitErrorAndNote(ref.referenceLoc, 2577 "`Constraint`s applied via a variable constraint " 2578 "list must take a single input, but got " + 2579 Twine(inputs.size()), 2580 cst->getLoc(), 2581 "see definition of constraint here"); 2582 } 2583 constraintType = inputs.front()->getType(); 2584 } else { 2585 llvm_unreachable("unknown constraint type"); 2586 } 2587 2588 // Check that the constraint type is compatible with the current inferred 2589 // type. 2590 if (!inferredType) { 2591 inferredType = constraintType; 2592 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2593 inferredType = mergedTy; 2594 } else { 2595 return emitError(ref.referenceLoc, 2596 llvm::formatv("constraint type `{0}` is incompatible " 2597 "with the previously inferred type `{1}`", 2598 constraintType, inferredType)); 2599 } 2600 return success(); 2601 } 2602 2603 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2604 ast::Type typeExprType = typeExpr->getType(); 2605 if (typeExprType != typeTy) { 2606 return emitError(typeExpr->getLoc(), 2607 "expected expression of `Type` in type constraint"); 2608 } 2609 return success(); 2610 } 2611 2612 LogicalResult 2613 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2614 ast::Type typeExprType = typeExpr->getType(); 2615 if (typeExprType != typeRangeTy) { 2616 return emitError(typeExpr->getLoc(), 2617 "expected expression of `TypeRange` in type constraint"); 2618 } 2619 return success(); 2620 } 2621 2622 //===----------------------------------------------------------------------===// 2623 // Exprs 2624 2625 FailureOr<ast::CallExpr *> 2626 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2627 MutableArrayRef<ast::Expr *> arguments) { 2628 ast::Type parentType = parentExpr->getType(); 2629 2630 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2631 if (!callableDecl) { 2632 return emitError(loc, 2633 llvm::formatv("expected a reference to a callable " 2634 "`Constraint` or `Rewrite`, but got: `{0}`", 2635 parentType)); 2636 } 2637 if (parserContext == ParserContext::Rewrite) { 2638 if (isa<ast::UserConstraintDecl>(callableDecl)) 2639 return emitError( 2640 loc, "unable to invoke `Constraint` within a rewrite section"); 2641 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2642 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2643 } 2644 2645 // Verify the arguments of the call. 2646 /// Handle size mismatch. 2647 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2648 if (callArgs.size() != arguments.size()) { 2649 return emitErrorAndNote( 2650 loc, 2651 llvm::formatv("invalid number of arguments for {0} call; expected " 2652 "{1}, but got {2}", 2653 callableDecl->getCallableType(), callArgs.size(), 2654 arguments.size()), 2655 callableDecl->getLoc(), 2656 llvm::formatv("see the definition of {0} here", 2657 callableDecl->getName()->getName())); 2658 } 2659 2660 /// Handle argument type mismatch. 2661 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2662 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2663 callableDecl->getName()->getName()), 2664 callableDecl->getLoc()); 2665 }; 2666 for (auto it : llvm::zip(callArgs, arguments)) { 2667 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2668 attachDiagFn))) 2669 return failure(); 2670 } 2671 2672 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2673 callableDecl->getResultType()); 2674 } 2675 2676 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2677 ast::Decl *decl) { 2678 // Check the type of decl being referenced. 2679 ast::Type declType; 2680 if (isa<ast::ConstraintDecl>(decl)) 2681 declType = ast::ConstraintType::get(ctx); 2682 else if (isa<ast::UserRewriteDecl>(decl)) 2683 declType = ast::RewriteType::get(ctx); 2684 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2685 declType = varDecl->getType(); 2686 else 2687 return emitError(loc, "invalid reference to `" + 2688 decl->getName()->getName() + "`"); 2689 2690 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2691 } 2692 2693 FailureOr<ast::DeclRefExpr *> 2694 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2695 ArrayRef<ast::ConstraintRef> constraints) { 2696 FailureOr<ast::VariableDecl *> decl = 2697 defineVariableDecl(name, loc, type, constraints); 2698 if (failed(decl)) 2699 return failure(); 2700 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2701 } 2702 2703 FailureOr<ast::MemberAccessExpr *> 2704 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2705 SMRange loc) { 2706 // Validate the member name for the given parent expression. 2707 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2708 if (failed(memberType)) 2709 return failure(); 2710 2711 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2712 } 2713 2714 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2715 StringRef name, SMRange loc) { 2716 ast::Type parentType = parentExpr->getType(); 2717 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2718 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2719 return valueRangeTy; 2720 2721 // Verify member access based on the operation type. 2722 if (const ods::Operation *odsOp = opType.getODSOperation()) { 2723 auto results = odsOp->getResults(); 2724 2725 // Handle indexed results. 2726 unsigned index = 0; 2727 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2728 index < results.size()) { 2729 return results[index].isVariadic() ? valueRangeTy : valueTy; 2730 } 2731 2732 // Handle named results. 2733 const auto *it = llvm::find_if(results, [&](const auto &result) { 2734 return result.getName() == name; 2735 }); 2736 if (it != results.end()) 2737 return it->isVariadic() ? valueRangeTy : valueTy; 2738 } else if (llvm::isDigit(name[0])) { 2739 // Allow unchecked numeric indexing of the results of unregistered 2740 // operations. It returns a single value. 2741 return valueTy; 2742 } 2743 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2744 // Handle indexed results. 2745 unsigned index = 0; 2746 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2747 index < tupleType.size()) { 2748 return tupleType.getElementTypes()[index]; 2749 } 2750 2751 // Handle named results. 2752 auto elementNames = tupleType.getElementNames(); 2753 const auto *it = llvm::find(elementNames, name); 2754 if (it != elementNames.end()) 2755 return tupleType.getElementTypes()[it - elementNames.begin()]; 2756 } 2757 return emitError( 2758 loc, 2759 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2760 name, parentType)); 2761 } 2762 2763 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2764 SMRange loc, const ast::OpNameDecl *name, 2765 OpResultTypeContext resultTypeContext, 2766 MutableArrayRef<ast::Expr *> operands, 2767 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2768 MutableArrayRef<ast::Expr *> results) { 2769 Optional<StringRef> opNameRef = name->getName(); 2770 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2771 2772 // Verify the inputs operands. 2773 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2774 return failure(); 2775 2776 // Verify the attribute list. 2777 for (ast::NamedAttributeDecl *attr : attributes) { 2778 // Check for an attribute type, or a type awaiting resolution. 2779 ast::Type attrType = attr->getValue()->getType(); 2780 if (!attrType.isa<ast::AttributeType>()) { 2781 return emitError( 2782 attr->getValue()->getLoc(), 2783 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2784 } 2785 } 2786 2787 assert( 2788 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && 2789 "unexpected inferrence when results were explicitly specified"); 2790 2791 // If we aren't relying on type inferrence, or explicit results were provided, 2792 // validate them. 2793 if (resultTypeContext == OpResultTypeContext::Explicit) { 2794 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2795 return failure(); 2796 2797 // Validate the use of interface based type inferrence for this operation. 2798 } else if (resultTypeContext == OpResultTypeContext::Interface) { 2799 assert(opNameRef && 2800 "expected valid operation name when inferring operation results"); 2801 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); 2802 } 2803 2804 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results, 2805 attributes); 2806 } 2807 2808 LogicalResult 2809 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2810 const ods::Operation *odsOp, 2811 MutableArrayRef<ast::Expr *> operands) { 2812 return validateOperationOperandsOrResults( 2813 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2814 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2815 valueRangeTy); 2816 } 2817 2818 LogicalResult 2819 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2820 const ods::Operation *odsOp, 2821 MutableArrayRef<ast::Expr *> results) { 2822 return validateOperationOperandsOrResults( 2823 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2824 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2825 } 2826 2827 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, 2828 const ods::Operation *odsOp) { 2829 // If the operation might not have inferrence support, emit a warning to the 2830 // user. We don't emit an error because the interface might be added to the 2831 // operation at runtime. It's rare, but it could still happen. We emit a 2832 // warning here instead. 2833 2834 // Handle inferrence warnings for unknown operations. 2835 if (!odsOp) { 2836 ctx.getDiagEngine().emitWarning( 2837 loc, llvm::formatv( 2838 "operation result types are marked to be inferred, but " 2839 "`{0}` is unknown. Ensure that `{0}` supports zero " 2840 "results or implements `InferTypeOpInterface`. Include " 2841 "the ODS definition of this operation to remove this warning.", 2842 opName)); 2843 return; 2844 } 2845 2846 // Handle inferrence warnings for known operations that expected at least one 2847 // result, but don't have inference support. An elided results list can mean 2848 // "zero-results", and we don't want to warn when that is the expected 2849 // behavior. 2850 bool requiresInferrence = 2851 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) { 2852 return !result.isVariableLength(); 2853 }); 2854 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { 2855 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( 2856 loc, 2857 llvm::formatv("operation result types are marked to be inferred, but " 2858 "`{0}` does not provide an implementation of " 2859 "`InferTypeOpInterface`. Ensure that `{0}` attaches " 2860 "`InferTypeOpInterface` at runtime, or add support to " 2861 "the ODS definition to remove this warning.", 2862 opName)); 2863 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName), 2864 odsOp->getLoc()); 2865 return; 2866 } 2867 } 2868 2869 LogicalResult Parser::validateOperationOperandsOrResults( 2870 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2871 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2872 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2873 ast::Type rangeTy) { 2874 // All operation types accept a single range parameter. 2875 if (values.size() == 1) { 2876 if (failed(convertExpressionTo(values[0], rangeTy))) 2877 return failure(); 2878 return success(); 2879 } 2880 2881 /// If the operation has ODS information, we can more accurately verify the 2882 /// values. 2883 if (odsOpLoc) { 2884 if (odsValues.size() != values.size()) { 2885 return emitErrorAndNote( 2886 loc, 2887 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2888 "{2}, but got {3}", 2889 groupName, *name, odsValues.size(), values.size()), 2890 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2891 } 2892 auto diagFn = [&](ast::Diagnostic &diag) { 2893 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2894 *odsOpLoc); 2895 }; 2896 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2897 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2898 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2899 return failure(); 2900 } 2901 return success(); 2902 } 2903 2904 // Otherwise, accept the value groups as they have been defined and just 2905 // ensure they are one of the expected types. 2906 for (ast::Expr *&valueExpr : values) { 2907 ast::Type valueExprType = valueExpr->getType(); 2908 2909 // Check if this is one of the expected types. 2910 if (valueExprType == rangeTy || valueExprType == singleTy) 2911 continue; 2912 2913 // If the operand is an Operation, allow converting to a Value or 2914 // ValueRange. This situations arises quite often with nested operation 2915 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2916 if (singleTy == valueTy) { 2917 if (valueExprType.isa<ast::OperationType>()) { 2918 valueExpr = convertOpToValue(valueExpr); 2919 continue; 2920 } 2921 } 2922 2923 return emitError( 2924 valueExpr->getLoc(), 2925 llvm::formatv( 2926 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2927 singleTy, rangeTy, valueExprType)); 2928 } 2929 return success(); 2930 } 2931 2932 FailureOr<ast::TupleExpr *> 2933 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2934 ArrayRef<StringRef> elementNames) { 2935 for (const ast::Expr *element : elements) { 2936 ast::Type eleTy = element->getType(); 2937 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2938 return emitError( 2939 element->getLoc(), 2940 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2941 } 2942 } 2943 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2944 } 2945 2946 //===----------------------------------------------------------------------===// 2947 // Stmts 2948 2949 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2950 ast::Expr *rootOp) { 2951 // Check that root is an Operation. 2952 ast::Type rootType = rootOp->getType(); 2953 if (!rootType.isa<ast::OperationType>()) 2954 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2955 2956 return ast::EraseStmt::create(ctx, loc, rootOp); 2957 } 2958 2959 FailureOr<ast::ReplaceStmt *> 2960 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2961 MutableArrayRef<ast::Expr *> replValues) { 2962 // Check that root is an Operation. 2963 ast::Type rootType = rootOp->getType(); 2964 if (!rootType.isa<ast::OperationType>()) { 2965 return emitError( 2966 rootOp->getLoc(), 2967 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2968 } 2969 2970 // If there are multiple replacement values, we implicitly convert any Op 2971 // expressions to the value form. 2972 bool shouldConvertOpToValues = replValues.size() > 1; 2973 for (ast::Expr *&replExpr : replValues) { 2974 ast::Type replType = replExpr->getType(); 2975 2976 // Check that replExpr is an Operation, Value, or ValueRange. 2977 if (replType.isa<ast::OperationType>()) { 2978 if (shouldConvertOpToValues) 2979 replExpr = convertOpToValue(replExpr); 2980 continue; 2981 } 2982 2983 if (replType != valueTy && replType != valueRangeTy) { 2984 return emitError(replExpr->getLoc(), 2985 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2986 "expression, but got `{0}`", 2987 replType)); 2988 } 2989 } 2990 2991 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2992 } 2993 2994 FailureOr<ast::RewriteStmt *> 2995 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2996 ast::CompoundStmt *rewriteBody) { 2997 // Check that root is an Operation. 2998 ast::Type rootType = rootOp->getType(); 2999 if (!rootType.isa<ast::OperationType>()) { 3000 return emitError( 3001 rootOp->getLoc(), 3002 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3003 } 3004 3005 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 3006 } 3007 3008 //===----------------------------------------------------------------------===// 3009 // Code Completion 3010 //===----------------------------------------------------------------------===// 3011 3012 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 3013 ast::Type parentType = parentExpr->getType(); 3014 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 3015 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 3016 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 3017 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 3018 return failure(); 3019 } 3020 3021 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 3022 if (opName) 3023 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 3024 return failure(); 3025 } 3026 3027 LogicalResult 3028 Parser::codeCompleteConstraintName(ast::Type inferredType, 3029 bool allowNonCoreConstraints, 3030 bool allowInlineTypeConstraints) { 3031 codeCompleteContext->codeCompleteConstraintName( 3032 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 3033 curDeclScope); 3034 return failure(); 3035 } 3036 3037 LogicalResult Parser::codeCompleteDialectName() { 3038 codeCompleteContext->codeCompleteDialectName(); 3039 return failure(); 3040 } 3041 3042 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 3043 codeCompleteContext->codeCompleteOperationName(dialectName); 3044 return failure(); 3045 } 3046 3047 LogicalResult Parser::codeCompletePatternMetadata() { 3048 codeCompleteContext->codeCompletePatternMetadata(); 3049 return failure(); 3050 } 3051 3052 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 3053 codeCompleteContext->codeCompleteIncludeFilename(curPath); 3054 return failure(); 3055 } 3056 3057 void Parser::codeCompleteCallSignature(ast::Node *parent, 3058 unsigned currentNumArgs) { 3059 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 3060 if (!callableDecl) 3061 return; 3062 3063 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 3064 } 3065 3066 void Parser::codeCompleteOperationOperandsSignature( 3067 Optional<StringRef> opName, unsigned currentNumOperands) { 3068 codeCompleteContext->codeCompleteOperationOperandsSignature( 3069 opName, currentNumOperands); 3070 } 3071 3072 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 3073 unsigned currentNumResults) { 3074 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 3075 currentNumResults); 3076 } 3077 3078 //===----------------------------------------------------------------------===// 3079 // Parser 3080 //===----------------------------------------------------------------------===// 3081 3082 FailureOr<ast::Module *> 3083 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 3084 CodeCompleteContext *codeCompleteContext) { 3085 Parser parser(ctx, sourceMgr, codeCompleteContext); 3086 return parser.parseModule(); 3087 } 3088