1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 428 429 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 430 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 431 unsigned currentNumOperands); 432 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 433 unsigned currentNumResults); 434 435 //===--------------------------------------------------------------------===// 436 // Lexer Utilities 437 //===--------------------------------------------------------------------===// 438 439 /// If the current token has the specified kind, consume it and return true. 440 /// If not, return false. 441 bool consumeIf(Token::Kind kind) { 442 if (curToken.isNot(kind)) 443 return false; 444 consumeToken(kind); 445 return true; 446 } 447 448 /// Advance the current lexer onto the next token. 449 void consumeToken() { 450 assert(curToken.isNot(Token::eof, Token::error) && 451 "shouldn't advance past EOF or errors"); 452 curToken = lexer.lexToken(); 453 } 454 455 /// Advance the current lexer onto the next token, asserting what the expected 456 /// current token is. This is preferred to the above method because it leads 457 /// to more self-documenting code with better checking. 458 void consumeToken(Token::Kind kind) { 459 assert(curToken.is(kind) && "consumed an unexpected token"); 460 consumeToken(); 461 } 462 463 /// Reset the lexer to the location at the given position. 464 void resetToken(SMRange tokLoc) { 465 lexer.resetPointer(tokLoc.Start.getPointer()); 466 curToken = lexer.lexToken(); 467 } 468 469 /// Consume the specified token if present and return success. On failure, 470 /// output a diagnostic and return failure. 471 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 472 if (curToken.getKind() != kind) 473 return emitError(curToken.getLoc(), msg); 474 consumeToken(); 475 return success(); 476 } 477 LogicalResult emitError(SMRange loc, const Twine &msg) { 478 lexer.emitError(loc, msg); 479 return failure(); 480 } 481 LogicalResult emitError(const Twine &msg) { 482 return emitError(curToken.getLoc(), msg); 483 } 484 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 485 const Twine ¬e) { 486 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 487 return failure(); 488 } 489 490 //===--------------------------------------------------------------------===// 491 // Fields 492 //===--------------------------------------------------------------------===// 493 494 /// The owning AST context. 495 ast::Context &ctx; 496 497 /// The lexer of this parser. 498 Lexer lexer; 499 500 /// The current token within the lexer. 501 Token curToken; 502 503 /// The most recently defined decl scope. 504 ast::DeclScope *curDeclScope = nullptr; 505 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 506 507 /// The current context of the parser. 508 ParserContext parserContext = ParserContext::Global; 509 510 /// Cached types to simplify verification and expression creation. 511 ast::Type valueTy, valueRangeTy; 512 ast::Type typeTy, typeRangeTy; 513 ast::Type attrTy; 514 515 /// A counter used when naming anonymous constraints and rewrites. 516 unsigned anonymousDeclNameCounter = 0; 517 518 /// The optional code completion context. 519 CodeCompleteContext *codeCompleteContext; 520 }; 521 } // namespace 522 523 FailureOr<ast::Module *> Parser::parseModule() { 524 SMLoc moduleLoc = curToken.getStartLoc(); 525 pushDeclScope(); 526 527 // Parse the top-level decls of the module. 528 SmallVector<ast::Decl *> decls; 529 if (failed(parseModuleBody(decls))) 530 return popDeclScope(), failure(); 531 532 popDeclScope(); 533 return ast::Module::create(ctx, moduleLoc, decls); 534 } 535 536 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 537 while (curToken.isNot(Token::eof)) { 538 if (curToken.is(Token::directive)) { 539 if (failed(parseDirective(decls))) 540 return failure(); 541 continue; 542 } 543 544 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 545 if (failed(decl)) 546 return failure(); 547 decls.push_back(*decl); 548 } 549 return success(); 550 } 551 552 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 553 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 554 valueRangeTy); 555 } 556 557 LogicalResult Parser::convertExpressionTo( 558 ast::Expr *&expr, ast::Type type, 559 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 560 ast::Type exprType = expr->getType(); 561 if (exprType == type) 562 return success(); 563 564 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 565 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 566 expr->getLoc(), llvm::formatv("unable to convert expression of type " 567 "`{0}` to the expected type of " 568 "`{1}`", 569 exprType, type)); 570 if (noteAttachFn) 571 noteAttachFn(*diag); 572 return diag; 573 }; 574 575 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 576 // Two operation types are compatible if they have the same name, or if the 577 // expected type is more general. 578 if (auto opType = type.dyn_cast<ast::OperationType>()) { 579 if (opType.getName()) 580 return emitConvertError(); 581 return success(); 582 } 583 584 // An operation can always convert to a ValueRange. 585 if (type == valueRangeTy) { 586 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 587 valueRangeTy); 588 return success(); 589 } 590 591 // Allow conversion to a single value by constraining the result range. 592 if (type == valueTy) { 593 // If the operation is registered, we can verify if it can ever have a 594 // single result. 595 Optional<StringRef> opName = exprOpType.getName(); 596 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 597 if (odsOp->getResults().empty()) { 598 return emitConvertError()->attachNote( 599 llvm::formatv("see the definition of `{0}`, which was defined " 600 "with zero results", 601 odsOp->getName()), 602 odsOp->getLoc()); 603 } 604 605 unsigned numSingleResults = llvm::count_if( 606 odsOp->getResults(), [](const ods::OperandOrResult &result) { 607 return result.getVariableLengthKind() == 608 ods::VariableLengthKind::Single; 609 }); 610 if (numSingleResults > 1) { 611 return emitConvertError()->attachNote( 612 llvm::formatv("see the definition of `{0}`, which was defined " 613 "with at least {1} results", 614 odsOp->getName(), numSingleResults), 615 odsOp->getLoc()); 616 } 617 } 618 619 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 620 valueTy); 621 return success(); 622 } 623 return emitConvertError(); 624 } 625 626 // FIXME: Decide how to allow/support converting a single result to multiple, 627 // and multiple to a single result. For now, we just allow Single->Range, 628 // but this isn't something really supported in the PDL dialect. We should 629 // figure out some way to support both. 630 if ((exprType == valueTy || exprType == valueRangeTy) && 631 (type == valueTy || type == valueRangeTy)) 632 return success(); 633 if ((exprType == typeTy || exprType == typeRangeTy) && 634 (type == typeTy || type == typeRangeTy)) 635 return success(); 636 637 // Handle tuple types. 638 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 639 auto tupleType = type.dyn_cast<ast::TupleType>(); 640 if (!tupleType || tupleType.size() != exprTupleType.size()) 641 return emitConvertError(); 642 643 // Build a new tuple expression using each of the elements of the current 644 // tuple. 645 SmallVector<ast::Expr *> newExprs; 646 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 647 newExprs.push_back(ast::MemberAccessExpr::create( 648 ctx, expr->getLoc(), expr, llvm::to_string(i), 649 exprTupleType.getElementTypes()[i])); 650 651 auto diagFn = [&](ast::Diagnostic &diag) { 652 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 653 i, exprTupleType)); 654 if (noteAttachFn) 655 noteAttachFn(diag); 656 }; 657 if (failed(convertExpressionTo(newExprs.back(), 658 tupleType.getElementTypes()[i], diagFn))) 659 return failure(); 660 } 661 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 662 tupleType.getElementNames()); 663 return success(); 664 } 665 666 return emitConvertError(); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Directives 671 672 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 673 StringRef directive = curToken.getSpelling(); 674 if (directive == "#include") 675 return parseInclude(decls); 676 677 return emitError("unknown directive `" + directive + "`"); 678 } 679 680 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 681 SMRange loc = curToken.getLoc(); 682 consumeToken(Token::directive); 683 684 // Handle code completion of the include file path. 685 if (curToken.is(Token::code_complete_string)) 686 return codeCompleteIncludeFilename(curToken.getStringValue()); 687 688 // Parse the file being included. 689 if (!curToken.isString()) 690 return emitError(loc, 691 "expected string file name after `include` directive"); 692 SMRange fileLoc = curToken.getLoc(); 693 std::string filenameStr = curToken.getStringValue(); 694 StringRef filename = filenameStr; 695 consumeToken(); 696 697 // Check the type of include. If ending with `.pdll`, this is another pdl file 698 // to be parsed along with the current module. 699 if (filename.endswith(".pdll")) { 700 if (failed(lexer.pushInclude(filename, fileLoc))) 701 return emitError(fileLoc, 702 "unable to open include file `" + filename + "`"); 703 704 // If we added the include successfully, parse it into the current module. 705 // Make sure to update to the next token after we finish parsing the nested 706 // file. 707 curToken = lexer.lexToken(); 708 LogicalResult result = parseModuleBody(decls); 709 curToken = lexer.lexToken(); 710 return result; 711 } 712 713 // Otherwise, this must be a `.td` include. 714 if (filename.endswith(".td")) 715 return parseTdInclude(filename, fileLoc, decls); 716 717 return emitError(fileLoc, 718 "expected include filename to end with `.pdll` or `.td`"); 719 } 720 721 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 722 SmallVectorImpl<ast::Decl *> &decls) { 723 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 724 725 // This class provides a context argument for the llvm::SourceMgr diagnostic 726 // handler. 727 struct DiagHandlerContext { 728 Parser &parser; 729 StringRef filename; 730 llvm::SMRange loc; 731 } handlerContext{*this, filename, fileLoc}; 732 733 // Set the diagnostic handler for the tablegen source manager. 734 llvm::SrcMgr.setDiagHandler( 735 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 736 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 737 (void)ctx->parser.emitError( 738 ctx->loc, 739 llvm::formatv("error while processing include file `{0}`: {1}", 740 ctx->filename, diag.getMessage())); 741 }, 742 &handlerContext); 743 744 // Use the source manager to open the file, but don't yet add it. 745 std::string includedFile; 746 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 747 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 748 if (!includeBuffer) 749 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 750 751 auto processFn = [&](llvm::RecordKeeper &records) { 752 processTdIncludeRecords(records, decls); 753 754 // After we are done processing, move all of the tablegen source buffers to 755 // the main parser source mgr. This allows for directly using source 756 // locations from the .td files without needing to remap them. 757 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr, fileLoc.End); 758 return false; 759 }; 760 if (llvm::TableGenParseFile(std::move(*includeBuffer), 761 parserSrcMgr.getIncludeDirs(), processFn)) 762 return failure(); 763 764 return success(); 765 } 766 767 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 768 SmallVectorImpl<ast::Decl *> &decls) { 769 // Return the length kind of the given value. 770 auto getLengthKind = [](const auto &value) { 771 if (value.isOptional()) 772 return ods::VariableLengthKind::Optional; 773 return value.isVariadic() ? ods::VariableLengthKind::Variadic 774 : ods::VariableLengthKind::Single; 775 }; 776 777 // Insert a type constraint into the ODS context. 778 ods::Context &odsContext = ctx.getODSContext(); 779 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 780 -> const ods::TypeConstraint & { 781 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 782 cst.constraint.getSummary(), 783 cst.constraint.getCPPClassName()); 784 }; 785 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 786 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 787 }; 788 789 // Process the parsed tablegen records to build ODS information. 790 /// Operations. 791 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 792 tblgen::Operator op(def); 793 794 bool inserted = false; 795 ods::Operation *odsOp = nullptr; 796 std::tie(odsOp, inserted) = 797 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 798 op.getDescription(), op.getLoc().front()); 799 800 // Ignore operations that have already been added. 801 if (!inserted) 802 continue; 803 804 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 805 odsOp->appendAttribute( 806 attr.name, attr.attr.isOptional(), 807 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 808 attr.attr.getSummary(), 809 attr.attr.getStorageType())); 810 } 811 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 812 odsOp->appendOperand(operand.name, getLengthKind(operand), 813 addTypeConstraint(operand)); 814 } 815 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 816 odsOp->appendResult(result.name, getLengthKind(result), 817 addTypeConstraint(result)); 818 } 819 } 820 /// Attr constraints. 821 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 822 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 823 decls.push_back( 824 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 825 tblgen::AttrConstraint(def), 826 convertLocToRange(def->getLoc().front()), attrTy)); 827 } 828 } 829 /// Type constraints. 830 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 831 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 832 decls.push_back( 833 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 834 tblgen::TypeConstraint(def), 835 convertLocToRange(def->getLoc().front()), typeTy)); 836 } 837 } 838 /// Interfaces. 839 ast::Type opTy = ast::OperationType::get(ctx); 840 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 841 StringRef name = def->getName(); 842 if (def->isAnonymous() || curDeclScope->lookup(name) || 843 def->isSubClassOf("DeclareInterfaceMethods")) 844 continue; 845 SMRange loc = convertLocToRange(def->getLoc().front()); 846 847 StringRef className = def->getValueAsString("cppClassName"); 848 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 849 std::string codeBlock = 850 llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));", 851 cppNamespace, className) 852 .str(); 853 854 if (def->isSubClassOf("OpInterface")) { 855 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 856 name, codeBlock, loc, opTy)); 857 } else if (def->isSubClassOf("AttrInterface")) { 858 decls.push_back( 859 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 860 name, codeBlock, loc, attrTy)); 861 } else if (def->isSubClassOf("TypeInterface")) { 862 decls.push_back( 863 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 864 name, codeBlock, loc, typeTy)); 865 } 866 } 867 } 868 869 template <typename ConstraintT> 870 ast::Decl * 871 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 872 SMRange loc, ast::Type type) { 873 // Build the single input parameter. 874 ast::DeclScope *argScope = pushDeclScope(); 875 auto *paramVar = ast::VariableDecl::create( 876 ctx, ast::Name::create(ctx, "self", loc), type, 877 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 878 argScope->add(paramVar); 879 popDeclScope(); 880 881 // Build the native constraint. 882 auto *constraintDecl = ast::UserConstraintDecl::createNative( 883 ctx, ast::Name::create(ctx, name, loc), paramVar, 884 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 885 curDeclScope->add(constraintDecl); 886 return constraintDecl; 887 } 888 889 template <typename ConstraintT> 890 ast::Decl * 891 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 892 SMRange loc, ast::Type type) { 893 // Format the condition template. 894 tblgen::FmtContext fmtContext; 895 fmtContext.withSelf("self"); 896 std::string codeBlock = tblgen::tgfmt( 897 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 898 &fmtContext); 899 900 return createODSNativePDLLConstraintDecl<ConstraintT>( 901 constraint.getUniqueDefName(), codeBlock, loc, type); 902 } 903 904 //===----------------------------------------------------------------------===// 905 // Decls 906 907 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 908 FailureOr<ast::Decl *> decl; 909 switch (curToken.getKind()) { 910 case Token::kw_Constraint: 911 decl = parseUserConstraintDecl(); 912 break; 913 case Token::kw_Pattern: 914 decl = parsePatternDecl(); 915 break; 916 case Token::kw_Rewrite: 917 decl = parseUserRewriteDecl(); 918 break; 919 default: 920 return emitError("expected top-level declaration, such as a `Pattern`"); 921 } 922 if (failed(decl)) 923 return failure(); 924 925 // If the decl has a name, add it to the current scope. 926 if (const ast::Name *name = (*decl)->getName()) { 927 if (failed(checkDefineNamedDecl(*name))) 928 return failure(); 929 curDeclScope->add(*decl); 930 } 931 return decl; 932 } 933 934 FailureOr<ast::NamedAttributeDecl *> 935 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 936 // Check for name code completion. 937 if (curToken.is(Token::code_complete)) 938 return codeCompleteAttributeName(parentOpName); 939 940 std::string attrNameStr; 941 if (curToken.isString()) 942 attrNameStr = curToken.getStringValue(); 943 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 944 attrNameStr = curToken.getSpelling().str(); 945 else 946 return emitError("expected identifier or string attribute name"); 947 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 948 consumeToken(); 949 950 // Check for a value of the attribute. 951 ast::Expr *attrValue = nullptr; 952 if (consumeIf(Token::equal)) { 953 FailureOr<ast::Expr *> attrExpr = parseExpr(); 954 if (failed(attrExpr)) 955 return failure(); 956 attrValue = *attrExpr; 957 } else { 958 // If there isn't a concrete value, create an expression representing a 959 // UnitAttr. 960 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 961 } 962 963 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 964 } 965 966 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 967 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 968 bool expectTerminalSemicolon) { 969 consumeToken(Token::equal_arrow); 970 971 // Parse the single statement of the lambda body. 972 SMLoc bodyStartLoc = curToken.getStartLoc(); 973 pushDeclScope(); 974 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 975 bool failedToParse = 976 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 977 popDeclScope(); 978 if (failedToParse) 979 return failure(); 980 981 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 982 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 983 } 984 985 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 986 // Ensure that the argument is named. 987 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 988 return emitError("expected identifier argument name"); 989 990 // Parse the argument similarly to a normal variable. 991 StringRef name = curToken.getSpelling(); 992 SMRange nameLoc = curToken.getLoc(); 993 consumeToken(); 994 995 if (failed( 996 parseToken(Token::colon, "expected `:` before argument constraint"))) 997 return failure(); 998 999 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1000 if (failed(cst)) 1001 return failure(); 1002 1003 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1004 } 1005 1006 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1007 // Check to see if this result is named. 1008 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1009 // Check to see if this name actually refers to a Constraint. 1010 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1011 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1012 // If yes, and this is a Rewrite, give a nice error message as non-Core 1013 // constraints are not supported on Rewrite results. 1014 if (parserContext == ParserContext::Rewrite) { 1015 return emitError( 1016 "`Rewrite` results are only permitted to use core constraints, " 1017 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1018 } 1019 1020 // Otherwise, parse this as an unnamed result variable. 1021 } else { 1022 // If it wasn't a constraint, parse the result similarly to a variable. If 1023 // there is already an existing decl, we will emit an error when defining 1024 // this variable later. 1025 StringRef name = curToken.getSpelling(); 1026 SMRange nameLoc = curToken.getLoc(); 1027 consumeToken(); 1028 1029 if (failed(parseToken(Token::colon, 1030 "expected `:` before result constraint"))) 1031 return failure(); 1032 1033 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1034 if (failed(cst)) 1035 return failure(); 1036 1037 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1038 } 1039 } 1040 1041 // If it isn't named, we parse the constraint directly and create an unnamed 1042 // result variable. 1043 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1044 if (failed(cst)) 1045 return failure(); 1046 1047 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1048 } 1049 1050 FailureOr<ast::UserConstraintDecl *> 1051 Parser::parseUserConstraintDecl(bool isInline) { 1052 // Constraints and rewrites have very similar formats, dispatch to a shared 1053 // interface for parsing. 1054 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1055 [&](auto &&...args) { 1056 return this->parseUserPDLLConstraintDecl(args...); 1057 }, 1058 ParserContext::Constraint, "constraint", isInline); 1059 } 1060 1061 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1062 FailureOr<ast::UserConstraintDecl *> decl = 1063 parseUserConstraintDecl(/*isInline=*/true); 1064 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1065 return failure(); 1066 1067 curDeclScope->add(*decl); 1068 return decl; 1069 } 1070 1071 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1072 const ast::Name &name, bool isInline, 1073 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1074 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1075 // Push the argument scope back onto the list, so that the body can 1076 // reference arguments. 1077 pushDeclScope(argumentScope); 1078 1079 // Parse the body of the constraint. The body is either defined as a compound 1080 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1081 ast::CompoundStmt *body; 1082 if (curToken.is(Token::equal_arrow)) { 1083 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1084 [&](ast::Stmt *&stmt) -> LogicalResult { 1085 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1086 if (!stmtExpr) { 1087 return emitError(stmt->getLoc(), 1088 "expected `Constraint` lambda body to contain a " 1089 "single expression"); 1090 } 1091 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1092 return success(); 1093 }, 1094 /*expectTerminalSemicolon=*/!isInline); 1095 if (failed(bodyResult)) 1096 return failure(); 1097 body = *bodyResult; 1098 } else { 1099 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1100 if (failed(bodyResult)) 1101 return failure(); 1102 body = *bodyResult; 1103 1104 // Verify the structure of the body. 1105 auto bodyIt = body->begin(), bodyE = body->end(); 1106 for (; bodyIt != bodyE; ++bodyIt) 1107 if (isa<ast::ReturnStmt>(*bodyIt)) 1108 break; 1109 if (failed(validateUserConstraintOrRewriteReturn( 1110 "Constraint", body, bodyIt, bodyE, results, resultType))) 1111 return failure(); 1112 } 1113 popDeclScope(); 1114 1115 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1116 name, arguments, results, resultType, body); 1117 } 1118 1119 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1120 // Constraints and rewrites have very similar formats, dispatch to a shared 1121 // interface for parsing. 1122 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1123 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1124 ParserContext::Rewrite, "rewrite", isInline); 1125 } 1126 1127 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1128 FailureOr<ast::UserRewriteDecl *> decl = 1129 parseUserRewriteDecl(/*isInline=*/true); 1130 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1131 return failure(); 1132 1133 curDeclScope->add(*decl); 1134 return decl; 1135 } 1136 1137 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1138 const ast::Name &name, bool isInline, 1139 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1140 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1141 // Push the argument scope back onto the list, so that the body can 1142 // reference arguments. 1143 curDeclScope = argumentScope; 1144 ast::CompoundStmt *body; 1145 if (curToken.is(Token::equal_arrow)) { 1146 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1147 [&](ast::Stmt *&statement) -> LogicalResult { 1148 if (isa<ast::OpRewriteStmt>(statement)) 1149 return success(); 1150 1151 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1152 if (!statementExpr) { 1153 return emitError( 1154 statement->getLoc(), 1155 "expected `Rewrite` lambda body to contain a single expression " 1156 "or an operation rewrite statement; such as `erase`, " 1157 "`replace`, or `rewrite`"); 1158 } 1159 statement = 1160 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1161 return success(); 1162 }, 1163 /*expectTerminalSemicolon=*/!isInline); 1164 if (failed(bodyResult)) 1165 return failure(); 1166 body = *bodyResult; 1167 } else { 1168 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1169 if (failed(bodyResult)) 1170 return failure(); 1171 body = *bodyResult; 1172 } 1173 popDeclScope(); 1174 1175 // Verify the structure of the body. 1176 auto bodyIt = body->begin(), bodyE = body->end(); 1177 for (; bodyIt != bodyE; ++bodyIt) 1178 if (isa<ast::ReturnStmt>(*bodyIt)) 1179 break; 1180 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1181 bodyE, results, resultType))) 1182 return failure(); 1183 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1184 name, arguments, results, resultType, body); 1185 } 1186 1187 template <typename T, typename ParseUserPDLLDeclFnT> 1188 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1189 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1190 StringRef anonymousNamePrefix, bool isInline) { 1191 SMRange loc = curToken.getLoc(); 1192 consumeToken(); 1193 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1194 1195 // Parse the name of the decl. 1196 const ast::Name *name = nullptr; 1197 if (curToken.isNot(Token::identifier)) { 1198 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1199 // in C++, so being unnamed is fine. 1200 if (!isInline) 1201 return emitError("expected identifier name"); 1202 1203 // Create a unique anonymous name to use, as the name for this decl is not 1204 // important. 1205 std::string anonName = 1206 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1207 anonymousDeclNameCounter++) 1208 .str(); 1209 name = &ast::Name::create(ctx, anonName, loc); 1210 } else { 1211 // If a name was provided, we can use it directly. 1212 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1213 consumeToken(Token::identifier); 1214 } 1215 1216 // Parse the functional signature of the decl. 1217 SmallVector<ast::VariableDecl *> arguments, results; 1218 ast::DeclScope *argumentScope; 1219 ast::Type resultType; 1220 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1221 argumentScope, resultType))) 1222 return failure(); 1223 1224 // Check to see which type of constraint this is. If the constraint contains a 1225 // compound body, this is a PDLL decl. 1226 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1227 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1228 resultType); 1229 1230 // Otherwise, this is a native decl. 1231 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1232 results, resultType); 1233 } 1234 1235 template <typename T> 1236 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1237 const ast::Name &name, bool isInline, 1238 ArrayRef<ast::VariableDecl *> arguments, 1239 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1240 // If followed by a string, the native code body has also been specified. 1241 std::string codeStrStorage; 1242 Optional<StringRef> optCodeStr; 1243 if (curToken.isString()) { 1244 codeStrStorage = curToken.getStringValue(); 1245 optCodeStr = codeStrStorage; 1246 consumeToken(); 1247 } else if (isInline) { 1248 return emitError(name.getLoc(), 1249 "external declarations must be declared in global scope"); 1250 } else if (curToken.is(Token::error)) { 1251 return failure(); 1252 } 1253 if (failed(parseToken(Token::semicolon, 1254 "expected `;` after native declaration"))) 1255 return failure(); 1256 // TODO: PDL should be able to support constraint results in certain 1257 // situations, we should revise this. 1258 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1259 return emitError( 1260 "native Constraints currently do not support returning results"); 1261 } 1262 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1263 } 1264 1265 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1266 SmallVectorImpl<ast::VariableDecl *> &arguments, 1267 SmallVectorImpl<ast::VariableDecl *> &results, 1268 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1269 // Parse the argument list of the decl. 1270 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1271 return failure(); 1272 1273 argumentScope = pushDeclScope(); 1274 if (curToken.isNot(Token::r_paren)) { 1275 do { 1276 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1277 if (failed(argument)) 1278 return failure(); 1279 arguments.emplace_back(*argument); 1280 } while (consumeIf(Token::comma)); 1281 } 1282 popDeclScope(); 1283 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1284 return failure(); 1285 1286 // Parse the results of the decl. 1287 pushDeclScope(); 1288 if (consumeIf(Token::arrow)) { 1289 auto parseResultFn = [&]() -> LogicalResult { 1290 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1291 if (failed(result)) 1292 return failure(); 1293 results.emplace_back(*result); 1294 return success(); 1295 }; 1296 1297 // Check for a list of results. 1298 if (consumeIf(Token::l_paren)) { 1299 do { 1300 if (failed(parseResultFn())) 1301 return failure(); 1302 } while (consumeIf(Token::comma)); 1303 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1304 return failure(); 1305 1306 // Otherwise, there is only one result. 1307 } else if (failed(parseResultFn())) { 1308 return failure(); 1309 } 1310 } 1311 popDeclScope(); 1312 1313 // Compute the result type of the decl. 1314 resultType = createUserConstraintRewriteResultType(results); 1315 1316 // Verify that results are only named if there are more than one. 1317 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1318 return emitError( 1319 results.front()->getLoc(), 1320 "cannot create a single-element tuple with an element label"); 1321 } 1322 return success(); 1323 } 1324 1325 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1326 StringRef declType, ast::CompoundStmt *body, 1327 ArrayRef<ast::Stmt *>::iterator bodyIt, 1328 ArrayRef<ast::Stmt *>::iterator bodyE, 1329 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1330 // Handle if a `return` was provided. 1331 if (bodyIt != bodyE) { 1332 // Emit an error if we have trailing statements after the return. 1333 if (std::next(bodyIt) != bodyE) { 1334 return emitError( 1335 (*std::next(bodyIt))->getLoc(), 1336 llvm::formatv("`return` terminated the `{0}` body, but found " 1337 "trailing statements afterwards", 1338 declType)); 1339 } 1340 1341 // Otherwise if a return wasn't provided, check that no results are 1342 // expected. 1343 } else if (!results.empty()) { 1344 return emitError( 1345 {body->getLoc().End, body->getLoc().End}, 1346 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1347 declType, resultType)); 1348 } 1349 return success(); 1350 } 1351 1352 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1353 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1354 if (isa<ast::OpRewriteStmt>(statement)) 1355 return success(); 1356 return emitError( 1357 statement->getLoc(), 1358 "expected Pattern lambda body to contain a single operation " 1359 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1360 }); 1361 } 1362 1363 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1364 SMRange loc = curToken.getLoc(); 1365 consumeToken(Token::kw_Pattern); 1366 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1367 ParserContext::PatternMatch); 1368 1369 // Check for an optional identifier for the pattern name. 1370 const ast::Name *name = nullptr; 1371 if (curToken.is(Token::identifier)) { 1372 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1373 consumeToken(Token::identifier); 1374 } 1375 1376 // Parse any pattern metadata. 1377 ParsedPatternMetadata metadata; 1378 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1379 return failure(); 1380 1381 // Parse the pattern body. 1382 ast::CompoundStmt *body; 1383 1384 // Handle a lambda body. 1385 if (curToken.is(Token::equal_arrow)) { 1386 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1387 if (failed(bodyResult)) 1388 return failure(); 1389 body = *bodyResult; 1390 } else { 1391 if (curToken.isNot(Token::l_brace)) 1392 return emitError("expected `{` or `=>` to start pattern body"); 1393 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1394 if (failed(bodyResult)) 1395 return failure(); 1396 body = *bodyResult; 1397 1398 // Verify the body of the pattern. 1399 auto bodyIt = body->begin(), bodyE = body->end(); 1400 for (; bodyIt != bodyE; ++bodyIt) { 1401 if (isa<ast::ReturnStmt>(*bodyIt)) { 1402 return emitError((*bodyIt)->getLoc(), 1403 "`return` statements are only permitted within a " 1404 "`Constraint` or `Rewrite` body"); 1405 } 1406 // Break when we've found the rewrite statement. 1407 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1408 break; 1409 } 1410 if (bodyIt == bodyE) { 1411 return emitError(loc, 1412 "expected Pattern body to terminate with an operation " 1413 "rewrite statement, such as `erase`"); 1414 } 1415 if (std::next(bodyIt) != bodyE) { 1416 return emitError((*std::next(bodyIt))->getLoc(), 1417 "Pattern body was terminated by an operation " 1418 "rewrite statement, but found trailing statements"); 1419 } 1420 } 1421 1422 return createPatternDecl(loc, name, metadata, body); 1423 } 1424 1425 LogicalResult 1426 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1427 Optional<SMRange> benefitLoc; 1428 Optional<SMRange> hasBoundedRecursionLoc; 1429 1430 do { 1431 // Handle metadata code completion. 1432 if (curToken.is(Token::code_complete)) 1433 return codeCompletePatternMetadata(); 1434 1435 if (curToken.isNot(Token::identifier)) 1436 return emitError("expected pattern metadata identifier"); 1437 StringRef metadataStr = curToken.getSpelling(); 1438 SMRange metadataLoc = curToken.getLoc(); 1439 consumeToken(Token::identifier); 1440 1441 // Parse the benefit metadata: benefit(<integer-value>) 1442 if (metadataStr == "benefit") { 1443 if (benefitLoc) { 1444 return emitErrorAndNote(metadataLoc, 1445 "pattern benefit has already been specified", 1446 *benefitLoc, "see previous definition here"); 1447 } 1448 if (failed(parseToken(Token::l_paren, 1449 "expected `(` before pattern benefit"))) 1450 return failure(); 1451 1452 uint16_t benefitValue = 0; 1453 if (curToken.isNot(Token::integer)) 1454 return emitError("expected integral pattern benefit"); 1455 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1456 return emitError( 1457 "expected pattern benefit to fit within a 16-bit integer"); 1458 consumeToken(Token::integer); 1459 1460 metadata.benefit = benefitValue; 1461 benefitLoc = metadataLoc; 1462 1463 if (failed( 1464 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1465 return failure(); 1466 continue; 1467 } 1468 1469 // Parse the bounded recursion metadata: recursion 1470 if (metadataStr == "recursion") { 1471 if (hasBoundedRecursionLoc) { 1472 return emitErrorAndNote( 1473 metadataLoc, 1474 "pattern recursion metadata has already been specified", 1475 *hasBoundedRecursionLoc, "see previous definition here"); 1476 } 1477 metadata.hasBoundedRecursion = true; 1478 hasBoundedRecursionLoc = metadataLoc; 1479 continue; 1480 } 1481 1482 return emitError(metadataLoc, "unknown pattern metadata"); 1483 } while (consumeIf(Token::comma)); 1484 1485 return success(); 1486 } 1487 1488 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1489 consumeToken(Token::less); 1490 1491 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1492 if (failed(typeExpr) || 1493 failed(parseToken(Token::greater, 1494 "expected `>` after variable type constraint"))) 1495 return failure(); 1496 return typeExpr; 1497 } 1498 1499 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1500 assert(curDeclScope && "defining decl outside of a decl scope"); 1501 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1502 return emitErrorAndNote( 1503 name.getLoc(), "`" + name.getName() + "` has already been defined", 1504 lastDecl->getName()->getLoc(), "see previous definition here"); 1505 } 1506 return success(); 1507 } 1508 1509 FailureOr<ast::VariableDecl *> 1510 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1511 ast::Expr *initExpr, 1512 ArrayRef<ast::ConstraintRef> constraints) { 1513 assert(curDeclScope && "defining variable outside of decl scope"); 1514 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1515 1516 // If the name of the variable indicates a special variable, we don't add it 1517 // to the scope. This variable is local to the definition point. 1518 if (name.empty() || name == "_") { 1519 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1520 constraints); 1521 } 1522 if (failed(checkDefineNamedDecl(nameDecl))) 1523 return failure(); 1524 1525 auto *varDecl = 1526 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1527 curDeclScope->add(varDecl); 1528 return varDecl; 1529 } 1530 1531 FailureOr<ast::VariableDecl *> 1532 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1533 ArrayRef<ast::ConstraintRef> constraints) { 1534 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1535 constraints); 1536 } 1537 1538 LogicalResult Parser::parseVariableDeclConstraintList( 1539 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1540 Optional<SMRange> typeConstraint; 1541 auto parseSingleConstraint = [&] { 1542 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1543 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1544 /*allowNonCoreConstraints=*/true); 1545 if (failed(constraint)) 1546 return failure(); 1547 constraints.push_back(*constraint); 1548 return success(); 1549 }; 1550 1551 // Check to see if this is a single constraint, or a list. 1552 if (!consumeIf(Token::l_square)) 1553 return parseSingleConstraint(); 1554 1555 do { 1556 if (failed(parseSingleConstraint())) 1557 return failure(); 1558 } while (consumeIf(Token::comma)); 1559 return parseToken(Token::r_square, "expected `]` after constraint list"); 1560 } 1561 1562 FailureOr<ast::ConstraintRef> 1563 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1564 ArrayRef<ast::ConstraintRef> existingConstraints, 1565 bool allowInlineTypeConstraints, 1566 bool allowNonCoreConstraints) { 1567 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1568 if (!allowInlineTypeConstraints) { 1569 return emitError( 1570 curToken.getLoc(), 1571 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1572 "permitted on arguments or results"); 1573 } 1574 if (typeConstraint) 1575 return emitErrorAndNote( 1576 curToken.getLoc(), 1577 "the type of this variable has already been constrained", 1578 *typeConstraint, "see previous constraint location here"); 1579 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1580 if (failed(constraintExpr)) 1581 return failure(); 1582 typeExpr = *constraintExpr; 1583 typeConstraint = typeExpr->getLoc(); 1584 return success(); 1585 }; 1586 1587 SMRange loc = curToken.getLoc(); 1588 switch (curToken.getKind()) { 1589 case Token::kw_Attr: { 1590 consumeToken(Token::kw_Attr); 1591 1592 // Check for a type constraint. 1593 ast::Expr *typeExpr = nullptr; 1594 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1595 return failure(); 1596 return ast::ConstraintRef( 1597 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1598 } 1599 case Token::kw_Op: { 1600 consumeToken(Token::kw_Op); 1601 1602 // Parse an optional operation name. If the name isn't provided, this refers 1603 // to "any" operation. 1604 FailureOr<ast::OpNameDecl *> opName = 1605 parseWrappedOperationName(/*allowEmptyName=*/true); 1606 if (failed(opName)) 1607 return failure(); 1608 1609 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1610 loc); 1611 } 1612 case Token::kw_Type: 1613 consumeToken(Token::kw_Type); 1614 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1615 case Token::kw_TypeRange: 1616 consumeToken(Token::kw_TypeRange); 1617 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1618 loc); 1619 case Token::kw_Value: { 1620 consumeToken(Token::kw_Value); 1621 1622 // Check for a type constraint. 1623 ast::Expr *typeExpr = nullptr; 1624 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1625 return failure(); 1626 1627 return ast::ConstraintRef( 1628 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1629 } 1630 case Token::kw_ValueRange: { 1631 consumeToken(Token::kw_ValueRange); 1632 1633 // Check for a type constraint. 1634 ast::Expr *typeExpr = nullptr; 1635 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1636 return failure(); 1637 1638 return ast::ConstraintRef( 1639 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1640 } 1641 1642 case Token::kw_Constraint: { 1643 // Handle an inline constraint. 1644 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1645 if (failed(decl)) 1646 return failure(); 1647 return ast::ConstraintRef(*decl, loc); 1648 } 1649 case Token::identifier: { 1650 StringRef constraintName = curToken.getSpelling(); 1651 consumeToken(Token::identifier); 1652 1653 // Lookup the referenced constraint. 1654 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1655 if (!cstDecl) { 1656 return emitError(loc, "unknown reference to constraint `" + 1657 constraintName + "`"); 1658 } 1659 1660 // Handle a reference to a proper constraint. 1661 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1662 return ast::ConstraintRef(cst, loc); 1663 1664 return emitErrorAndNote( 1665 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1666 "see the definition of `" + constraintName + "` here"); 1667 } 1668 // Handle single entity constraint code completion. 1669 case Token::code_complete: { 1670 // Try to infer the current type for use by code completion. 1671 ast::Type inferredType; 1672 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1673 allowNonCoreConstraints))) 1674 return failure(); 1675 1676 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1677 allowInlineTypeConstraints); 1678 } 1679 default: 1680 break; 1681 } 1682 return emitError(loc, "expected identifier constraint"); 1683 } 1684 1685 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1686 // Constraint arguments may apply more complex constraints via the arguments. 1687 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1688 1689 Optional<SMRange> typeConstraint; 1690 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1691 /*allowInlineTypeConstraints=*/false, 1692 allowNonCoreConstraints); 1693 } 1694 1695 //===----------------------------------------------------------------------===// 1696 // Exprs 1697 1698 FailureOr<ast::Expr *> Parser::parseExpr() { 1699 if (curToken.is(Token::underscore)) 1700 return parseUnderscoreExpr(); 1701 1702 // Parse the LHS expression. 1703 FailureOr<ast::Expr *> lhsExpr; 1704 switch (curToken.getKind()) { 1705 case Token::kw_attr: 1706 lhsExpr = parseAttributeExpr(); 1707 break; 1708 case Token::kw_Constraint: 1709 lhsExpr = parseInlineConstraintLambdaExpr(); 1710 break; 1711 case Token::identifier: 1712 lhsExpr = parseIdentifierExpr(); 1713 break; 1714 case Token::kw_op: 1715 lhsExpr = parseOperationExpr(); 1716 break; 1717 case Token::kw_Rewrite: 1718 lhsExpr = parseInlineRewriteLambdaExpr(); 1719 break; 1720 case Token::kw_type: 1721 lhsExpr = parseTypeExpr(); 1722 break; 1723 case Token::l_paren: 1724 lhsExpr = parseTupleExpr(); 1725 break; 1726 default: 1727 return emitError("expected expression"); 1728 } 1729 if (failed(lhsExpr)) 1730 return failure(); 1731 1732 // Check for an operator expression. 1733 while (true) { 1734 switch (curToken.getKind()) { 1735 case Token::dot: 1736 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1737 break; 1738 case Token::l_paren: 1739 lhsExpr = parseCallExpr(*lhsExpr); 1740 break; 1741 default: 1742 return lhsExpr; 1743 } 1744 if (failed(lhsExpr)) 1745 return failure(); 1746 } 1747 } 1748 1749 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1750 SMRange loc = curToken.getLoc(); 1751 consumeToken(Token::kw_attr); 1752 1753 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1754 // identifier. 1755 if (!consumeIf(Token::less)) { 1756 resetToken(loc); 1757 return parseIdentifierExpr(); 1758 } 1759 1760 if (!curToken.isString()) 1761 return emitError("expected string literal containing MLIR attribute"); 1762 std::string attrExpr = curToken.getStringValue(); 1763 consumeToken(); 1764 1765 if (failed( 1766 parseToken(Token::greater, "expected `>` after attribute literal"))) 1767 return failure(); 1768 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1769 } 1770 1771 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1772 SMRange loc = curToken.getLoc(); 1773 consumeToken(Token::l_paren); 1774 1775 // Parse the arguments of the call. 1776 SmallVector<ast::Expr *> arguments; 1777 if (curToken.isNot(Token::r_paren)) { 1778 do { 1779 // Handle code completion for the call arguments. 1780 if (curToken.is(Token::code_complete)) { 1781 codeCompleteCallSignature(parentExpr, arguments.size()); 1782 return failure(); 1783 } 1784 1785 FailureOr<ast::Expr *> argument = parseExpr(); 1786 if (failed(argument)) 1787 return failure(); 1788 arguments.push_back(*argument); 1789 } while (consumeIf(Token::comma)); 1790 } 1791 loc.End = curToken.getEndLoc(); 1792 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1793 return failure(); 1794 1795 return createCallExpr(loc, parentExpr, arguments); 1796 } 1797 1798 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1799 ast::Decl *decl = curDeclScope->lookup(name); 1800 if (!decl) 1801 return emitError(loc, "undefined reference to `" + name + "`"); 1802 1803 return createDeclRefExpr(loc, decl); 1804 } 1805 1806 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1807 StringRef name = curToken.getSpelling(); 1808 SMRange nameLoc = curToken.getLoc(); 1809 consumeToken(); 1810 1811 // Check to see if this is a decl ref expression that defines a variable 1812 // inline. 1813 if (consumeIf(Token::colon)) { 1814 SmallVector<ast::ConstraintRef> constraints; 1815 if (failed(parseVariableDeclConstraintList(constraints))) 1816 return failure(); 1817 ast::Type type; 1818 if (failed(validateVariableConstraints(constraints, type))) 1819 return failure(); 1820 return createInlineVariableExpr(type, name, nameLoc, constraints); 1821 } 1822 1823 return parseDeclRefExpr(name, nameLoc); 1824 } 1825 1826 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1827 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1828 if (failed(decl)) 1829 return failure(); 1830 1831 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1832 ast::ConstraintType::get(ctx)); 1833 } 1834 1835 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1836 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1837 if (failed(decl)) 1838 return failure(); 1839 1840 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1841 ast::RewriteType::get(ctx)); 1842 } 1843 1844 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1845 SMRange loc = curToken.getLoc(); 1846 consumeToken(Token::dot); 1847 1848 // Check for code completion of the member name. 1849 if (curToken.is(Token::code_complete)) 1850 return codeCompleteMemberAccess(parentExpr); 1851 1852 // Parse the member name. 1853 Token memberNameTok = curToken; 1854 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1855 !memberNameTok.isKeyword()) 1856 return emitError(loc, "expected identifier or numeric member name"); 1857 StringRef memberName = memberNameTok.getSpelling(); 1858 consumeToken(); 1859 1860 return createMemberAccessExpr(parentExpr, memberName, loc); 1861 } 1862 1863 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1864 SMRange loc = curToken.getLoc(); 1865 1866 // Check for code completion for the dialect name. 1867 if (curToken.is(Token::code_complete)) 1868 return codeCompleteDialectName(); 1869 1870 // Handle the case of an no operation name. 1871 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1872 if (allowEmptyName) 1873 return ast::OpNameDecl::create(ctx, SMRange()); 1874 return emitError("expected dialect namespace"); 1875 } 1876 StringRef name = curToken.getSpelling(); 1877 consumeToken(); 1878 1879 // Otherwise, this is a literal operation name. 1880 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1881 return failure(); 1882 1883 // Check for code completion for the operation name. 1884 if (curToken.is(Token::code_complete)) 1885 return codeCompleteOperationName(name); 1886 1887 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1888 return emitError("expected operation name after dialect namespace"); 1889 1890 name = StringRef(name.data(), name.size() + 1); 1891 do { 1892 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1893 loc.End = curToken.getEndLoc(); 1894 consumeToken(); 1895 } while (curToken.isAny(Token::identifier, Token::dot) || 1896 curToken.isKeyword()); 1897 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1898 } 1899 1900 FailureOr<ast::OpNameDecl *> 1901 Parser::parseWrappedOperationName(bool allowEmptyName) { 1902 if (!consumeIf(Token::less)) 1903 return ast::OpNameDecl::create(ctx, SMRange()); 1904 1905 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1906 if (failed(opNameDecl)) 1907 return failure(); 1908 1909 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1910 return failure(); 1911 return opNameDecl; 1912 } 1913 1914 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1915 SMRange loc = curToken.getLoc(); 1916 consumeToken(Token::kw_op); 1917 1918 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1919 // identifier. 1920 if (curToken.isNot(Token::less)) { 1921 resetToken(loc); 1922 return parseIdentifierExpr(); 1923 } 1924 1925 // Parse the operation name. The name may be elided, in which case the 1926 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1927 // `Operation*`). Operation names within a rewrite context must be named. 1928 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1929 FailureOr<ast::OpNameDecl *> opNameDecl = 1930 parseWrappedOperationName(allowEmptyName); 1931 if (failed(opNameDecl)) 1932 return failure(); 1933 Optional<StringRef> opName = (*opNameDecl)->getName(); 1934 1935 // Functor used to create an implicit range variable, used for implicit "all" 1936 // operand or results variables. 1937 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1938 FailureOr<ast::VariableDecl *> rangeVar = 1939 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1940 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1941 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1942 }; 1943 1944 // Check for the optional list of operands. 1945 SmallVector<ast::Expr *> operands; 1946 if (!consumeIf(Token::l_paren)) { 1947 // If the operand list isn't specified and we are in a match context, define 1948 // an inplace unconstrained operand range corresponding to all of the 1949 // operands of the operation. This avoids treating zero operands the same 1950 // way as "unconstrained operands". 1951 if (parserContext != ParserContext::Rewrite) { 1952 operands.push_back(createImplicitRangeVar( 1953 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1954 } 1955 } else if (!consumeIf(Token::r_paren)) { 1956 // Check for operand signature code completion. 1957 if (curToken.is(Token::code_complete)) { 1958 codeCompleteOperationOperandsSignature(opName, operands.size()); 1959 return failure(); 1960 } 1961 1962 // If the operand list was specified and non-empty, parse the operands. 1963 do { 1964 FailureOr<ast::Expr *> operand = parseExpr(); 1965 if (failed(operand)) 1966 return failure(); 1967 operands.push_back(*operand); 1968 } while (consumeIf(Token::comma)); 1969 1970 if (failed(parseToken(Token::r_paren, 1971 "expected `)` after operation operand list"))) 1972 return failure(); 1973 } 1974 1975 // Check for the optional list of attributes. 1976 SmallVector<ast::NamedAttributeDecl *> attributes; 1977 if (consumeIf(Token::l_brace)) { 1978 do { 1979 FailureOr<ast::NamedAttributeDecl *> decl = 1980 parseNamedAttributeDecl(opName); 1981 if (failed(decl)) 1982 return failure(); 1983 attributes.emplace_back(*decl); 1984 } while (consumeIf(Token::comma)); 1985 1986 if (failed(parseToken(Token::r_brace, 1987 "expected `}` after operation attribute list"))) 1988 return failure(); 1989 } 1990 1991 // Check for the optional list of result types. 1992 SmallVector<ast::Expr *> resultTypes; 1993 if (consumeIf(Token::arrow)) { 1994 if (failed(parseToken(Token::l_paren, 1995 "expected `(` before operation result type list"))) 1996 return failure(); 1997 1998 // Handle the case of an empty result list. 1999 if (!consumeIf(Token::r_paren)) { 2000 do { 2001 // Check for result signature code completion. 2002 if (curToken.is(Token::code_complete)) { 2003 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2004 return failure(); 2005 } 2006 2007 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2008 if (failed(resultTypeExpr)) 2009 return failure(); 2010 resultTypes.push_back(*resultTypeExpr); 2011 } while (consumeIf(Token::comma)); 2012 2013 if (failed(parseToken(Token::r_paren, 2014 "expected `)` after operation result type list"))) 2015 return failure(); 2016 } 2017 } else if (parserContext != ParserContext::Rewrite) { 2018 // If the result list isn't specified and we are in a match context, define 2019 // an inplace unconstrained result range corresponding to all of the results 2020 // of the operation. This avoids treating zero results the same way as 2021 // "unconstrained results". 2022 resultTypes.push_back(createImplicitRangeVar( 2023 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2024 } 2025 2026 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2027 resultTypes); 2028 } 2029 2030 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2031 SMRange loc = curToken.getLoc(); 2032 consumeToken(Token::l_paren); 2033 2034 DenseMap<StringRef, SMRange> usedNames; 2035 SmallVector<StringRef> elementNames; 2036 SmallVector<ast::Expr *> elements; 2037 if (curToken.isNot(Token::r_paren)) { 2038 do { 2039 // Check for the optional element name assignment before the value. 2040 StringRef elementName; 2041 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2042 Token elementNameTok = curToken; 2043 consumeToken(); 2044 2045 // The element name is only present if followed by an `=`. 2046 if (consumeIf(Token::equal)) { 2047 elementName = elementNameTok.getSpelling(); 2048 2049 // Check to see if this name is already used. 2050 auto elementNameIt = 2051 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2052 if (!elementNameIt.second) { 2053 return emitErrorAndNote( 2054 elementNameTok.getLoc(), 2055 llvm::formatv("duplicate tuple element label `{0}`", 2056 elementName), 2057 elementNameIt.first->getSecond(), 2058 "see previous label use here"); 2059 } 2060 } else { 2061 // Otherwise, we treat this as part of an expression so reset the 2062 // lexer. 2063 resetToken(elementNameTok.getLoc()); 2064 } 2065 } 2066 elementNames.push_back(elementName); 2067 2068 // Parse the tuple element value. 2069 FailureOr<ast::Expr *> element = parseExpr(); 2070 if (failed(element)) 2071 return failure(); 2072 elements.push_back(*element); 2073 } while (consumeIf(Token::comma)); 2074 } 2075 loc.End = curToken.getEndLoc(); 2076 if (failed( 2077 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2078 return failure(); 2079 return createTupleExpr(loc, elements, elementNames); 2080 } 2081 2082 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2083 SMRange loc = curToken.getLoc(); 2084 consumeToken(Token::kw_type); 2085 2086 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2087 // identifier. 2088 if (!consumeIf(Token::less)) { 2089 resetToken(loc); 2090 return parseIdentifierExpr(); 2091 } 2092 2093 if (!curToken.isString()) 2094 return emitError("expected string literal containing MLIR type"); 2095 std::string attrExpr = curToken.getStringValue(); 2096 consumeToken(); 2097 2098 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2099 return failure(); 2100 return ast::TypeExpr::create(ctx, loc, attrExpr); 2101 } 2102 2103 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2104 StringRef name = curToken.getSpelling(); 2105 SMRange nameLoc = curToken.getLoc(); 2106 consumeToken(Token::underscore); 2107 2108 // Underscore expressions require a constraint list. 2109 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2110 return failure(); 2111 2112 // Parse the constraints for the expression. 2113 SmallVector<ast::ConstraintRef> constraints; 2114 if (failed(parseVariableDeclConstraintList(constraints))) 2115 return failure(); 2116 2117 ast::Type type; 2118 if (failed(validateVariableConstraints(constraints, type))) 2119 return failure(); 2120 return createInlineVariableExpr(type, name, nameLoc, constraints); 2121 } 2122 2123 //===----------------------------------------------------------------------===// 2124 // Stmts 2125 2126 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2127 FailureOr<ast::Stmt *> stmt; 2128 switch (curToken.getKind()) { 2129 case Token::kw_erase: 2130 stmt = parseEraseStmt(); 2131 break; 2132 case Token::kw_let: 2133 stmt = parseLetStmt(); 2134 break; 2135 case Token::kw_replace: 2136 stmt = parseReplaceStmt(); 2137 break; 2138 case Token::kw_return: 2139 stmt = parseReturnStmt(); 2140 break; 2141 case Token::kw_rewrite: 2142 stmt = parseRewriteStmt(); 2143 break; 2144 default: 2145 stmt = parseExpr(); 2146 break; 2147 } 2148 if (failed(stmt) || 2149 (expectTerminalSemicolon && 2150 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2151 return failure(); 2152 return stmt; 2153 } 2154 2155 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2156 SMLoc startLoc = curToken.getStartLoc(); 2157 consumeToken(Token::l_brace); 2158 2159 // Push a new block scope and parse any nested statements. 2160 pushDeclScope(); 2161 SmallVector<ast::Stmt *> statements; 2162 while (curToken.isNot(Token::r_brace)) { 2163 FailureOr<ast::Stmt *> statement = parseStmt(); 2164 if (failed(statement)) 2165 return popDeclScope(), failure(); 2166 statements.push_back(*statement); 2167 } 2168 popDeclScope(); 2169 2170 // Consume the end brace. 2171 SMRange location(startLoc, curToken.getEndLoc()); 2172 consumeToken(Token::r_brace); 2173 2174 return ast::CompoundStmt::create(ctx, location, statements); 2175 } 2176 2177 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2178 if (parserContext == ParserContext::Constraint) 2179 return emitError("`erase` cannot be used within a Constraint"); 2180 SMRange loc = curToken.getLoc(); 2181 consumeToken(Token::kw_erase); 2182 2183 // Parse the root operation expression. 2184 FailureOr<ast::Expr *> rootOp = parseExpr(); 2185 if (failed(rootOp)) 2186 return failure(); 2187 2188 return createEraseStmt(loc, *rootOp); 2189 } 2190 2191 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2192 SMRange loc = curToken.getLoc(); 2193 consumeToken(Token::kw_let); 2194 2195 // Parse the name of the new variable. 2196 SMRange varLoc = curToken.getLoc(); 2197 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2198 // `_` is a reserved variable name. 2199 if (curToken.is(Token::underscore)) { 2200 return emitError(varLoc, 2201 "`_` may only be used to define \"inline\" variables"); 2202 } 2203 return emitError(varLoc, 2204 "expected identifier after `let` to name a new variable"); 2205 } 2206 StringRef varName = curToken.getSpelling(); 2207 consumeToken(); 2208 2209 // Parse the optional set of constraints. 2210 SmallVector<ast::ConstraintRef> constraints; 2211 if (consumeIf(Token::colon) && 2212 failed(parseVariableDeclConstraintList(constraints))) 2213 return failure(); 2214 2215 // Parse the optional initializer expression. 2216 ast::Expr *initializer = nullptr; 2217 if (consumeIf(Token::equal)) { 2218 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2219 if (failed(initOrFailure)) 2220 return failure(); 2221 initializer = *initOrFailure; 2222 2223 // Check that the constraints are compatible with having an initializer, 2224 // e.g. type constraints cannot be used with initializers. 2225 for (ast::ConstraintRef constraint : constraints) { 2226 LogicalResult result = 2227 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2228 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2229 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2230 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2231 return this->emitError( 2232 constraint.referenceLoc, 2233 "type constraints are not permitted on variables with " 2234 "initializers"); 2235 } 2236 return success(); 2237 }) 2238 .Default(success()); 2239 if (failed(result)) 2240 return failure(); 2241 } 2242 } 2243 2244 FailureOr<ast::VariableDecl *> varDecl = 2245 createVariableDecl(varName, varLoc, initializer, constraints); 2246 if (failed(varDecl)) 2247 return failure(); 2248 return ast::LetStmt::create(ctx, loc, *varDecl); 2249 } 2250 2251 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2252 if (parserContext == ParserContext::Constraint) 2253 return emitError("`replace` cannot be used within a Constraint"); 2254 SMRange loc = curToken.getLoc(); 2255 consumeToken(Token::kw_replace); 2256 2257 // Parse the root operation expression. 2258 FailureOr<ast::Expr *> rootOp = parseExpr(); 2259 if (failed(rootOp)) 2260 return failure(); 2261 2262 if (failed( 2263 parseToken(Token::kw_with, "expected `with` after root operation"))) 2264 return failure(); 2265 2266 // The replacement portion of this statement is within a rewrite context. 2267 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2268 ParserContext::Rewrite); 2269 2270 // Parse the replacement values. 2271 SmallVector<ast::Expr *> replValues; 2272 if (consumeIf(Token::l_paren)) { 2273 if (consumeIf(Token::r_paren)) { 2274 return emitError( 2275 loc, "expected at least one replacement value, consider using " 2276 "`erase` if no replacement values are desired"); 2277 } 2278 2279 do { 2280 FailureOr<ast::Expr *> replExpr = parseExpr(); 2281 if (failed(replExpr)) 2282 return failure(); 2283 replValues.emplace_back(*replExpr); 2284 } while (consumeIf(Token::comma)); 2285 2286 if (failed(parseToken(Token::r_paren, 2287 "expected `)` after replacement values"))) 2288 return failure(); 2289 } else { 2290 FailureOr<ast::Expr *> replExpr = parseExpr(); 2291 if (failed(replExpr)) 2292 return failure(); 2293 replValues.emplace_back(*replExpr); 2294 } 2295 2296 return createReplaceStmt(loc, *rootOp, replValues); 2297 } 2298 2299 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2300 SMRange loc = curToken.getLoc(); 2301 consumeToken(Token::kw_return); 2302 2303 // Parse the result value. 2304 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2305 if (failed(resultExpr)) 2306 return failure(); 2307 2308 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2309 } 2310 2311 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2312 if (parserContext == ParserContext::Constraint) 2313 return emitError("`rewrite` cannot be used within a Constraint"); 2314 SMRange loc = curToken.getLoc(); 2315 consumeToken(Token::kw_rewrite); 2316 2317 // Parse the root operation. 2318 FailureOr<ast::Expr *> rootOp = parseExpr(); 2319 if (failed(rootOp)) 2320 return failure(); 2321 2322 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2323 return failure(); 2324 2325 if (curToken.isNot(Token::l_brace)) 2326 return emitError("expected `{` to start rewrite body"); 2327 2328 // The rewrite body of this statement is within a rewrite context. 2329 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2330 ParserContext::Rewrite); 2331 2332 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2333 if (failed(rewriteBody)) 2334 return failure(); 2335 2336 // Verify the rewrite body. 2337 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2338 if (isa<ast::ReturnStmt>(stmt)) { 2339 return emitError(stmt->getLoc(), 2340 "`return` statements are only permitted within a " 2341 "`Constraint` or `Rewrite` body"); 2342 } 2343 } 2344 2345 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2346 } 2347 2348 //===----------------------------------------------------------------------===// 2349 // Creation+Analysis 2350 //===----------------------------------------------------------------------===// 2351 2352 //===----------------------------------------------------------------------===// 2353 // Decls 2354 2355 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2356 // Unwrap reference expressions. 2357 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2358 node = init->getDecl(); 2359 return dyn_cast<ast::CallableDecl>(node); 2360 } 2361 2362 FailureOr<ast::PatternDecl *> 2363 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2364 const ParsedPatternMetadata &metadata, 2365 ast::CompoundStmt *body) { 2366 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2367 metadata.hasBoundedRecursion, body); 2368 } 2369 2370 ast::Type Parser::createUserConstraintRewriteResultType( 2371 ArrayRef<ast::VariableDecl *> results) { 2372 // Single result decls use the type of the single result. 2373 if (results.size() == 1) 2374 return results[0]->getType(); 2375 2376 // Multiple results use a tuple type, with the types and names grabbed from 2377 // the result variable decls. 2378 auto resultTypes = llvm::map_range( 2379 results, [&](const auto *result) { return result->getType(); }); 2380 auto resultNames = llvm::map_range( 2381 results, [&](const auto *result) { return result->getName().getName(); }); 2382 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2383 llvm::to_vector(resultNames)); 2384 } 2385 2386 template <typename T> 2387 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2388 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2389 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2390 ast::CompoundStmt *body) { 2391 if (!body->getChildren().empty()) { 2392 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2393 ast::Expr *resultExpr = retStmt->getResultExpr(); 2394 2395 // Process the result of the decl. If no explicit signature results 2396 // were provided, check for return type inference. Otherwise, check that 2397 // the return expression can be converted to the expected type. 2398 if (results.empty()) 2399 resultType = resultExpr->getType(); 2400 else if (failed(convertExpressionTo(resultExpr, resultType))) 2401 return failure(); 2402 else 2403 retStmt->setResultExpr(resultExpr); 2404 } 2405 } 2406 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2407 } 2408 2409 FailureOr<ast::VariableDecl *> 2410 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2411 ArrayRef<ast::ConstraintRef> constraints) { 2412 // The type of the variable, which is expected to be inferred by either a 2413 // constraint or an initializer expression. 2414 ast::Type type; 2415 if (failed(validateVariableConstraints(constraints, type))) 2416 return failure(); 2417 2418 if (initializer) { 2419 // Update the variable type based on the initializer, or try to convert the 2420 // initializer to the existing type. 2421 if (!type) 2422 type = initializer->getType(); 2423 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2424 type = mergedType; 2425 else if (failed(convertExpressionTo(initializer, type))) 2426 return failure(); 2427 2428 // Otherwise, if there is no initializer check that the type has already 2429 // been resolved from the constraint list. 2430 } else if (!type) { 2431 return emitErrorAndNote( 2432 loc, "unable to infer type for variable `" + name + "`", loc, 2433 "the type of a variable must be inferable from the constraint " 2434 "list or the initializer"); 2435 } 2436 2437 // Constraint types cannot be used when defining variables. 2438 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2439 return emitError( 2440 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2441 } 2442 2443 // Try to define a variable with the given name. 2444 FailureOr<ast::VariableDecl *> varDecl = 2445 defineVariableDecl(name, loc, type, initializer, constraints); 2446 if (failed(varDecl)) 2447 return failure(); 2448 2449 return *varDecl; 2450 } 2451 2452 FailureOr<ast::VariableDecl *> 2453 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2454 const ast::ConstraintRef &constraint) { 2455 // Constraint arguments may apply more complex constraints via the arguments. 2456 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2457 ast::Type argType; 2458 if (failed(validateVariableConstraint(constraint, argType, 2459 allowNonCoreConstraints))) 2460 return failure(); 2461 return defineVariableDecl(name, loc, argType, constraint); 2462 } 2463 2464 LogicalResult 2465 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2466 ast::Type &inferredType, 2467 bool allowNonCoreConstraints) { 2468 for (const ast::ConstraintRef &ref : constraints) 2469 if (failed(validateVariableConstraint(ref, inferredType, 2470 allowNonCoreConstraints))) 2471 return failure(); 2472 return success(); 2473 } 2474 2475 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2476 ast::Type &inferredType, 2477 bool allowNonCoreConstraints) { 2478 ast::Type constraintType; 2479 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2480 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2481 if (failed(validateTypeConstraintExpr(typeExpr))) 2482 return failure(); 2483 } 2484 constraintType = ast::AttributeType::get(ctx); 2485 } else if (const auto *cst = 2486 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2487 constraintType = ast::OperationType::get(ctx, cst->getName()); 2488 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2489 constraintType = typeTy; 2490 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2491 constraintType = typeRangeTy; 2492 } else if (const auto *cst = 2493 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2494 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2495 if (failed(validateTypeConstraintExpr(typeExpr))) 2496 return failure(); 2497 } 2498 constraintType = valueTy; 2499 } else if (const auto *cst = 2500 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2501 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2502 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2503 return failure(); 2504 } 2505 constraintType = valueRangeTy; 2506 } else if (const auto *cst = 2507 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2508 if (!allowNonCoreConstraints) { 2509 return emitError(ref.referenceLoc, 2510 "`Rewrite` arguments and results are only permitted to " 2511 "use core constraints, such as `Attr`, `Op`, `Type`, " 2512 "`TypeRange`, `Value`, `ValueRange`"); 2513 } 2514 2515 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2516 if (inputs.size() != 1) { 2517 return emitErrorAndNote(ref.referenceLoc, 2518 "`Constraint`s applied via a variable constraint " 2519 "list must take a single input, but got " + 2520 Twine(inputs.size()), 2521 cst->getLoc(), 2522 "see definition of constraint here"); 2523 } 2524 constraintType = inputs.front()->getType(); 2525 } else { 2526 llvm_unreachable("unknown constraint type"); 2527 } 2528 2529 // Check that the constraint type is compatible with the current inferred 2530 // type. 2531 if (!inferredType) { 2532 inferredType = constraintType; 2533 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2534 inferredType = mergedTy; 2535 } else { 2536 return emitError(ref.referenceLoc, 2537 llvm::formatv("constraint type `{0}` is incompatible " 2538 "with the previously inferred type `{1}`", 2539 constraintType, inferredType)); 2540 } 2541 return success(); 2542 } 2543 2544 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2545 ast::Type typeExprType = typeExpr->getType(); 2546 if (typeExprType != typeTy) { 2547 return emitError(typeExpr->getLoc(), 2548 "expected expression of `Type` in type constraint"); 2549 } 2550 return success(); 2551 } 2552 2553 LogicalResult 2554 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2555 ast::Type typeExprType = typeExpr->getType(); 2556 if (typeExprType != typeRangeTy) { 2557 return emitError(typeExpr->getLoc(), 2558 "expected expression of `TypeRange` in type constraint"); 2559 } 2560 return success(); 2561 } 2562 2563 //===----------------------------------------------------------------------===// 2564 // Exprs 2565 2566 FailureOr<ast::CallExpr *> 2567 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2568 MutableArrayRef<ast::Expr *> arguments) { 2569 ast::Type parentType = parentExpr->getType(); 2570 2571 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2572 if (!callableDecl) { 2573 return emitError(loc, 2574 llvm::formatv("expected a reference to a callable " 2575 "`Constraint` or `Rewrite`, but got: `{0}`", 2576 parentType)); 2577 } 2578 if (parserContext == ParserContext::Rewrite) { 2579 if (isa<ast::UserConstraintDecl>(callableDecl)) 2580 return emitError( 2581 loc, "unable to invoke `Constraint` within a rewrite section"); 2582 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2583 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2584 } 2585 2586 // Verify the arguments of the call. 2587 /// Handle size mismatch. 2588 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2589 if (callArgs.size() != arguments.size()) { 2590 return emitErrorAndNote( 2591 loc, 2592 llvm::formatv("invalid number of arguments for {0} call; expected " 2593 "{1}, but got {2}", 2594 callableDecl->getCallableType(), callArgs.size(), 2595 arguments.size()), 2596 callableDecl->getLoc(), 2597 llvm::formatv("see the definition of {0} here", 2598 callableDecl->getName()->getName())); 2599 } 2600 2601 /// Handle argument type mismatch. 2602 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2603 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2604 callableDecl->getName()->getName()), 2605 callableDecl->getLoc()); 2606 }; 2607 for (auto it : llvm::zip(callArgs, arguments)) { 2608 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2609 attachDiagFn))) 2610 return failure(); 2611 } 2612 2613 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2614 callableDecl->getResultType()); 2615 } 2616 2617 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2618 ast::Decl *decl) { 2619 // Check the type of decl being referenced. 2620 ast::Type declType; 2621 if (isa<ast::ConstraintDecl>(decl)) 2622 declType = ast::ConstraintType::get(ctx); 2623 else if (isa<ast::UserRewriteDecl>(decl)) 2624 declType = ast::RewriteType::get(ctx); 2625 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2626 declType = varDecl->getType(); 2627 else 2628 return emitError(loc, "invalid reference to `" + 2629 decl->getName()->getName() + "`"); 2630 2631 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2632 } 2633 2634 FailureOr<ast::DeclRefExpr *> 2635 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2636 ArrayRef<ast::ConstraintRef> constraints) { 2637 FailureOr<ast::VariableDecl *> decl = 2638 defineVariableDecl(name, loc, type, constraints); 2639 if (failed(decl)) 2640 return failure(); 2641 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2642 } 2643 2644 FailureOr<ast::MemberAccessExpr *> 2645 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2646 SMRange loc) { 2647 // Validate the member name for the given parent expression. 2648 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2649 if (failed(memberType)) 2650 return failure(); 2651 2652 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2653 } 2654 2655 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2656 StringRef name, SMRange loc) { 2657 ast::Type parentType = parentExpr->getType(); 2658 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2659 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2660 return valueRangeTy; 2661 2662 // Verify member access based on the operation type. 2663 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2664 auto results = odsOp->getResults(); 2665 2666 // Handle indexed results. 2667 unsigned index = 0; 2668 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2669 index < results.size()) { 2670 return results[index].isVariadic() ? valueRangeTy : valueTy; 2671 } 2672 2673 // Handle named results. 2674 const auto *it = llvm::find_if(results, [&](const auto &result) { 2675 return result.getName() == name; 2676 }); 2677 if (it != results.end()) 2678 return it->isVariadic() ? valueRangeTy : valueTy; 2679 } 2680 2681 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2682 // Handle indexed results. 2683 unsigned index = 0; 2684 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2685 index < tupleType.size()) { 2686 return tupleType.getElementTypes()[index]; 2687 } 2688 2689 // Handle named results. 2690 auto elementNames = tupleType.getElementNames(); 2691 const auto *it = llvm::find(elementNames, name); 2692 if (it != elementNames.end()) 2693 return tupleType.getElementTypes()[it - elementNames.begin()]; 2694 } 2695 return emitError( 2696 loc, 2697 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2698 name, parentType)); 2699 } 2700 2701 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2702 SMRange loc, const ast::OpNameDecl *name, 2703 MutableArrayRef<ast::Expr *> operands, 2704 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2705 MutableArrayRef<ast::Expr *> results) { 2706 Optional<StringRef> opNameRef = name->getName(); 2707 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2708 2709 // Verify the inputs operands. 2710 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2711 return failure(); 2712 2713 // Verify the attribute list. 2714 for (ast::NamedAttributeDecl *attr : attributes) { 2715 // Check for an attribute type, or a type awaiting resolution. 2716 ast::Type attrType = attr->getValue()->getType(); 2717 if (!attrType.isa<ast::AttributeType>()) { 2718 return emitError( 2719 attr->getValue()->getLoc(), 2720 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2721 } 2722 } 2723 2724 // Verify the result types. 2725 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2726 return failure(); 2727 2728 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2729 attributes); 2730 } 2731 2732 LogicalResult 2733 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2734 const ods::Operation *odsOp, 2735 MutableArrayRef<ast::Expr *> operands) { 2736 return validateOperationOperandsOrResults( 2737 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2738 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2739 valueRangeTy); 2740 } 2741 2742 LogicalResult 2743 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2744 const ods::Operation *odsOp, 2745 MutableArrayRef<ast::Expr *> results) { 2746 return validateOperationOperandsOrResults( 2747 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2748 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2749 } 2750 2751 LogicalResult Parser::validateOperationOperandsOrResults( 2752 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2753 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2754 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2755 ast::Type rangeTy) { 2756 // All operation types accept a single range parameter. 2757 if (values.size() == 1) { 2758 if (failed(convertExpressionTo(values[0], rangeTy))) 2759 return failure(); 2760 return success(); 2761 } 2762 2763 /// If the operation has ODS information, we can more accurately verify the 2764 /// values. 2765 if (odsOpLoc) { 2766 if (odsValues.size() != values.size()) { 2767 return emitErrorAndNote( 2768 loc, 2769 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2770 "{2}, but got {3}", 2771 groupName, *name, odsValues.size(), values.size()), 2772 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2773 } 2774 auto diagFn = [&](ast::Diagnostic &diag) { 2775 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2776 *odsOpLoc); 2777 }; 2778 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2779 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2780 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2781 return failure(); 2782 } 2783 return success(); 2784 } 2785 2786 // Otherwise, accept the value groups as they have been defined and just 2787 // ensure they are one of the expected types. 2788 for (ast::Expr *&valueExpr : values) { 2789 ast::Type valueExprType = valueExpr->getType(); 2790 2791 // Check if this is one of the expected types. 2792 if (valueExprType == rangeTy || valueExprType == singleTy) 2793 continue; 2794 2795 // If the operand is an Operation, allow converting to a Value or 2796 // ValueRange. This situations arises quite often with nested operation 2797 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2798 if (singleTy == valueTy) { 2799 if (valueExprType.isa<ast::OperationType>()) { 2800 valueExpr = convertOpToValue(valueExpr); 2801 continue; 2802 } 2803 } 2804 2805 return emitError( 2806 valueExpr->getLoc(), 2807 llvm::formatv( 2808 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2809 singleTy, rangeTy, valueExprType)); 2810 } 2811 return success(); 2812 } 2813 2814 FailureOr<ast::TupleExpr *> 2815 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2816 ArrayRef<StringRef> elementNames) { 2817 for (const ast::Expr *element : elements) { 2818 ast::Type eleTy = element->getType(); 2819 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2820 return emitError( 2821 element->getLoc(), 2822 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2823 } 2824 } 2825 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2826 } 2827 2828 //===----------------------------------------------------------------------===// 2829 // Stmts 2830 2831 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2832 ast::Expr *rootOp) { 2833 // Check that root is an Operation. 2834 ast::Type rootType = rootOp->getType(); 2835 if (!rootType.isa<ast::OperationType>()) 2836 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2837 2838 return ast::EraseStmt::create(ctx, loc, rootOp); 2839 } 2840 2841 FailureOr<ast::ReplaceStmt *> 2842 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2843 MutableArrayRef<ast::Expr *> replValues) { 2844 // Check that root is an Operation. 2845 ast::Type rootType = rootOp->getType(); 2846 if (!rootType.isa<ast::OperationType>()) { 2847 return emitError( 2848 rootOp->getLoc(), 2849 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2850 } 2851 2852 // If there are multiple replacement values, we implicitly convert any Op 2853 // expressions to the value form. 2854 bool shouldConvertOpToValues = replValues.size() > 1; 2855 for (ast::Expr *&replExpr : replValues) { 2856 ast::Type replType = replExpr->getType(); 2857 2858 // Check that replExpr is an Operation, Value, or ValueRange. 2859 if (replType.isa<ast::OperationType>()) { 2860 if (shouldConvertOpToValues) 2861 replExpr = convertOpToValue(replExpr); 2862 continue; 2863 } 2864 2865 if (replType != valueTy && replType != valueRangeTy) { 2866 return emitError(replExpr->getLoc(), 2867 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2868 "expression, but got `{0}`", 2869 replType)); 2870 } 2871 } 2872 2873 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2874 } 2875 2876 FailureOr<ast::RewriteStmt *> 2877 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2878 ast::CompoundStmt *rewriteBody) { 2879 // Check that root is an Operation. 2880 ast::Type rootType = rootOp->getType(); 2881 if (!rootType.isa<ast::OperationType>()) { 2882 return emitError( 2883 rootOp->getLoc(), 2884 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2885 } 2886 2887 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2888 } 2889 2890 //===----------------------------------------------------------------------===// 2891 // Code Completion 2892 //===----------------------------------------------------------------------===// 2893 2894 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2895 ast::Type parentType = parentExpr->getType(); 2896 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2897 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2898 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2899 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2900 return failure(); 2901 } 2902 2903 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2904 if (opName) 2905 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2906 return failure(); 2907 } 2908 2909 LogicalResult 2910 Parser::codeCompleteConstraintName(ast::Type inferredType, 2911 bool allowNonCoreConstraints, 2912 bool allowInlineTypeConstraints) { 2913 codeCompleteContext->codeCompleteConstraintName( 2914 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2915 curDeclScope); 2916 return failure(); 2917 } 2918 2919 LogicalResult Parser::codeCompleteDialectName() { 2920 codeCompleteContext->codeCompleteDialectName(); 2921 return failure(); 2922 } 2923 2924 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2925 codeCompleteContext->codeCompleteOperationName(dialectName); 2926 return failure(); 2927 } 2928 2929 LogicalResult Parser::codeCompletePatternMetadata() { 2930 codeCompleteContext->codeCompletePatternMetadata(); 2931 return failure(); 2932 } 2933 2934 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 2935 codeCompleteContext->codeCompleteIncludeFilename(curPath); 2936 return failure(); 2937 } 2938 2939 void Parser::codeCompleteCallSignature(ast::Node *parent, 2940 unsigned currentNumArgs) { 2941 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2942 if (!callableDecl) 2943 return; 2944 2945 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2946 } 2947 2948 void Parser::codeCompleteOperationOperandsSignature( 2949 Optional<StringRef> opName, unsigned currentNumOperands) { 2950 codeCompleteContext->codeCompleteOperationOperandsSignature( 2951 opName, currentNumOperands); 2952 } 2953 2954 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2955 unsigned currentNumResults) { 2956 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2957 currentNumResults); 2958 } 2959 2960 //===----------------------------------------------------------------------===// 2961 // Parser 2962 //===----------------------------------------------------------------------===// 2963 2964 FailureOr<ast::Module *> 2965 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2966 CodeCompleteContext *codeCompleteContext) { 2967 Parser parser(ctx, sourceMgr, codeCompleteContext); 2968 return parser.parseModule(); 2969 } 2970