1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 /// The current specification context of an operations result type. This 80 /// indicates how the result types of an operation may be inferred. 81 enum class OpResultTypeContext { 82 /// The result types of the operation are not known to be inferred. 83 Explicit, 84 /// The result types of the operation are inferred from the root input of a 85 /// `replace` statement. 86 Replacement, 87 /// The result types of the operation are inferred by using the 88 /// `InferTypeOpInterface` interface provided by the operation. 89 Interface, 90 }; 91 92 //===--------------------------------------------------------------------===// 93 // Parsing 94 //===--------------------------------------------------------------------===// 95 96 /// Push a new decl scope onto the lexer. 97 ast::DeclScope *pushDeclScope() { 98 ast::DeclScope *newScope = 99 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 100 return (curDeclScope = newScope); 101 } 102 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 103 104 /// Pop the last decl scope from the lexer. 105 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 106 107 /// Parse the body of an AST module. 108 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 109 110 /// Try to convert the given expression to `type`. Returns failure and emits 111 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 112 /// invoked to attach notes to the emitted error diagnostic. On success, 113 /// `expr` is updated to the expression used to convert to `type`. 114 LogicalResult convertExpressionTo( 115 ast::Expr *&expr, ast::Type type, 116 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 117 118 /// Given an operation expression, convert it to a Value or ValueRange 119 /// typed expression. 120 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 121 122 /// Lookup ODS information for the given operation, returns nullptr if no 123 /// information is found. 124 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 125 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 126 } 127 128 //===--------------------------------------------------------------------===// 129 // Directives 130 131 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 132 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 133 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 134 SmallVectorImpl<ast::Decl *> &decls); 135 136 /// Process the records of a parsed tablegen include file. 137 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 138 SmallVectorImpl<ast::Decl *> &decls); 139 140 /// Create a user defined native constraint for a constraint imported from 141 /// ODS. 142 template <typename ConstraintT> 143 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 144 StringRef codeBlock, SMRange loc, 145 ast::Type type); 146 template <typename ConstraintT> 147 ast::Decl * 148 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 149 SMRange loc, ast::Type type); 150 151 //===--------------------------------------------------------------------===// 152 // Decls 153 154 /// This structure contains the set of pattern metadata that may be parsed. 155 struct ParsedPatternMetadata { 156 Optional<uint16_t> benefit; 157 bool hasBoundedRecursion = false; 158 }; 159 160 FailureOr<ast::Decl *> parseTopLevelDecl(); 161 FailureOr<ast::NamedAttributeDecl *> 162 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 163 164 /// Parse an argument variable as part of the signature of a 165 /// UserConstraintDecl or UserRewriteDecl. 166 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 167 168 /// Parse a result variable as part of the signature of a UserConstraintDecl 169 /// or UserRewriteDecl. 170 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 171 172 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 173 /// defined in a non-global context. 174 FailureOr<ast::UserConstraintDecl *> 175 parseUserConstraintDecl(bool isInline = false); 176 177 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 178 /// non-global context, such as within a Pattern/Constraint/etc. 179 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 180 181 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 182 /// PDLL constructs. 183 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 184 const ast::Name &name, bool isInline, 185 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 186 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 187 188 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 189 /// defined in a non-global context. 190 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 191 192 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 193 /// non-global context, such as within a Pattern/Rewrite/etc. 194 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 195 196 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 197 /// PDLL constructs. 198 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 199 const ast::Name &name, bool isInline, 200 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 201 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 202 203 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 204 /// effectively the same syntax, and only differ on slight semantics (given 205 /// the different parsing contexts). 206 template <typename T, typename ParseUserPDLLDeclFnT> 207 FailureOr<T *> parseUserConstraintOrRewriteDecl( 208 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 209 StringRef anonymousNamePrefix, bool isInline); 210 211 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 212 /// These decls have effectively the same syntax. 213 template <typename T> 214 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 215 const ast::Name &name, bool isInline, 216 ArrayRef<ast::VariableDecl *> arguments, 217 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 218 219 /// Parse the functional signature (i.e. the arguments and results) of a 220 /// UserConstraintDecl or UserRewriteDecl. 221 LogicalResult parseUserConstraintOrRewriteSignature( 222 SmallVectorImpl<ast::VariableDecl *> &arguments, 223 SmallVectorImpl<ast::VariableDecl *> &results, 224 ast::DeclScope *&argumentScope, ast::Type &resultType); 225 226 /// Validate the return (which if present is specified by bodyIt) of a 227 /// UserConstraintDecl or UserRewriteDecl. 228 LogicalResult validateUserConstraintOrRewriteReturn( 229 StringRef declType, ast::CompoundStmt *body, 230 ArrayRef<ast::Stmt *>::iterator bodyIt, 231 ArrayRef<ast::Stmt *>::iterator bodyE, 232 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 233 234 FailureOr<ast::CompoundStmt *> 235 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 236 bool expectTerminalSemicolon = true); 237 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 238 FailureOr<ast::Decl *> parsePatternDecl(); 239 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 240 241 /// Check to see if a decl has already been defined with the given name, if 242 /// one has emit and error and return failure. Returns success otherwise. 243 LogicalResult checkDefineNamedDecl(const ast::Name &name); 244 245 /// Try to define a variable decl with the given components, returns the 246 /// variable on success. 247 FailureOr<ast::VariableDecl *> 248 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 249 ast::Expr *initExpr, 250 ArrayRef<ast::ConstraintRef> constraints); 251 FailureOr<ast::VariableDecl *> 252 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 253 ArrayRef<ast::ConstraintRef> constraints); 254 255 /// Parse the constraint reference list for a variable decl. 256 LogicalResult parseVariableDeclConstraintList( 257 SmallVectorImpl<ast::ConstraintRef> &constraints); 258 259 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 260 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 261 262 /// Try to parse a single reference to a constraint. `typeConstraint` is the 263 /// location of a previously parsed type constraint for the entity that will 264 /// be constrained by the parsed constraint. `existingConstraints` are any 265 /// existing constraints that have already been parsed for the same entity 266 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 267 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 268 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 269 /// constraints) may be used with the variable. 270 FailureOr<ast::ConstraintRef> 271 parseConstraint(Optional<SMRange> &typeConstraint, 272 ArrayRef<ast::ConstraintRef> existingConstraints, 273 bool allowInlineTypeConstraints, 274 bool allowNonCoreConstraints); 275 276 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 277 /// argument or result variable. The constraints for these variables do not 278 /// allow inline type constraints, and only permit a single constraint. 279 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 280 281 //===--------------------------------------------------------------------===// 282 // Exprs 283 284 FailureOr<ast::Expr *> parseExpr(); 285 286 /// Identifier expressions. 287 FailureOr<ast::Expr *> parseAttributeExpr(); 288 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 289 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 290 FailureOr<ast::Expr *> parseIdentifierExpr(); 291 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 292 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 293 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 294 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 295 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 296 FailureOr<ast::Expr *> 297 parseOperationExpr(OpResultTypeContext inputResultTypeContext = 298 OpResultTypeContext::Explicit); 299 FailureOr<ast::Expr *> parseTupleExpr(); 300 FailureOr<ast::Expr *> parseTypeExpr(); 301 FailureOr<ast::Expr *> parseUnderscoreExpr(); 302 303 //===--------------------------------------------------------------------===// 304 // Stmts 305 306 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 307 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 308 FailureOr<ast::EraseStmt *> parseEraseStmt(); 309 FailureOr<ast::LetStmt *> parseLetStmt(); 310 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 311 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 312 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 313 314 //===--------------------------------------------------------------------===// 315 // Creation+Analysis 316 //===--------------------------------------------------------------------===// 317 318 //===--------------------------------------------------------------------===// 319 // Decls 320 321 /// Try to extract a callable from the given AST node. Returns nullptr on 322 /// failure. 323 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 324 325 /// Try to create a pattern decl with the given components, returning the 326 /// Pattern on success. 327 FailureOr<ast::PatternDecl *> 328 createPatternDecl(SMRange loc, const ast::Name *name, 329 const ParsedPatternMetadata &metadata, 330 ast::CompoundStmt *body); 331 332 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 333 /// of results, defined as part of the signature. 334 ast::Type 335 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 336 337 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 338 template <typename T> 339 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 340 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 341 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 342 ast::CompoundStmt *body); 343 344 /// Try to create a variable decl with the given components, returning the 345 /// Variable on success. 346 FailureOr<ast::VariableDecl *> 347 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 348 ArrayRef<ast::ConstraintRef> constraints); 349 350 /// Create a variable for an argument or result defined as part of the 351 /// signature of a UserConstraintDecl/UserRewriteDecl. 352 FailureOr<ast::VariableDecl *> 353 createArgOrResultVariableDecl(StringRef name, SMRange loc, 354 const ast::ConstraintRef &constraint); 355 356 /// Validate the constraints used to constraint a variable decl. 357 /// `inferredType` is the type of the variable inferred by the constraints 358 /// within the list, and is updated to the most refined type as determined by 359 /// the constraints. Returns success if the constraint list is valid, failure 360 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 361 /// defined constraints) may be used with the variable. 362 LogicalResult 363 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 364 ast::Type &inferredType, 365 bool allowNonCoreConstraints = true); 366 /// Validate a single reference to a constraint. `inferredType` contains the 367 /// currently inferred variabled type and is refined within the type defined 368 /// by the constraint. Returns success if the constraint is valid, failure 369 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 370 /// defined constraints) may be used with the variable. 371 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 372 ast::Type &inferredType, 373 bool allowNonCoreConstraints = true); 374 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 375 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 376 377 //===--------------------------------------------------------------------===// 378 // Exprs 379 380 FailureOr<ast::CallExpr *> 381 createCallExpr(SMRange loc, ast::Expr *parentExpr, 382 MutableArrayRef<ast::Expr *> arguments); 383 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 384 FailureOr<ast::DeclRefExpr *> 385 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 386 ArrayRef<ast::ConstraintRef> constraints); 387 FailureOr<ast::MemberAccessExpr *> 388 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 389 390 /// Validate the member access `name` into the given parent expression. On 391 /// success, this also returns the type of the member accessed. 392 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 393 StringRef name, SMRange loc); 394 FailureOr<ast::OperationExpr *> 395 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 396 OpResultTypeContext resultTypeContext, 397 MutableArrayRef<ast::Expr *> operands, 398 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 399 MutableArrayRef<ast::Expr *> results); 400 LogicalResult 401 validateOperationOperands(SMRange loc, Optional<StringRef> name, 402 const ods::Operation *odsOp, 403 MutableArrayRef<ast::Expr *> operands); 404 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 405 const ods::Operation *odsOp, 406 MutableArrayRef<ast::Expr *> results); 407 void checkOperationResultTypeInferrence(SMRange loc, StringRef name, 408 const ods::Operation *odsOp); 409 LogicalResult validateOperationOperandsOrResults( 410 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 411 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 412 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 413 ast::Type rangeTy); 414 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 415 ArrayRef<ast::Expr *> elements, 416 ArrayRef<StringRef> elementNames); 417 418 //===--------------------------------------------------------------------===// 419 // Stmts 420 421 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 422 FailureOr<ast::ReplaceStmt *> 423 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 424 MutableArrayRef<ast::Expr *> replValues); 425 FailureOr<ast::RewriteStmt *> 426 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 427 ast::CompoundStmt *rewriteBody); 428 429 //===--------------------------------------------------------------------===// 430 // Code Completion 431 //===--------------------------------------------------------------------===// 432 433 /// The set of various code completion methods. Every completion method 434 /// returns `failure` to stop the parsing process after providing completion 435 /// results. 436 437 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 438 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 439 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 440 bool allowNonCoreConstraints, 441 bool allowInlineTypeConstraints); 442 LogicalResult codeCompleteDialectName(); 443 LogicalResult codeCompleteOperationName(StringRef dialectName); 444 LogicalResult codeCompletePatternMetadata(); 445 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 446 447 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 448 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 449 unsigned currentNumOperands); 450 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 451 unsigned currentNumResults); 452 453 //===--------------------------------------------------------------------===// 454 // Lexer Utilities 455 //===--------------------------------------------------------------------===// 456 457 /// If the current token has the specified kind, consume it and return true. 458 /// If not, return false. 459 bool consumeIf(Token::Kind kind) { 460 if (curToken.isNot(kind)) 461 return false; 462 consumeToken(kind); 463 return true; 464 } 465 466 /// Advance the current lexer onto the next token. 467 void consumeToken() { 468 assert(curToken.isNot(Token::eof, Token::error) && 469 "shouldn't advance past EOF or errors"); 470 curToken = lexer.lexToken(); 471 } 472 473 /// Advance the current lexer onto the next token, asserting what the expected 474 /// current token is. This is preferred to the above method because it leads 475 /// to more self-documenting code with better checking. 476 void consumeToken(Token::Kind kind) { 477 assert(curToken.is(kind) && "consumed an unexpected token"); 478 consumeToken(); 479 } 480 481 /// Reset the lexer to the location at the given position. 482 void resetToken(SMRange tokLoc) { 483 lexer.resetPointer(tokLoc.Start.getPointer()); 484 curToken = lexer.lexToken(); 485 } 486 487 /// Consume the specified token if present and return success. On failure, 488 /// output a diagnostic and return failure. 489 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 490 if (curToken.getKind() != kind) 491 return emitError(curToken.getLoc(), msg); 492 consumeToken(); 493 return success(); 494 } 495 LogicalResult emitError(SMRange loc, const Twine &msg) { 496 lexer.emitError(loc, msg); 497 return failure(); 498 } 499 LogicalResult emitError(const Twine &msg) { 500 return emitError(curToken.getLoc(), msg); 501 } 502 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 503 const Twine ¬e) { 504 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 505 return failure(); 506 } 507 508 //===--------------------------------------------------------------------===// 509 // Fields 510 //===--------------------------------------------------------------------===// 511 512 /// The owning AST context. 513 ast::Context &ctx; 514 515 /// The lexer of this parser. 516 Lexer lexer; 517 518 /// The current token within the lexer. 519 Token curToken; 520 521 /// The most recently defined decl scope. 522 ast::DeclScope *curDeclScope = nullptr; 523 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 524 525 /// The current context of the parser. 526 ParserContext parserContext = ParserContext::Global; 527 528 /// Cached types to simplify verification and expression creation. 529 ast::Type valueTy, valueRangeTy; 530 ast::Type typeTy, typeRangeTy; 531 ast::Type attrTy; 532 533 /// A counter used when naming anonymous constraints and rewrites. 534 unsigned anonymousDeclNameCounter = 0; 535 536 /// The optional code completion context. 537 CodeCompleteContext *codeCompleteContext; 538 }; 539 } // namespace 540 541 FailureOr<ast::Module *> Parser::parseModule() { 542 SMLoc moduleLoc = curToken.getStartLoc(); 543 pushDeclScope(); 544 545 // Parse the top-level decls of the module. 546 SmallVector<ast::Decl *> decls; 547 if (failed(parseModuleBody(decls))) 548 return popDeclScope(), failure(); 549 550 popDeclScope(); 551 return ast::Module::create(ctx, moduleLoc, decls); 552 } 553 554 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 555 while (curToken.isNot(Token::eof)) { 556 if (curToken.is(Token::directive)) { 557 if (failed(parseDirective(decls))) 558 return failure(); 559 continue; 560 } 561 562 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 563 if (failed(decl)) 564 return failure(); 565 decls.push_back(*decl); 566 } 567 return success(); 568 } 569 570 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 571 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 572 valueRangeTy); 573 } 574 575 LogicalResult Parser::convertExpressionTo( 576 ast::Expr *&expr, ast::Type type, 577 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 578 ast::Type exprType = expr->getType(); 579 if (exprType == type) 580 return success(); 581 582 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 583 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 584 expr->getLoc(), llvm::formatv("unable to convert expression of type " 585 "`{0}` to the expected type of " 586 "`{1}`", 587 exprType, type)); 588 if (noteAttachFn) 589 noteAttachFn(*diag); 590 return diag; 591 }; 592 593 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 594 // Two operation types are compatible if they have the same name, or if the 595 // expected type is more general. 596 if (auto opType = type.dyn_cast<ast::OperationType>()) { 597 if (opType.getName()) 598 return emitConvertError(); 599 return success(); 600 } 601 602 // An operation can always convert to a ValueRange. 603 if (type == valueRangeTy) { 604 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 605 valueRangeTy); 606 return success(); 607 } 608 609 // Allow conversion to a single value by constraining the result range. 610 if (type == valueTy) { 611 // If the operation is registered, we can verify if it can ever have a 612 // single result. 613 Optional<StringRef> opName = exprOpType.getName(); 614 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 615 if (odsOp->getResults().empty()) { 616 return emitConvertError()->attachNote( 617 llvm::formatv("see the definition of `{0}`, which was defined " 618 "with zero results", 619 odsOp->getName()), 620 odsOp->getLoc()); 621 } 622 623 unsigned numSingleResults = llvm::count_if( 624 odsOp->getResults(), [](const ods::OperandOrResult &result) { 625 return result.getVariableLengthKind() == 626 ods::VariableLengthKind::Single; 627 }); 628 if (numSingleResults > 1) { 629 return emitConvertError()->attachNote( 630 llvm::formatv("see the definition of `{0}`, which was defined " 631 "with at least {1} results", 632 odsOp->getName(), numSingleResults), 633 odsOp->getLoc()); 634 } 635 } 636 637 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 638 valueTy); 639 return success(); 640 } 641 return emitConvertError(); 642 } 643 644 // FIXME: Decide how to allow/support converting a single result to multiple, 645 // and multiple to a single result. For now, we just allow Single->Range, 646 // but this isn't something really supported in the PDL dialect. We should 647 // figure out some way to support both. 648 if ((exprType == valueTy || exprType == valueRangeTy) && 649 (type == valueTy || type == valueRangeTy)) 650 return success(); 651 if ((exprType == typeTy || exprType == typeRangeTy) && 652 (type == typeTy || type == typeRangeTy)) 653 return success(); 654 655 // Handle tuple types. 656 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 657 auto tupleType = type.dyn_cast<ast::TupleType>(); 658 if (!tupleType || tupleType.size() != exprTupleType.size()) 659 return emitConvertError(); 660 661 // Build a new tuple expression using each of the elements of the current 662 // tuple. 663 SmallVector<ast::Expr *> newExprs; 664 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 665 newExprs.push_back(ast::MemberAccessExpr::create( 666 ctx, expr->getLoc(), expr, llvm::to_string(i), 667 exprTupleType.getElementTypes()[i])); 668 669 auto diagFn = [&](ast::Diagnostic &diag) { 670 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 671 i, exprTupleType)); 672 if (noteAttachFn) 673 noteAttachFn(diag); 674 }; 675 if (failed(convertExpressionTo(newExprs.back(), 676 tupleType.getElementTypes()[i], diagFn))) 677 return failure(); 678 } 679 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 680 tupleType.getElementNames()); 681 return success(); 682 } 683 684 return emitConvertError(); 685 } 686 687 //===----------------------------------------------------------------------===// 688 // Directives 689 690 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 691 StringRef directive = curToken.getSpelling(); 692 if (directive == "#include") 693 return parseInclude(decls); 694 695 return emitError("unknown directive `" + directive + "`"); 696 } 697 698 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 699 SMRange loc = curToken.getLoc(); 700 consumeToken(Token::directive); 701 702 // Handle code completion of the include file path. 703 if (curToken.is(Token::code_complete_string)) 704 return codeCompleteIncludeFilename(curToken.getStringValue()); 705 706 // Parse the file being included. 707 if (!curToken.isString()) 708 return emitError(loc, 709 "expected string file name after `include` directive"); 710 SMRange fileLoc = curToken.getLoc(); 711 std::string filenameStr = curToken.getStringValue(); 712 StringRef filename = filenameStr; 713 consumeToken(); 714 715 // Check the type of include. If ending with `.pdll`, this is another pdl file 716 // to be parsed along with the current module. 717 if (filename.endswith(".pdll")) { 718 if (failed(lexer.pushInclude(filename, fileLoc))) 719 return emitError(fileLoc, 720 "unable to open include file `" + filename + "`"); 721 722 // If we added the include successfully, parse it into the current module. 723 // Make sure to update to the next token after we finish parsing the nested 724 // file. 725 curToken = lexer.lexToken(); 726 LogicalResult result = parseModuleBody(decls); 727 curToken = lexer.lexToken(); 728 return result; 729 } 730 731 // Otherwise, this must be a `.td` include. 732 if (filename.endswith(".td")) 733 return parseTdInclude(filename, fileLoc, decls); 734 735 return emitError(fileLoc, 736 "expected include filename to end with `.pdll` or `.td`"); 737 } 738 739 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 740 SmallVectorImpl<ast::Decl *> &decls) { 741 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 742 743 // Use the source manager to open the file, but don't yet add it. 744 std::string includedFile; 745 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 746 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 747 if (!includeBuffer) 748 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 749 750 // Setup the source manager for parsing the tablegen file. 751 llvm::SourceMgr tdSrcMgr; 752 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 753 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 754 755 // This class provides a context argument for the llvm::SourceMgr diagnostic 756 // handler. 757 struct DiagHandlerContext { 758 Parser &parser; 759 StringRef filename; 760 llvm::SMRange loc; 761 } handlerContext{*this, filename, fileLoc}; 762 763 // Set the diagnostic handler for the tablegen source manager. 764 tdSrcMgr.setDiagHandler( 765 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 766 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 767 (void)ctx->parser.emitError( 768 ctx->loc, 769 llvm::formatv("error while processing include file `{0}`: {1}", 770 ctx->filename, diag.getMessage())); 771 }, 772 &handlerContext); 773 774 // Parse the tablegen file. 775 llvm::RecordKeeper tdRecords; 776 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 777 return failure(); 778 779 // Process the parsed records. 780 processTdIncludeRecords(tdRecords, decls); 781 782 // After we are done processing, move all of the tablegen source buffers to 783 // the main parser source mgr. This allows for directly using source locations 784 // from the .td files without needing to remap them. 785 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 786 return success(); 787 } 788 789 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 790 SmallVectorImpl<ast::Decl *> &decls) { 791 // Return the length kind of the given value. 792 auto getLengthKind = [](const auto &value) { 793 if (value.isOptional()) 794 return ods::VariableLengthKind::Optional; 795 return value.isVariadic() ? ods::VariableLengthKind::Variadic 796 : ods::VariableLengthKind::Single; 797 }; 798 799 // Insert a type constraint into the ODS context. 800 ods::Context &odsContext = ctx.getODSContext(); 801 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 802 -> const ods::TypeConstraint & { 803 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 804 cst.constraint.getSummary(), 805 cst.constraint.getCPPClassName()); 806 }; 807 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 808 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 809 }; 810 811 // Process the parsed tablegen records to build ODS information. 812 /// Operations. 813 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 814 tblgen::Operator op(def); 815 816 // Check to see if this operation is known to support type inferrence. 817 bool supportsResultTypeInferrence = 818 op.getTrait("::mlir::InferTypeOpInterface::Trait"); 819 820 bool inserted = false; 821 ods::Operation *odsOp = nullptr; 822 std::tie(odsOp, inserted) = odsContext.insertOperation( 823 op.getOperationName(), op.getSummary(), op.getDescription(), 824 supportsResultTypeInferrence, op.getLoc().front()); 825 826 // Ignore operations that have already been added. 827 if (!inserted) 828 continue; 829 830 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 831 odsOp->appendAttribute( 832 attr.name, attr.attr.isOptional(), 833 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 834 attr.attr.getSummary(), 835 attr.attr.getStorageType())); 836 } 837 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 838 odsOp->appendOperand(operand.name, getLengthKind(operand), 839 addTypeConstraint(operand)); 840 } 841 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 842 odsOp->appendResult(result.name, getLengthKind(result), 843 addTypeConstraint(result)); 844 } 845 } 846 /// Attr constraints. 847 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 848 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 849 decls.push_back( 850 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 851 tblgen::AttrConstraint(def), 852 convertLocToRange(def->getLoc().front()), attrTy)); 853 } 854 } 855 /// Type constraints. 856 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 857 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 858 decls.push_back( 859 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 860 tblgen::TypeConstraint(def), 861 convertLocToRange(def->getLoc().front()), typeTy)); 862 } 863 } 864 /// Interfaces. 865 ast::Type opTy = ast::OperationType::get(ctx); 866 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 867 StringRef name = def->getName(); 868 if (def->isAnonymous() || curDeclScope->lookup(name) || 869 def->isSubClassOf("DeclareInterfaceMethods")) 870 continue; 871 SMRange loc = convertLocToRange(def->getLoc().front()); 872 873 StringRef className = def->getValueAsString("cppClassName"); 874 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 875 std::string codeBlock = 876 llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));", 877 cppNamespace, className) 878 .str(); 879 880 if (def->isSubClassOf("OpInterface")) { 881 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 882 name, codeBlock, loc, opTy)); 883 } else if (def->isSubClassOf("AttrInterface")) { 884 decls.push_back( 885 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 886 name, codeBlock, loc, attrTy)); 887 } else if (def->isSubClassOf("TypeInterface")) { 888 decls.push_back( 889 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 890 name, codeBlock, loc, typeTy)); 891 } 892 } 893 } 894 895 template <typename ConstraintT> 896 ast::Decl * 897 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 898 SMRange loc, ast::Type type) { 899 // Build the single input parameter. 900 ast::DeclScope *argScope = pushDeclScope(); 901 auto *paramVar = ast::VariableDecl::create( 902 ctx, ast::Name::create(ctx, "self", loc), type, 903 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 904 argScope->add(paramVar); 905 popDeclScope(); 906 907 // Build the native constraint. 908 auto *constraintDecl = ast::UserConstraintDecl::createNative( 909 ctx, ast::Name::create(ctx, name, loc), paramVar, 910 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 911 curDeclScope->add(constraintDecl); 912 return constraintDecl; 913 } 914 915 template <typename ConstraintT> 916 ast::Decl * 917 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 918 SMRange loc, ast::Type type) { 919 // Format the condition template. 920 tblgen::FmtContext fmtContext; 921 fmtContext.withSelf("self"); 922 std::string codeBlock = tblgen::tgfmt( 923 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 924 &fmtContext); 925 926 return createODSNativePDLLConstraintDecl<ConstraintT>( 927 constraint.getUniqueDefName(), codeBlock, loc, type); 928 } 929 930 //===----------------------------------------------------------------------===// 931 // Decls 932 933 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 934 FailureOr<ast::Decl *> decl; 935 switch (curToken.getKind()) { 936 case Token::kw_Constraint: 937 decl = parseUserConstraintDecl(); 938 break; 939 case Token::kw_Pattern: 940 decl = parsePatternDecl(); 941 break; 942 case Token::kw_Rewrite: 943 decl = parseUserRewriteDecl(); 944 break; 945 default: 946 return emitError("expected top-level declaration, such as a `Pattern`"); 947 } 948 if (failed(decl)) 949 return failure(); 950 951 // If the decl has a name, add it to the current scope. 952 if (const ast::Name *name = (*decl)->getName()) { 953 if (failed(checkDefineNamedDecl(*name))) 954 return failure(); 955 curDeclScope->add(*decl); 956 } 957 return decl; 958 } 959 960 FailureOr<ast::NamedAttributeDecl *> 961 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 962 // Check for name code completion. 963 if (curToken.is(Token::code_complete)) 964 return codeCompleteAttributeName(parentOpName); 965 966 std::string attrNameStr; 967 if (curToken.isString()) 968 attrNameStr = curToken.getStringValue(); 969 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 970 attrNameStr = curToken.getSpelling().str(); 971 else 972 return emitError("expected identifier or string attribute name"); 973 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 974 consumeToken(); 975 976 // Check for a value of the attribute. 977 ast::Expr *attrValue = nullptr; 978 if (consumeIf(Token::equal)) { 979 FailureOr<ast::Expr *> attrExpr = parseExpr(); 980 if (failed(attrExpr)) 981 return failure(); 982 attrValue = *attrExpr; 983 } else { 984 // If there isn't a concrete value, create an expression representing a 985 // UnitAttr. 986 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 987 } 988 989 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 990 } 991 992 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 993 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 994 bool expectTerminalSemicolon) { 995 consumeToken(Token::equal_arrow); 996 997 // Parse the single statement of the lambda body. 998 SMLoc bodyStartLoc = curToken.getStartLoc(); 999 pushDeclScope(); 1000 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 1001 bool failedToParse = 1002 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 1003 popDeclScope(); 1004 if (failedToParse) 1005 return failure(); 1006 1007 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 1008 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 1009 } 1010 1011 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 1012 // Ensure that the argument is named. 1013 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 1014 return emitError("expected identifier argument name"); 1015 1016 // Parse the argument similarly to a normal variable. 1017 StringRef name = curToken.getSpelling(); 1018 SMRange nameLoc = curToken.getLoc(); 1019 consumeToken(); 1020 1021 if (failed( 1022 parseToken(Token::colon, "expected `:` before argument constraint"))) 1023 return failure(); 1024 1025 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1026 if (failed(cst)) 1027 return failure(); 1028 1029 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1030 } 1031 1032 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1033 // Check to see if this result is named. 1034 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1035 // Check to see if this name actually refers to a Constraint. 1036 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1037 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1038 // If yes, and this is a Rewrite, give a nice error message as non-Core 1039 // constraints are not supported on Rewrite results. 1040 if (parserContext == ParserContext::Rewrite) { 1041 return emitError( 1042 "`Rewrite` results are only permitted to use core constraints, " 1043 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1044 } 1045 1046 // Otherwise, parse this as an unnamed result variable. 1047 } else { 1048 // If it wasn't a constraint, parse the result similarly to a variable. If 1049 // there is already an existing decl, we will emit an error when defining 1050 // this variable later. 1051 StringRef name = curToken.getSpelling(); 1052 SMRange nameLoc = curToken.getLoc(); 1053 consumeToken(); 1054 1055 if (failed(parseToken(Token::colon, 1056 "expected `:` before result constraint"))) 1057 return failure(); 1058 1059 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1060 if (failed(cst)) 1061 return failure(); 1062 1063 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1064 } 1065 } 1066 1067 // If it isn't named, we parse the constraint directly and create an unnamed 1068 // result variable. 1069 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1070 if (failed(cst)) 1071 return failure(); 1072 1073 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1074 } 1075 1076 FailureOr<ast::UserConstraintDecl *> 1077 Parser::parseUserConstraintDecl(bool isInline) { 1078 // Constraints and rewrites have very similar formats, dispatch to a shared 1079 // interface for parsing. 1080 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1081 [&](auto &&...args) { 1082 return this->parseUserPDLLConstraintDecl(args...); 1083 }, 1084 ParserContext::Constraint, "constraint", isInline); 1085 } 1086 1087 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1088 FailureOr<ast::UserConstraintDecl *> decl = 1089 parseUserConstraintDecl(/*isInline=*/true); 1090 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1091 return failure(); 1092 1093 curDeclScope->add(*decl); 1094 return decl; 1095 } 1096 1097 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1098 const ast::Name &name, bool isInline, 1099 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1100 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1101 // Push the argument scope back onto the list, so that the body can 1102 // reference arguments. 1103 pushDeclScope(argumentScope); 1104 1105 // Parse the body of the constraint. The body is either defined as a compound 1106 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1107 ast::CompoundStmt *body; 1108 if (curToken.is(Token::equal_arrow)) { 1109 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1110 [&](ast::Stmt *&stmt) -> LogicalResult { 1111 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1112 if (!stmtExpr) { 1113 return emitError(stmt->getLoc(), 1114 "expected `Constraint` lambda body to contain a " 1115 "single expression"); 1116 } 1117 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1118 return success(); 1119 }, 1120 /*expectTerminalSemicolon=*/!isInline); 1121 if (failed(bodyResult)) 1122 return failure(); 1123 body = *bodyResult; 1124 } else { 1125 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1126 if (failed(bodyResult)) 1127 return failure(); 1128 body = *bodyResult; 1129 1130 // Verify the structure of the body. 1131 auto bodyIt = body->begin(), bodyE = body->end(); 1132 for (; bodyIt != bodyE; ++bodyIt) 1133 if (isa<ast::ReturnStmt>(*bodyIt)) 1134 break; 1135 if (failed(validateUserConstraintOrRewriteReturn( 1136 "Constraint", body, bodyIt, bodyE, results, resultType))) 1137 return failure(); 1138 } 1139 popDeclScope(); 1140 1141 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1142 name, arguments, results, resultType, body); 1143 } 1144 1145 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1146 // Constraints and rewrites have very similar formats, dispatch to a shared 1147 // interface for parsing. 1148 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1149 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1150 ParserContext::Rewrite, "rewrite", isInline); 1151 } 1152 1153 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1154 FailureOr<ast::UserRewriteDecl *> decl = 1155 parseUserRewriteDecl(/*isInline=*/true); 1156 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1157 return failure(); 1158 1159 curDeclScope->add(*decl); 1160 return decl; 1161 } 1162 1163 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1164 const ast::Name &name, bool isInline, 1165 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1166 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1167 // Push the argument scope back onto the list, so that the body can 1168 // reference arguments. 1169 curDeclScope = argumentScope; 1170 ast::CompoundStmt *body; 1171 if (curToken.is(Token::equal_arrow)) { 1172 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1173 [&](ast::Stmt *&statement) -> LogicalResult { 1174 if (isa<ast::OpRewriteStmt>(statement)) 1175 return success(); 1176 1177 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1178 if (!statementExpr) { 1179 return emitError( 1180 statement->getLoc(), 1181 "expected `Rewrite` lambda body to contain a single expression " 1182 "or an operation rewrite statement; such as `erase`, " 1183 "`replace`, or `rewrite`"); 1184 } 1185 statement = 1186 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1187 return success(); 1188 }, 1189 /*expectTerminalSemicolon=*/!isInline); 1190 if (failed(bodyResult)) 1191 return failure(); 1192 body = *bodyResult; 1193 } else { 1194 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1195 if (failed(bodyResult)) 1196 return failure(); 1197 body = *bodyResult; 1198 } 1199 popDeclScope(); 1200 1201 // Verify the structure of the body. 1202 auto bodyIt = body->begin(), bodyE = body->end(); 1203 for (; bodyIt != bodyE; ++bodyIt) 1204 if (isa<ast::ReturnStmt>(*bodyIt)) 1205 break; 1206 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1207 bodyE, results, resultType))) 1208 return failure(); 1209 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1210 name, arguments, results, resultType, body); 1211 } 1212 1213 template <typename T, typename ParseUserPDLLDeclFnT> 1214 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1215 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1216 StringRef anonymousNamePrefix, bool isInline) { 1217 SMRange loc = curToken.getLoc(); 1218 consumeToken(); 1219 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1220 1221 // Parse the name of the decl. 1222 const ast::Name *name = nullptr; 1223 if (curToken.isNot(Token::identifier)) { 1224 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1225 // in C++, so being unnamed is fine. 1226 if (!isInline) 1227 return emitError("expected identifier name"); 1228 1229 // Create a unique anonymous name to use, as the name for this decl is not 1230 // important. 1231 std::string anonName = 1232 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1233 anonymousDeclNameCounter++) 1234 .str(); 1235 name = &ast::Name::create(ctx, anonName, loc); 1236 } else { 1237 // If a name was provided, we can use it directly. 1238 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1239 consumeToken(Token::identifier); 1240 } 1241 1242 // Parse the functional signature of the decl. 1243 SmallVector<ast::VariableDecl *> arguments, results; 1244 ast::DeclScope *argumentScope; 1245 ast::Type resultType; 1246 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1247 argumentScope, resultType))) 1248 return failure(); 1249 1250 // Check to see which type of constraint this is. If the constraint contains a 1251 // compound body, this is a PDLL decl. 1252 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1253 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1254 resultType); 1255 1256 // Otherwise, this is a native decl. 1257 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1258 results, resultType); 1259 } 1260 1261 template <typename T> 1262 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1263 const ast::Name &name, bool isInline, 1264 ArrayRef<ast::VariableDecl *> arguments, 1265 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1266 // If followed by a string, the native code body has also been specified. 1267 std::string codeStrStorage; 1268 Optional<StringRef> optCodeStr; 1269 if (curToken.isString()) { 1270 codeStrStorage = curToken.getStringValue(); 1271 optCodeStr = codeStrStorage; 1272 consumeToken(); 1273 } else if (isInline) { 1274 return emitError(name.getLoc(), 1275 "external declarations must be declared in global scope"); 1276 } else if (curToken.is(Token::error)) { 1277 return failure(); 1278 } 1279 if (failed(parseToken(Token::semicolon, 1280 "expected `;` after native declaration"))) 1281 return failure(); 1282 // TODO: PDL should be able to support constraint results in certain 1283 // situations, we should revise this. 1284 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1285 return emitError( 1286 "native Constraints currently do not support returning results"); 1287 } 1288 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1289 } 1290 1291 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1292 SmallVectorImpl<ast::VariableDecl *> &arguments, 1293 SmallVectorImpl<ast::VariableDecl *> &results, 1294 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1295 // Parse the argument list of the decl. 1296 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1297 return failure(); 1298 1299 argumentScope = pushDeclScope(); 1300 if (curToken.isNot(Token::r_paren)) { 1301 do { 1302 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1303 if (failed(argument)) 1304 return failure(); 1305 arguments.emplace_back(*argument); 1306 } while (consumeIf(Token::comma)); 1307 } 1308 popDeclScope(); 1309 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1310 return failure(); 1311 1312 // Parse the results of the decl. 1313 pushDeclScope(); 1314 if (consumeIf(Token::arrow)) { 1315 auto parseResultFn = [&]() -> LogicalResult { 1316 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1317 if (failed(result)) 1318 return failure(); 1319 results.emplace_back(*result); 1320 return success(); 1321 }; 1322 1323 // Check for a list of results. 1324 if (consumeIf(Token::l_paren)) { 1325 do { 1326 if (failed(parseResultFn())) 1327 return failure(); 1328 } while (consumeIf(Token::comma)); 1329 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1330 return failure(); 1331 1332 // Otherwise, there is only one result. 1333 } else if (failed(parseResultFn())) { 1334 return failure(); 1335 } 1336 } 1337 popDeclScope(); 1338 1339 // Compute the result type of the decl. 1340 resultType = createUserConstraintRewriteResultType(results); 1341 1342 // Verify that results are only named if there are more than one. 1343 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1344 return emitError( 1345 results.front()->getLoc(), 1346 "cannot create a single-element tuple with an element label"); 1347 } 1348 return success(); 1349 } 1350 1351 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1352 StringRef declType, ast::CompoundStmt *body, 1353 ArrayRef<ast::Stmt *>::iterator bodyIt, 1354 ArrayRef<ast::Stmt *>::iterator bodyE, 1355 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1356 // Handle if a `return` was provided. 1357 if (bodyIt != bodyE) { 1358 // Emit an error if we have trailing statements after the return. 1359 if (std::next(bodyIt) != bodyE) { 1360 return emitError( 1361 (*std::next(bodyIt))->getLoc(), 1362 llvm::formatv("`return` terminated the `{0}` body, but found " 1363 "trailing statements afterwards", 1364 declType)); 1365 } 1366 1367 // Otherwise if a return wasn't provided, check that no results are 1368 // expected. 1369 } else if (!results.empty()) { 1370 return emitError( 1371 {body->getLoc().End, body->getLoc().End}, 1372 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1373 declType, resultType)); 1374 } 1375 return success(); 1376 } 1377 1378 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1379 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1380 if (isa<ast::OpRewriteStmt>(statement)) 1381 return success(); 1382 return emitError( 1383 statement->getLoc(), 1384 "expected Pattern lambda body to contain a single operation " 1385 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1386 }); 1387 } 1388 1389 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1390 SMRange loc = curToken.getLoc(); 1391 consumeToken(Token::kw_Pattern); 1392 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1393 ParserContext::PatternMatch); 1394 1395 // Check for an optional identifier for the pattern name. 1396 const ast::Name *name = nullptr; 1397 if (curToken.is(Token::identifier)) { 1398 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1399 consumeToken(Token::identifier); 1400 } 1401 1402 // Parse any pattern metadata. 1403 ParsedPatternMetadata metadata; 1404 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1405 return failure(); 1406 1407 // Parse the pattern body. 1408 ast::CompoundStmt *body; 1409 1410 // Handle a lambda body. 1411 if (curToken.is(Token::equal_arrow)) { 1412 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1413 if (failed(bodyResult)) 1414 return failure(); 1415 body = *bodyResult; 1416 } else { 1417 if (curToken.isNot(Token::l_brace)) 1418 return emitError("expected `{` or `=>` to start pattern body"); 1419 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1420 if (failed(bodyResult)) 1421 return failure(); 1422 body = *bodyResult; 1423 1424 // Verify the body of the pattern. 1425 auto bodyIt = body->begin(), bodyE = body->end(); 1426 for (; bodyIt != bodyE; ++bodyIt) { 1427 if (isa<ast::ReturnStmt>(*bodyIt)) { 1428 return emitError((*bodyIt)->getLoc(), 1429 "`return` statements are only permitted within a " 1430 "`Constraint` or `Rewrite` body"); 1431 } 1432 // Break when we've found the rewrite statement. 1433 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1434 break; 1435 } 1436 if (bodyIt == bodyE) { 1437 return emitError(loc, 1438 "expected Pattern body to terminate with an operation " 1439 "rewrite statement, such as `erase`"); 1440 } 1441 if (std::next(bodyIt) != bodyE) { 1442 return emitError((*std::next(bodyIt))->getLoc(), 1443 "Pattern body was terminated by an operation " 1444 "rewrite statement, but found trailing statements"); 1445 } 1446 } 1447 1448 return createPatternDecl(loc, name, metadata, body); 1449 } 1450 1451 LogicalResult 1452 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1453 Optional<SMRange> benefitLoc; 1454 Optional<SMRange> hasBoundedRecursionLoc; 1455 1456 do { 1457 // Handle metadata code completion. 1458 if (curToken.is(Token::code_complete)) 1459 return codeCompletePatternMetadata(); 1460 1461 if (curToken.isNot(Token::identifier)) 1462 return emitError("expected pattern metadata identifier"); 1463 StringRef metadataStr = curToken.getSpelling(); 1464 SMRange metadataLoc = curToken.getLoc(); 1465 consumeToken(Token::identifier); 1466 1467 // Parse the benefit metadata: benefit(<integer-value>) 1468 if (metadataStr == "benefit") { 1469 if (benefitLoc) { 1470 return emitErrorAndNote(metadataLoc, 1471 "pattern benefit has already been specified", 1472 *benefitLoc, "see previous definition here"); 1473 } 1474 if (failed(parseToken(Token::l_paren, 1475 "expected `(` before pattern benefit"))) 1476 return failure(); 1477 1478 uint16_t benefitValue = 0; 1479 if (curToken.isNot(Token::integer)) 1480 return emitError("expected integral pattern benefit"); 1481 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1482 return emitError( 1483 "expected pattern benefit to fit within a 16-bit integer"); 1484 consumeToken(Token::integer); 1485 1486 metadata.benefit = benefitValue; 1487 benefitLoc = metadataLoc; 1488 1489 if (failed( 1490 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1491 return failure(); 1492 continue; 1493 } 1494 1495 // Parse the bounded recursion metadata: recursion 1496 if (metadataStr == "recursion") { 1497 if (hasBoundedRecursionLoc) { 1498 return emitErrorAndNote( 1499 metadataLoc, 1500 "pattern recursion metadata has already been specified", 1501 *hasBoundedRecursionLoc, "see previous definition here"); 1502 } 1503 metadata.hasBoundedRecursion = true; 1504 hasBoundedRecursionLoc = metadataLoc; 1505 continue; 1506 } 1507 1508 return emitError(metadataLoc, "unknown pattern metadata"); 1509 } while (consumeIf(Token::comma)); 1510 1511 return success(); 1512 } 1513 1514 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1515 consumeToken(Token::less); 1516 1517 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1518 if (failed(typeExpr) || 1519 failed(parseToken(Token::greater, 1520 "expected `>` after variable type constraint"))) 1521 return failure(); 1522 return typeExpr; 1523 } 1524 1525 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1526 assert(curDeclScope && "defining decl outside of a decl scope"); 1527 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1528 return emitErrorAndNote( 1529 name.getLoc(), "`" + name.getName() + "` has already been defined", 1530 lastDecl->getName()->getLoc(), "see previous definition here"); 1531 } 1532 return success(); 1533 } 1534 1535 FailureOr<ast::VariableDecl *> 1536 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1537 ast::Expr *initExpr, 1538 ArrayRef<ast::ConstraintRef> constraints) { 1539 assert(curDeclScope && "defining variable outside of decl scope"); 1540 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1541 1542 // If the name of the variable indicates a special variable, we don't add it 1543 // to the scope. This variable is local to the definition point. 1544 if (name.empty() || name == "_") { 1545 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1546 constraints); 1547 } 1548 if (failed(checkDefineNamedDecl(nameDecl))) 1549 return failure(); 1550 1551 auto *varDecl = 1552 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1553 curDeclScope->add(varDecl); 1554 return varDecl; 1555 } 1556 1557 FailureOr<ast::VariableDecl *> 1558 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1559 ArrayRef<ast::ConstraintRef> constraints) { 1560 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1561 constraints); 1562 } 1563 1564 LogicalResult Parser::parseVariableDeclConstraintList( 1565 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1566 Optional<SMRange> typeConstraint; 1567 auto parseSingleConstraint = [&] { 1568 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1569 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1570 /*allowNonCoreConstraints=*/true); 1571 if (failed(constraint)) 1572 return failure(); 1573 constraints.push_back(*constraint); 1574 return success(); 1575 }; 1576 1577 // Check to see if this is a single constraint, or a list. 1578 if (!consumeIf(Token::l_square)) 1579 return parseSingleConstraint(); 1580 1581 do { 1582 if (failed(parseSingleConstraint())) 1583 return failure(); 1584 } while (consumeIf(Token::comma)); 1585 return parseToken(Token::r_square, "expected `]` after constraint list"); 1586 } 1587 1588 FailureOr<ast::ConstraintRef> 1589 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1590 ArrayRef<ast::ConstraintRef> existingConstraints, 1591 bool allowInlineTypeConstraints, 1592 bool allowNonCoreConstraints) { 1593 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1594 if (!allowInlineTypeConstraints) { 1595 return emitError( 1596 curToken.getLoc(), 1597 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1598 "permitted on arguments or results"); 1599 } 1600 if (typeConstraint) 1601 return emitErrorAndNote( 1602 curToken.getLoc(), 1603 "the type of this variable has already been constrained", 1604 *typeConstraint, "see previous constraint location here"); 1605 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1606 if (failed(constraintExpr)) 1607 return failure(); 1608 typeExpr = *constraintExpr; 1609 typeConstraint = typeExpr->getLoc(); 1610 return success(); 1611 }; 1612 1613 SMRange loc = curToken.getLoc(); 1614 switch (curToken.getKind()) { 1615 case Token::kw_Attr: { 1616 consumeToken(Token::kw_Attr); 1617 1618 // Check for a type constraint. 1619 ast::Expr *typeExpr = nullptr; 1620 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1621 return failure(); 1622 return ast::ConstraintRef( 1623 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1624 } 1625 case Token::kw_Op: { 1626 consumeToken(Token::kw_Op); 1627 1628 // Parse an optional operation name. If the name isn't provided, this refers 1629 // to "any" operation. 1630 FailureOr<ast::OpNameDecl *> opName = 1631 parseWrappedOperationName(/*allowEmptyName=*/true); 1632 if (failed(opName)) 1633 return failure(); 1634 1635 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1636 loc); 1637 } 1638 case Token::kw_Type: 1639 consumeToken(Token::kw_Type); 1640 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1641 case Token::kw_TypeRange: 1642 consumeToken(Token::kw_TypeRange); 1643 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1644 loc); 1645 case Token::kw_Value: { 1646 consumeToken(Token::kw_Value); 1647 1648 // Check for a type constraint. 1649 ast::Expr *typeExpr = nullptr; 1650 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1651 return failure(); 1652 1653 return ast::ConstraintRef( 1654 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1655 } 1656 case Token::kw_ValueRange: { 1657 consumeToken(Token::kw_ValueRange); 1658 1659 // Check for a type constraint. 1660 ast::Expr *typeExpr = nullptr; 1661 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1662 return failure(); 1663 1664 return ast::ConstraintRef( 1665 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1666 } 1667 1668 case Token::kw_Constraint: { 1669 // Handle an inline constraint. 1670 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1671 if (failed(decl)) 1672 return failure(); 1673 return ast::ConstraintRef(*decl, loc); 1674 } 1675 case Token::identifier: { 1676 StringRef constraintName = curToken.getSpelling(); 1677 consumeToken(Token::identifier); 1678 1679 // Lookup the referenced constraint. 1680 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1681 if (!cstDecl) { 1682 return emitError(loc, "unknown reference to constraint `" + 1683 constraintName + "`"); 1684 } 1685 1686 // Handle a reference to a proper constraint. 1687 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1688 return ast::ConstraintRef(cst, loc); 1689 1690 return emitErrorAndNote( 1691 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1692 "see the definition of `" + constraintName + "` here"); 1693 } 1694 // Handle single entity constraint code completion. 1695 case Token::code_complete: { 1696 // Try to infer the current type for use by code completion. 1697 ast::Type inferredType; 1698 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1699 allowNonCoreConstraints))) 1700 return failure(); 1701 1702 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1703 allowInlineTypeConstraints); 1704 } 1705 default: 1706 break; 1707 } 1708 return emitError(loc, "expected identifier constraint"); 1709 } 1710 1711 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1712 // Constraint arguments may apply more complex constraints via the arguments. 1713 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1714 1715 Optional<SMRange> typeConstraint; 1716 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1717 /*allowInlineTypeConstraints=*/false, 1718 allowNonCoreConstraints); 1719 } 1720 1721 //===----------------------------------------------------------------------===// 1722 // Exprs 1723 1724 FailureOr<ast::Expr *> Parser::parseExpr() { 1725 if (curToken.is(Token::underscore)) 1726 return parseUnderscoreExpr(); 1727 1728 // Parse the LHS expression. 1729 FailureOr<ast::Expr *> lhsExpr; 1730 switch (curToken.getKind()) { 1731 case Token::kw_attr: 1732 lhsExpr = parseAttributeExpr(); 1733 break; 1734 case Token::kw_Constraint: 1735 lhsExpr = parseInlineConstraintLambdaExpr(); 1736 break; 1737 case Token::identifier: 1738 lhsExpr = parseIdentifierExpr(); 1739 break; 1740 case Token::kw_op: 1741 lhsExpr = parseOperationExpr(); 1742 break; 1743 case Token::kw_Rewrite: 1744 lhsExpr = parseInlineRewriteLambdaExpr(); 1745 break; 1746 case Token::kw_type: 1747 lhsExpr = parseTypeExpr(); 1748 break; 1749 case Token::l_paren: 1750 lhsExpr = parseTupleExpr(); 1751 break; 1752 default: 1753 return emitError("expected expression"); 1754 } 1755 if (failed(lhsExpr)) 1756 return failure(); 1757 1758 // Check for an operator expression. 1759 while (true) { 1760 switch (curToken.getKind()) { 1761 case Token::dot: 1762 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1763 break; 1764 case Token::l_paren: 1765 lhsExpr = parseCallExpr(*lhsExpr); 1766 break; 1767 default: 1768 return lhsExpr; 1769 } 1770 if (failed(lhsExpr)) 1771 return failure(); 1772 } 1773 } 1774 1775 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1776 SMRange loc = curToken.getLoc(); 1777 consumeToken(Token::kw_attr); 1778 1779 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1780 // identifier. 1781 if (!consumeIf(Token::less)) { 1782 resetToken(loc); 1783 return parseIdentifierExpr(); 1784 } 1785 1786 if (!curToken.isString()) 1787 return emitError("expected string literal containing MLIR attribute"); 1788 std::string attrExpr = curToken.getStringValue(); 1789 consumeToken(); 1790 1791 loc.End = curToken.getEndLoc(); 1792 if (failed( 1793 parseToken(Token::greater, "expected `>` after attribute literal"))) 1794 return failure(); 1795 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1796 } 1797 1798 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1799 consumeToken(Token::l_paren); 1800 1801 // Parse the arguments of the call. 1802 SmallVector<ast::Expr *> arguments; 1803 if (curToken.isNot(Token::r_paren)) { 1804 do { 1805 // Handle code completion for the call arguments. 1806 if (curToken.is(Token::code_complete)) { 1807 codeCompleteCallSignature(parentExpr, arguments.size()); 1808 return failure(); 1809 } 1810 1811 FailureOr<ast::Expr *> argument = parseExpr(); 1812 if (failed(argument)) 1813 return failure(); 1814 arguments.push_back(*argument); 1815 } while (consumeIf(Token::comma)); 1816 } 1817 1818 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1819 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1820 return failure(); 1821 1822 return createCallExpr(loc, parentExpr, arguments); 1823 } 1824 1825 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1826 ast::Decl *decl = curDeclScope->lookup(name); 1827 if (!decl) 1828 return emitError(loc, "undefined reference to `" + name + "`"); 1829 1830 return createDeclRefExpr(loc, decl); 1831 } 1832 1833 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1834 StringRef name = curToken.getSpelling(); 1835 SMRange nameLoc = curToken.getLoc(); 1836 consumeToken(); 1837 1838 // Check to see if this is a decl ref expression that defines a variable 1839 // inline. 1840 if (consumeIf(Token::colon)) { 1841 SmallVector<ast::ConstraintRef> constraints; 1842 if (failed(parseVariableDeclConstraintList(constraints))) 1843 return failure(); 1844 ast::Type type; 1845 if (failed(validateVariableConstraints(constraints, type))) 1846 return failure(); 1847 return createInlineVariableExpr(type, name, nameLoc, constraints); 1848 } 1849 1850 return parseDeclRefExpr(name, nameLoc); 1851 } 1852 1853 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1854 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1855 if (failed(decl)) 1856 return failure(); 1857 1858 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1859 ast::ConstraintType::get(ctx)); 1860 } 1861 1862 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1863 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1864 if (failed(decl)) 1865 return failure(); 1866 1867 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1868 ast::RewriteType::get(ctx)); 1869 } 1870 1871 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1872 SMRange dotLoc = curToken.getLoc(); 1873 consumeToken(Token::dot); 1874 1875 // Check for code completion of the member name. 1876 if (curToken.is(Token::code_complete)) 1877 return codeCompleteMemberAccess(parentExpr); 1878 1879 // Parse the member name. 1880 Token memberNameTok = curToken; 1881 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1882 !memberNameTok.isKeyword()) 1883 return emitError(dotLoc, "expected identifier or numeric member name"); 1884 StringRef memberName = memberNameTok.getSpelling(); 1885 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1886 consumeToken(); 1887 1888 return createMemberAccessExpr(parentExpr, memberName, loc); 1889 } 1890 1891 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1892 SMRange loc = curToken.getLoc(); 1893 1894 // Check for code completion for the dialect name. 1895 if (curToken.is(Token::code_complete)) 1896 return codeCompleteDialectName(); 1897 1898 // Handle the case of an no operation name. 1899 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1900 if (allowEmptyName) 1901 return ast::OpNameDecl::create(ctx, SMRange()); 1902 return emitError("expected dialect namespace"); 1903 } 1904 StringRef name = curToken.getSpelling(); 1905 consumeToken(); 1906 1907 // Otherwise, this is a literal operation name. 1908 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1909 return failure(); 1910 1911 // Check for code completion for the operation name. 1912 if (curToken.is(Token::code_complete)) 1913 return codeCompleteOperationName(name); 1914 1915 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1916 return emitError("expected operation name after dialect namespace"); 1917 1918 name = StringRef(name.data(), name.size() + 1); 1919 do { 1920 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1921 loc.End = curToken.getEndLoc(); 1922 consumeToken(); 1923 } while (curToken.isAny(Token::identifier, Token::dot) || 1924 curToken.isKeyword()); 1925 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1926 } 1927 1928 FailureOr<ast::OpNameDecl *> 1929 Parser::parseWrappedOperationName(bool allowEmptyName) { 1930 if (!consumeIf(Token::less)) 1931 return ast::OpNameDecl::create(ctx, SMRange()); 1932 1933 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1934 if (failed(opNameDecl)) 1935 return failure(); 1936 1937 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1938 return failure(); 1939 return opNameDecl; 1940 } 1941 1942 FailureOr<ast::Expr *> 1943 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { 1944 SMRange loc = curToken.getLoc(); 1945 consumeToken(Token::kw_op); 1946 1947 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1948 // identifier. 1949 if (curToken.isNot(Token::less)) { 1950 resetToken(loc); 1951 return parseIdentifierExpr(); 1952 } 1953 1954 // Parse the operation name. The name may be elided, in which case the 1955 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1956 // `Operation*`). Operation names within a rewrite context must be named. 1957 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1958 FailureOr<ast::OpNameDecl *> opNameDecl = 1959 parseWrappedOperationName(allowEmptyName); 1960 if (failed(opNameDecl)) 1961 return failure(); 1962 Optional<StringRef> opName = (*opNameDecl)->getName(); 1963 1964 // Functor used to create an implicit range variable, used for implicit "all" 1965 // operand or results variables. 1966 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1967 FailureOr<ast::VariableDecl *> rangeVar = 1968 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1969 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1970 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1971 }; 1972 1973 // Check for the optional list of operands. 1974 SmallVector<ast::Expr *> operands; 1975 if (!consumeIf(Token::l_paren)) { 1976 // If the operand list isn't specified and we are in a match context, define 1977 // an inplace unconstrained operand range corresponding to all of the 1978 // operands of the operation. This avoids treating zero operands the same 1979 // way as "unconstrained operands". 1980 if (parserContext != ParserContext::Rewrite) { 1981 operands.push_back(createImplicitRangeVar( 1982 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1983 } 1984 } else if (!consumeIf(Token::r_paren)) { 1985 // Check for operand signature code completion. 1986 if (curToken.is(Token::code_complete)) { 1987 codeCompleteOperationOperandsSignature(opName, operands.size()); 1988 return failure(); 1989 } 1990 1991 // If the operand list was specified and non-empty, parse the operands. 1992 do { 1993 FailureOr<ast::Expr *> operand = parseExpr(); 1994 if (failed(operand)) 1995 return failure(); 1996 operands.push_back(*operand); 1997 } while (consumeIf(Token::comma)); 1998 1999 if (failed(parseToken(Token::r_paren, 2000 "expected `)` after operation operand list"))) 2001 return failure(); 2002 } 2003 2004 // Check for the optional list of attributes. 2005 SmallVector<ast::NamedAttributeDecl *> attributes; 2006 if (consumeIf(Token::l_brace)) { 2007 do { 2008 FailureOr<ast::NamedAttributeDecl *> decl = 2009 parseNamedAttributeDecl(opName); 2010 if (failed(decl)) 2011 return failure(); 2012 attributes.emplace_back(*decl); 2013 } while (consumeIf(Token::comma)); 2014 2015 if (failed(parseToken(Token::r_brace, 2016 "expected `}` after operation attribute list"))) 2017 return failure(); 2018 } 2019 2020 // Handle the result types of the operation. 2021 SmallVector<ast::Expr *> resultTypes; 2022 OpResultTypeContext resultTypeContext = inputResultTypeContext; 2023 2024 // Check for an explicit list of result types. 2025 if (consumeIf(Token::arrow)) { 2026 if (failed(parseToken(Token::l_paren, 2027 "expected `(` before operation result type list"))) 2028 return failure(); 2029 2030 // If result types are provided, initially assume that the operation does 2031 // not rely on type inferrence. We don't assert that it isn't, because we 2032 // may be inferring the value of some type/type range variables, but given 2033 // that these variables may be defined in calls we can't always discern when 2034 // this is the case. 2035 resultTypeContext = OpResultTypeContext::Explicit; 2036 2037 // Handle the case of an empty result list. 2038 if (!consumeIf(Token::r_paren)) { 2039 do { 2040 // Check for result signature code completion. 2041 if (curToken.is(Token::code_complete)) { 2042 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2043 return failure(); 2044 } 2045 2046 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2047 if (failed(resultTypeExpr)) 2048 return failure(); 2049 resultTypes.push_back(*resultTypeExpr); 2050 } while (consumeIf(Token::comma)); 2051 2052 if (failed(parseToken(Token::r_paren, 2053 "expected `)` after operation result type list"))) 2054 return failure(); 2055 } 2056 } else if (parserContext != ParserContext::Rewrite) { 2057 // If the result list isn't specified and we are in a match context, define 2058 // an inplace unconstrained result range corresponding to all of the results 2059 // of the operation. This avoids treating zero results the same way as 2060 // "unconstrained results". 2061 resultTypes.push_back(createImplicitRangeVar( 2062 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2063 } else if (resultTypeContext == OpResultTypeContext::Explicit) { 2064 // If the result list isn't specified and we are in a rewrite, try to infer 2065 // them at runtime instead. 2066 resultTypeContext = OpResultTypeContext::Interface; 2067 } 2068 2069 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands, 2070 attributes, resultTypes); 2071 } 2072 2073 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2074 SMRange loc = curToken.getLoc(); 2075 consumeToken(Token::l_paren); 2076 2077 DenseMap<StringRef, SMRange> usedNames; 2078 SmallVector<StringRef> elementNames; 2079 SmallVector<ast::Expr *> elements; 2080 if (curToken.isNot(Token::r_paren)) { 2081 do { 2082 // Check for the optional element name assignment before the value. 2083 StringRef elementName; 2084 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2085 Token elementNameTok = curToken; 2086 consumeToken(); 2087 2088 // The element name is only present if followed by an `=`. 2089 if (consumeIf(Token::equal)) { 2090 elementName = elementNameTok.getSpelling(); 2091 2092 // Check to see if this name is already used. 2093 auto elementNameIt = 2094 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2095 if (!elementNameIt.second) { 2096 return emitErrorAndNote( 2097 elementNameTok.getLoc(), 2098 llvm::formatv("duplicate tuple element label `{0}`", 2099 elementName), 2100 elementNameIt.first->getSecond(), 2101 "see previous label use here"); 2102 } 2103 } else { 2104 // Otherwise, we treat this as part of an expression so reset the 2105 // lexer. 2106 resetToken(elementNameTok.getLoc()); 2107 } 2108 } 2109 elementNames.push_back(elementName); 2110 2111 // Parse the tuple element value. 2112 FailureOr<ast::Expr *> element = parseExpr(); 2113 if (failed(element)) 2114 return failure(); 2115 elements.push_back(*element); 2116 } while (consumeIf(Token::comma)); 2117 } 2118 loc.End = curToken.getEndLoc(); 2119 if (failed( 2120 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2121 return failure(); 2122 return createTupleExpr(loc, elements, elementNames); 2123 } 2124 2125 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2126 SMRange loc = curToken.getLoc(); 2127 consumeToken(Token::kw_type); 2128 2129 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2130 // identifier. 2131 if (!consumeIf(Token::less)) { 2132 resetToken(loc); 2133 return parseIdentifierExpr(); 2134 } 2135 2136 if (!curToken.isString()) 2137 return emitError("expected string literal containing MLIR type"); 2138 std::string attrExpr = curToken.getStringValue(); 2139 consumeToken(); 2140 2141 loc.End = curToken.getEndLoc(); 2142 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2143 return failure(); 2144 return ast::TypeExpr::create(ctx, loc, attrExpr); 2145 } 2146 2147 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2148 StringRef name = curToken.getSpelling(); 2149 SMRange nameLoc = curToken.getLoc(); 2150 consumeToken(Token::underscore); 2151 2152 // Underscore expressions require a constraint list. 2153 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2154 return failure(); 2155 2156 // Parse the constraints for the expression. 2157 SmallVector<ast::ConstraintRef> constraints; 2158 if (failed(parseVariableDeclConstraintList(constraints))) 2159 return failure(); 2160 2161 ast::Type type; 2162 if (failed(validateVariableConstraints(constraints, type))) 2163 return failure(); 2164 return createInlineVariableExpr(type, name, nameLoc, constraints); 2165 } 2166 2167 //===----------------------------------------------------------------------===// 2168 // Stmts 2169 2170 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2171 FailureOr<ast::Stmt *> stmt; 2172 switch (curToken.getKind()) { 2173 case Token::kw_erase: 2174 stmt = parseEraseStmt(); 2175 break; 2176 case Token::kw_let: 2177 stmt = parseLetStmt(); 2178 break; 2179 case Token::kw_replace: 2180 stmt = parseReplaceStmt(); 2181 break; 2182 case Token::kw_return: 2183 stmt = parseReturnStmt(); 2184 break; 2185 case Token::kw_rewrite: 2186 stmt = parseRewriteStmt(); 2187 break; 2188 default: 2189 stmt = parseExpr(); 2190 break; 2191 } 2192 if (failed(stmt) || 2193 (expectTerminalSemicolon && 2194 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2195 return failure(); 2196 return stmt; 2197 } 2198 2199 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2200 SMLoc startLoc = curToken.getStartLoc(); 2201 consumeToken(Token::l_brace); 2202 2203 // Push a new block scope and parse any nested statements. 2204 pushDeclScope(); 2205 SmallVector<ast::Stmt *> statements; 2206 while (curToken.isNot(Token::r_brace)) { 2207 FailureOr<ast::Stmt *> statement = parseStmt(); 2208 if (failed(statement)) 2209 return popDeclScope(), failure(); 2210 statements.push_back(*statement); 2211 } 2212 popDeclScope(); 2213 2214 // Consume the end brace. 2215 SMRange location(startLoc, curToken.getEndLoc()); 2216 consumeToken(Token::r_brace); 2217 2218 return ast::CompoundStmt::create(ctx, location, statements); 2219 } 2220 2221 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2222 if (parserContext == ParserContext::Constraint) 2223 return emitError("`erase` cannot be used within a Constraint"); 2224 SMRange loc = curToken.getLoc(); 2225 consumeToken(Token::kw_erase); 2226 2227 // Parse the root operation expression. 2228 FailureOr<ast::Expr *> rootOp = parseExpr(); 2229 if (failed(rootOp)) 2230 return failure(); 2231 2232 return createEraseStmt(loc, *rootOp); 2233 } 2234 2235 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2236 SMRange loc = curToken.getLoc(); 2237 consumeToken(Token::kw_let); 2238 2239 // Parse the name of the new variable. 2240 SMRange varLoc = curToken.getLoc(); 2241 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2242 // `_` is a reserved variable name. 2243 if (curToken.is(Token::underscore)) { 2244 return emitError(varLoc, 2245 "`_` may only be used to define \"inline\" variables"); 2246 } 2247 return emitError(varLoc, 2248 "expected identifier after `let` to name a new variable"); 2249 } 2250 StringRef varName = curToken.getSpelling(); 2251 consumeToken(); 2252 2253 // Parse the optional set of constraints. 2254 SmallVector<ast::ConstraintRef> constraints; 2255 if (consumeIf(Token::colon) && 2256 failed(parseVariableDeclConstraintList(constraints))) 2257 return failure(); 2258 2259 // Parse the optional initializer expression. 2260 ast::Expr *initializer = nullptr; 2261 if (consumeIf(Token::equal)) { 2262 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2263 if (failed(initOrFailure)) 2264 return failure(); 2265 initializer = *initOrFailure; 2266 2267 // Check that the constraints are compatible with having an initializer, 2268 // e.g. type constraints cannot be used with initializers. 2269 for (ast::ConstraintRef constraint : constraints) { 2270 LogicalResult result = 2271 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2272 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2273 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2274 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2275 return this->emitError( 2276 constraint.referenceLoc, 2277 "type constraints are not permitted on variables with " 2278 "initializers"); 2279 } 2280 return success(); 2281 }) 2282 .Default(success()); 2283 if (failed(result)) 2284 return failure(); 2285 } 2286 } 2287 2288 FailureOr<ast::VariableDecl *> varDecl = 2289 createVariableDecl(varName, varLoc, initializer, constraints); 2290 if (failed(varDecl)) 2291 return failure(); 2292 return ast::LetStmt::create(ctx, loc, *varDecl); 2293 } 2294 2295 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2296 if (parserContext == ParserContext::Constraint) 2297 return emitError("`replace` cannot be used within a Constraint"); 2298 SMRange loc = curToken.getLoc(); 2299 consumeToken(Token::kw_replace); 2300 2301 // Parse the root operation expression. 2302 FailureOr<ast::Expr *> rootOp = parseExpr(); 2303 if (failed(rootOp)) 2304 return failure(); 2305 2306 if (failed( 2307 parseToken(Token::kw_with, "expected `with` after root operation"))) 2308 return failure(); 2309 2310 // The replacement portion of this statement is within a rewrite context. 2311 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2312 ParserContext::Rewrite); 2313 2314 // Parse the replacement values. 2315 SmallVector<ast::Expr *> replValues; 2316 if (consumeIf(Token::l_paren)) { 2317 if (consumeIf(Token::r_paren)) { 2318 return emitError( 2319 loc, "expected at least one replacement value, consider using " 2320 "`erase` if no replacement values are desired"); 2321 } 2322 2323 do { 2324 FailureOr<ast::Expr *> replExpr = parseExpr(); 2325 if (failed(replExpr)) 2326 return failure(); 2327 replValues.emplace_back(*replExpr); 2328 } while (consumeIf(Token::comma)); 2329 2330 if (failed(parseToken(Token::r_paren, 2331 "expected `)` after replacement values"))) 2332 return failure(); 2333 } else { 2334 // Handle replacement with an operation uniquely, as the replacement 2335 // operation supports type inferrence from the root operation. 2336 FailureOr<ast::Expr *> replExpr; 2337 if (curToken.is(Token::kw_op)) 2338 replExpr = parseOperationExpr(OpResultTypeContext::Replacement); 2339 else 2340 replExpr = parseExpr(); 2341 if (failed(replExpr)) 2342 return failure(); 2343 replValues.emplace_back(*replExpr); 2344 } 2345 2346 return createReplaceStmt(loc, *rootOp, replValues); 2347 } 2348 2349 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2350 SMRange loc = curToken.getLoc(); 2351 consumeToken(Token::kw_return); 2352 2353 // Parse the result value. 2354 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2355 if (failed(resultExpr)) 2356 return failure(); 2357 2358 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2359 } 2360 2361 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2362 if (parserContext == ParserContext::Constraint) 2363 return emitError("`rewrite` cannot be used within a Constraint"); 2364 SMRange loc = curToken.getLoc(); 2365 consumeToken(Token::kw_rewrite); 2366 2367 // Parse the root operation. 2368 FailureOr<ast::Expr *> rootOp = parseExpr(); 2369 if (failed(rootOp)) 2370 return failure(); 2371 2372 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2373 return failure(); 2374 2375 if (curToken.isNot(Token::l_brace)) 2376 return emitError("expected `{` to start rewrite body"); 2377 2378 // The rewrite body of this statement is within a rewrite context. 2379 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2380 ParserContext::Rewrite); 2381 2382 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2383 if (failed(rewriteBody)) 2384 return failure(); 2385 2386 // Verify the rewrite body. 2387 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2388 if (isa<ast::ReturnStmt>(stmt)) { 2389 return emitError(stmt->getLoc(), 2390 "`return` statements are only permitted within a " 2391 "`Constraint` or `Rewrite` body"); 2392 } 2393 } 2394 2395 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2396 } 2397 2398 //===----------------------------------------------------------------------===// 2399 // Creation+Analysis 2400 //===----------------------------------------------------------------------===// 2401 2402 //===----------------------------------------------------------------------===// 2403 // Decls 2404 2405 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2406 // Unwrap reference expressions. 2407 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2408 node = init->getDecl(); 2409 return dyn_cast<ast::CallableDecl>(node); 2410 } 2411 2412 FailureOr<ast::PatternDecl *> 2413 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2414 const ParsedPatternMetadata &metadata, 2415 ast::CompoundStmt *body) { 2416 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2417 metadata.hasBoundedRecursion, body); 2418 } 2419 2420 ast::Type Parser::createUserConstraintRewriteResultType( 2421 ArrayRef<ast::VariableDecl *> results) { 2422 // Single result decls use the type of the single result. 2423 if (results.size() == 1) 2424 return results[0]->getType(); 2425 2426 // Multiple results use a tuple type, with the types and names grabbed from 2427 // the result variable decls. 2428 auto resultTypes = llvm::map_range( 2429 results, [&](const auto *result) { return result->getType(); }); 2430 auto resultNames = llvm::map_range( 2431 results, [&](const auto *result) { return result->getName().getName(); }); 2432 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2433 llvm::to_vector(resultNames)); 2434 } 2435 2436 template <typename T> 2437 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2438 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2439 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2440 ast::CompoundStmt *body) { 2441 if (!body->getChildren().empty()) { 2442 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2443 ast::Expr *resultExpr = retStmt->getResultExpr(); 2444 2445 // Process the result of the decl. If no explicit signature results 2446 // were provided, check for return type inference. Otherwise, check that 2447 // the return expression can be converted to the expected type. 2448 if (results.empty()) 2449 resultType = resultExpr->getType(); 2450 else if (failed(convertExpressionTo(resultExpr, resultType))) 2451 return failure(); 2452 else 2453 retStmt->setResultExpr(resultExpr); 2454 } 2455 } 2456 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2457 } 2458 2459 FailureOr<ast::VariableDecl *> 2460 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2461 ArrayRef<ast::ConstraintRef> constraints) { 2462 // The type of the variable, which is expected to be inferred by either a 2463 // constraint or an initializer expression. 2464 ast::Type type; 2465 if (failed(validateVariableConstraints(constraints, type))) 2466 return failure(); 2467 2468 if (initializer) { 2469 // Update the variable type based on the initializer, or try to convert the 2470 // initializer to the existing type. 2471 if (!type) 2472 type = initializer->getType(); 2473 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2474 type = mergedType; 2475 else if (failed(convertExpressionTo(initializer, type))) 2476 return failure(); 2477 2478 // Otherwise, if there is no initializer check that the type has already 2479 // been resolved from the constraint list. 2480 } else if (!type) { 2481 return emitErrorAndNote( 2482 loc, "unable to infer type for variable `" + name + "`", loc, 2483 "the type of a variable must be inferable from the constraint " 2484 "list or the initializer"); 2485 } 2486 2487 // Constraint types cannot be used when defining variables. 2488 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2489 return emitError( 2490 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2491 } 2492 2493 // Try to define a variable with the given name. 2494 FailureOr<ast::VariableDecl *> varDecl = 2495 defineVariableDecl(name, loc, type, initializer, constraints); 2496 if (failed(varDecl)) 2497 return failure(); 2498 2499 return *varDecl; 2500 } 2501 2502 FailureOr<ast::VariableDecl *> 2503 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2504 const ast::ConstraintRef &constraint) { 2505 // Constraint arguments may apply more complex constraints via the arguments. 2506 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2507 ast::Type argType; 2508 if (failed(validateVariableConstraint(constraint, argType, 2509 allowNonCoreConstraints))) 2510 return failure(); 2511 return defineVariableDecl(name, loc, argType, constraint); 2512 } 2513 2514 LogicalResult 2515 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2516 ast::Type &inferredType, 2517 bool allowNonCoreConstraints) { 2518 for (const ast::ConstraintRef &ref : constraints) 2519 if (failed(validateVariableConstraint(ref, inferredType, 2520 allowNonCoreConstraints))) 2521 return failure(); 2522 return success(); 2523 } 2524 2525 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2526 ast::Type &inferredType, 2527 bool allowNonCoreConstraints) { 2528 ast::Type constraintType; 2529 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2530 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2531 if (failed(validateTypeConstraintExpr(typeExpr))) 2532 return failure(); 2533 } 2534 constraintType = ast::AttributeType::get(ctx); 2535 } else if (const auto *cst = 2536 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2537 constraintType = ast::OperationType::get(ctx, cst->getName()); 2538 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2539 constraintType = typeTy; 2540 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2541 constraintType = typeRangeTy; 2542 } else if (const auto *cst = 2543 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2544 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2545 if (failed(validateTypeConstraintExpr(typeExpr))) 2546 return failure(); 2547 } 2548 constraintType = valueTy; 2549 } else if (const auto *cst = 2550 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2551 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2552 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2553 return failure(); 2554 } 2555 constraintType = valueRangeTy; 2556 } else if (const auto *cst = 2557 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2558 if (!allowNonCoreConstraints) { 2559 return emitError(ref.referenceLoc, 2560 "`Rewrite` arguments and results are only permitted to " 2561 "use core constraints, such as `Attr`, `Op`, `Type`, " 2562 "`TypeRange`, `Value`, `ValueRange`"); 2563 } 2564 2565 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2566 if (inputs.size() != 1) { 2567 return emitErrorAndNote(ref.referenceLoc, 2568 "`Constraint`s applied via a variable constraint " 2569 "list must take a single input, but got " + 2570 Twine(inputs.size()), 2571 cst->getLoc(), 2572 "see definition of constraint here"); 2573 } 2574 constraintType = inputs.front()->getType(); 2575 } else { 2576 llvm_unreachable("unknown constraint type"); 2577 } 2578 2579 // Check that the constraint type is compatible with the current inferred 2580 // type. 2581 if (!inferredType) { 2582 inferredType = constraintType; 2583 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2584 inferredType = mergedTy; 2585 } else { 2586 return emitError(ref.referenceLoc, 2587 llvm::formatv("constraint type `{0}` is incompatible " 2588 "with the previously inferred type `{1}`", 2589 constraintType, inferredType)); 2590 } 2591 return success(); 2592 } 2593 2594 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2595 ast::Type typeExprType = typeExpr->getType(); 2596 if (typeExprType != typeTy) { 2597 return emitError(typeExpr->getLoc(), 2598 "expected expression of `Type` in type constraint"); 2599 } 2600 return success(); 2601 } 2602 2603 LogicalResult 2604 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2605 ast::Type typeExprType = typeExpr->getType(); 2606 if (typeExprType != typeRangeTy) { 2607 return emitError(typeExpr->getLoc(), 2608 "expected expression of `TypeRange` in type constraint"); 2609 } 2610 return success(); 2611 } 2612 2613 //===----------------------------------------------------------------------===// 2614 // Exprs 2615 2616 FailureOr<ast::CallExpr *> 2617 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2618 MutableArrayRef<ast::Expr *> arguments) { 2619 ast::Type parentType = parentExpr->getType(); 2620 2621 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2622 if (!callableDecl) { 2623 return emitError(loc, 2624 llvm::formatv("expected a reference to a callable " 2625 "`Constraint` or `Rewrite`, but got: `{0}`", 2626 parentType)); 2627 } 2628 if (parserContext == ParserContext::Rewrite) { 2629 if (isa<ast::UserConstraintDecl>(callableDecl)) 2630 return emitError( 2631 loc, "unable to invoke `Constraint` within a rewrite section"); 2632 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2633 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2634 } 2635 2636 // Verify the arguments of the call. 2637 /// Handle size mismatch. 2638 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2639 if (callArgs.size() != arguments.size()) { 2640 return emitErrorAndNote( 2641 loc, 2642 llvm::formatv("invalid number of arguments for {0} call; expected " 2643 "{1}, but got {2}", 2644 callableDecl->getCallableType(), callArgs.size(), 2645 arguments.size()), 2646 callableDecl->getLoc(), 2647 llvm::formatv("see the definition of {0} here", 2648 callableDecl->getName()->getName())); 2649 } 2650 2651 /// Handle argument type mismatch. 2652 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2653 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2654 callableDecl->getName()->getName()), 2655 callableDecl->getLoc()); 2656 }; 2657 for (auto it : llvm::zip(callArgs, arguments)) { 2658 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2659 attachDiagFn))) 2660 return failure(); 2661 } 2662 2663 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2664 callableDecl->getResultType()); 2665 } 2666 2667 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2668 ast::Decl *decl) { 2669 // Check the type of decl being referenced. 2670 ast::Type declType; 2671 if (isa<ast::ConstraintDecl>(decl)) 2672 declType = ast::ConstraintType::get(ctx); 2673 else if (isa<ast::UserRewriteDecl>(decl)) 2674 declType = ast::RewriteType::get(ctx); 2675 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2676 declType = varDecl->getType(); 2677 else 2678 return emitError(loc, "invalid reference to `" + 2679 decl->getName()->getName() + "`"); 2680 2681 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2682 } 2683 2684 FailureOr<ast::DeclRefExpr *> 2685 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2686 ArrayRef<ast::ConstraintRef> constraints) { 2687 FailureOr<ast::VariableDecl *> decl = 2688 defineVariableDecl(name, loc, type, constraints); 2689 if (failed(decl)) 2690 return failure(); 2691 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2692 } 2693 2694 FailureOr<ast::MemberAccessExpr *> 2695 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2696 SMRange loc) { 2697 // Validate the member name for the given parent expression. 2698 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2699 if (failed(memberType)) 2700 return failure(); 2701 2702 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2703 } 2704 2705 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2706 StringRef name, SMRange loc) { 2707 ast::Type parentType = parentExpr->getType(); 2708 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2709 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2710 return valueRangeTy; 2711 2712 // Verify member access based on the operation type. 2713 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2714 auto results = odsOp->getResults(); 2715 2716 // Handle indexed results. 2717 unsigned index = 0; 2718 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2719 index < results.size()) { 2720 return results[index].isVariadic() ? valueRangeTy : valueTy; 2721 } 2722 2723 // Handle named results. 2724 const auto *it = llvm::find_if(results, [&](const auto &result) { 2725 return result.getName() == name; 2726 }); 2727 if (it != results.end()) 2728 return it->isVariadic() ? valueRangeTy : valueTy; 2729 } else if (llvm::isDigit(name[0])) { 2730 // Allow unchecked numeric indexing of the results of unregistered 2731 // operations. It returns a single value. 2732 return valueTy; 2733 } 2734 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2735 // Handle indexed results. 2736 unsigned index = 0; 2737 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2738 index < tupleType.size()) { 2739 return tupleType.getElementTypes()[index]; 2740 } 2741 2742 // Handle named results. 2743 auto elementNames = tupleType.getElementNames(); 2744 const auto *it = llvm::find(elementNames, name); 2745 if (it != elementNames.end()) 2746 return tupleType.getElementTypes()[it - elementNames.begin()]; 2747 } 2748 return emitError( 2749 loc, 2750 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2751 name, parentType)); 2752 } 2753 2754 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2755 SMRange loc, const ast::OpNameDecl *name, 2756 OpResultTypeContext resultTypeContext, 2757 MutableArrayRef<ast::Expr *> operands, 2758 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2759 MutableArrayRef<ast::Expr *> results) { 2760 Optional<StringRef> opNameRef = name->getName(); 2761 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2762 2763 // Verify the inputs operands. 2764 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2765 return failure(); 2766 2767 // Verify the attribute list. 2768 for (ast::NamedAttributeDecl *attr : attributes) { 2769 // Check for an attribute type, or a type awaiting resolution. 2770 ast::Type attrType = attr->getValue()->getType(); 2771 if (!attrType.isa<ast::AttributeType>()) { 2772 return emitError( 2773 attr->getValue()->getLoc(), 2774 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2775 } 2776 } 2777 2778 assert( 2779 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && 2780 "unexpected inferrence when results were explicitly specified"); 2781 2782 // If we aren't relying on type inferrence, or explicit results were provided, 2783 // validate them. 2784 if (resultTypeContext == OpResultTypeContext::Explicit) { 2785 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2786 return failure(); 2787 2788 // Validate the use of interface based type inferrence for this operation. 2789 } else if (resultTypeContext == OpResultTypeContext::Interface) { 2790 assert(opNameRef && 2791 "expected valid operation name when inferring operation results"); 2792 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); 2793 } 2794 2795 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2796 attributes); 2797 } 2798 2799 LogicalResult 2800 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2801 const ods::Operation *odsOp, 2802 MutableArrayRef<ast::Expr *> operands) { 2803 return validateOperationOperandsOrResults( 2804 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2805 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2806 valueRangeTy); 2807 } 2808 2809 LogicalResult 2810 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2811 const ods::Operation *odsOp, 2812 MutableArrayRef<ast::Expr *> results) { 2813 return validateOperationOperandsOrResults( 2814 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2815 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2816 } 2817 2818 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, 2819 const ods::Operation *odsOp) { 2820 // If the operation might not have inferrence support, emit a warning to the 2821 // user. We don't emit an error because the interface might be added to the 2822 // operation at runtime. It's rare, but it could still happen. We emit a 2823 // warning here instead. 2824 2825 // Handle inferrence warnings for unknown operations. 2826 if (!odsOp) { 2827 ctx.getDiagEngine().emitWarning( 2828 loc, llvm::formatv( 2829 "operation result types are marked to be inferred, but " 2830 "`{0}` is unknown. Ensure that `{0}` supports zero " 2831 "results or implements `InferTypeOpInterface`. Include " 2832 "the ODS definition of this operation to remove this warning.", 2833 opName)); 2834 return; 2835 } 2836 2837 // Handle inferrence warnings for known operations that expected at least one 2838 // result, but don't have inference support. An elided results list can mean 2839 // "zero-results", and we don't want to warn when that is the expected 2840 // behavior. 2841 bool requiresInferrence = 2842 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) { 2843 return !result.isVariableLength(); 2844 }); 2845 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { 2846 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( 2847 loc, 2848 llvm::formatv("operation result types are marked to be inferred, but " 2849 "`{0}` does not provide an implementation of " 2850 "`InferTypeOpInterface`. Ensure that `{0}` attaches " 2851 "`InferTypeOpInterface` at runtime, or add support to " 2852 "the ODS definition to remove this warning.", 2853 opName)); 2854 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName), 2855 odsOp->getLoc()); 2856 return; 2857 } 2858 } 2859 2860 LogicalResult Parser::validateOperationOperandsOrResults( 2861 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2862 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2863 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2864 ast::Type rangeTy) { 2865 // All operation types accept a single range parameter. 2866 if (values.size() == 1) { 2867 if (failed(convertExpressionTo(values[0], rangeTy))) 2868 return failure(); 2869 return success(); 2870 } 2871 2872 /// If the operation has ODS information, we can more accurately verify the 2873 /// values. 2874 if (odsOpLoc) { 2875 if (odsValues.size() != values.size()) { 2876 return emitErrorAndNote( 2877 loc, 2878 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2879 "{2}, but got {3}", 2880 groupName, *name, odsValues.size(), values.size()), 2881 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2882 } 2883 auto diagFn = [&](ast::Diagnostic &diag) { 2884 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2885 *odsOpLoc); 2886 }; 2887 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2888 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2889 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2890 return failure(); 2891 } 2892 return success(); 2893 } 2894 2895 // Otherwise, accept the value groups as they have been defined and just 2896 // ensure they are one of the expected types. 2897 for (ast::Expr *&valueExpr : values) { 2898 ast::Type valueExprType = valueExpr->getType(); 2899 2900 // Check if this is one of the expected types. 2901 if (valueExprType == rangeTy || valueExprType == singleTy) 2902 continue; 2903 2904 // If the operand is an Operation, allow converting to a Value or 2905 // ValueRange. This situations arises quite often with nested operation 2906 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2907 if (singleTy == valueTy) { 2908 if (valueExprType.isa<ast::OperationType>()) { 2909 valueExpr = convertOpToValue(valueExpr); 2910 continue; 2911 } 2912 } 2913 2914 return emitError( 2915 valueExpr->getLoc(), 2916 llvm::formatv( 2917 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2918 singleTy, rangeTy, valueExprType)); 2919 } 2920 return success(); 2921 } 2922 2923 FailureOr<ast::TupleExpr *> 2924 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2925 ArrayRef<StringRef> elementNames) { 2926 for (const ast::Expr *element : elements) { 2927 ast::Type eleTy = element->getType(); 2928 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2929 return emitError( 2930 element->getLoc(), 2931 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2932 } 2933 } 2934 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2935 } 2936 2937 //===----------------------------------------------------------------------===// 2938 // Stmts 2939 2940 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2941 ast::Expr *rootOp) { 2942 // Check that root is an Operation. 2943 ast::Type rootType = rootOp->getType(); 2944 if (!rootType.isa<ast::OperationType>()) 2945 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2946 2947 return ast::EraseStmt::create(ctx, loc, rootOp); 2948 } 2949 2950 FailureOr<ast::ReplaceStmt *> 2951 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2952 MutableArrayRef<ast::Expr *> replValues) { 2953 // Check that root is an Operation. 2954 ast::Type rootType = rootOp->getType(); 2955 if (!rootType.isa<ast::OperationType>()) { 2956 return emitError( 2957 rootOp->getLoc(), 2958 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2959 } 2960 2961 // If there are multiple replacement values, we implicitly convert any Op 2962 // expressions to the value form. 2963 bool shouldConvertOpToValues = replValues.size() > 1; 2964 for (ast::Expr *&replExpr : replValues) { 2965 ast::Type replType = replExpr->getType(); 2966 2967 // Check that replExpr is an Operation, Value, or ValueRange. 2968 if (replType.isa<ast::OperationType>()) { 2969 if (shouldConvertOpToValues) 2970 replExpr = convertOpToValue(replExpr); 2971 continue; 2972 } 2973 2974 if (replType != valueTy && replType != valueRangeTy) { 2975 return emitError(replExpr->getLoc(), 2976 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2977 "expression, but got `{0}`", 2978 replType)); 2979 } 2980 } 2981 2982 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2983 } 2984 2985 FailureOr<ast::RewriteStmt *> 2986 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2987 ast::CompoundStmt *rewriteBody) { 2988 // Check that root is an Operation. 2989 ast::Type rootType = rootOp->getType(); 2990 if (!rootType.isa<ast::OperationType>()) { 2991 return emitError( 2992 rootOp->getLoc(), 2993 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2994 } 2995 2996 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2997 } 2998 2999 //===----------------------------------------------------------------------===// 3000 // Code Completion 3001 //===----------------------------------------------------------------------===// 3002 3003 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 3004 ast::Type parentType = parentExpr->getType(); 3005 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 3006 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 3007 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 3008 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 3009 return failure(); 3010 } 3011 3012 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 3013 if (opName) 3014 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 3015 return failure(); 3016 } 3017 3018 LogicalResult 3019 Parser::codeCompleteConstraintName(ast::Type inferredType, 3020 bool allowNonCoreConstraints, 3021 bool allowInlineTypeConstraints) { 3022 codeCompleteContext->codeCompleteConstraintName( 3023 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 3024 curDeclScope); 3025 return failure(); 3026 } 3027 3028 LogicalResult Parser::codeCompleteDialectName() { 3029 codeCompleteContext->codeCompleteDialectName(); 3030 return failure(); 3031 } 3032 3033 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 3034 codeCompleteContext->codeCompleteOperationName(dialectName); 3035 return failure(); 3036 } 3037 3038 LogicalResult Parser::codeCompletePatternMetadata() { 3039 codeCompleteContext->codeCompletePatternMetadata(); 3040 return failure(); 3041 } 3042 3043 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 3044 codeCompleteContext->codeCompleteIncludeFilename(curPath); 3045 return failure(); 3046 } 3047 3048 void Parser::codeCompleteCallSignature(ast::Node *parent, 3049 unsigned currentNumArgs) { 3050 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 3051 if (!callableDecl) 3052 return; 3053 3054 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 3055 } 3056 3057 void Parser::codeCompleteOperationOperandsSignature( 3058 Optional<StringRef> opName, unsigned currentNumOperands) { 3059 codeCompleteContext->codeCompleteOperationOperandsSignature( 3060 opName, currentNumOperands); 3061 } 3062 3063 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 3064 unsigned currentNumResults) { 3065 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 3066 currentNumResults); 3067 } 3068 3069 //===----------------------------------------------------------------------===// 3070 // Parser 3071 //===----------------------------------------------------------------------===// 3072 3073 FailureOr<ast::Module *> 3074 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 3075 CodeCompleteContext *codeCompleteContext) { 3076 Parser parser(ctx, sourceMgr, codeCompleteContext); 3077 return parser.parseModule(); 3078 } 3079