1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 428 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 429 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 430 unsigned currentNumOperands); 431 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 432 unsigned currentNumResults); 433 434 //===--------------------------------------------------------------------===// 435 // Lexer Utilities 436 //===--------------------------------------------------------------------===// 437 438 /// If the current token has the specified kind, consume it and return true. 439 /// If not, return false. 440 bool consumeIf(Token::Kind kind) { 441 if (curToken.isNot(kind)) 442 return false; 443 consumeToken(kind); 444 return true; 445 } 446 447 /// Advance the current lexer onto the next token. 448 void consumeToken() { 449 assert(curToken.isNot(Token::eof, Token::error) && 450 "shouldn't advance past EOF or errors"); 451 curToken = lexer.lexToken(); 452 } 453 454 /// Advance the current lexer onto the next token, asserting what the expected 455 /// current token is. This is preferred to the above method because it leads 456 /// to more self-documenting code with better checking. 457 void consumeToken(Token::Kind kind) { 458 assert(curToken.is(kind) && "consumed an unexpected token"); 459 consumeToken(); 460 } 461 462 /// Reset the lexer to the location at the given position. 463 void resetToken(SMRange tokLoc) { 464 lexer.resetPointer(tokLoc.Start.getPointer()); 465 curToken = lexer.lexToken(); 466 } 467 468 /// Consume the specified token if present and return success. On failure, 469 /// output a diagnostic and return failure. 470 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 471 if (curToken.getKind() != kind) 472 return emitError(curToken.getLoc(), msg); 473 consumeToken(); 474 return success(); 475 } 476 LogicalResult emitError(SMRange loc, const Twine &msg) { 477 lexer.emitError(loc, msg); 478 return failure(); 479 } 480 LogicalResult emitError(const Twine &msg) { 481 return emitError(curToken.getLoc(), msg); 482 } 483 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 484 const Twine ¬e) { 485 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 486 return failure(); 487 } 488 489 //===--------------------------------------------------------------------===// 490 // Fields 491 //===--------------------------------------------------------------------===// 492 493 /// The owning AST context. 494 ast::Context &ctx; 495 496 /// The lexer of this parser. 497 Lexer lexer; 498 499 /// The current token within the lexer. 500 Token curToken; 501 502 /// The most recently defined decl scope. 503 ast::DeclScope *curDeclScope = nullptr; 504 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 505 506 /// The current context of the parser. 507 ParserContext parserContext = ParserContext::Global; 508 509 /// Cached types to simplify verification and expression creation. 510 ast::Type valueTy, valueRangeTy; 511 ast::Type typeTy, typeRangeTy; 512 ast::Type attrTy; 513 514 /// A counter used when naming anonymous constraints and rewrites. 515 unsigned anonymousDeclNameCounter = 0; 516 517 /// The optional code completion context. 518 CodeCompleteContext *codeCompleteContext; 519 }; 520 } // namespace 521 522 FailureOr<ast::Module *> Parser::parseModule() { 523 SMLoc moduleLoc = curToken.getStartLoc(); 524 pushDeclScope(); 525 526 // Parse the top-level decls of the module. 527 SmallVector<ast::Decl *> decls; 528 if (failed(parseModuleBody(decls))) 529 return popDeclScope(), failure(); 530 531 popDeclScope(); 532 return ast::Module::create(ctx, moduleLoc, decls); 533 } 534 535 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 536 while (curToken.isNot(Token::eof)) { 537 if (curToken.is(Token::directive)) { 538 if (failed(parseDirective(decls))) 539 return failure(); 540 continue; 541 } 542 543 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 544 if (failed(decl)) 545 return failure(); 546 decls.push_back(*decl); 547 } 548 return success(); 549 } 550 551 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 552 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 553 valueRangeTy); 554 } 555 556 LogicalResult Parser::convertExpressionTo( 557 ast::Expr *&expr, ast::Type type, 558 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 559 ast::Type exprType = expr->getType(); 560 if (exprType == type) 561 return success(); 562 563 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 564 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 565 expr->getLoc(), llvm::formatv("unable to convert expression of type " 566 "`{0}` to the expected type of " 567 "`{1}`", 568 exprType, type)); 569 if (noteAttachFn) 570 noteAttachFn(*diag); 571 return diag; 572 }; 573 574 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 575 // Two operation types are compatible if they have the same name, or if the 576 // expected type is more general. 577 if (auto opType = type.dyn_cast<ast::OperationType>()) { 578 if (opType.getName()) 579 return emitConvertError(); 580 return success(); 581 } 582 583 // An operation can always convert to a ValueRange. 584 if (type == valueRangeTy) { 585 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 586 valueRangeTy); 587 return success(); 588 } 589 590 // Allow conversion to a single value by constraining the result range. 591 if (type == valueTy) { 592 // If the operation is registered, we can verify if it can ever have a 593 // single result. 594 Optional<StringRef> opName = exprOpType.getName(); 595 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 596 if (odsOp->getResults().empty()) { 597 return emitConvertError()->attachNote( 598 llvm::formatv("see the definition of `{0}`, which was defined " 599 "with zero results", 600 odsOp->getName()), 601 odsOp->getLoc()); 602 } 603 604 unsigned numSingleResults = llvm::count_if( 605 odsOp->getResults(), [](const ods::OperandOrResult &result) { 606 return result.getVariableLengthKind() == 607 ods::VariableLengthKind::Single; 608 }); 609 if (numSingleResults > 1) { 610 return emitConvertError()->attachNote( 611 llvm::formatv("see the definition of `{0}`, which was defined " 612 "with at least {1} results", 613 odsOp->getName(), numSingleResults), 614 odsOp->getLoc()); 615 } 616 } 617 618 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 619 valueTy); 620 return success(); 621 } 622 return emitConvertError(); 623 } 624 625 // FIXME: Decide how to allow/support converting a single result to multiple, 626 // and multiple to a single result. For now, we just allow Single->Range, 627 // but this isn't something really supported in the PDL dialect. We should 628 // figure out some way to support both. 629 if ((exprType == valueTy || exprType == valueRangeTy) && 630 (type == valueTy || type == valueRangeTy)) 631 return success(); 632 if ((exprType == typeTy || exprType == typeRangeTy) && 633 (type == typeTy || type == typeRangeTy)) 634 return success(); 635 636 // Handle tuple types. 637 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 638 auto tupleType = type.dyn_cast<ast::TupleType>(); 639 if (!tupleType || tupleType.size() != exprTupleType.size()) 640 return emitConvertError(); 641 642 // Build a new tuple expression using each of the elements of the current 643 // tuple. 644 SmallVector<ast::Expr *> newExprs; 645 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 646 newExprs.push_back(ast::MemberAccessExpr::create( 647 ctx, expr->getLoc(), expr, llvm::to_string(i), 648 exprTupleType.getElementTypes()[i])); 649 650 auto diagFn = [&](ast::Diagnostic &diag) { 651 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 652 i, exprTupleType)); 653 if (noteAttachFn) 654 noteAttachFn(diag); 655 }; 656 if (failed(convertExpressionTo(newExprs.back(), 657 tupleType.getElementTypes()[i], diagFn))) 658 return failure(); 659 } 660 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 661 tupleType.getElementNames()); 662 return success(); 663 } 664 665 return emitConvertError(); 666 } 667 668 //===----------------------------------------------------------------------===// 669 // Directives 670 671 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 672 StringRef directive = curToken.getSpelling(); 673 if (directive == "#include") 674 return parseInclude(decls); 675 676 return emitError("unknown directive `" + directive + "`"); 677 } 678 679 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 680 SMRange loc = curToken.getLoc(); 681 consumeToken(Token::directive); 682 683 // Parse the file being included. 684 if (!curToken.isString()) 685 return emitError(loc, 686 "expected string file name after `include` directive"); 687 SMRange fileLoc = curToken.getLoc(); 688 std::string filenameStr = curToken.getStringValue(); 689 StringRef filename = filenameStr; 690 consumeToken(); 691 692 // Check the type of include. If ending with `.pdll`, this is another pdl file 693 // to be parsed along with the current module. 694 if (filename.endswith(".pdll")) { 695 if (failed(lexer.pushInclude(filename, fileLoc))) 696 return emitError(fileLoc, 697 "unable to open include file `" + filename + "`"); 698 699 // If we added the include successfully, parse it into the current module. 700 // Make sure to update to the next token after we finish parsing the nested 701 // file. 702 curToken = lexer.lexToken(); 703 LogicalResult result = parseModuleBody(decls); 704 curToken = lexer.lexToken(); 705 return result; 706 } 707 708 // Otherwise, this must be a `.td` include. 709 if (filename.endswith(".td")) 710 return parseTdInclude(filename, fileLoc, decls); 711 712 return emitError(fileLoc, 713 "expected include filename to end with `.pdll` or `.td`"); 714 } 715 716 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 717 SmallVectorImpl<ast::Decl *> &decls) { 718 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 719 720 // This class provides a context argument for the llvm::SourceMgr diagnostic 721 // handler. 722 struct DiagHandlerContext { 723 Parser &parser; 724 StringRef filename; 725 llvm::SMRange loc; 726 } handlerContext{*this, filename, fileLoc}; 727 728 // Set the diagnostic handler for the tablegen source manager. 729 llvm::SrcMgr.setDiagHandler( 730 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 731 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 732 (void)ctx->parser.emitError( 733 ctx->loc, 734 llvm::formatv("error while processing include file `{0}`: {1}", 735 ctx->filename, diag.getMessage())); 736 }, 737 &handlerContext); 738 739 // Use the source manager to open the file, but don't yet add it. 740 std::string includedFile; 741 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 742 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 743 if (!includeBuffer) 744 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 745 746 auto processFn = [&](llvm::RecordKeeper &records) { 747 processTdIncludeRecords(records, decls); 748 749 // After we are done processing, move all of the tablegen source buffers to 750 // the main parser source mgr. This allows for directly using source 751 // locations from the .td files without needing to remap them. 752 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr, fileLoc.End); 753 return false; 754 }; 755 if (llvm::TableGenParseFile(std::move(*includeBuffer), 756 parserSrcMgr.getIncludeDirs(), processFn)) 757 return failure(); 758 759 return success(); 760 } 761 762 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 763 SmallVectorImpl<ast::Decl *> &decls) { 764 // Return the length kind of the given value. 765 auto getLengthKind = [](const auto &value) { 766 if (value.isOptional()) 767 return ods::VariableLengthKind::Optional; 768 return value.isVariadic() ? ods::VariableLengthKind::Variadic 769 : ods::VariableLengthKind::Single; 770 }; 771 772 // Insert a type constraint into the ODS context. 773 ods::Context &odsContext = ctx.getODSContext(); 774 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 775 -> const ods::TypeConstraint & { 776 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 777 cst.constraint.getSummary(), 778 cst.constraint.getCPPClassName()); 779 }; 780 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 781 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 782 }; 783 784 // Process the parsed tablegen records to build ODS information. 785 /// Operations. 786 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 787 tblgen::Operator op(def); 788 789 bool inserted = false; 790 ods::Operation *odsOp = nullptr; 791 std::tie(odsOp, inserted) = 792 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 793 op.getDescription(), op.getLoc().front()); 794 795 // Ignore operations that have already been added. 796 if (!inserted) 797 continue; 798 799 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 800 odsOp->appendAttribute( 801 attr.name, attr.attr.isOptional(), 802 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 803 attr.attr.getSummary(), 804 attr.attr.getStorageType())); 805 } 806 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 807 odsOp->appendOperand(operand.name, getLengthKind(operand), 808 addTypeConstraint(operand)); 809 } 810 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 811 odsOp->appendResult(result.name, getLengthKind(result), 812 addTypeConstraint(result)); 813 } 814 } 815 /// Attr constraints. 816 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 817 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 818 decls.push_back( 819 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 820 tblgen::AttrConstraint(def), 821 convertLocToRange(def->getLoc().front()), attrTy)); 822 } 823 } 824 /// Type constraints. 825 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 826 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 827 decls.push_back( 828 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 829 tblgen::TypeConstraint(def), 830 convertLocToRange(def->getLoc().front()), typeTy)); 831 } 832 } 833 /// Interfaces. 834 ast::Type opTy = ast::OperationType::get(ctx); 835 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 836 StringRef name = def->getName(); 837 if (def->isAnonymous() || curDeclScope->lookup(name) || 838 def->isSubClassOf("DeclareInterfaceMethods")) 839 continue; 840 SMRange loc = convertLocToRange(def->getLoc().front()); 841 842 StringRef className = def->getValueAsString("cppClassName"); 843 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 844 std::string codeBlock = 845 llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) 846 .str(); 847 848 if (def->isSubClassOf("OpInterface")) { 849 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 850 name, codeBlock, loc, opTy)); 851 } else if (def->isSubClassOf("AttrInterface")) { 852 decls.push_back( 853 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 854 name, codeBlock, loc, attrTy)); 855 } else if (def->isSubClassOf("TypeInterface")) { 856 decls.push_back( 857 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 858 name, codeBlock, loc, typeTy)); 859 } 860 } 861 } 862 863 template <typename ConstraintT> 864 ast::Decl * 865 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 866 SMRange loc, ast::Type type) { 867 // Build the single input parameter. 868 ast::DeclScope *argScope = pushDeclScope(); 869 auto *paramVar = ast::VariableDecl::create( 870 ctx, ast::Name::create(ctx, "self", loc), type, 871 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 872 argScope->add(paramVar); 873 popDeclScope(); 874 875 // Build the native constraint. 876 auto *constraintDecl = ast::UserConstraintDecl::createNative( 877 ctx, ast::Name::create(ctx, name, loc), paramVar, 878 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 879 curDeclScope->add(constraintDecl); 880 return constraintDecl; 881 } 882 883 template <typename ConstraintT> 884 ast::Decl * 885 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 886 SMRange loc, ast::Type type) { 887 // Format the condition template. 888 tblgen::FmtContext fmtContext; 889 fmtContext.withSelf("self"); 890 std::string codeBlock = 891 tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); 892 893 return createODSNativePDLLConstraintDecl<ConstraintT>( 894 constraint.getUniqueDefName(), codeBlock, loc, type); 895 } 896 897 //===----------------------------------------------------------------------===// 898 // Decls 899 900 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 901 FailureOr<ast::Decl *> decl; 902 switch (curToken.getKind()) { 903 case Token::kw_Constraint: 904 decl = parseUserConstraintDecl(); 905 break; 906 case Token::kw_Pattern: 907 decl = parsePatternDecl(); 908 break; 909 case Token::kw_Rewrite: 910 decl = parseUserRewriteDecl(); 911 break; 912 default: 913 return emitError("expected top-level declaration, such as a `Pattern`"); 914 } 915 if (failed(decl)) 916 return failure(); 917 918 // If the decl has a name, add it to the current scope. 919 if (const ast::Name *name = (*decl)->getName()) { 920 if (failed(checkDefineNamedDecl(*name))) 921 return failure(); 922 curDeclScope->add(*decl); 923 } 924 return decl; 925 } 926 927 FailureOr<ast::NamedAttributeDecl *> 928 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 929 // Check for name code completion. 930 if (curToken.is(Token::code_complete)) 931 return codeCompleteAttributeName(parentOpName); 932 933 std::string attrNameStr; 934 if (curToken.isString()) 935 attrNameStr = curToken.getStringValue(); 936 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 937 attrNameStr = curToken.getSpelling().str(); 938 else 939 return emitError("expected identifier or string attribute name"); 940 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 941 consumeToken(); 942 943 // Check for a value of the attribute. 944 ast::Expr *attrValue = nullptr; 945 if (consumeIf(Token::equal)) { 946 FailureOr<ast::Expr *> attrExpr = parseExpr(); 947 if (failed(attrExpr)) 948 return failure(); 949 attrValue = *attrExpr; 950 } else { 951 // If there isn't a concrete value, create an expression representing a 952 // UnitAttr. 953 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 954 } 955 956 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 957 } 958 959 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 960 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 961 bool expectTerminalSemicolon) { 962 consumeToken(Token::equal_arrow); 963 964 // Parse the single statement of the lambda body. 965 SMLoc bodyStartLoc = curToken.getStartLoc(); 966 pushDeclScope(); 967 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 968 bool failedToParse = 969 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 970 popDeclScope(); 971 if (failedToParse) 972 return failure(); 973 974 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 975 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 976 } 977 978 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 979 // Ensure that the argument is named. 980 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 981 return emitError("expected identifier argument name"); 982 983 // Parse the argument similarly to a normal variable. 984 StringRef name = curToken.getSpelling(); 985 SMRange nameLoc = curToken.getLoc(); 986 consumeToken(); 987 988 if (failed( 989 parseToken(Token::colon, "expected `:` before argument constraint"))) 990 return failure(); 991 992 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 993 if (failed(cst)) 994 return failure(); 995 996 return createArgOrResultVariableDecl(name, nameLoc, *cst); 997 } 998 999 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1000 // Check to see if this result is named. 1001 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1002 // Check to see if this name actually refers to a Constraint. 1003 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1004 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1005 // If yes, and this is a Rewrite, give a nice error message as non-Core 1006 // constraints are not supported on Rewrite results. 1007 if (parserContext == ParserContext::Rewrite) { 1008 return emitError( 1009 "`Rewrite` results are only permitted to use core constraints, " 1010 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1011 } 1012 1013 // Otherwise, parse this as an unnamed result variable. 1014 } else { 1015 // If it wasn't a constraint, parse the result similarly to a variable. If 1016 // there is already an existing decl, we will emit an error when defining 1017 // this variable later. 1018 StringRef name = curToken.getSpelling(); 1019 SMRange nameLoc = curToken.getLoc(); 1020 consumeToken(); 1021 1022 if (failed(parseToken(Token::colon, 1023 "expected `:` before result constraint"))) 1024 return failure(); 1025 1026 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1027 if (failed(cst)) 1028 return failure(); 1029 1030 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1031 } 1032 } 1033 1034 // If it isn't named, we parse the constraint directly and create an unnamed 1035 // result variable. 1036 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1037 if (failed(cst)) 1038 return failure(); 1039 1040 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1041 } 1042 1043 FailureOr<ast::UserConstraintDecl *> 1044 Parser::parseUserConstraintDecl(bool isInline) { 1045 // Constraints and rewrites have very similar formats, dispatch to a shared 1046 // interface for parsing. 1047 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1048 [&](auto &&...args) { 1049 return this->parseUserPDLLConstraintDecl(args...); 1050 }, 1051 ParserContext::Constraint, "constraint", isInline); 1052 } 1053 1054 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1055 FailureOr<ast::UserConstraintDecl *> decl = 1056 parseUserConstraintDecl(/*isInline=*/true); 1057 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1058 return failure(); 1059 1060 curDeclScope->add(*decl); 1061 return decl; 1062 } 1063 1064 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1065 const ast::Name &name, bool isInline, 1066 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1067 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1068 // Push the argument scope back onto the list, so that the body can 1069 // reference arguments. 1070 pushDeclScope(argumentScope); 1071 1072 // Parse the body of the constraint. The body is either defined as a compound 1073 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1074 ast::CompoundStmt *body; 1075 if (curToken.is(Token::equal_arrow)) { 1076 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1077 [&](ast::Stmt *&stmt) -> LogicalResult { 1078 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1079 if (!stmtExpr) { 1080 return emitError(stmt->getLoc(), 1081 "expected `Constraint` lambda body to contain a " 1082 "single expression"); 1083 } 1084 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1085 return success(); 1086 }, 1087 /*expectTerminalSemicolon=*/!isInline); 1088 if (failed(bodyResult)) 1089 return failure(); 1090 body = *bodyResult; 1091 } else { 1092 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1093 if (failed(bodyResult)) 1094 return failure(); 1095 body = *bodyResult; 1096 1097 // Verify the structure of the body. 1098 auto bodyIt = body->begin(), bodyE = body->end(); 1099 for (; bodyIt != bodyE; ++bodyIt) 1100 if (isa<ast::ReturnStmt>(*bodyIt)) 1101 break; 1102 if (failed(validateUserConstraintOrRewriteReturn( 1103 "Constraint", body, bodyIt, bodyE, results, resultType))) 1104 return failure(); 1105 } 1106 popDeclScope(); 1107 1108 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1109 name, arguments, results, resultType, body); 1110 } 1111 1112 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1113 // Constraints and rewrites have very similar formats, dispatch to a shared 1114 // interface for parsing. 1115 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1116 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1117 ParserContext::Rewrite, "rewrite", isInline); 1118 } 1119 1120 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1121 FailureOr<ast::UserRewriteDecl *> decl = 1122 parseUserRewriteDecl(/*isInline=*/true); 1123 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1124 return failure(); 1125 1126 curDeclScope->add(*decl); 1127 return decl; 1128 } 1129 1130 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1131 const ast::Name &name, bool isInline, 1132 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1133 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1134 // Push the argument scope back onto the list, so that the body can 1135 // reference arguments. 1136 curDeclScope = argumentScope; 1137 ast::CompoundStmt *body; 1138 if (curToken.is(Token::equal_arrow)) { 1139 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1140 [&](ast::Stmt *&statement) -> LogicalResult { 1141 if (isa<ast::OpRewriteStmt>(statement)) 1142 return success(); 1143 1144 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1145 if (!statementExpr) { 1146 return emitError( 1147 statement->getLoc(), 1148 "expected `Rewrite` lambda body to contain a single expression " 1149 "or an operation rewrite statement; such as `erase`, " 1150 "`replace`, or `rewrite`"); 1151 } 1152 statement = 1153 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1154 return success(); 1155 }, 1156 /*expectTerminalSemicolon=*/!isInline); 1157 if (failed(bodyResult)) 1158 return failure(); 1159 body = *bodyResult; 1160 } else { 1161 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1162 if (failed(bodyResult)) 1163 return failure(); 1164 body = *bodyResult; 1165 } 1166 popDeclScope(); 1167 1168 // Verify the structure of the body. 1169 auto bodyIt = body->begin(), bodyE = body->end(); 1170 for (; bodyIt != bodyE; ++bodyIt) 1171 if (isa<ast::ReturnStmt>(*bodyIt)) 1172 break; 1173 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1174 bodyE, results, resultType))) 1175 return failure(); 1176 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1177 name, arguments, results, resultType, body); 1178 } 1179 1180 template <typename T, typename ParseUserPDLLDeclFnT> 1181 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1182 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1183 StringRef anonymousNamePrefix, bool isInline) { 1184 SMRange loc = curToken.getLoc(); 1185 consumeToken(); 1186 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1187 1188 // Parse the name of the decl. 1189 const ast::Name *name = nullptr; 1190 if (curToken.isNot(Token::identifier)) { 1191 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1192 // in C++, so being unnamed is fine. 1193 if (!isInline) 1194 return emitError("expected identifier name"); 1195 1196 // Create a unique anonymous name to use, as the name for this decl is not 1197 // important. 1198 std::string anonName = 1199 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1200 anonymousDeclNameCounter++) 1201 .str(); 1202 name = &ast::Name::create(ctx, anonName, loc); 1203 } else { 1204 // If a name was provided, we can use it directly. 1205 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1206 consumeToken(Token::identifier); 1207 } 1208 1209 // Parse the functional signature of the decl. 1210 SmallVector<ast::VariableDecl *> arguments, results; 1211 ast::DeclScope *argumentScope; 1212 ast::Type resultType; 1213 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1214 argumentScope, resultType))) 1215 return failure(); 1216 1217 // Check to see which type of constraint this is. If the constraint contains a 1218 // compound body, this is a PDLL decl. 1219 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1220 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1221 resultType); 1222 1223 // Otherwise, this is a native decl. 1224 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1225 results, resultType); 1226 } 1227 1228 template <typename T> 1229 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1230 const ast::Name &name, bool isInline, 1231 ArrayRef<ast::VariableDecl *> arguments, 1232 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1233 // If followed by a string, the native code body has also been specified. 1234 std::string codeStrStorage; 1235 Optional<StringRef> optCodeStr; 1236 if (curToken.isString()) { 1237 codeStrStorage = curToken.getStringValue(); 1238 optCodeStr = codeStrStorage; 1239 consumeToken(); 1240 } else if (isInline) { 1241 return emitError(name.getLoc(), 1242 "external declarations must be declared in global scope"); 1243 } 1244 if (failed(parseToken(Token::semicolon, 1245 "expected `;` after native declaration"))) 1246 return failure(); 1247 // TODO: PDL should be able to support constraint results in certain 1248 // situations, we should revise this. 1249 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1250 return emitError( 1251 "native Constraints currently do not support returning results"); 1252 } 1253 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1254 } 1255 1256 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1257 SmallVectorImpl<ast::VariableDecl *> &arguments, 1258 SmallVectorImpl<ast::VariableDecl *> &results, 1259 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1260 // Parse the argument list of the decl. 1261 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1262 return failure(); 1263 1264 argumentScope = pushDeclScope(); 1265 if (curToken.isNot(Token::r_paren)) { 1266 do { 1267 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1268 if (failed(argument)) 1269 return failure(); 1270 arguments.emplace_back(*argument); 1271 } while (consumeIf(Token::comma)); 1272 } 1273 popDeclScope(); 1274 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1275 return failure(); 1276 1277 // Parse the results of the decl. 1278 pushDeclScope(); 1279 if (consumeIf(Token::arrow)) { 1280 auto parseResultFn = [&]() -> LogicalResult { 1281 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1282 if (failed(result)) 1283 return failure(); 1284 results.emplace_back(*result); 1285 return success(); 1286 }; 1287 1288 // Check for a list of results. 1289 if (consumeIf(Token::l_paren)) { 1290 do { 1291 if (failed(parseResultFn())) 1292 return failure(); 1293 } while (consumeIf(Token::comma)); 1294 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1295 return failure(); 1296 1297 // Otherwise, there is only one result. 1298 } else if (failed(parseResultFn())) { 1299 return failure(); 1300 } 1301 } 1302 popDeclScope(); 1303 1304 // Compute the result type of the decl. 1305 resultType = createUserConstraintRewriteResultType(results); 1306 1307 // Verify that results are only named if there are more than one. 1308 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1309 return emitError( 1310 results.front()->getLoc(), 1311 "cannot create a single-element tuple with an element label"); 1312 } 1313 return success(); 1314 } 1315 1316 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1317 StringRef declType, ast::CompoundStmt *body, 1318 ArrayRef<ast::Stmt *>::iterator bodyIt, 1319 ArrayRef<ast::Stmt *>::iterator bodyE, 1320 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1321 // Handle if a `return` was provided. 1322 if (bodyIt != bodyE) { 1323 // Emit an error if we have trailing statements after the return. 1324 if (std::next(bodyIt) != bodyE) { 1325 return emitError( 1326 (*std::next(bodyIt))->getLoc(), 1327 llvm::formatv("`return` terminated the `{0}` body, but found " 1328 "trailing statements afterwards", 1329 declType)); 1330 } 1331 1332 // Otherwise if a return wasn't provided, check that no results are 1333 // expected. 1334 } else if (!results.empty()) { 1335 return emitError( 1336 {body->getLoc().End, body->getLoc().End}, 1337 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1338 declType, resultType)); 1339 } 1340 return success(); 1341 } 1342 1343 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1344 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1345 if (isa<ast::OpRewriteStmt>(statement)) 1346 return success(); 1347 return emitError( 1348 statement->getLoc(), 1349 "expected Pattern lambda body to contain a single operation " 1350 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1351 }); 1352 } 1353 1354 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1355 SMRange loc = curToken.getLoc(); 1356 consumeToken(Token::kw_Pattern); 1357 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1358 ParserContext::PatternMatch); 1359 1360 // Check for an optional identifier for the pattern name. 1361 const ast::Name *name = nullptr; 1362 if (curToken.is(Token::identifier)) { 1363 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1364 consumeToken(Token::identifier); 1365 } 1366 1367 // Parse any pattern metadata. 1368 ParsedPatternMetadata metadata; 1369 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1370 return failure(); 1371 1372 // Parse the pattern body. 1373 ast::CompoundStmt *body; 1374 1375 // Handle a lambda body. 1376 if (curToken.is(Token::equal_arrow)) { 1377 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1378 if (failed(bodyResult)) 1379 return failure(); 1380 body = *bodyResult; 1381 } else { 1382 if (curToken.isNot(Token::l_brace)) 1383 return emitError("expected `{` or `=>` to start pattern body"); 1384 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1385 if (failed(bodyResult)) 1386 return failure(); 1387 body = *bodyResult; 1388 1389 // Verify the body of the pattern. 1390 auto bodyIt = body->begin(), bodyE = body->end(); 1391 for (; bodyIt != bodyE; ++bodyIt) { 1392 if (isa<ast::ReturnStmt>(*bodyIt)) { 1393 return emitError((*bodyIt)->getLoc(), 1394 "`return` statements are only permitted within a " 1395 "`Constraint` or `Rewrite` body"); 1396 } 1397 // Break when we've found the rewrite statement. 1398 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1399 break; 1400 } 1401 if (bodyIt == bodyE) { 1402 return emitError(loc, 1403 "expected Pattern body to terminate with an operation " 1404 "rewrite statement, such as `erase`"); 1405 } 1406 if (std::next(bodyIt) != bodyE) { 1407 return emitError((*std::next(bodyIt))->getLoc(), 1408 "Pattern body was terminated by an operation " 1409 "rewrite statement, but found trailing statements"); 1410 } 1411 } 1412 1413 return createPatternDecl(loc, name, metadata, body); 1414 } 1415 1416 LogicalResult 1417 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1418 Optional<SMRange> benefitLoc; 1419 Optional<SMRange> hasBoundedRecursionLoc; 1420 1421 do { 1422 // Handle metadata code completion. 1423 if (curToken.is(Token::code_complete)) 1424 return codeCompletePatternMetadata(); 1425 1426 if (curToken.isNot(Token::identifier)) 1427 return emitError("expected pattern metadata identifier"); 1428 StringRef metadataStr = curToken.getSpelling(); 1429 SMRange metadataLoc = curToken.getLoc(); 1430 consumeToken(Token::identifier); 1431 1432 // Parse the benefit metadata: benefit(<integer-value>) 1433 if (metadataStr == "benefit") { 1434 if (benefitLoc) { 1435 return emitErrorAndNote(metadataLoc, 1436 "pattern benefit has already been specified", 1437 *benefitLoc, "see previous definition here"); 1438 } 1439 if (failed(parseToken(Token::l_paren, 1440 "expected `(` before pattern benefit"))) 1441 return failure(); 1442 1443 uint16_t benefitValue = 0; 1444 if (curToken.isNot(Token::integer)) 1445 return emitError("expected integral pattern benefit"); 1446 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1447 return emitError( 1448 "expected pattern benefit to fit within a 16-bit integer"); 1449 consumeToken(Token::integer); 1450 1451 metadata.benefit = benefitValue; 1452 benefitLoc = metadataLoc; 1453 1454 if (failed( 1455 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1456 return failure(); 1457 continue; 1458 } 1459 1460 // Parse the bounded recursion metadata: recursion 1461 if (metadataStr == "recursion") { 1462 if (hasBoundedRecursionLoc) { 1463 return emitErrorAndNote( 1464 metadataLoc, 1465 "pattern recursion metadata has already been specified", 1466 *hasBoundedRecursionLoc, "see previous definition here"); 1467 } 1468 metadata.hasBoundedRecursion = true; 1469 hasBoundedRecursionLoc = metadataLoc; 1470 continue; 1471 } 1472 1473 return emitError(metadataLoc, "unknown pattern metadata"); 1474 } while (consumeIf(Token::comma)); 1475 1476 return success(); 1477 } 1478 1479 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1480 consumeToken(Token::less); 1481 1482 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1483 if (failed(typeExpr) || 1484 failed(parseToken(Token::greater, 1485 "expected `>` after variable type constraint"))) 1486 return failure(); 1487 return typeExpr; 1488 } 1489 1490 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1491 assert(curDeclScope && "defining decl outside of a decl scope"); 1492 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1493 return emitErrorAndNote( 1494 name.getLoc(), "`" + name.getName() + "` has already been defined", 1495 lastDecl->getName()->getLoc(), "see previous definition here"); 1496 } 1497 return success(); 1498 } 1499 1500 FailureOr<ast::VariableDecl *> 1501 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1502 ast::Expr *initExpr, 1503 ArrayRef<ast::ConstraintRef> constraints) { 1504 assert(curDeclScope && "defining variable outside of decl scope"); 1505 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1506 1507 // If the name of the variable indicates a special variable, we don't add it 1508 // to the scope. This variable is local to the definition point. 1509 if (name.empty() || name == "_") { 1510 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1511 constraints); 1512 } 1513 if (failed(checkDefineNamedDecl(nameDecl))) 1514 return failure(); 1515 1516 auto *varDecl = 1517 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1518 curDeclScope->add(varDecl); 1519 return varDecl; 1520 } 1521 1522 FailureOr<ast::VariableDecl *> 1523 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1524 ArrayRef<ast::ConstraintRef> constraints) { 1525 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1526 constraints); 1527 } 1528 1529 LogicalResult Parser::parseVariableDeclConstraintList( 1530 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1531 Optional<SMRange> typeConstraint; 1532 auto parseSingleConstraint = [&] { 1533 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1534 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1535 /*allowNonCoreConstraints=*/true); 1536 if (failed(constraint)) 1537 return failure(); 1538 constraints.push_back(*constraint); 1539 return success(); 1540 }; 1541 1542 // Check to see if this is a single constraint, or a list. 1543 if (!consumeIf(Token::l_square)) 1544 return parseSingleConstraint(); 1545 1546 do { 1547 if (failed(parseSingleConstraint())) 1548 return failure(); 1549 } while (consumeIf(Token::comma)); 1550 return parseToken(Token::r_square, "expected `]` after constraint list"); 1551 } 1552 1553 FailureOr<ast::ConstraintRef> 1554 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1555 ArrayRef<ast::ConstraintRef> existingConstraints, 1556 bool allowInlineTypeConstraints, 1557 bool allowNonCoreConstraints) { 1558 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1559 if (!allowInlineTypeConstraints) { 1560 return emitError( 1561 curToken.getLoc(), 1562 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1563 "permitted on arguments or results"); 1564 } 1565 if (typeConstraint) 1566 return emitErrorAndNote( 1567 curToken.getLoc(), 1568 "the type of this variable has already been constrained", 1569 *typeConstraint, "see previous constraint location here"); 1570 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1571 if (failed(constraintExpr)) 1572 return failure(); 1573 typeExpr = *constraintExpr; 1574 typeConstraint = typeExpr->getLoc(); 1575 return success(); 1576 }; 1577 1578 SMRange loc = curToken.getLoc(); 1579 switch (curToken.getKind()) { 1580 case Token::kw_Attr: { 1581 consumeToken(Token::kw_Attr); 1582 1583 // Check for a type constraint. 1584 ast::Expr *typeExpr = nullptr; 1585 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1586 return failure(); 1587 return ast::ConstraintRef( 1588 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1589 } 1590 case Token::kw_Op: { 1591 consumeToken(Token::kw_Op); 1592 1593 // Parse an optional operation name. If the name isn't provided, this refers 1594 // to "any" operation. 1595 FailureOr<ast::OpNameDecl *> opName = 1596 parseWrappedOperationName(/*allowEmptyName=*/true); 1597 if (failed(opName)) 1598 return failure(); 1599 1600 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1601 loc); 1602 } 1603 case Token::kw_Type: 1604 consumeToken(Token::kw_Type); 1605 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1606 case Token::kw_TypeRange: 1607 consumeToken(Token::kw_TypeRange); 1608 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1609 loc); 1610 case Token::kw_Value: { 1611 consumeToken(Token::kw_Value); 1612 1613 // Check for a type constraint. 1614 ast::Expr *typeExpr = nullptr; 1615 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1616 return failure(); 1617 1618 return ast::ConstraintRef( 1619 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1620 } 1621 case Token::kw_ValueRange: { 1622 consumeToken(Token::kw_ValueRange); 1623 1624 // Check for a type constraint. 1625 ast::Expr *typeExpr = nullptr; 1626 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1627 return failure(); 1628 1629 return ast::ConstraintRef( 1630 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1631 } 1632 1633 case Token::kw_Constraint: { 1634 // Handle an inline constraint. 1635 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1636 if (failed(decl)) 1637 return failure(); 1638 return ast::ConstraintRef(*decl, loc); 1639 } 1640 case Token::identifier: { 1641 StringRef constraintName = curToken.getSpelling(); 1642 consumeToken(Token::identifier); 1643 1644 // Lookup the referenced constraint. 1645 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1646 if (!cstDecl) { 1647 return emitError(loc, "unknown reference to constraint `" + 1648 constraintName + "`"); 1649 } 1650 1651 // Handle a reference to a proper constraint. 1652 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1653 return ast::ConstraintRef(cst, loc); 1654 1655 return emitErrorAndNote( 1656 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1657 "see the definition of `" + constraintName + "` here"); 1658 } 1659 // Handle single entity constraint code completion. 1660 case Token::code_complete: { 1661 // Try to infer the current type for use by code completion. 1662 ast::Type inferredType; 1663 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1664 allowNonCoreConstraints))) 1665 return failure(); 1666 1667 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1668 allowInlineTypeConstraints); 1669 } 1670 default: 1671 break; 1672 } 1673 return emitError(loc, "expected identifier constraint"); 1674 } 1675 1676 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1677 // Constraint arguments may apply more complex constraints via the arguments. 1678 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1679 1680 Optional<SMRange> typeConstraint; 1681 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1682 /*allowInlineTypeConstraints=*/false, 1683 allowNonCoreConstraints); 1684 } 1685 1686 //===----------------------------------------------------------------------===// 1687 // Exprs 1688 1689 FailureOr<ast::Expr *> Parser::parseExpr() { 1690 if (curToken.is(Token::underscore)) 1691 return parseUnderscoreExpr(); 1692 1693 // Parse the LHS expression. 1694 FailureOr<ast::Expr *> lhsExpr; 1695 switch (curToken.getKind()) { 1696 case Token::kw_attr: 1697 lhsExpr = parseAttributeExpr(); 1698 break; 1699 case Token::kw_Constraint: 1700 lhsExpr = parseInlineConstraintLambdaExpr(); 1701 break; 1702 case Token::identifier: 1703 lhsExpr = parseIdentifierExpr(); 1704 break; 1705 case Token::kw_op: 1706 lhsExpr = parseOperationExpr(); 1707 break; 1708 case Token::kw_Rewrite: 1709 lhsExpr = parseInlineRewriteLambdaExpr(); 1710 break; 1711 case Token::kw_type: 1712 lhsExpr = parseTypeExpr(); 1713 break; 1714 case Token::l_paren: 1715 lhsExpr = parseTupleExpr(); 1716 break; 1717 default: 1718 return emitError("expected expression"); 1719 } 1720 if (failed(lhsExpr)) 1721 return failure(); 1722 1723 // Check for an operator expression. 1724 while (true) { 1725 switch (curToken.getKind()) { 1726 case Token::dot: 1727 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1728 break; 1729 case Token::l_paren: 1730 lhsExpr = parseCallExpr(*lhsExpr); 1731 break; 1732 default: 1733 return lhsExpr; 1734 } 1735 if (failed(lhsExpr)) 1736 return failure(); 1737 } 1738 } 1739 1740 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1741 SMRange loc = curToken.getLoc(); 1742 consumeToken(Token::kw_attr); 1743 1744 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1745 // identifier. 1746 if (!consumeIf(Token::less)) { 1747 resetToken(loc); 1748 return parseIdentifierExpr(); 1749 } 1750 1751 if (!curToken.isString()) 1752 return emitError("expected string literal containing MLIR attribute"); 1753 std::string attrExpr = curToken.getStringValue(); 1754 consumeToken(); 1755 1756 if (failed( 1757 parseToken(Token::greater, "expected `>` after attribute literal"))) 1758 return failure(); 1759 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1760 } 1761 1762 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1763 SMRange loc = curToken.getLoc(); 1764 consumeToken(Token::l_paren); 1765 1766 // Parse the arguments of the call. 1767 SmallVector<ast::Expr *> arguments; 1768 if (curToken.isNot(Token::r_paren)) { 1769 do { 1770 // Handle code completion for the call arguments. 1771 if (curToken.is(Token::code_complete)) { 1772 codeCompleteCallSignature(parentExpr, arguments.size()); 1773 return failure(); 1774 } 1775 1776 FailureOr<ast::Expr *> argument = parseExpr(); 1777 if (failed(argument)) 1778 return failure(); 1779 arguments.push_back(*argument); 1780 } while (consumeIf(Token::comma)); 1781 } 1782 loc.End = curToken.getEndLoc(); 1783 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1784 return failure(); 1785 1786 return createCallExpr(loc, parentExpr, arguments); 1787 } 1788 1789 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1790 ast::Decl *decl = curDeclScope->lookup(name); 1791 if (!decl) 1792 return emitError(loc, "undefined reference to `" + name + "`"); 1793 1794 return createDeclRefExpr(loc, decl); 1795 } 1796 1797 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1798 StringRef name = curToken.getSpelling(); 1799 SMRange nameLoc = curToken.getLoc(); 1800 consumeToken(); 1801 1802 // Check to see if this is a decl ref expression that defines a variable 1803 // inline. 1804 if (consumeIf(Token::colon)) { 1805 SmallVector<ast::ConstraintRef> constraints; 1806 if (failed(parseVariableDeclConstraintList(constraints))) 1807 return failure(); 1808 ast::Type type; 1809 if (failed(validateVariableConstraints(constraints, type))) 1810 return failure(); 1811 return createInlineVariableExpr(type, name, nameLoc, constraints); 1812 } 1813 1814 return parseDeclRefExpr(name, nameLoc); 1815 } 1816 1817 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1818 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1819 if (failed(decl)) 1820 return failure(); 1821 1822 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1823 ast::ConstraintType::get(ctx)); 1824 } 1825 1826 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1827 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1828 if (failed(decl)) 1829 return failure(); 1830 1831 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1832 ast::RewriteType::get(ctx)); 1833 } 1834 1835 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1836 SMRange loc = curToken.getLoc(); 1837 consumeToken(Token::dot); 1838 1839 // Check for code completion of the member name. 1840 if (curToken.is(Token::code_complete)) 1841 return codeCompleteMemberAccess(parentExpr); 1842 1843 // Parse the member name. 1844 Token memberNameTok = curToken; 1845 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1846 !memberNameTok.isKeyword()) 1847 return emitError(loc, "expected identifier or numeric member name"); 1848 StringRef memberName = memberNameTok.getSpelling(); 1849 consumeToken(); 1850 1851 return createMemberAccessExpr(parentExpr, memberName, loc); 1852 } 1853 1854 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1855 SMRange loc = curToken.getLoc(); 1856 1857 // Check for code completion for the dialect name. 1858 if (curToken.is(Token::code_complete)) 1859 return codeCompleteDialectName(); 1860 1861 // Handle the case of an no operation name. 1862 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1863 if (allowEmptyName) 1864 return ast::OpNameDecl::create(ctx, SMRange()); 1865 return emitError("expected dialect namespace"); 1866 } 1867 StringRef name = curToken.getSpelling(); 1868 consumeToken(); 1869 1870 // Otherwise, this is a literal operation name. 1871 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1872 return failure(); 1873 1874 // Check for code completion for the operation name. 1875 if (curToken.is(Token::code_complete)) 1876 return codeCompleteOperationName(name); 1877 1878 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1879 return emitError("expected operation name after dialect namespace"); 1880 1881 name = StringRef(name.data(), name.size() + 1); 1882 do { 1883 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1884 loc.End = curToken.getEndLoc(); 1885 consumeToken(); 1886 } while (curToken.isAny(Token::identifier, Token::dot) || 1887 curToken.isKeyword()); 1888 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1889 } 1890 1891 FailureOr<ast::OpNameDecl *> 1892 Parser::parseWrappedOperationName(bool allowEmptyName) { 1893 if (!consumeIf(Token::less)) 1894 return ast::OpNameDecl::create(ctx, SMRange()); 1895 1896 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1897 if (failed(opNameDecl)) 1898 return failure(); 1899 1900 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1901 return failure(); 1902 return opNameDecl; 1903 } 1904 1905 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1906 SMRange loc = curToken.getLoc(); 1907 consumeToken(Token::kw_op); 1908 1909 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1910 // identifier. 1911 if (curToken.isNot(Token::less)) { 1912 resetToken(loc); 1913 return parseIdentifierExpr(); 1914 } 1915 1916 // Parse the operation name. The name may be elided, in which case the 1917 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1918 // `Operation*`). Operation names within a rewrite context must be named. 1919 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1920 FailureOr<ast::OpNameDecl *> opNameDecl = 1921 parseWrappedOperationName(allowEmptyName); 1922 if (failed(opNameDecl)) 1923 return failure(); 1924 Optional<StringRef> opName = (*opNameDecl)->getName(); 1925 1926 // Functor used to create an implicit range variable, used for implicit "all" 1927 // operand or results variables. 1928 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1929 FailureOr<ast::VariableDecl *> rangeVar = 1930 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1931 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1932 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1933 }; 1934 1935 // Check for the optional list of operands. 1936 SmallVector<ast::Expr *> operands; 1937 if (!consumeIf(Token::l_paren)) { 1938 // If the operand list isn't specified and we are in a match context, define 1939 // an inplace unconstrained operand range corresponding to all of the 1940 // operands of the operation. This avoids treating zero operands the same 1941 // way as "unconstrained operands". 1942 if (parserContext != ParserContext::Rewrite) { 1943 operands.push_back(createImplicitRangeVar( 1944 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1945 } 1946 } else if (!consumeIf(Token::r_paren)) { 1947 // Check for operand signature code completion. 1948 if (curToken.is(Token::code_complete)) { 1949 codeCompleteOperationOperandsSignature(opName, operands.size()); 1950 return failure(); 1951 } 1952 1953 // If the operand list was specified and non-empty, parse the operands. 1954 do { 1955 FailureOr<ast::Expr *> operand = parseExpr(); 1956 if (failed(operand)) 1957 return failure(); 1958 operands.push_back(*operand); 1959 } while (consumeIf(Token::comma)); 1960 1961 if (failed(parseToken(Token::r_paren, 1962 "expected `)` after operation operand list"))) 1963 return failure(); 1964 } 1965 1966 // Check for the optional list of attributes. 1967 SmallVector<ast::NamedAttributeDecl *> attributes; 1968 if (consumeIf(Token::l_brace)) { 1969 do { 1970 FailureOr<ast::NamedAttributeDecl *> decl = 1971 parseNamedAttributeDecl(opName); 1972 if (failed(decl)) 1973 return failure(); 1974 attributes.emplace_back(*decl); 1975 } while (consumeIf(Token::comma)); 1976 1977 if (failed(parseToken(Token::r_brace, 1978 "expected `}` after operation attribute list"))) 1979 return failure(); 1980 } 1981 1982 // Check for the optional list of result types. 1983 SmallVector<ast::Expr *> resultTypes; 1984 if (consumeIf(Token::arrow)) { 1985 if (failed(parseToken(Token::l_paren, 1986 "expected `(` before operation result type list"))) 1987 return failure(); 1988 1989 // Handle the case of an empty result list. 1990 if (!consumeIf(Token::r_paren)) { 1991 do { 1992 // Check for result signature code completion. 1993 if (curToken.is(Token::code_complete)) { 1994 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 1995 return failure(); 1996 } 1997 1998 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1999 if (failed(resultTypeExpr)) 2000 return failure(); 2001 resultTypes.push_back(*resultTypeExpr); 2002 } while (consumeIf(Token::comma)); 2003 2004 if (failed(parseToken(Token::r_paren, 2005 "expected `)` after operation result type list"))) 2006 return failure(); 2007 } 2008 } else if (parserContext != ParserContext::Rewrite) { 2009 // If the result list isn't specified and we are in a match context, define 2010 // an inplace unconstrained result range corresponding to all of the results 2011 // of the operation. This avoids treating zero results the same way as 2012 // "unconstrained results". 2013 resultTypes.push_back(createImplicitRangeVar( 2014 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2015 } 2016 2017 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2018 resultTypes); 2019 } 2020 2021 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2022 SMRange loc = curToken.getLoc(); 2023 consumeToken(Token::l_paren); 2024 2025 DenseMap<StringRef, SMRange> usedNames; 2026 SmallVector<StringRef> elementNames; 2027 SmallVector<ast::Expr *> elements; 2028 if (curToken.isNot(Token::r_paren)) { 2029 do { 2030 // Check for the optional element name assignment before the value. 2031 StringRef elementName; 2032 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2033 Token elementNameTok = curToken; 2034 consumeToken(); 2035 2036 // The element name is only present if followed by an `=`. 2037 if (consumeIf(Token::equal)) { 2038 elementName = elementNameTok.getSpelling(); 2039 2040 // Check to see if this name is already used. 2041 auto elementNameIt = 2042 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2043 if (!elementNameIt.second) { 2044 return emitErrorAndNote( 2045 elementNameTok.getLoc(), 2046 llvm::formatv("duplicate tuple element label `{0}`", 2047 elementName), 2048 elementNameIt.first->getSecond(), 2049 "see previous label use here"); 2050 } 2051 } else { 2052 // Otherwise, we treat this as part of an expression so reset the 2053 // lexer. 2054 resetToken(elementNameTok.getLoc()); 2055 } 2056 } 2057 elementNames.push_back(elementName); 2058 2059 // Parse the tuple element value. 2060 FailureOr<ast::Expr *> element = parseExpr(); 2061 if (failed(element)) 2062 return failure(); 2063 elements.push_back(*element); 2064 } while (consumeIf(Token::comma)); 2065 } 2066 loc.End = curToken.getEndLoc(); 2067 if (failed( 2068 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2069 return failure(); 2070 return createTupleExpr(loc, elements, elementNames); 2071 } 2072 2073 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2074 SMRange loc = curToken.getLoc(); 2075 consumeToken(Token::kw_type); 2076 2077 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2078 // identifier. 2079 if (!consumeIf(Token::less)) { 2080 resetToken(loc); 2081 return parseIdentifierExpr(); 2082 } 2083 2084 if (!curToken.isString()) 2085 return emitError("expected string literal containing MLIR type"); 2086 std::string attrExpr = curToken.getStringValue(); 2087 consumeToken(); 2088 2089 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2090 return failure(); 2091 return ast::TypeExpr::create(ctx, loc, attrExpr); 2092 } 2093 2094 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2095 StringRef name = curToken.getSpelling(); 2096 SMRange nameLoc = curToken.getLoc(); 2097 consumeToken(Token::underscore); 2098 2099 // Underscore expressions require a constraint list. 2100 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2101 return failure(); 2102 2103 // Parse the constraints for the expression. 2104 SmallVector<ast::ConstraintRef> constraints; 2105 if (failed(parseVariableDeclConstraintList(constraints))) 2106 return failure(); 2107 2108 ast::Type type; 2109 if (failed(validateVariableConstraints(constraints, type))) 2110 return failure(); 2111 return createInlineVariableExpr(type, name, nameLoc, constraints); 2112 } 2113 2114 //===----------------------------------------------------------------------===// 2115 // Stmts 2116 2117 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2118 FailureOr<ast::Stmt *> stmt; 2119 switch (curToken.getKind()) { 2120 case Token::kw_erase: 2121 stmt = parseEraseStmt(); 2122 break; 2123 case Token::kw_let: 2124 stmt = parseLetStmt(); 2125 break; 2126 case Token::kw_replace: 2127 stmt = parseReplaceStmt(); 2128 break; 2129 case Token::kw_return: 2130 stmt = parseReturnStmt(); 2131 break; 2132 case Token::kw_rewrite: 2133 stmt = parseRewriteStmt(); 2134 break; 2135 default: 2136 stmt = parseExpr(); 2137 break; 2138 } 2139 if (failed(stmt) || 2140 (expectTerminalSemicolon && 2141 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2142 return failure(); 2143 return stmt; 2144 } 2145 2146 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2147 SMLoc startLoc = curToken.getStartLoc(); 2148 consumeToken(Token::l_brace); 2149 2150 // Push a new block scope and parse any nested statements. 2151 pushDeclScope(); 2152 SmallVector<ast::Stmt *> statements; 2153 while (curToken.isNot(Token::r_brace)) { 2154 FailureOr<ast::Stmt *> statement = parseStmt(); 2155 if (failed(statement)) 2156 return popDeclScope(), failure(); 2157 statements.push_back(*statement); 2158 } 2159 popDeclScope(); 2160 2161 // Consume the end brace. 2162 SMRange location(startLoc, curToken.getEndLoc()); 2163 consumeToken(Token::r_brace); 2164 2165 return ast::CompoundStmt::create(ctx, location, statements); 2166 } 2167 2168 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2169 if (parserContext == ParserContext::Constraint) 2170 return emitError("`erase` cannot be used within a Constraint"); 2171 SMRange loc = curToken.getLoc(); 2172 consumeToken(Token::kw_erase); 2173 2174 // Parse the root operation expression. 2175 FailureOr<ast::Expr *> rootOp = parseExpr(); 2176 if (failed(rootOp)) 2177 return failure(); 2178 2179 return createEraseStmt(loc, *rootOp); 2180 } 2181 2182 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2183 SMRange loc = curToken.getLoc(); 2184 consumeToken(Token::kw_let); 2185 2186 // Parse the name of the new variable. 2187 SMRange varLoc = curToken.getLoc(); 2188 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2189 // `_` is a reserved variable name. 2190 if (curToken.is(Token::underscore)) { 2191 return emitError(varLoc, 2192 "`_` may only be used to define \"inline\" variables"); 2193 } 2194 return emitError(varLoc, 2195 "expected identifier after `let` to name a new variable"); 2196 } 2197 StringRef varName = curToken.getSpelling(); 2198 consumeToken(); 2199 2200 // Parse the optional set of constraints. 2201 SmallVector<ast::ConstraintRef> constraints; 2202 if (consumeIf(Token::colon) && 2203 failed(parseVariableDeclConstraintList(constraints))) 2204 return failure(); 2205 2206 // Parse the optional initializer expression. 2207 ast::Expr *initializer = nullptr; 2208 if (consumeIf(Token::equal)) { 2209 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2210 if (failed(initOrFailure)) 2211 return failure(); 2212 initializer = *initOrFailure; 2213 2214 // Check that the constraints are compatible with having an initializer, 2215 // e.g. type constraints cannot be used with initializers. 2216 for (ast::ConstraintRef constraint : constraints) { 2217 LogicalResult result = 2218 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2219 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2220 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2221 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2222 return this->emitError( 2223 constraint.referenceLoc, 2224 "type constraints are not permitted on variables with " 2225 "initializers"); 2226 } 2227 return success(); 2228 }) 2229 .Default(success()); 2230 if (failed(result)) 2231 return failure(); 2232 } 2233 } 2234 2235 FailureOr<ast::VariableDecl *> varDecl = 2236 createVariableDecl(varName, varLoc, initializer, constraints); 2237 if (failed(varDecl)) 2238 return failure(); 2239 return ast::LetStmt::create(ctx, loc, *varDecl); 2240 } 2241 2242 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2243 if (parserContext == ParserContext::Constraint) 2244 return emitError("`replace` cannot be used within a Constraint"); 2245 SMRange loc = curToken.getLoc(); 2246 consumeToken(Token::kw_replace); 2247 2248 // Parse the root operation expression. 2249 FailureOr<ast::Expr *> rootOp = parseExpr(); 2250 if (failed(rootOp)) 2251 return failure(); 2252 2253 if (failed( 2254 parseToken(Token::kw_with, "expected `with` after root operation"))) 2255 return failure(); 2256 2257 // The replacement portion of this statement is within a rewrite context. 2258 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2259 ParserContext::Rewrite); 2260 2261 // Parse the replacement values. 2262 SmallVector<ast::Expr *> replValues; 2263 if (consumeIf(Token::l_paren)) { 2264 if (consumeIf(Token::r_paren)) { 2265 return emitError( 2266 loc, "expected at least one replacement value, consider using " 2267 "`erase` if no replacement values are desired"); 2268 } 2269 2270 do { 2271 FailureOr<ast::Expr *> replExpr = parseExpr(); 2272 if (failed(replExpr)) 2273 return failure(); 2274 replValues.emplace_back(*replExpr); 2275 } while (consumeIf(Token::comma)); 2276 2277 if (failed(parseToken(Token::r_paren, 2278 "expected `)` after replacement values"))) 2279 return failure(); 2280 } else { 2281 FailureOr<ast::Expr *> replExpr = parseExpr(); 2282 if (failed(replExpr)) 2283 return failure(); 2284 replValues.emplace_back(*replExpr); 2285 } 2286 2287 return createReplaceStmt(loc, *rootOp, replValues); 2288 } 2289 2290 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2291 SMRange loc = curToken.getLoc(); 2292 consumeToken(Token::kw_return); 2293 2294 // Parse the result value. 2295 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2296 if (failed(resultExpr)) 2297 return failure(); 2298 2299 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2300 } 2301 2302 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2303 if (parserContext == ParserContext::Constraint) 2304 return emitError("`rewrite` cannot be used within a Constraint"); 2305 SMRange loc = curToken.getLoc(); 2306 consumeToken(Token::kw_rewrite); 2307 2308 // Parse the root operation. 2309 FailureOr<ast::Expr *> rootOp = parseExpr(); 2310 if (failed(rootOp)) 2311 return failure(); 2312 2313 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2314 return failure(); 2315 2316 if (curToken.isNot(Token::l_brace)) 2317 return emitError("expected `{` to start rewrite body"); 2318 2319 // The rewrite body of this statement is within a rewrite context. 2320 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2321 ParserContext::Rewrite); 2322 2323 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2324 if (failed(rewriteBody)) 2325 return failure(); 2326 2327 // Verify the rewrite body. 2328 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2329 if (isa<ast::ReturnStmt>(stmt)) { 2330 return emitError(stmt->getLoc(), 2331 "`return` statements are only permitted within a " 2332 "`Constraint` or `Rewrite` body"); 2333 } 2334 } 2335 2336 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2337 } 2338 2339 //===----------------------------------------------------------------------===// 2340 // Creation+Analysis 2341 //===----------------------------------------------------------------------===// 2342 2343 //===----------------------------------------------------------------------===// 2344 // Decls 2345 2346 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2347 // Unwrap reference expressions. 2348 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2349 node = init->getDecl(); 2350 return dyn_cast<ast::CallableDecl>(node); 2351 } 2352 2353 FailureOr<ast::PatternDecl *> 2354 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2355 const ParsedPatternMetadata &metadata, 2356 ast::CompoundStmt *body) { 2357 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2358 metadata.hasBoundedRecursion, body); 2359 } 2360 2361 ast::Type Parser::createUserConstraintRewriteResultType( 2362 ArrayRef<ast::VariableDecl *> results) { 2363 // Single result decls use the type of the single result. 2364 if (results.size() == 1) 2365 return results[0]->getType(); 2366 2367 // Multiple results use a tuple type, with the types and names grabbed from 2368 // the result variable decls. 2369 auto resultTypes = llvm::map_range( 2370 results, [&](const auto *result) { return result->getType(); }); 2371 auto resultNames = llvm::map_range( 2372 results, [&](const auto *result) { return result->getName().getName(); }); 2373 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2374 llvm::to_vector(resultNames)); 2375 } 2376 2377 template <typename T> 2378 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2379 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2380 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2381 ast::CompoundStmt *body) { 2382 if (!body->getChildren().empty()) { 2383 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2384 ast::Expr *resultExpr = retStmt->getResultExpr(); 2385 2386 // Process the result of the decl. If no explicit signature results 2387 // were provided, check for return type inference. Otherwise, check that 2388 // the return expression can be converted to the expected type. 2389 if (results.empty()) 2390 resultType = resultExpr->getType(); 2391 else if (failed(convertExpressionTo(resultExpr, resultType))) 2392 return failure(); 2393 else 2394 retStmt->setResultExpr(resultExpr); 2395 } 2396 } 2397 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2398 } 2399 2400 FailureOr<ast::VariableDecl *> 2401 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2402 ArrayRef<ast::ConstraintRef> constraints) { 2403 // The type of the variable, which is expected to be inferred by either a 2404 // constraint or an initializer expression. 2405 ast::Type type; 2406 if (failed(validateVariableConstraints(constraints, type))) 2407 return failure(); 2408 2409 if (initializer) { 2410 // Update the variable type based on the initializer, or try to convert the 2411 // initializer to the existing type. 2412 if (!type) 2413 type = initializer->getType(); 2414 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2415 type = mergedType; 2416 else if (failed(convertExpressionTo(initializer, type))) 2417 return failure(); 2418 2419 // Otherwise, if there is no initializer check that the type has already 2420 // been resolved from the constraint list. 2421 } else if (!type) { 2422 return emitErrorAndNote( 2423 loc, "unable to infer type for variable `" + name + "`", loc, 2424 "the type of a variable must be inferable from the constraint " 2425 "list or the initializer"); 2426 } 2427 2428 // Constraint types cannot be used when defining variables. 2429 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2430 return emitError( 2431 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2432 } 2433 2434 // Try to define a variable with the given name. 2435 FailureOr<ast::VariableDecl *> varDecl = 2436 defineVariableDecl(name, loc, type, initializer, constraints); 2437 if (failed(varDecl)) 2438 return failure(); 2439 2440 return *varDecl; 2441 } 2442 2443 FailureOr<ast::VariableDecl *> 2444 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2445 const ast::ConstraintRef &constraint) { 2446 // Constraint arguments may apply more complex constraints via the arguments. 2447 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2448 ast::Type argType; 2449 if (failed(validateVariableConstraint(constraint, argType, 2450 allowNonCoreConstraints))) 2451 return failure(); 2452 return defineVariableDecl(name, loc, argType, constraint); 2453 } 2454 2455 LogicalResult 2456 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2457 ast::Type &inferredType, 2458 bool allowNonCoreConstraints) { 2459 for (const ast::ConstraintRef &ref : constraints) 2460 if (failed(validateVariableConstraint(ref, inferredType, 2461 allowNonCoreConstraints))) 2462 return failure(); 2463 return success(); 2464 } 2465 2466 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2467 ast::Type &inferredType, 2468 bool allowNonCoreConstraints) { 2469 ast::Type constraintType; 2470 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2471 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2472 if (failed(validateTypeConstraintExpr(typeExpr))) 2473 return failure(); 2474 } 2475 constraintType = ast::AttributeType::get(ctx); 2476 } else if (const auto *cst = 2477 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2478 constraintType = ast::OperationType::get(ctx, cst->getName()); 2479 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2480 constraintType = typeTy; 2481 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2482 constraintType = typeRangeTy; 2483 } else if (const auto *cst = 2484 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2485 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2486 if (failed(validateTypeConstraintExpr(typeExpr))) 2487 return failure(); 2488 } 2489 constraintType = valueTy; 2490 } else if (const auto *cst = 2491 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2492 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2493 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2494 return failure(); 2495 } 2496 constraintType = valueRangeTy; 2497 } else if (const auto *cst = 2498 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2499 if (!allowNonCoreConstraints) { 2500 return emitError(ref.referenceLoc, 2501 "`Rewrite` arguments and results are only permitted to " 2502 "use core constraints, such as `Attr`, `Op`, `Type`, " 2503 "`TypeRange`, `Value`, `ValueRange`"); 2504 } 2505 2506 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2507 if (inputs.size() != 1) { 2508 return emitErrorAndNote(ref.referenceLoc, 2509 "`Constraint`s applied via a variable constraint " 2510 "list must take a single input, but got " + 2511 Twine(inputs.size()), 2512 cst->getLoc(), 2513 "see definition of constraint here"); 2514 } 2515 constraintType = inputs.front()->getType(); 2516 } else { 2517 llvm_unreachable("unknown constraint type"); 2518 } 2519 2520 // Check that the constraint type is compatible with the current inferred 2521 // type. 2522 if (!inferredType) { 2523 inferredType = constraintType; 2524 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2525 inferredType = mergedTy; 2526 } else { 2527 return emitError(ref.referenceLoc, 2528 llvm::formatv("constraint type `{0}` is incompatible " 2529 "with the previously inferred type `{1}`", 2530 constraintType, inferredType)); 2531 } 2532 return success(); 2533 } 2534 2535 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2536 ast::Type typeExprType = typeExpr->getType(); 2537 if (typeExprType != typeTy) { 2538 return emitError(typeExpr->getLoc(), 2539 "expected expression of `Type` in type constraint"); 2540 } 2541 return success(); 2542 } 2543 2544 LogicalResult 2545 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2546 ast::Type typeExprType = typeExpr->getType(); 2547 if (typeExprType != typeRangeTy) { 2548 return emitError(typeExpr->getLoc(), 2549 "expected expression of `TypeRange` in type constraint"); 2550 } 2551 return success(); 2552 } 2553 2554 //===----------------------------------------------------------------------===// 2555 // Exprs 2556 2557 FailureOr<ast::CallExpr *> 2558 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2559 MutableArrayRef<ast::Expr *> arguments) { 2560 ast::Type parentType = parentExpr->getType(); 2561 2562 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2563 if (!callableDecl) { 2564 return emitError(loc, 2565 llvm::formatv("expected a reference to a callable " 2566 "`Constraint` or `Rewrite`, but got: `{0}`", 2567 parentType)); 2568 } 2569 if (parserContext == ParserContext::Rewrite) { 2570 if (isa<ast::UserConstraintDecl>(callableDecl)) 2571 return emitError( 2572 loc, "unable to invoke `Constraint` within a rewrite section"); 2573 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2574 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2575 } 2576 2577 // Verify the arguments of the call. 2578 /// Handle size mismatch. 2579 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2580 if (callArgs.size() != arguments.size()) { 2581 return emitErrorAndNote( 2582 loc, 2583 llvm::formatv("invalid number of arguments for {0} call; expected " 2584 "{1}, but got {2}", 2585 callableDecl->getCallableType(), callArgs.size(), 2586 arguments.size()), 2587 callableDecl->getLoc(), 2588 llvm::formatv("see the definition of {0} here", 2589 callableDecl->getName()->getName())); 2590 } 2591 2592 /// Handle argument type mismatch. 2593 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2594 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2595 callableDecl->getName()->getName()), 2596 callableDecl->getLoc()); 2597 }; 2598 for (auto it : llvm::zip(callArgs, arguments)) { 2599 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2600 attachDiagFn))) 2601 return failure(); 2602 } 2603 2604 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2605 callableDecl->getResultType()); 2606 } 2607 2608 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2609 ast::Decl *decl) { 2610 // Check the type of decl being referenced. 2611 ast::Type declType; 2612 if (isa<ast::ConstraintDecl>(decl)) 2613 declType = ast::ConstraintType::get(ctx); 2614 else if (isa<ast::UserRewriteDecl>(decl)) 2615 declType = ast::RewriteType::get(ctx); 2616 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2617 declType = varDecl->getType(); 2618 else 2619 return emitError(loc, "invalid reference to `" + 2620 decl->getName()->getName() + "`"); 2621 2622 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2623 } 2624 2625 FailureOr<ast::DeclRefExpr *> 2626 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2627 ArrayRef<ast::ConstraintRef> constraints) { 2628 FailureOr<ast::VariableDecl *> decl = 2629 defineVariableDecl(name, loc, type, constraints); 2630 if (failed(decl)) 2631 return failure(); 2632 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2633 } 2634 2635 FailureOr<ast::MemberAccessExpr *> 2636 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2637 SMRange loc) { 2638 // Validate the member name for the given parent expression. 2639 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2640 if (failed(memberType)) 2641 return failure(); 2642 2643 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2644 } 2645 2646 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2647 StringRef name, SMRange loc) { 2648 ast::Type parentType = parentExpr->getType(); 2649 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2650 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2651 return valueRangeTy; 2652 2653 // Verify member access based on the operation type. 2654 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2655 auto results = odsOp->getResults(); 2656 2657 // Handle indexed results. 2658 unsigned index = 0; 2659 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2660 index < results.size()) { 2661 return results[index].isVariadic() ? valueRangeTy : valueTy; 2662 } 2663 2664 // Handle named results. 2665 const auto *it = llvm::find_if(results, [&](const auto &result) { 2666 return result.getName() == name; 2667 }); 2668 if (it != results.end()) 2669 return it->isVariadic() ? valueRangeTy : valueTy; 2670 } 2671 2672 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2673 // Handle indexed results. 2674 unsigned index = 0; 2675 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2676 index < tupleType.size()) { 2677 return tupleType.getElementTypes()[index]; 2678 } 2679 2680 // Handle named results. 2681 auto elementNames = tupleType.getElementNames(); 2682 const auto *it = llvm::find(elementNames, name); 2683 if (it != elementNames.end()) 2684 return tupleType.getElementTypes()[it - elementNames.begin()]; 2685 } 2686 return emitError( 2687 loc, 2688 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2689 name, parentType)); 2690 } 2691 2692 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2693 SMRange loc, const ast::OpNameDecl *name, 2694 MutableArrayRef<ast::Expr *> operands, 2695 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2696 MutableArrayRef<ast::Expr *> results) { 2697 Optional<StringRef> opNameRef = name->getName(); 2698 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2699 2700 // Verify the inputs operands. 2701 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2702 return failure(); 2703 2704 // Verify the attribute list. 2705 for (ast::NamedAttributeDecl *attr : attributes) { 2706 // Check for an attribute type, or a type awaiting resolution. 2707 ast::Type attrType = attr->getValue()->getType(); 2708 if (!attrType.isa<ast::AttributeType>()) { 2709 return emitError( 2710 attr->getValue()->getLoc(), 2711 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2712 } 2713 } 2714 2715 // Verify the result types. 2716 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2717 return failure(); 2718 2719 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2720 attributes); 2721 } 2722 2723 LogicalResult 2724 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2725 const ods::Operation *odsOp, 2726 MutableArrayRef<ast::Expr *> operands) { 2727 return validateOperationOperandsOrResults( 2728 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2729 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2730 valueRangeTy); 2731 } 2732 2733 LogicalResult 2734 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2735 const ods::Operation *odsOp, 2736 MutableArrayRef<ast::Expr *> results) { 2737 return validateOperationOperandsOrResults( 2738 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2739 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2740 } 2741 2742 LogicalResult Parser::validateOperationOperandsOrResults( 2743 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2744 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2745 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2746 ast::Type rangeTy) { 2747 // All operation types accept a single range parameter. 2748 if (values.size() == 1) { 2749 if (failed(convertExpressionTo(values[0], rangeTy))) 2750 return failure(); 2751 return success(); 2752 } 2753 2754 /// If the operation has ODS information, we can more accurately verify the 2755 /// values. 2756 if (odsOpLoc) { 2757 if (odsValues.size() != values.size()) { 2758 return emitErrorAndNote( 2759 loc, 2760 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2761 "{2}, but got {3}", 2762 groupName, *name, odsValues.size(), values.size()), 2763 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2764 } 2765 auto diagFn = [&](ast::Diagnostic &diag) { 2766 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2767 *odsOpLoc); 2768 }; 2769 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2770 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2771 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2772 return failure(); 2773 } 2774 return success(); 2775 } 2776 2777 // Otherwise, accept the value groups as they have been defined and just 2778 // ensure they are one of the expected types. 2779 for (ast::Expr *&valueExpr : values) { 2780 ast::Type valueExprType = valueExpr->getType(); 2781 2782 // Check if this is one of the expected types. 2783 if (valueExprType == rangeTy || valueExprType == singleTy) 2784 continue; 2785 2786 // If the operand is an Operation, allow converting to a Value or 2787 // ValueRange. This situations arises quite often with nested operation 2788 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2789 if (singleTy == valueTy) { 2790 if (valueExprType.isa<ast::OperationType>()) { 2791 valueExpr = convertOpToValue(valueExpr); 2792 continue; 2793 } 2794 } 2795 2796 return emitError( 2797 valueExpr->getLoc(), 2798 llvm::formatv( 2799 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2800 singleTy, rangeTy, valueExprType)); 2801 } 2802 return success(); 2803 } 2804 2805 FailureOr<ast::TupleExpr *> 2806 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2807 ArrayRef<StringRef> elementNames) { 2808 for (const ast::Expr *element : elements) { 2809 ast::Type eleTy = element->getType(); 2810 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2811 return emitError( 2812 element->getLoc(), 2813 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2814 } 2815 } 2816 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2817 } 2818 2819 //===----------------------------------------------------------------------===// 2820 // Stmts 2821 2822 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2823 ast::Expr *rootOp) { 2824 // Check that root is an Operation. 2825 ast::Type rootType = rootOp->getType(); 2826 if (!rootType.isa<ast::OperationType>()) 2827 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2828 2829 return ast::EraseStmt::create(ctx, loc, rootOp); 2830 } 2831 2832 FailureOr<ast::ReplaceStmt *> 2833 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2834 MutableArrayRef<ast::Expr *> replValues) { 2835 // Check that root is an Operation. 2836 ast::Type rootType = rootOp->getType(); 2837 if (!rootType.isa<ast::OperationType>()) { 2838 return emitError( 2839 rootOp->getLoc(), 2840 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2841 } 2842 2843 // If there are multiple replacement values, we implicitly convert any Op 2844 // expressions to the value form. 2845 bool shouldConvertOpToValues = replValues.size() > 1; 2846 for (ast::Expr *&replExpr : replValues) { 2847 ast::Type replType = replExpr->getType(); 2848 2849 // Check that replExpr is an Operation, Value, or ValueRange. 2850 if (replType.isa<ast::OperationType>()) { 2851 if (shouldConvertOpToValues) 2852 replExpr = convertOpToValue(replExpr); 2853 continue; 2854 } 2855 2856 if (replType != valueTy && replType != valueRangeTy) { 2857 return emitError(replExpr->getLoc(), 2858 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2859 "expression, but got `{0}`", 2860 replType)); 2861 } 2862 } 2863 2864 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2865 } 2866 2867 FailureOr<ast::RewriteStmt *> 2868 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2869 ast::CompoundStmt *rewriteBody) { 2870 // Check that root is an Operation. 2871 ast::Type rootType = rootOp->getType(); 2872 if (!rootType.isa<ast::OperationType>()) { 2873 return emitError( 2874 rootOp->getLoc(), 2875 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2876 } 2877 2878 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2879 } 2880 2881 //===----------------------------------------------------------------------===// 2882 // Code Completion 2883 //===----------------------------------------------------------------------===// 2884 2885 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2886 ast::Type parentType = parentExpr->getType(); 2887 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2888 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2889 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2890 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2891 return failure(); 2892 } 2893 2894 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2895 if (opName) 2896 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2897 return failure(); 2898 } 2899 2900 LogicalResult 2901 Parser::codeCompleteConstraintName(ast::Type inferredType, 2902 bool allowNonCoreConstraints, 2903 bool allowInlineTypeConstraints) { 2904 codeCompleteContext->codeCompleteConstraintName( 2905 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2906 curDeclScope); 2907 return failure(); 2908 } 2909 2910 LogicalResult Parser::codeCompleteDialectName() { 2911 codeCompleteContext->codeCompleteDialectName(); 2912 return failure(); 2913 } 2914 2915 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2916 codeCompleteContext->codeCompleteOperationName(dialectName); 2917 return failure(); 2918 } 2919 2920 LogicalResult Parser::codeCompletePatternMetadata() { 2921 codeCompleteContext->codeCompletePatternMetadata(); 2922 return failure(); 2923 } 2924 2925 void Parser::codeCompleteCallSignature(ast::Node *parent, 2926 unsigned currentNumArgs) { 2927 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2928 if (!callableDecl) 2929 return; 2930 2931 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2932 } 2933 2934 void Parser::codeCompleteOperationOperandsSignature( 2935 Optional<StringRef> opName, unsigned currentNumOperands) { 2936 codeCompleteContext->codeCompleteOperationOperandsSignature( 2937 opName, currentNumOperands); 2938 } 2939 2940 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2941 unsigned currentNumResults) { 2942 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2943 currentNumResults); 2944 } 2945 2946 //===----------------------------------------------------------------------===// 2947 // Parser 2948 //===----------------------------------------------------------------------===// 2949 2950 FailureOr<ast::Module *> 2951 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2952 CodeCompleteContext *codeCompleteContext) { 2953 Parser parser(ctx, sourceMgr, codeCompleteContext); 2954 return parser.parseModule(); 2955 } 2956