1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 428 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 429 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 430 unsigned currentNumOperands); 431 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 432 unsigned currentNumResults); 433 434 //===--------------------------------------------------------------------===// 435 // Lexer Utilities 436 //===--------------------------------------------------------------------===// 437 438 /// If the current token has the specified kind, consume it and return true. 439 /// If not, return false. 440 bool consumeIf(Token::Kind kind) { 441 if (curToken.isNot(kind)) 442 return false; 443 consumeToken(kind); 444 return true; 445 } 446 447 /// Advance the current lexer onto the next token. 448 void consumeToken() { 449 assert(curToken.isNot(Token::eof, Token::error) && 450 "shouldn't advance past EOF or errors"); 451 curToken = lexer.lexToken(); 452 } 453 454 /// Advance the current lexer onto the next token, asserting what the expected 455 /// current token is. This is preferred to the above method because it leads 456 /// to more self-documenting code with better checking. 457 void consumeToken(Token::Kind kind) { 458 assert(curToken.is(kind) && "consumed an unexpected token"); 459 consumeToken(); 460 } 461 462 /// Reset the lexer to the location at the given position. 463 void resetToken(SMRange tokLoc) { 464 lexer.resetPointer(tokLoc.Start.getPointer()); 465 curToken = lexer.lexToken(); 466 } 467 468 /// Consume the specified token if present and return success. On failure, 469 /// output a diagnostic and return failure. 470 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 471 if (curToken.getKind() != kind) 472 return emitError(curToken.getLoc(), msg); 473 consumeToken(); 474 return success(); 475 } 476 LogicalResult emitError(SMRange loc, const Twine &msg) { 477 lexer.emitError(loc, msg); 478 return failure(); 479 } 480 LogicalResult emitError(const Twine &msg) { 481 return emitError(curToken.getLoc(), msg); 482 } 483 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 484 const Twine ¬e) { 485 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 486 return failure(); 487 } 488 489 //===--------------------------------------------------------------------===// 490 // Fields 491 //===--------------------------------------------------------------------===// 492 493 /// The owning AST context. 494 ast::Context &ctx; 495 496 /// The lexer of this parser. 497 Lexer lexer; 498 499 /// The current token within the lexer. 500 Token curToken; 501 502 /// The most recently defined decl scope. 503 ast::DeclScope *curDeclScope = nullptr; 504 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 505 506 /// The current context of the parser. 507 ParserContext parserContext = ParserContext::Global; 508 509 /// Cached types to simplify verification and expression creation. 510 ast::Type valueTy, valueRangeTy; 511 ast::Type typeTy, typeRangeTy; 512 ast::Type attrTy; 513 514 /// A counter used when naming anonymous constraints and rewrites. 515 unsigned anonymousDeclNameCounter = 0; 516 517 /// The optional code completion context. 518 CodeCompleteContext *codeCompleteContext; 519 }; 520 } // namespace 521 522 FailureOr<ast::Module *> Parser::parseModule() { 523 SMLoc moduleLoc = curToken.getStartLoc(); 524 pushDeclScope(); 525 526 // Parse the top-level decls of the module. 527 SmallVector<ast::Decl *> decls; 528 if (failed(parseModuleBody(decls))) 529 return popDeclScope(), failure(); 530 531 popDeclScope(); 532 return ast::Module::create(ctx, moduleLoc, decls); 533 } 534 535 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 536 while (curToken.isNot(Token::eof)) { 537 if (curToken.is(Token::directive)) { 538 if (failed(parseDirective(decls))) 539 return failure(); 540 continue; 541 } 542 543 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 544 if (failed(decl)) 545 return failure(); 546 decls.push_back(*decl); 547 } 548 return success(); 549 } 550 551 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 552 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 553 valueRangeTy); 554 } 555 556 LogicalResult Parser::convertExpressionTo( 557 ast::Expr *&expr, ast::Type type, 558 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 559 ast::Type exprType = expr->getType(); 560 if (exprType == type) 561 return success(); 562 563 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 564 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 565 expr->getLoc(), llvm::formatv("unable to convert expression of type " 566 "`{0}` to the expected type of " 567 "`{1}`", 568 exprType, type)); 569 if (noteAttachFn) 570 noteAttachFn(*diag); 571 return diag; 572 }; 573 574 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 575 // Two operation types are compatible if they have the same name, or if the 576 // expected type is more general. 577 if (auto opType = type.dyn_cast<ast::OperationType>()) { 578 if (opType.getName()) 579 return emitConvertError(); 580 return success(); 581 } 582 583 // An operation can always convert to a ValueRange. 584 if (type == valueRangeTy) { 585 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 586 valueRangeTy); 587 return success(); 588 } 589 590 // Allow conversion to a single value by constraining the result range. 591 if (type == valueTy) { 592 // If the operation is registered, we can verify if it can ever have a 593 // single result. 594 Optional<StringRef> opName = exprOpType.getName(); 595 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 596 if (odsOp->getResults().empty()) { 597 return emitConvertError()->attachNote( 598 llvm::formatv("see the definition of `{0}`, which was defined " 599 "with zero results", 600 odsOp->getName()), 601 odsOp->getLoc()); 602 } 603 604 unsigned numSingleResults = llvm::count_if( 605 odsOp->getResults(), [](const ods::OperandOrResult &result) { 606 return result.getVariableLengthKind() == 607 ods::VariableLengthKind::Single; 608 }); 609 if (numSingleResults > 1) { 610 return emitConvertError()->attachNote( 611 llvm::formatv("see the definition of `{0}`, which was defined " 612 "with at least {1} results", 613 odsOp->getName(), numSingleResults), 614 odsOp->getLoc()); 615 } 616 } 617 618 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 619 valueTy); 620 return success(); 621 } 622 return emitConvertError(); 623 } 624 625 // FIXME: Decide how to allow/support converting a single result to multiple, 626 // and multiple to a single result. For now, we just allow Single->Range, 627 // but this isn't something really supported in the PDL dialect. We should 628 // figure out some way to support both. 629 if ((exprType == valueTy || exprType == valueRangeTy) && 630 (type == valueTy || type == valueRangeTy)) 631 return success(); 632 if ((exprType == typeTy || exprType == typeRangeTy) && 633 (type == typeTy || type == typeRangeTy)) 634 return success(); 635 636 // Handle tuple types. 637 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 638 auto tupleType = type.dyn_cast<ast::TupleType>(); 639 if (!tupleType || tupleType.size() != exprTupleType.size()) 640 return emitConvertError(); 641 642 // Build a new tuple expression using each of the elements of the current 643 // tuple. 644 SmallVector<ast::Expr *> newExprs; 645 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 646 newExprs.push_back(ast::MemberAccessExpr::create( 647 ctx, expr->getLoc(), expr, llvm::to_string(i), 648 exprTupleType.getElementTypes()[i])); 649 650 auto diagFn = [&](ast::Diagnostic &diag) { 651 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 652 i, exprTupleType)); 653 if (noteAttachFn) 654 noteAttachFn(diag); 655 }; 656 if (failed(convertExpressionTo(newExprs.back(), 657 tupleType.getElementTypes()[i], diagFn))) 658 return failure(); 659 } 660 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 661 tupleType.getElementNames()); 662 return success(); 663 } 664 665 return emitConvertError(); 666 } 667 668 //===----------------------------------------------------------------------===// 669 // Directives 670 671 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 672 StringRef directive = curToken.getSpelling(); 673 if (directive == "#include") 674 return parseInclude(decls); 675 676 return emitError("unknown directive `" + directive + "`"); 677 } 678 679 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 680 SMRange loc = curToken.getLoc(); 681 consumeToken(Token::directive); 682 683 // Parse the file being included. 684 if (!curToken.isString()) 685 return emitError(loc, 686 "expected string file name after `include` directive"); 687 SMRange fileLoc = curToken.getLoc(); 688 std::string filenameStr = curToken.getStringValue(); 689 StringRef filename = filenameStr; 690 consumeToken(); 691 692 // Check the type of include. If ending with `.pdll`, this is another pdl file 693 // to be parsed along with the current module. 694 if (filename.endswith(".pdll")) { 695 if (failed(lexer.pushInclude(filename))) 696 return emitError(fileLoc, 697 "unable to open include file `" + filename + "`"); 698 699 // If we added the include successfully, parse it into the current module. 700 // Make sure to save the current token so that we can restore it when we 701 // finish parsing the nested file. 702 Token oldToken = curToken; 703 curToken = lexer.lexToken(); 704 LogicalResult result = parseModuleBody(decls); 705 curToken = oldToken; 706 return result; 707 } 708 709 // Otherwise, this must be a `.td` include. 710 if (filename.endswith(".td")) 711 return parseTdInclude(filename, fileLoc, decls); 712 713 return emitError(fileLoc, 714 "expected include filename to end with `.pdll` or `.td`"); 715 } 716 717 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 718 SmallVectorImpl<ast::Decl *> &decls) { 719 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 720 721 // This class provides a context argument for the llvm::SourceMgr diagnostic 722 // handler. 723 struct DiagHandlerContext { 724 Parser &parser; 725 StringRef filename; 726 llvm::SMRange loc; 727 } handlerContext{*this, filename, fileLoc}; 728 729 // Set the diagnostic handler for the tablegen source manager. 730 llvm::SrcMgr.setDiagHandler( 731 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 732 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 733 (void)ctx->parser.emitError( 734 ctx->loc, 735 llvm::formatv("error while processing include file `{0}`: {1}", 736 ctx->filename, diag.getMessage())); 737 }, 738 &handlerContext); 739 740 // Use the source manager to open the file, but don't yet add it. 741 std::string includedFile; 742 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 743 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 744 if (!includeBuffer) 745 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 746 747 auto processFn = [&](llvm::RecordKeeper &records) { 748 processTdIncludeRecords(records, decls); 749 750 // After we are done processing, move all of the tablegen source buffers to 751 // the main parser source mgr. This allows for directly using source 752 // locations from the .td files without needing to remap them. 753 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr); 754 return false; 755 }; 756 if (llvm::TableGenParseFile(std::move(*includeBuffer), 757 parserSrcMgr.getIncludeDirs(), processFn)) 758 return failure(); 759 760 return success(); 761 } 762 763 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 764 SmallVectorImpl<ast::Decl *> &decls) { 765 // Return the length kind of the given value. 766 auto getLengthKind = [](const auto &value) { 767 if (value.isOptional()) 768 return ods::VariableLengthKind::Optional; 769 return value.isVariadic() ? ods::VariableLengthKind::Variadic 770 : ods::VariableLengthKind::Single; 771 }; 772 773 // Insert a type constraint into the ODS context. 774 ods::Context &odsContext = ctx.getODSContext(); 775 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 776 -> const ods::TypeConstraint & { 777 return odsContext.insertTypeConstraint(cst.constraint.getDefName(), 778 cst.constraint.getSummary(), 779 cst.constraint.getCPPClassName()); 780 }; 781 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 782 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 783 }; 784 785 // Process the parsed tablegen records to build ODS information. 786 /// Operations. 787 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 788 tblgen::Operator op(def); 789 790 bool inserted = false; 791 ods::Operation *odsOp = nullptr; 792 std::tie(odsOp, inserted) = 793 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 794 op.getDescription(), op.getLoc().front()); 795 796 // Ignore operations that have already been added. 797 if (!inserted) 798 continue; 799 800 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 801 odsOp->appendAttribute( 802 attr.name, attr.attr.isOptional(), 803 odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(), 804 attr.attr.getSummary(), 805 attr.attr.getStorageType())); 806 } 807 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 808 odsOp->appendOperand(operand.name, getLengthKind(operand), 809 addTypeConstraint(operand)); 810 } 811 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 812 odsOp->appendResult(result.name, getLengthKind(result), 813 addTypeConstraint(result)); 814 } 815 } 816 /// Attr constraints. 817 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 818 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 819 decls.push_back( 820 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 821 tblgen::AttrConstraint(def), 822 convertLocToRange(def->getLoc().front()), attrTy)); 823 } 824 } 825 /// Type constraints. 826 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 827 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 828 decls.push_back( 829 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 830 tblgen::TypeConstraint(def), 831 convertLocToRange(def->getLoc().front()), typeTy)); 832 } 833 } 834 /// Interfaces. 835 ast::Type opTy = ast::OperationType::get(ctx); 836 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 837 StringRef name = def->getName(); 838 if (def->isAnonymous() || curDeclScope->lookup(name) || 839 def->isSubClassOf("DeclareInterfaceMethods")) 840 continue; 841 SMRange loc = convertLocToRange(def->getLoc().front()); 842 843 StringRef className = def->getValueAsString("cppClassName"); 844 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 845 std::string codeBlock = 846 llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) 847 .str(); 848 849 if (def->isSubClassOf("OpInterface")) { 850 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 851 name, codeBlock, loc, opTy)); 852 } else if (def->isSubClassOf("AttrInterface")) { 853 decls.push_back( 854 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 855 name, codeBlock, loc, attrTy)); 856 } else if (def->isSubClassOf("TypeInterface")) { 857 decls.push_back( 858 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 859 name, codeBlock, loc, typeTy)); 860 } 861 } 862 } 863 864 template <typename ConstraintT> 865 ast::Decl * 866 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 867 SMRange loc, ast::Type type) { 868 // Build the single input parameter. 869 ast::DeclScope *argScope = pushDeclScope(); 870 auto *paramVar = ast::VariableDecl::create( 871 ctx, ast::Name::create(ctx, "self", loc), type, 872 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 873 argScope->add(paramVar); 874 popDeclScope(); 875 876 // Build the native constraint. 877 auto *constraintDecl = ast::UserConstraintDecl::createNative( 878 ctx, ast::Name::create(ctx, name, loc), paramVar, 879 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 880 curDeclScope->add(constraintDecl); 881 return constraintDecl; 882 } 883 884 template <typename ConstraintT> 885 ast::Decl * 886 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 887 SMRange loc, ast::Type type) { 888 // Format the condition template. 889 tblgen::FmtContext fmtContext; 890 fmtContext.withSelf("self"); 891 std::string codeBlock = 892 tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); 893 894 return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(), 895 codeBlock, loc, type); 896 } 897 898 //===----------------------------------------------------------------------===// 899 // Decls 900 901 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 902 FailureOr<ast::Decl *> decl; 903 switch (curToken.getKind()) { 904 case Token::kw_Constraint: 905 decl = parseUserConstraintDecl(); 906 break; 907 case Token::kw_Pattern: 908 decl = parsePatternDecl(); 909 break; 910 case Token::kw_Rewrite: 911 decl = parseUserRewriteDecl(); 912 break; 913 default: 914 return emitError("expected top-level declaration, such as a `Pattern`"); 915 } 916 if (failed(decl)) 917 return failure(); 918 919 // If the decl has a name, add it to the current scope. 920 if (const ast::Name *name = (*decl)->getName()) { 921 if (failed(checkDefineNamedDecl(*name))) 922 return failure(); 923 curDeclScope->add(*decl); 924 } 925 return decl; 926 } 927 928 FailureOr<ast::NamedAttributeDecl *> 929 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 930 // Check for name code completion. 931 if (curToken.is(Token::code_complete)) 932 return codeCompleteAttributeName(parentOpName); 933 934 std::string attrNameStr; 935 if (curToken.isString()) 936 attrNameStr = curToken.getStringValue(); 937 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 938 attrNameStr = curToken.getSpelling().str(); 939 else 940 return emitError("expected identifier or string attribute name"); 941 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 942 consumeToken(); 943 944 // Check for a value of the attribute. 945 ast::Expr *attrValue = nullptr; 946 if (consumeIf(Token::equal)) { 947 FailureOr<ast::Expr *> attrExpr = parseExpr(); 948 if (failed(attrExpr)) 949 return failure(); 950 attrValue = *attrExpr; 951 } else { 952 // If there isn't a concrete value, create an expression representing a 953 // UnitAttr. 954 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 955 } 956 957 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 958 } 959 960 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 961 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 962 bool expectTerminalSemicolon) { 963 consumeToken(Token::equal_arrow); 964 965 // Parse the single statement of the lambda body. 966 SMLoc bodyStartLoc = curToken.getStartLoc(); 967 pushDeclScope(); 968 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 969 bool failedToParse = 970 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 971 popDeclScope(); 972 if (failedToParse) 973 return failure(); 974 975 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 976 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 977 } 978 979 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 980 // Ensure that the argument is named. 981 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 982 return emitError("expected identifier argument name"); 983 984 // Parse the argument similarly to a normal variable. 985 StringRef name = curToken.getSpelling(); 986 SMRange nameLoc = curToken.getLoc(); 987 consumeToken(); 988 989 if (failed( 990 parseToken(Token::colon, "expected `:` before argument constraint"))) 991 return failure(); 992 993 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 994 if (failed(cst)) 995 return failure(); 996 997 return createArgOrResultVariableDecl(name, nameLoc, *cst); 998 } 999 1000 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1001 // Check to see if this result is named. 1002 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1003 // Check to see if this name actually refers to a Constraint. 1004 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1005 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1006 // If yes, and this is a Rewrite, give a nice error message as non-Core 1007 // constraints are not supported on Rewrite results. 1008 if (parserContext == ParserContext::Rewrite) { 1009 return emitError( 1010 "`Rewrite` results are only permitted to use core constraints, " 1011 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1012 } 1013 1014 // Otherwise, parse this as an unnamed result variable. 1015 } else { 1016 // If it wasn't a constraint, parse the result similarly to a variable. If 1017 // there is already an existing decl, we will emit an error when defining 1018 // this variable later. 1019 StringRef name = curToken.getSpelling(); 1020 SMRange nameLoc = curToken.getLoc(); 1021 consumeToken(); 1022 1023 if (failed(parseToken(Token::colon, 1024 "expected `:` before result constraint"))) 1025 return failure(); 1026 1027 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1028 if (failed(cst)) 1029 return failure(); 1030 1031 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1032 } 1033 } 1034 1035 // If it isn't named, we parse the constraint directly and create an unnamed 1036 // result variable. 1037 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1038 if (failed(cst)) 1039 return failure(); 1040 1041 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1042 } 1043 1044 FailureOr<ast::UserConstraintDecl *> 1045 Parser::parseUserConstraintDecl(bool isInline) { 1046 // Constraints and rewrites have very similar formats, dispatch to a shared 1047 // interface for parsing. 1048 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1049 [&](auto &&...args) { 1050 return this->parseUserPDLLConstraintDecl(args...); 1051 }, 1052 ParserContext::Constraint, "constraint", isInline); 1053 } 1054 1055 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1056 FailureOr<ast::UserConstraintDecl *> decl = 1057 parseUserConstraintDecl(/*isInline=*/true); 1058 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1059 return failure(); 1060 1061 curDeclScope->add(*decl); 1062 return decl; 1063 } 1064 1065 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1066 const ast::Name &name, bool isInline, 1067 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1068 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1069 // Push the argument scope back onto the list, so that the body can 1070 // reference arguments. 1071 pushDeclScope(argumentScope); 1072 1073 // Parse the body of the constraint. The body is either defined as a compound 1074 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1075 ast::CompoundStmt *body; 1076 if (curToken.is(Token::equal_arrow)) { 1077 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1078 [&](ast::Stmt *&stmt) -> LogicalResult { 1079 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1080 if (!stmtExpr) { 1081 return emitError(stmt->getLoc(), 1082 "expected `Constraint` lambda body to contain a " 1083 "single expression"); 1084 } 1085 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1086 return success(); 1087 }, 1088 /*expectTerminalSemicolon=*/!isInline); 1089 if (failed(bodyResult)) 1090 return failure(); 1091 body = *bodyResult; 1092 } else { 1093 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1094 if (failed(bodyResult)) 1095 return failure(); 1096 body = *bodyResult; 1097 1098 // Verify the structure of the body. 1099 auto bodyIt = body->begin(), bodyE = body->end(); 1100 for (; bodyIt != bodyE; ++bodyIt) 1101 if (isa<ast::ReturnStmt>(*bodyIt)) 1102 break; 1103 if (failed(validateUserConstraintOrRewriteReturn( 1104 "Constraint", body, bodyIt, bodyE, results, resultType))) 1105 return failure(); 1106 } 1107 popDeclScope(); 1108 1109 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1110 name, arguments, results, resultType, body); 1111 } 1112 1113 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1114 // Constraints and rewrites have very similar formats, dispatch to a shared 1115 // interface for parsing. 1116 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1117 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1118 ParserContext::Rewrite, "rewrite", isInline); 1119 } 1120 1121 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1122 FailureOr<ast::UserRewriteDecl *> decl = 1123 parseUserRewriteDecl(/*isInline=*/true); 1124 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1125 return failure(); 1126 1127 curDeclScope->add(*decl); 1128 return decl; 1129 } 1130 1131 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1132 const ast::Name &name, bool isInline, 1133 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1134 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1135 // Push the argument scope back onto the list, so that the body can 1136 // reference arguments. 1137 curDeclScope = argumentScope; 1138 ast::CompoundStmt *body; 1139 if (curToken.is(Token::equal_arrow)) { 1140 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1141 [&](ast::Stmt *&statement) -> LogicalResult { 1142 if (isa<ast::OpRewriteStmt>(statement)) 1143 return success(); 1144 1145 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1146 if (!statementExpr) { 1147 return emitError( 1148 statement->getLoc(), 1149 "expected `Rewrite` lambda body to contain a single expression " 1150 "or an operation rewrite statement; such as `erase`, " 1151 "`replace`, or `rewrite`"); 1152 } 1153 statement = 1154 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1155 return success(); 1156 }, 1157 /*expectTerminalSemicolon=*/!isInline); 1158 if (failed(bodyResult)) 1159 return failure(); 1160 body = *bodyResult; 1161 } else { 1162 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1163 if (failed(bodyResult)) 1164 return failure(); 1165 body = *bodyResult; 1166 } 1167 popDeclScope(); 1168 1169 // Verify the structure of the body. 1170 auto bodyIt = body->begin(), bodyE = body->end(); 1171 for (; bodyIt != bodyE; ++bodyIt) 1172 if (isa<ast::ReturnStmt>(*bodyIt)) 1173 break; 1174 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1175 bodyE, results, resultType))) 1176 return failure(); 1177 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1178 name, arguments, results, resultType, body); 1179 } 1180 1181 template <typename T, typename ParseUserPDLLDeclFnT> 1182 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1183 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1184 StringRef anonymousNamePrefix, bool isInline) { 1185 SMRange loc = curToken.getLoc(); 1186 consumeToken(); 1187 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1188 1189 // Parse the name of the decl. 1190 const ast::Name *name = nullptr; 1191 if (curToken.isNot(Token::identifier)) { 1192 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1193 // in C++, so being unnamed is fine. 1194 if (!isInline) 1195 return emitError("expected identifier name"); 1196 1197 // Create a unique anonymous name to use, as the name for this decl is not 1198 // important. 1199 std::string anonName = 1200 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1201 anonymousDeclNameCounter++) 1202 .str(); 1203 name = &ast::Name::create(ctx, anonName, loc); 1204 } else { 1205 // If a name was provided, we can use it directly. 1206 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1207 consumeToken(Token::identifier); 1208 } 1209 1210 // Parse the functional signature of the decl. 1211 SmallVector<ast::VariableDecl *> arguments, results; 1212 ast::DeclScope *argumentScope; 1213 ast::Type resultType; 1214 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1215 argumentScope, resultType))) 1216 return failure(); 1217 1218 // Check to see which type of constraint this is. If the constraint contains a 1219 // compound body, this is a PDLL decl. 1220 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1221 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1222 resultType); 1223 1224 // Otherwise, this is a native decl. 1225 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1226 results, resultType); 1227 } 1228 1229 template <typename T> 1230 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1231 const ast::Name &name, bool isInline, 1232 ArrayRef<ast::VariableDecl *> arguments, 1233 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1234 // If followed by a string, the native code body has also been specified. 1235 std::string codeStrStorage; 1236 Optional<StringRef> optCodeStr; 1237 if (curToken.isString()) { 1238 codeStrStorage = curToken.getStringValue(); 1239 optCodeStr = codeStrStorage; 1240 consumeToken(); 1241 } else if (isInline) { 1242 return emitError(name.getLoc(), 1243 "external declarations must be declared in global scope"); 1244 } 1245 if (failed(parseToken(Token::semicolon, 1246 "expected `;` after native declaration"))) 1247 return failure(); 1248 // TODO: PDL should be able to support constraint results in certain 1249 // situations, we should revise this. 1250 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1251 return emitError( 1252 "native Constraints currently do not support returning results"); 1253 } 1254 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1255 } 1256 1257 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1258 SmallVectorImpl<ast::VariableDecl *> &arguments, 1259 SmallVectorImpl<ast::VariableDecl *> &results, 1260 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1261 // Parse the argument list of the decl. 1262 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1263 return failure(); 1264 1265 argumentScope = pushDeclScope(); 1266 if (curToken.isNot(Token::r_paren)) { 1267 do { 1268 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1269 if (failed(argument)) 1270 return failure(); 1271 arguments.emplace_back(*argument); 1272 } while (consumeIf(Token::comma)); 1273 } 1274 popDeclScope(); 1275 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1276 return failure(); 1277 1278 // Parse the results of the decl. 1279 pushDeclScope(); 1280 if (consumeIf(Token::arrow)) { 1281 auto parseResultFn = [&]() -> LogicalResult { 1282 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1283 if (failed(result)) 1284 return failure(); 1285 results.emplace_back(*result); 1286 return success(); 1287 }; 1288 1289 // Check for a list of results. 1290 if (consumeIf(Token::l_paren)) { 1291 do { 1292 if (failed(parseResultFn())) 1293 return failure(); 1294 } while (consumeIf(Token::comma)); 1295 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1296 return failure(); 1297 1298 // Otherwise, there is only one result. 1299 } else if (failed(parseResultFn())) { 1300 return failure(); 1301 } 1302 } 1303 popDeclScope(); 1304 1305 // Compute the result type of the decl. 1306 resultType = createUserConstraintRewriteResultType(results); 1307 1308 // Verify that results are only named if there are more than one. 1309 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1310 return emitError( 1311 results.front()->getLoc(), 1312 "cannot create a single-element tuple with an element label"); 1313 } 1314 return success(); 1315 } 1316 1317 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1318 StringRef declType, ast::CompoundStmt *body, 1319 ArrayRef<ast::Stmt *>::iterator bodyIt, 1320 ArrayRef<ast::Stmt *>::iterator bodyE, 1321 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1322 // Handle if a `return` was provided. 1323 if (bodyIt != bodyE) { 1324 // Emit an error if we have trailing statements after the return. 1325 if (std::next(bodyIt) != bodyE) { 1326 return emitError( 1327 (*std::next(bodyIt))->getLoc(), 1328 llvm::formatv("`return` terminated the `{0}` body, but found " 1329 "trailing statements afterwards", 1330 declType)); 1331 } 1332 1333 // Otherwise if a return wasn't provided, check that no results are 1334 // expected. 1335 } else if (!results.empty()) { 1336 return emitError( 1337 {body->getLoc().End, body->getLoc().End}, 1338 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1339 declType, resultType)); 1340 } 1341 return success(); 1342 } 1343 1344 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1345 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1346 if (isa<ast::OpRewriteStmt>(statement)) 1347 return success(); 1348 return emitError( 1349 statement->getLoc(), 1350 "expected Pattern lambda body to contain a single operation " 1351 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1352 }); 1353 } 1354 1355 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1356 SMRange loc = curToken.getLoc(); 1357 consumeToken(Token::kw_Pattern); 1358 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1359 ParserContext::PatternMatch); 1360 1361 // Check for an optional identifier for the pattern name. 1362 const ast::Name *name = nullptr; 1363 if (curToken.is(Token::identifier)) { 1364 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1365 consumeToken(Token::identifier); 1366 } 1367 1368 // Parse any pattern metadata. 1369 ParsedPatternMetadata metadata; 1370 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1371 return failure(); 1372 1373 // Parse the pattern body. 1374 ast::CompoundStmt *body; 1375 1376 // Handle a lambda body. 1377 if (curToken.is(Token::equal_arrow)) { 1378 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1379 if (failed(bodyResult)) 1380 return failure(); 1381 body = *bodyResult; 1382 } else { 1383 if (curToken.isNot(Token::l_brace)) 1384 return emitError("expected `{` or `=>` to start pattern body"); 1385 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1386 if (failed(bodyResult)) 1387 return failure(); 1388 body = *bodyResult; 1389 1390 // Verify the body of the pattern. 1391 auto bodyIt = body->begin(), bodyE = body->end(); 1392 for (; bodyIt != bodyE; ++bodyIt) { 1393 if (isa<ast::ReturnStmt>(*bodyIt)) { 1394 return emitError((*bodyIt)->getLoc(), 1395 "`return` statements are only permitted within a " 1396 "`Constraint` or `Rewrite` body"); 1397 } 1398 // Break when we've found the rewrite statement. 1399 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1400 break; 1401 } 1402 if (bodyIt == bodyE) { 1403 return emitError(loc, 1404 "expected Pattern body to terminate with an operation " 1405 "rewrite statement, such as `erase`"); 1406 } 1407 if (std::next(bodyIt) != bodyE) { 1408 return emitError((*std::next(bodyIt))->getLoc(), 1409 "Pattern body was terminated by an operation " 1410 "rewrite statement, but found trailing statements"); 1411 } 1412 } 1413 1414 return createPatternDecl(loc, name, metadata, body); 1415 } 1416 1417 LogicalResult 1418 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1419 Optional<SMRange> benefitLoc; 1420 Optional<SMRange> hasBoundedRecursionLoc; 1421 1422 do { 1423 // Handle metadata code completion. 1424 if (curToken.is(Token::code_complete)) 1425 return codeCompletePatternMetadata(); 1426 1427 if (curToken.isNot(Token::identifier)) 1428 return emitError("expected pattern metadata identifier"); 1429 StringRef metadataStr = curToken.getSpelling(); 1430 SMRange metadataLoc = curToken.getLoc(); 1431 consumeToken(Token::identifier); 1432 1433 // Parse the benefit metadata: benefit(<integer-value>) 1434 if (metadataStr == "benefit") { 1435 if (benefitLoc) { 1436 return emitErrorAndNote(metadataLoc, 1437 "pattern benefit has already been specified", 1438 *benefitLoc, "see previous definition here"); 1439 } 1440 if (failed(parseToken(Token::l_paren, 1441 "expected `(` before pattern benefit"))) 1442 return failure(); 1443 1444 uint16_t benefitValue = 0; 1445 if (curToken.isNot(Token::integer)) 1446 return emitError("expected integral pattern benefit"); 1447 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1448 return emitError( 1449 "expected pattern benefit to fit within a 16-bit integer"); 1450 consumeToken(Token::integer); 1451 1452 metadata.benefit = benefitValue; 1453 benefitLoc = metadataLoc; 1454 1455 if (failed( 1456 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1457 return failure(); 1458 continue; 1459 } 1460 1461 // Parse the bounded recursion metadata: recursion 1462 if (metadataStr == "recursion") { 1463 if (hasBoundedRecursionLoc) { 1464 return emitErrorAndNote( 1465 metadataLoc, 1466 "pattern recursion metadata has already been specified", 1467 *hasBoundedRecursionLoc, "see previous definition here"); 1468 } 1469 metadata.hasBoundedRecursion = true; 1470 hasBoundedRecursionLoc = metadataLoc; 1471 continue; 1472 } 1473 1474 return emitError(metadataLoc, "unknown pattern metadata"); 1475 } while (consumeIf(Token::comma)); 1476 1477 return success(); 1478 } 1479 1480 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1481 consumeToken(Token::less); 1482 1483 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1484 if (failed(typeExpr) || 1485 failed(parseToken(Token::greater, 1486 "expected `>` after variable type constraint"))) 1487 return failure(); 1488 return typeExpr; 1489 } 1490 1491 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1492 assert(curDeclScope && "defining decl outside of a decl scope"); 1493 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1494 return emitErrorAndNote( 1495 name.getLoc(), "`" + name.getName() + "` has already been defined", 1496 lastDecl->getName()->getLoc(), "see previous definition here"); 1497 } 1498 return success(); 1499 } 1500 1501 FailureOr<ast::VariableDecl *> 1502 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1503 ast::Expr *initExpr, 1504 ArrayRef<ast::ConstraintRef> constraints) { 1505 assert(curDeclScope && "defining variable outside of decl scope"); 1506 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1507 1508 // If the name of the variable indicates a special variable, we don't add it 1509 // to the scope. This variable is local to the definition point. 1510 if (name.empty() || name == "_") { 1511 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1512 constraints); 1513 } 1514 if (failed(checkDefineNamedDecl(nameDecl))) 1515 return failure(); 1516 1517 auto *varDecl = 1518 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1519 curDeclScope->add(varDecl); 1520 return varDecl; 1521 } 1522 1523 FailureOr<ast::VariableDecl *> 1524 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1525 ArrayRef<ast::ConstraintRef> constraints) { 1526 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1527 constraints); 1528 } 1529 1530 LogicalResult Parser::parseVariableDeclConstraintList( 1531 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1532 Optional<SMRange> typeConstraint; 1533 auto parseSingleConstraint = [&] { 1534 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1535 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1536 /*allowNonCoreConstraints=*/true); 1537 if (failed(constraint)) 1538 return failure(); 1539 constraints.push_back(*constraint); 1540 return success(); 1541 }; 1542 1543 // Check to see if this is a single constraint, or a list. 1544 if (!consumeIf(Token::l_square)) 1545 return parseSingleConstraint(); 1546 1547 do { 1548 if (failed(parseSingleConstraint())) 1549 return failure(); 1550 } while (consumeIf(Token::comma)); 1551 return parseToken(Token::r_square, "expected `]` after constraint list"); 1552 } 1553 1554 FailureOr<ast::ConstraintRef> 1555 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1556 ArrayRef<ast::ConstraintRef> existingConstraints, 1557 bool allowInlineTypeConstraints, 1558 bool allowNonCoreConstraints) { 1559 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1560 if (!allowInlineTypeConstraints) { 1561 return emitError( 1562 curToken.getLoc(), 1563 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1564 "permitted on arguments or results"); 1565 } 1566 if (typeConstraint) 1567 return emitErrorAndNote( 1568 curToken.getLoc(), 1569 "the type of this variable has already been constrained", 1570 *typeConstraint, "see previous constraint location here"); 1571 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1572 if (failed(constraintExpr)) 1573 return failure(); 1574 typeExpr = *constraintExpr; 1575 typeConstraint = typeExpr->getLoc(); 1576 return success(); 1577 }; 1578 1579 SMRange loc = curToken.getLoc(); 1580 switch (curToken.getKind()) { 1581 case Token::kw_Attr: { 1582 consumeToken(Token::kw_Attr); 1583 1584 // Check for a type constraint. 1585 ast::Expr *typeExpr = nullptr; 1586 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1587 return failure(); 1588 return ast::ConstraintRef( 1589 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1590 } 1591 case Token::kw_Op: { 1592 consumeToken(Token::kw_Op); 1593 1594 // Parse an optional operation name. If the name isn't provided, this refers 1595 // to "any" operation. 1596 FailureOr<ast::OpNameDecl *> opName = 1597 parseWrappedOperationName(/*allowEmptyName=*/true); 1598 if (failed(opName)) 1599 return failure(); 1600 1601 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1602 loc); 1603 } 1604 case Token::kw_Type: 1605 consumeToken(Token::kw_Type); 1606 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1607 case Token::kw_TypeRange: 1608 consumeToken(Token::kw_TypeRange); 1609 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1610 loc); 1611 case Token::kw_Value: { 1612 consumeToken(Token::kw_Value); 1613 1614 // Check for a type constraint. 1615 ast::Expr *typeExpr = nullptr; 1616 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1617 return failure(); 1618 1619 return ast::ConstraintRef( 1620 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1621 } 1622 case Token::kw_ValueRange: { 1623 consumeToken(Token::kw_ValueRange); 1624 1625 // Check for a type constraint. 1626 ast::Expr *typeExpr = nullptr; 1627 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1628 return failure(); 1629 1630 return ast::ConstraintRef( 1631 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1632 } 1633 1634 case Token::kw_Constraint: { 1635 // Handle an inline constraint. 1636 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1637 if (failed(decl)) 1638 return failure(); 1639 return ast::ConstraintRef(*decl, loc); 1640 } 1641 case Token::identifier: { 1642 StringRef constraintName = curToken.getSpelling(); 1643 consumeToken(Token::identifier); 1644 1645 // Lookup the referenced constraint. 1646 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1647 if (!cstDecl) { 1648 return emitError(loc, "unknown reference to constraint `" + 1649 constraintName + "`"); 1650 } 1651 1652 // Handle a reference to a proper constraint. 1653 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1654 return ast::ConstraintRef(cst, loc); 1655 1656 return emitErrorAndNote( 1657 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1658 "see the definition of `" + constraintName + "` here"); 1659 } 1660 // Handle single entity constraint code completion. 1661 case Token::code_complete: { 1662 // Try to infer the current type for use by code completion. 1663 ast::Type inferredType; 1664 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1665 allowNonCoreConstraints))) 1666 return failure(); 1667 1668 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1669 allowInlineTypeConstraints); 1670 } 1671 default: 1672 break; 1673 } 1674 return emitError(loc, "expected identifier constraint"); 1675 } 1676 1677 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1678 // Constraint arguments may apply more complex constraints via the arguments. 1679 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1680 1681 Optional<SMRange> typeConstraint; 1682 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1683 /*allowInlineTypeConstraints=*/false, 1684 allowNonCoreConstraints); 1685 } 1686 1687 //===----------------------------------------------------------------------===// 1688 // Exprs 1689 1690 FailureOr<ast::Expr *> Parser::parseExpr() { 1691 if (curToken.is(Token::underscore)) 1692 return parseUnderscoreExpr(); 1693 1694 // Parse the LHS expression. 1695 FailureOr<ast::Expr *> lhsExpr; 1696 switch (curToken.getKind()) { 1697 case Token::kw_attr: 1698 lhsExpr = parseAttributeExpr(); 1699 break; 1700 case Token::kw_Constraint: 1701 lhsExpr = parseInlineConstraintLambdaExpr(); 1702 break; 1703 case Token::identifier: 1704 lhsExpr = parseIdentifierExpr(); 1705 break; 1706 case Token::kw_op: 1707 lhsExpr = parseOperationExpr(); 1708 break; 1709 case Token::kw_Rewrite: 1710 lhsExpr = parseInlineRewriteLambdaExpr(); 1711 break; 1712 case Token::kw_type: 1713 lhsExpr = parseTypeExpr(); 1714 break; 1715 case Token::l_paren: 1716 lhsExpr = parseTupleExpr(); 1717 break; 1718 default: 1719 return emitError("expected expression"); 1720 } 1721 if (failed(lhsExpr)) 1722 return failure(); 1723 1724 // Check for an operator expression. 1725 while (true) { 1726 switch (curToken.getKind()) { 1727 case Token::dot: 1728 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1729 break; 1730 case Token::l_paren: 1731 lhsExpr = parseCallExpr(*lhsExpr); 1732 break; 1733 default: 1734 return lhsExpr; 1735 } 1736 if (failed(lhsExpr)) 1737 return failure(); 1738 } 1739 } 1740 1741 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1742 SMRange loc = curToken.getLoc(); 1743 consumeToken(Token::kw_attr); 1744 1745 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1746 // identifier. 1747 if (!consumeIf(Token::less)) { 1748 resetToken(loc); 1749 return parseIdentifierExpr(); 1750 } 1751 1752 if (!curToken.isString()) 1753 return emitError("expected string literal containing MLIR attribute"); 1754 std::string attrExpr = curToken.getStringValue(); 1755 consumeToken(); 1756 1757 if (failed( 1758 parseToken(Token::greater, "expected `>` after attribute literal"))) 1759 return failure(); 1760 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1761 } 1762 1763 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1764 SMRange loc = curToken.getLoc(); 1765 consumeToken(Token::l_paren); 1766 1767 // Parse the arguments of the call. 1768 SmallVector<ast::Expr *> arguments; 1769 if (curToken.isNot(Token::r_paren)) { 1770 do { 1771 // Handle code completion for the call arguments. 1772 if (curToken.is(Token::code_complete)) { 1773 codeCompleteCallSignature(parentExpr, arguments.size()); 1774 return failure(); 1775 } 1776 1777 FailureOr<ast::Expr *> argument = parseExpr(); 1778 if (failed(argument)) 1779 return failure(); 1780 arguments.push_back(*argument); 1781 } while (consumeIf(Token::comma)); 1782 } 1783 loc.End = curToken.getEndLoc(); 1784 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1785 return failure(); 1786 1787 return createCallExpr(loc, parentExpr, arguments); 1788 } 1789 1790 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1791 ast::Decl *decl = curDeclScope->lookup(name); 1792 if (!decl) 1793 return emitError(loc, "undefined reference to `" + name + "`"); 1794 1795 return createDeclRefExpr(loc, decl); 1796 } 1797 1798 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1799 StringRef name = curToken.getSpelling(); 1800 SMRange nameLoc = curToken.getLoc(); 1801 consumeToken(); 1802 1803 // Check to see if this is a decl ref expression that defines a variable 1804 // inline. 1805 if (consumeIf(Token::colon)) { 1806 SmallVector<ast::ConstraintRef> constraints; 1807 if (failed(parseVariableDeclConstraintList(constraints))) 1808 return failure(); 1809 ast::Type type; 1810 if (failed(validateVariableConstraints(constraints, type))) 1811 return failure(); 1812 return createInlineVariableExpr(type, name, nameLoc, constraints); 1813 } 1814 1815 return parseDeclRefExpr(name, nameLoc); 1816 } 1817 1818 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1819 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1820 if (failed(decl)) 1821 return failure(); 1822 1823 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1824 ast::ConstraintType::get(ctx)); 1825 } 1826 1827 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1828 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1829 if (failed(decl)) 1830 return failure(); 1831 1832 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1833 ast::RewriteType::get(ctx)); 1834 } 1835 1836 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1837 SMRange loc = curToken.getLoc(); 1838 consumeToken(Token::dot); 1839 1840 // Check for code completion of the member name. 1841 if (curToken.is(Token::code_complete)) 1842 return codeCompleteMemberAccess(parentExpr); 1843 1844 // Parse the member name. 1845 Token memberNameTok = curToken; 1846 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1847 !memberNameTok.isKeyword()) 1848 return emitError(loc, "expected identifier or numeric member name"); 1849 StringRef memberName = memberNameTok.getSpelling(); 1850 consumeToken(); 1851 1852 return createMemberAccessExpr(parentExpr, memberName, loc); 1853 } 1854 1855 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1856 SMRange loc = curToken.getLoc(); 1857 1858 // Check for code completion for the dialect name. 1859 if (curToken.is(Token::code_complete)) 1860 return codeCompleteDialectName(); 1861 1862 // Handle the case of an no operation name. 1863 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1864 if (allowEmptyName) 1865 return ast::OpNameDecl::create(ctx, SMRange()); 1866 return emitError("expected dialect namespace"); 1867 } 1868 StringRef name = curToken.getSpelling(); 1869 consumeToken(); 1870 1871 // Otherwise, this is a literal operation name. 1872 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1873 return failure(); 1874 1875 // Check for code completion for the operation name. 1876 if (curToken.is(Token::code_complete)) 1877 return codeCompleteOperationName(name); 1878 1879 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1880 return emitError("expected operation name after dialect namespace"); 1881 1882 name = StringRef(name.data(), name.size() + 1); 1883 do { 1884 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1885 loc.End = curToken.getEndLoc(); 1886 consumeToken(); 1887 } while (curToken.isAny(Token::identifier, Token::dot) || 1888 curToken.isKeyword()); 1889 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1890 } 1891 1892 FailureOr<ast::OpNameDecl *> 1893 Parser::parseWrappedOperationName(bool allowEmptyName) { 1894 if (!consumeIf(Token::less)) 1895 return ast::OpNameDecl::create(ctx, SMRange()); 1896 1897 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1898 if (failed(opNameDecl)) 1899 return failure(); 1900 1901 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1902 return failure(); 1903 return opNameDecl; 1904 } 1905 1906 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1907 SMRange loc = curToken.getLoc(); 1908 consumeToken(Token::kw_op); 1909 1910 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1911 // identifier. 1912 if (curToken.isNot(Token::less)) { 1913 resetToken(loc); 1914 return parseIdentifierExpr(); 1915 } 1916 1917 // Parse the operation name. The name may be elided, in which case the 1918 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1919 // `Operation*`). Operation names within a rewrite context must be named. 1920 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1921 FailureOr<ast::OpNameDecl *> opNameDecl = 1922 parseWrappedOperationName(allowEmptyName); 1923 if (failed(opNameDecl)) 1924 return failure(); 1925 Optional<StringRef> opName = (*opNameDecl)->getName(); 1926 1927 // Functor used to create an implicit range variable, used for implicit "all" 1928 // operand or results variables. 1929 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1930 FailureOr<ast::VariableDecl *> rangeVar = 1931 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1932 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1933 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1934 }; 1935 1936 // Check for the optional list of operands. 1937 SmallVector<ast::Expr *> operands; 1938 if (!consumeIf(Token::l_paren)) { 1939 // If the operand list isn't specified and we are in a match context, define 1940 // an inplace unconstrained operand range corresponding to all of the 1941 // operands of the operation. This avoids treating zero operands the same 1942 // way as "unconstrained operands". 1943 if (parserContext != ParserContext::Rewrite) { 1944 operands.push_back(createImplicitRangeVar( 1945 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1946 } 1947 } else if (!consumeIf(Token::r_paren)) { 1948 // Check for operand signature code completion. 1949 if (curToken.is(Token::code_complete)) { 1950 codeCompleteOperationOperandsSignature(opName, operands.size()); 1951 return failure(); 1952 } 1953 1954 // If the operand list was specified and non-empty, parse the operands. 1955 do { 1956 FailureOr<ast::Expr *> operand = parseExpr(); 1957 if (failed(operand)) 1958 return failure(); 1959 operands.push_back(*operand); 1960 } while (consumeIf(Token::comma)); 1961 1962 if (failed(parseToken(Token::r_paren, 1963 "expected `)` after operation operand list"))) 1964 return failure(); 1965 } 1966 1967 // Check for the optional list of attributes. 1968 SmallVector<ast::NamedAttributeDecl *> attributes; 1969 if (consumeIf(Token::l_brace)) { 1970 do { 1971 FailureOr<ast::NamedAttributeDecl *> decl = 1972 parseNamedAttributeDecl(opName); 1973 if (failed(decl)) 1974 return failure(); 1975 attributes.emplace_back(*decl); 1976 } while (consumeIf(Token::comma)); 1977 1978 if (failed(parseToken(Token::r_brace, 1979 "expected `}` after operation attribute list"))) 1980 return failure(); 1981 } 1982 1983 // Check for the optional list of result types. 1984 SmallVector<ast::Expr *> resultTypes; 1985 if (consumeIf(Token::arrow)) { 1986 if (failed(parseToken(Token::l_paren, 1987 "expected `(` before operation result type list"))) 1988 return failure(); 1989 1990 // Handle the case of an empty result list. 1991 if (!consumeIf(Token::r_paren)) { 1992 do { 1993 // Check for result signature code completion. 1994 if (curToken.is(Token::code_complete)) { 1995 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 1996 return failure(); 1997 } 1998 1999 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2000 if (failed(resultTypeExpr)) 2001 return failure(); 2002 resultTypes.push_back(*resultTypeExpr); 2003 } while (consumeIf(Token::comma)); 2004 2005 if (failed(parseToken(Token::r_paren, 2006 "expected `)` after operation result type list"))) 2007 return failure(); 2008 } 2009 } else if (parserContext != ParserContext::Rewrite) { 2010 // If the result list isn't specified and we are in a match context, define 2011 // an inplace unconstrained result range corresponding to all of the results 2012 // of the operation. This avoids treating zero results the same way as 2013 // "unconstrained results". 2014 resultTypes.push_back(createImplicitRangeVar( 2015 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2016 } 2017 2018 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2019 resultTypes); 2020 } 2021 2022 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2023 SMRange loc = curToken.getLoc(); 2024 consumeToken(Token::l_paren); 2025 2026 DenseMap<StringRef, SMRange> usedNames; 2027 SmallVector<StringRef> elementNames; 2028 SmallVector<ast::Expr *> elements; 2029 if (curToken.isNot(Token::r_paren)) { 2030 do { 2031 // Check for the optional element name assignment before the value. 2032 StringRef elementName; 2033 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2034 Token elementNameTok = curToken; 2035 consumeToken(); 2036 2037 // The element name is only present if followed by an `=`. 2038 if (consumeIf(Token::equal)) { 2039 elementName = elementNameTok.getSpelling(); 2040 2041 // Check to see if this name is already used. 2042 auto elementNameIt = 2043 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2044 if (!elementNameIt.second) { 2045 return emitErrorAndNote( 2046 elementNameTok.getLoc(), 2047 llvm::formatv("duplicate tuple element label `{0}`", 2048 elementName), 2049 elementNameIt.first->getSecond(), 2050 "see previous label use here"); 2051 } 2052 } else { 2053 // Otherwise, we treat this as part of an expression so reset the 2054 // lexer. 2055 resetToken(elementNameTok.getLoc()); 2056 } 2057 } 2058 elementNames.push_back(elementName); 2059 2060 // Parse the tuple element value. 2061 FailureOr<ast::Expr *> element = parseExpr(); 2062 if (failed(element)) 2063 return failure(); 2064 elements.push_back(*element); 2065 } while (consumeIf(Token::comma)); 2066 } 2067 loc.End = curToken.getEndLoc(); 2068 if (failed( 2069 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2070 return failure(); 2071 return createTupleExpr(loc, elements, elementNames); 2072 } 2073 2074 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2075 SMRange loc = curToken.getLoc(); 2076 consumeToken(Token::kw_type); 2077 2078 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2079 // identifier. 2080 if (!consumeIf(Token::less)) { 2081 resetToken(loc); 2082 return parseIdentifierExpr(); 2083 } 2084 2085 if (!curToken.isString()) 2086 return emitError("expected string literal containing MLIR type"); 2087 std::string attrExpr = curToken.getStringValue(); 2088 consumeToken(); 2089 2090 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2091 return failure(); 2092 return ast::TypeExpr::create(ctx, loc, attrExpr); 2093 } 2094 2095 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2096 StringRef name = curToken.getSpelling(); 2097 SMRange nameLoc = curToken.getLoc(); 2098 consumeToken(Token::underscore); 2099 2100 // Underscore expressions require a constraint list. 2101 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2102 return failure(); 2103 2104 // Parse the constraints for the expression. 2105 SmallVector<ast::ConstraintRef> constraints; 2106 if (failed(parseVariableDeclConstraintList(constraints))) 2107 return failure(); 2108 2109 ast::Type type; 2110 if (failed(validateVariableConstraints(constraints, type))) 2111 return failure(); 2112 return createInlineVariableExpr(type, name, nameLoc, constraints); 2113 } 2114 2115 //===----------------------------------------------------------------------===// 2116 // Stmts 2117 2118 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2119 FailureOr<ast::Stmt *> stmt; 2120 switch (curToken.getKind()) { 2121 case Token::kw_erase: 2122 stmt = parseEraseStmt(); 2123 break; 2124 case Token::kw_let: 2125 stmt = parseLetStmt(); 2126 break; 2127 case Token::kw_replace: 2128 stmt = parseReplaceStmt(); 2129 break; 2130 case Token::kw_return: 2131 stmt = parseReturnStmt(); 2132 break; 2133 case Token::kw_rewrite: 2134 stmt = parseRewriteStmt(); 2135 break; 2136 default: 2137 stmt = parseExpr(); 2138 break; 2139 } 2140 if (failed(stmt) || 2141 (expectTerminalSemicolon && 2142 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2143 return failure(); 2144 return stmt; 2145 } 2146 2147 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2148 SMLoc startLoc = curToken.getStartLoc(); 2149 consumeToken(Token::l_brace); 2150 2151 // Push a new block scope and parse any nested statements. 2152 pushDeclScope(); 2153 SmallVector<ast::Stmt *> statements; 2154 while (curToken.isNot(Token::r_brace)) { 2155 FailureOr<ast::Stmt *> statement = parseStmt(); 2156 if (failed(statement)) 2157 return popDeclScope(), failure(); 2158 statements.push_back(*statement); 2159 } 2160 popDeclScope(); 2161 2162 // Consume the end brace. 2163 SMRange location(startLoc, curToken.getEndLoc()); 2164 consumeToken(Token::r_brace); 2165 2166 return ast::CompoundStmt::create(ctx, location, statements); 2167 } 2168 2169 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2170 if (parserContext == ParserContext::Constraint) 2171 return emitError("`erase` cannot be used within a Constraint"); 2172 SMRange loc = curToken.getLoc(); 2173 consumeToken(Token::kw_erase); 2174 2175 // Parse the root operation expression. 2176 FailureOr<ast::Expr *> rootOp = parseExpr(); 2177 if (failed(rootOp)) 2178 return failure(); 2179 2180 return createEraseStmt(loc, *rootOp); 2181 } 2182 2183 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2184 SMRange loc = curToken.getLoc(); 2185 consumeToken(Token::kw_let); 2186 2187 // Parse the name of the new variable. 2188 SMRange varLoc = curToken.getLoc(); 2189 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2190 // `_` is a reserved variable name. 2191 if (curToken.is(Token::underscore)) { 2192 return emitError(varLoc, 2193 "`_` may only be used to define \"inline\" variables"); 2194 } 2195 return emitError(varLoc, 2196 "expected identifier after `let` to name a new variable"); 2197 } 2198 StringRef varName = curToken.getSpelling(); 2199 consumeToken(); 2200 2201 // Parse the optional set of constraints. 2202 SmallVector<ast::ConstraintRef> constraints; 2203 if (consumeIf(Token::colon) && 2204 failed(parseVariableDeclConstraintList(constraints))) 2205 return failure(); 2206 2207 // Parse the optional initializer expression. 2208 ast::Expr *initializer = nullptr; 2209 if (consumeIf(Token::equal)) { 2210 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2211 if (failed(initOrFailure)) 2212 return failure(); 2213 initializer = *initOrFailure; 2214 2215 // Check that the constraints are compatible with having an initializer, 2216 // e.g. type constraints cannot be used with initializers. 2217 for (ast::ConstraintRef constraint : constraints) { 2218 LogicalResult result = 2219 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2220 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2221 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2222 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2223 return this->emitError( 2224 constraint.referenceLoc, 2225 "type constraints are not permitted on variables with " 2226 "initializers"); 2227 } 2228 return success(); 2229 }) 2230 .Default(success()); 2231 if (failed(result)) 2232 return failure(); 2233 } 2234 } 2235 2236 FailureOr<ast::VariableDecl *> varDecl = 2237 createVariableDecl(varName, varLoc, initializer, constraints); 2238 if (failed(varDecl)) 2239 return failure(); 2240 return ast::LetStmt::create(ctx, loc, *varDecl); 2241 } 2242 2243 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2244 if (parserContext == ParserContext::Constraint) 2245 return emitError("`replace` cannot be used within a Constraint"); 2246 SMRange loc = curToken.getLoc(); 2247 consumeToken(Token::kw_replace); 2248 2249 // Parse the root operation expression. 2250 FailureOr<ast::Expr *> rootOp = parseExpr(); 2251 if (failed(rootOp)) 2252 return failure(); 2253 2254 if (failed( 2255 parseToken(Token::kw_with, "expected `with` after root operation"))) 2256 return failure(); 2257 2258 // The replacement portion of this statement is within a rewrite context. 2259 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2260 ParserContext::Rewrite); 2261 2262 // Parse the replacement values. 2263 SmallVector<ast::Expr *> replValues; 2264 if (consumeIf(Token::l_paren)) { 2265 if (consumeIf(Token::r_paren)) { 2266 return emitError( 2267 loc, "expected at least one replacement value, consider using " 2268 "`erase` if no replacement values are desired"); 2269 } 2270 2271 do { 2272 FailureOr<ast::Expr *> replExpr = parseExpr(); 2273 if (failed(replExpr)) 2274 return failure(); 2275 replValues.emplace_back(*replExpr); 2276 } while (consumeIf(Token::comma)); 2277 2278 if (failed(parseToken(Token::r_paren, 2279 "expected `)` after replacement values"))) 2280 return failure(); 2281 } else { 2282 FailureOr<ast::Expr *> replExpr = parseExpr(); 2283 if (failed(replExpr)) 2284 return failure(); 2285 replValues.emplace_back(*replExpr); 2286 } 2287 2288 return createReplaceStmt(loc, *rootOp, replValues); 2289 } 2290 2291 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2292 SMRange loc = curToken.getLoc(); 2293 consumeToken(Token::kw_return); 2294 2295 // Parse the result value. 2296 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2297 if (failed(resultExpr)) 2298 return failure(); 2299 2300 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2301 } 2302 2303 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2304 if (parserContext == ParserContext::Constraint) 2305 return emitError("`rewrite` cannot be used within a Constraint"); 2306 SMRange loc = curToken.getLoc(); 2307 consumeToken(Token::kw_rewrite); 2308 2309 // Parse the root operation. 2310 FailureOr<ast::Expr *> rootOp = parseExpr(); 2311 if (failed(rootOp)) 2312 return failure(); 2313 2314 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2315 return failure(); 2316 2317 if (curToken.isNot(Token::l_brace)) 2318 return emitError("expected `{` to start rewrite body"); 2319 2320 // The rewrite body of this statement is within a rewrite context. 2321 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2322 ParserContext::Rewrite); 2323 2324 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2325 if (failed(rewriteBody)) 2326 return failure(); 2327 2328 // Verify the rewrite body. 2329 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2330 if (isa<ast::ReturnStmt>(stmt)) { 2331 return emitError(stmt->getLoc(), 2332 "`return` statements are only permitted within a " 2333 "`Constraint` or `Rewrite` body"); 2334 } 2335 } 2336 2337 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2338 } 2339 2340 //===----------------------------------------------------------------------===// 2341 // Creation+Analysis 2342 //===----------------------------------------------------------------------===// 2343 2344 //===----------------------------------------------------------------------===// 2345 // Decls 2346 2347 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2348 // Unwrap reference expressions. 2349 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2350 node = init->getDecl(); 2351 return dyn_cast<ast::CallableDecl>(node); 2352 } 2353 2354 FailureOr<ast::PatternDecl *> 2355 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2356 const ParsedPatternMetadata &metadata, 2357 ast::CompoundStmt *body) { 2358 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2359 metadata.hasBoundedRecursion, body); 2360 } 2361 2362 ast::Type Parser::createUserConstraintRewriteResultType( 2363 ArrayRef<ast::VariableDecl *> results) { 2364 // Single result decls use the type of the single result. 2365 if (results.size() == 1) 2366 return results[0]->getType(); 2367 2368 // Multiple results use a tuple type, with the types and names grabbed from 2369 // the result variable decls. 2370 auto resultTypes = llvm::map_range( 2371 results, [&](const auto *result) { return result->getType(); }); 2372 auto resultNames = llvm::map_range( 2373 results, [&](const auto *result) { return result->getName().getName(); }); 2374 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2375 llvm::to_vector(resultNames)); 2376 } 2377 2378 template <typename T> 2379 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2380 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2381 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2382 ast::CompoundStmt *body) { 2383 if (!body->getChildren().empty()) { 2384 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2385 ast::Expr *resultExpr = retStmt->getResultExpr(); 2386 2387 // Process the result of the decl. If no explicit signature results 2388 // were provided, check for return type inference. Otherwise, check that 2389 // the return expression can be converted to the expected type. 2390 if (results.empty()) 2391 resultType = resultExpr->getType(); 2392 else if (failed(convertExpressionTo(resultExpr, resultType))) 2393 return failure(); 2394 else 2395 retStmt->setResultExpr(resultExpr); 2396 } 2397 } 2398 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2399 } 2400 2401 FailureOr<ast::VariableDecl *> 2402 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2403 ArrayRef<ast::ConstraintRef> constraints) { 2404 // The type of the variable, which is expected to be inferred by either a 2405 // constraint or an initializer expression. 2406 ast::Type type; 2407 if (failed(validateVariableConstraints(constraints, type))) 2408 return failure(); 2409 2410 if (initializer) { 2411 // Update the variable type based on the initializer, or try to convert the 2412 // initializer to the existing type. 2413 if (!type) 2414 type = initializer->getType(); 2415 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2416 type = mergedType; 2417 else if (failed(convertExpressionTo(initializer, type))) 2418 return failure(); 2419 2420 // Otherwise, if there is no initializer check that the type has already 2421 // been resolved from the constraint list. 2422 } else if (!type) { 2423 return emitErrorAndNote( 2424 loc, "unable to infer type for variable `" + name + "`", loc, 2425 "the type of a variable must be inferable from the constraint " 2426 "list or the initializer"); 2427 } 2428 2429 // Constraint types cannot be used when defining variables. 2430 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2431 return emitError( 2432 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2433 } 2434 2435 // Try to define a variable with the given name. 2436 FailureOr<ast::VariableDecl *> varDecl = 2437 defineVariableDecl(name, loc, type, initializer, constraints); 2438 if (failed(varDecl)) 2439 return failure(); 2440 2441 return *varDecl; 2442 } 2443 2444 FailureOr<ast::VariableDecl *> 2445 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2446 const ast::ConstraintRef &constraint) { 2447 // Constraint arguments may apply more complex constraints via the arguments. 2448 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2449 ast::Type argType; 2450 if (failed(validateVariableConstraint(constraint, argType, 2451 allowNonCoreConstraints))) 2452 return failure(); 2453 return defineVariableDecl(name, loc, argType, constraint); 2454 } 2455 2456 LogicalResult 2457 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2458 ast::Type &inferredType, 2459 bool allowNonCoreConstraints) { 2460 for (const ast::ConstraintRef &ref : constraints) 2461 if (failed(validateVariableConstraint(ref, inferredType, 2462 allowNonCoreConstraints))) 2463 return failure(); 2464 return success(); 2465 } 2466 2467 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2468 ast::Type &inferredType, 2469 bool allowNonCoreConstraints) { 2470 ast::Type constraintType; 2471 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2472 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2473 if (failed(validateTypeConstraintExpr(typeExpr))) 2474 return failure(); 2475 } 2476 constraintType = ast::AttributeType::get(ctx); 2477 } else if (const auto *cst = 2478 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2479 constraintType = ast::OperationType::get(ctx, cst->getName()); 2480 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2481 constraintType = typeTy; 2482 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2483 constraintType = typeRangeTy; 2484 } else if (const auto *cst = 2485 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2486 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2487 if (failed(validateTypeConstraintExpr(typeExpr))) 2488 return failure(); 2489 } 2490 constraintType = valueTy; 2491 } else if (const auto *cst = 2492 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2493 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2494 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2495 return failure(); 2496 } 2497 constraintType = valueRangeTy; 2498 } else if (const auto *cst = 2499 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2500 if (!allowNonCoreConstraints) { 2501 return emitError(ref.referenceLoc, 2502 "`Rewrite` arguments and results are only permitted to " 2503 "use core constraints, such as `Attr`, `Op`, `Type`, " 2504 "`TypeRange`, `Value`, `ValueRange`"); 2505 } 2506 2507 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2508 if (inputs.size() != 1) { 2509 return emitErrorAndNote(ref.referenceLoc, 2510 "`Constraint`s applied via a variable constraint " 2511 "list must take a single input, but got " + 2512 Twine(inputs.size()), 2513 cst->getLoc(), 2514 "see definition of constraint here"); 2515 } 2516 constraintType = inputs.front()->getType(); 2517 } else { 2518 llvm_unreachable("unknown constraint type"); 2519 } 2520 2521 // Check that the constraint type is compatible with the current inferred 2522 // type. 2523 if (!inferredType) { 2524 inferredType = constraintType; 2525 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2526 inferredType = mergedTy; 2527 } else { 2528 return emitError(ref.referenceLoc, 2529 llvm::formatv("constraint type `{0}` is incompatible " 2530 "with the previously inferred type `{1}`", 2531 constraintType, inferredType)); 2532 } 2533 return success(); 2534 } 2535 2536 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2537 ast::Type typeExprType = typeExpr->getType(); 2538 if (typeExprType != typeTy) { 2539 return emitError(typeExpr->getLoc(), 2540 "expected expression of `Type` in type constraint"); 2541 } 2542 return success(); 2543 } 2544 2545 LogicalResult 2546 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2547 ast::Type typeExprType = typeExpr->getType(); 2548 if (typeExprType != typeRangeTy) { 2549 return emitError(typeExpr->getLoc(), 2550 "expected expression of `TypeRange` in type constraint"); 2551 } 2552 return success(); 2553 } 2554 2555 //===----------------------------------------------------------------------===// 2556 // Exprs 2557 2558 FailureOr<ast::CallExpr *> 2559 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2560 MutableArrayRef<ast::Expr *> arguments) { 2561 ast::Type parentType = parentExpr->getType(); 2562 2563 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2564 if (!callableDecl) { 2565 return emitError(loc, 2566 llvm::formatv("expected a reference to a callable " 2567 "`Constraint` or `Rewrite`, but got: `{0}`", 2568 parentType)); 2569 } 2570 if (parserContext == ParserContext::Rewrite) { 2571 if (isa<ast::UserConstraintDecl>(callableDecl)) 2572 return emitError( 2573 loc, "unable to invoke `Constraint` within a rewrite section"); 2574 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2575 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2576 } 2577 2578 // Verify the arguments of the call. 2579 /// Handle size mismatch. 2580 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2581 if (callArgs.size() != arguments.size()) { 2582 return emitErrorAndNote( 2583 loc, 2584 llvm::formatv("invalid number of arguments for {0} call; expected " 2585 "{1}, but got {2}", 2586 callableDecl->getCallableType(), callArgs.size(), 2587 arguments.size()), 2588 callableDecl->getLoc(), 2589 llvm::formatv("see the definition of {0} here", 2590 callableDecl->getName()->getName())); 2591 } 2592 2593 /// Handle argument type mismatch. 2594 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2595 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2596 callableDecl->getName()->getName()), 2597 callableDecl->getLoc()); 2598 }; 2599 for (auto it : llvm::zip(callArgs, arguments)) { 2600 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2601 attachDiagFn))) 2602 return failure(); 2603 } 2604 2605 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2606 callableDecl->getResultType()); 2607 } 2608 2609 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2610 ast::Decl *decl) { 2611 // Check the type of decl being referenced. 2612 ast::Type declType; 2613 if (isa<ast::ConstraintDecl>(decl)) 2614 declType = ast::ConstraintType::get(ctx); 2615 else if (isa<ast::UserRewriteDecl>(decl)) 2616 declType = ast::RewriteType::get(ctx); 2617 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2618 declType = varDecl->getType(); 2619 else 2620 return emitError(loc, "invalid reference to `" + 2621 decl->getName()->getName() + "`"); 2622 2623 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2624 } 2625 2626 FailureOr<ast::DeclRefExpr *> 2627 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2628 ArrayRef<ast::ConstraintRef> constraints) { 2629 FailureOr<ast::VariableDecl *> decl = 2630 defineVariableDecl(name, loc, type, constraints); 2631 if (failed(decl)) 2632 return failure(); 2633 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2634 } 2635 2636 FailureOr<ast::MemberAccessExpr *> 2637 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2638 SMRange loc) { 2639 // Validate the member name for the given parent expression. 2640 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2641 if (failed(memberType)) 2642 return failure(); 2643 2644 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2645 } 2646 2647 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2648 StringRef name, SMRange loc) { 2649 ast::Type parentType = parentExpr->getType(); 2650 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2651 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2652 return valueRangeTy; 2653 2654 // Verify member access based on the operation type. 2655 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2656 auto results = odsOp->getResults(); 2657 2658 // Handle indexed results. 2659 unsigned index = 0; 2660 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2661 index < results.size()) { 2662 return results[index].isVariadic() ? valueRangeTy : valueTy; 2663 } 2664 2665 // Handle named results. 2666 const auto *it = llvm::find_if(results, [&](const auto &result) { 2667 return result.getName() == name; 2668 }); 2669 if (it != results.end()) 2670 return it->isVariadic() ? valueRangeTy : valueTy; 2671 } 2672 2673 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2674 // Handle indexed results. 2675 unsigned index = 0; 2676 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2677 index < tupleType.size()) { 2678 return tupleType.getElementTypes()[index]; 2679 } 2680 2681 // Handle named results. 2682 auto elementNames = tupleType.getElementNames(); 2683 const auto *it = llvm::find(elementNames, name); 2684 if (it != elementNames.end()) 2685 return tupleType.getElementTypes()[it - elementNames.begin()]; 2686 } 2687 return emitError( 2688 loc, 2689 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2690 name, parentType)); 2691 } 2692 2693 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2694 SMRange loc, const ast::OpNameDecl *name, 2695 MutableArrayRef<ast::Expr *> operands, 2696 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2697 MutableArrayRef<ast::Expr *> results) { 2698 Optional<StringRef> opNameRef = name->getName(); 2699 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2700 2701 // Verify the inputs operands. 2702 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2703 return failure(); 2704 2705 // Verify the attribute list. 2706 for (ast::NamedAttributeDecl *attr : attributes) { 2707 // Check for an attribute type, or a type awaiting resolution. 2708 ast::Type attrType = attr->getValue()->getType(); 2709 if (!attrType.isa<ast::AttributeType>()) { 2710 return emitError( 2711 attr->getValue()->getLoc(), 2712 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2713 } 2714 } 2715 2716 // Verify the result types. 2717 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2718 return failure(); 2719 2720 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2721 attributes); 2722 } 2723 2724 LogicalResult 2725 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2726 const ods::Operation *odsOp, 2727 MutableArrayRef<ast::Expr *> operands) { 2728 return validateOperationOperandsOrResults( 2729 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2730 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2731 valueRangeTy); 2732 } 2733 2734 LogicalResult 2735 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2736 const ods::Operation *odsOp, 2737 MutableArrayRef<ast::Expr *> results) { 2738 return validateOperationOperandsOrResults( 2739 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2740 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2741 } 2742 2743 LogicalResult Parser::validateOperationOperandsOrResults( 2744 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2745 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2746 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2747 ast::Type rangeTy) { 2748 // All operation types accept a single range parameter. 2749 if (values.size() == 1) { 2750 if (failed(convertExpressionTo(values[0], rangeTy))) 2751 return failure(); 2752 return success(); 2753 } 2754 2755 /// If the operation has ODS information, we can more accurately verify the 2756 /// values. 2757 if (odsOpLoc) { 2758 if (odsValues.size() != values.size()) { 2759 return emitErrorAndNote( 2760 loc, 2761 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2762 "{2}, but got {3}", 2763 groupName, *name, odsValues.size(), values.size()), 2764 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2765 } 2766 auto diagFn = [&](ast::Diagnostic &diag) { 2767 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2768 *odsOpLoc); 2769 }; 2770 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2771 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2772 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2773 return failure(); 2774 } 2775 return success(); 2776 } 2777 2778 // Otherwise, accept the value groups as they have been defined and just 2779 // ensure they are one of the expected types. 2780 for (ast::Expr *&valueExpr : values) { 2781 ast::Type valueExprType = valueExpr->getType(); 2782 2783 // Check if this is one of the expected types. 2784 if (valueExprType == rangeTy || valueExprType == singleTy) 2785 continue; 2786 2787 // If the operand is an Operation, allow converting to a Value or 2788 // ValueRange. This situations arises quite often with nested operation 2789 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2790 if (singleTy == valueTy) { 2791 if (valueExprType.isa<ast::OperationType>()) { 2792 valueExpr = convertOpToValue(valueExpr); 2793 continue; 2794 } 2795 } 2796 2797 return emitError( 2798 valueExpr->getLoc(), 2799 llvm::formatv( 2800 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2801 singleTy, rangeTy, valueExprType)); 2802 } 2803 return success(); 2804 } 2805 2806 FailureOr<ast::TupleExpr *> 2807 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2808 ArrayRef<StringRef> elementNames) { 2809 for (const ast::Expr *element : elements) { 2810 ast::Type eleTy = element->getType(); 2811 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2812 return emitError( 2813 element->getLoc(), 2814 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2815 } 2816 } 2817 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2818 } 2819 2820 //===----------------------------------------------------------------------===// 2821 // Stmts 2822 2823 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2824 ast::Expr *rootOp) { 2825 // Check that root is an Operation. 2826 ast::Type rootType = rootOp->getType(); 2827 if (!rootType.isa<ast::OperationType>()) 2828 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2829 2830 return ast::EraseStmt::create(ctx, loc, rootOp); 2831 } 2832 2833 FailureOr<ast::ReplaceStmt *> 2834 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2835 MutableArrayRef<ast::Expr *> replValues) { 2836 // Check that root is an Operation. 2837 ast::Type rootType = rootOp->getType(); 2838 if (!rootType.isa<ast::OperationType>()) { 2839 return emitError( 2840 rootOp->getLoc(), 2841 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2842 } 2843 2844 // If there are multiple replacement values, we implicitly convert any Op 2845 // expressions to the value form. 2846 bool shouldConvertOpToValues = replValues.size() > 1; 2847 for (ast::Expr *&replExpr : replValues) { 2848 ast::Type replType = replExpr->getType(); 2849 2850 // Check that replExpr is an Operation, Value, or ValueRange. 2851 if (replType.isa<ast::OperationType>()) { 2852 if (shouldConvertOpToValues) 2853 replExpr = convertOpToValue(replExpr); 2854 continue; 2855 } 2856 2857 if (replType != valueTy && replType != valueRangeTy) { 2858 return emitError(replExpr->getLoc(), 2859 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2860 "expression, but got `{0}`", 2861 replType)); 2862 } 2863 } 2864 2865 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2866 } 2867 2868 FailureOr<ast::RewriteStmt *> 2869 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2870 ast::CompoundStmt *rewriteBody) { 2871 // Check that root is an Operation. 2872 ast::Type rootType = rootOp->getType(); 2873 if (!rootType.isa<ast::OperationType>()) { 2874 return emitError( 2875 rootOp->getLoc(), 2876 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2877 } 2878 2879 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2880 } 2881 2882 //===----------------------------------------------------------------------===// 2883 // Code Completion 2884 //===----------------------------------------------------------------------===// 2885 2886 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2887 ast::Type parentType = parentExpr->getType(); 2888 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2889 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2890 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2891 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2892 return failure(); 2893 } 2894 2895 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2896 if (opName) 2897 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2898 return failure(); 2899 } 2900 2901 LogicalResult 2902 Parser::codeCompleteConstraintName(ast::Type inferredType, 2903 bool allowNonCoreConstraints, 2904 bool allowInlineTypeConstraints) { 2905 codeCompleteContext->codeCompleteConstraintName( 2906 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2907 curDeclScope); 2908 return failure(); 2909 } 2910 2911 LogicalResult Parser::codeCompleteDialectName() { 2912 codeCompleteContext->codeCompleteDialectName(); 2913 return failure(); 2914 } 2915 2916 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2917 codeCompleteContext->codeCompleteOperationName(dialectName); 2918 return failure(); 2919 } 2920 2921 LogicalResult Parser::codeCompletePatternMetadata() { 2922 codeCompleteContext->codeCompletePatternMetadata(); 2923 return failure(); 2924 } 2925 2926 void Parser::codeCompleteCallSignature(ast::Node *parent, 2927 unsigned currentNumArgs) { 2928 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2929 if (!callableDecl) 2930 return; 2931 2932 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2933 } 2934 2935 void Parser::codeCompleteOperationOperandsSignature( 2936 Optional<StringRef> opName, unsigned currentNumOperands) { 2937 codeCompleteContext->codeCompleteOperationOperandsSignature( 2938 opName, currentNumOperands); 2939 } 2940 2941 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2942 unsigned currentNumResults) { 2943 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2944 currentNumResults); 2945 } 2946 2947 //===----------------------------------------------------------------------===// 2948 // Parser 2949 //===----------------------------------------------------------------------===// 2950 2951 FailureOr<ast::Module *> 2952 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2953 CodeCompleteContext *codeCompleteContext) { 2954 Parser parser(ctx, sourceMgr, codeCompleteContext); 2955 return parser.parseModule(); 2956 } 2957