1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 428 429 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 430 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 431 unsigned currentNumOperands); 432 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 433 unsigned currentNumResults); 434 435 //===--------------------------------------------------------------------===// 436 // Lexer Utilities 437 //===--------------------------------------------------------------------===// 438 439 /// If the current token has the specified kind, consume it and return true. 440 /// If not, return false. 441 bool consumeIf(Token::Kind kind) { 442 if (curToken.isNot(kind)) 443 return false; 444 consumeToken(kind); 445 return true; 446 } 447 448 /// Advance the current lexer onto the next token. 449 void consumeToken() { 450 assert(curToken.isNot(Token::eof, Token::error) && 451 "shouldn't advance past EOF or errors"); 452 curToken = lexer.lexToken(); 453 } 454 455 /// Advance the current lexer onto the next token, asserting what the expected 456 /// current token is. This is preferred to the above method because it leads 457 /// to more self-documenting code with better checking. 458 void consumeToken(Token::Kind kind) { 459 assert(curToken.is(kind) && "consumed an unexpected token"); 460 consumeToken(); 461 } 462 463 /// Reset the lexer to the location at the given position. 464 void resetToken(SMRange tokLoc) { 465 lexer.resetPointer(tokLoc.Start.getPointer()); 466 curToken = lexer.lexToken(); 467 } 468 469 /// Consume the specified token if present and return success. On failure, 470 /// output a diagnostic and return failure. 471 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 472 if (curToken.getKind() != kind) 473 return emitError(curToken.getLoc(), msg); 474 consumeToken(); 475 return success(); 476 } 477 LogicalResult emitError(SMRange loc, const Twine &msg) { 478 lexer.emitError(loc, msg); 479 return failure(); 480 } 481 LogicalResult emitError(const Twine &msg) { 482 return emitError(curToken.getLoc(), msg); 483 } 484 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 485 const Twine ¬e) { 486 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 487 return failure(); 488 } 489 490 //===--------------------------------------------------------------------===// 491 // Fields 492 //===--------------------------------------------------------------------===// 493 494 /// The owning AST context. 495 ast::Context &ctx; 496 497 /// The lexer of this parser. 498 Lexer lexer; 499 500 /// The current token within the lexer. 501 Token curToken; 502 503 /// The most recently defined decl scope. 504 ast::DeclScope *curDeclScope = nullptr; 505 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 506 507 /// The current context of the parser. 508 ParserContext parserContext = ParserContext::Global; 509 510 /// Cached types to simplify verification and expression creation. 511 ast::Type valueTy, valueRangeTy; 512 ast::Type typeTy, typeRangeTy; 513 ast::Type attrTy; 514 515 /// A counter used when naming anonymous constraints and rewrites. 516 unsigned anonymousDeclNameCounter = 0; 517 518 /// The optional code completion context. 519 CodeCompleteContext *codeCompleteContext; 520 }; 521 } // namespace 522 523 FailureOr<ast::Module *> Parser::parseModule() { 524 SMLoc moduleLoc = curToken.getStartLoc(); 525 pushDeclScope(); 526 527 // Parse the top-level decls of the module. 528 SmallVector<ast::Decl *> decls; 529 if (failed(parseModuleBody(decls))) 530 return popDeclScope(), failure(); 531 532 popDeclScope(); 533 return ast::Module::create(ctx, moduleLoc, decls); 534 } 535 536 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 537 while (curToken.isNot(Token::eof)) { 538 if (curToken.is(Token::directive)) { 539 if (failed(parseDirective(decls))) 540 return failure(); 541 continue; 542 } 543 544 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 545 if (failed(decl)) 546 return failure(); 547 decls.push_back(*decl); 548 } 549 return success(); 550 } 551 552 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 553 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 554 valueRangeTy); 555 } 556 557 LogicalResult Parser::convertExpressionTo( 558 ast::Expr *&expr, ast::Type type, 559 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 560 ast::Type exprType = expr->getType(); 561 if (exprType == type) 562 return success(); 563 564 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 565 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 566 expr->getLoc(), llvm::formatv("unable to convert expression of type " 567 "`{0}` to the expected type of " 568 "`{1}`", 569 exprType, type)); 570 if (noteAttachFn) 571 noteAttachFn(*diag); 572 return diag; 573 }; 574 575 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 576 // Two operation types are compatible if they have the same name, or if the 577 // expected type is more general. 578 if (auto opType = type.dyn_cast<ast::OperationType>()) { 579 if (opType.getName()) 580 return emitConvertError(); 581 return success(); 582 } 583 584 // An operation can always convert to a ValueRange. 585 if (type == valueRangeTy) { 586 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 587 valueRangeTy); 588 return success(); 589 } 590 591 // Allow conversion to a single value by constraining the result range. 592 if (type == valueTy) { 593 // If the operation is registered, we can verify if it can ever have a 594 // single result. 595 Optional<StringRef> opName = exprOpType.getName(); 596 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 597 if (odsOp->getResults().empty()) { 598 return emitConvertError()->attachNote( 599 llvm::formatv("see the definition of `{0}`, which was defined " 600 "with zero results", 601 odsOp->getName()), 602 odsOp->getLoc()); 603 } 604 605 unsigned numSingleResults = llvm::count_if( 606 odsOp->getResults(), [](const ods::OperandOrResult &result) { 607 return result.getVariableLengthKind() == 608 ods::VariableLengthKind::Single; 609 }); 610 if (numSingleResults > 1) { 611 return emitConvertError()->attachNote( 612 llvm::formatv("see the definition of `{0}`, which was defined " 613 "with at least {1} results", 614 odsOp->getName(), numSingleResults), 615 odsOp->getLoc()); 616 } 617 } 618 619 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 620 valueTy); 621 return success(); 622 } 623 return emitConvertError(); 624 } 625 626 // FIXME: Decide how to allow/support converting a single result to multiple, 627 // and multiple to a single result. For now, we just allow Single->Range, 628 // but this isn't something really supported in the PDL dialect. We should 629 // figure out some way to support both. 630 if ((exprType == valueTy || exprType == valueRangeTy) && 631 (type == valueTy || type == valueRangeTy)) 632 return success(); 633 if ((exprType == typeTy || exprType == typeRangeTy) && 634 (type == typeTy || type == typeRangeTy)) 635 return success(); 636 637 // Handle tuple types. 638 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 639 auto tupleType = type.dyn_cast<ast::TupleType>(); 640 if (!tupleType || tupleType.size() != exprTupleType.size()) 641 return emitConvertError(); 642 643 // Build a new tuple expression using each of the elements of the current 644 // tuple. 645 SmallVector<ast::Expr *> newExprs; 646 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 647 newExprs.push_back(ast::MemberAccessExpr::create( 648 ctx, expr->getLoc(), expr, llvm::to_string(i), 649 exprTupleType.getElementTypes()[i])); 650 651 auto diagFn = [&](ast::Diagnostic &diag) { 652 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 653 i, exprTupleType)); 654 if (noteAttachFn) 655 noteAttachFn(diag); 656 }; 657 if (failed(convertExpressionTo(newExprs.back(), 658 tupleType.getElementTypes()[i], diagFn))) 659 return failure(); 660 } 661 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 662 tupleType.getElementNames()); 663 return success(); 664 } 665 666 return emitConvertError(); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Directives 671 672 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 673 StringRef directive = curToken.getSpelling(); 674 if (directive == "#include") 675 return parseInclude(decls); 676 677 return emitError("unknown directive `" + directive + "`"); 678 } 679 680 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 681 SMRange loc = curToken.getLoc(); 682 consumeToken(Token::directive); 683 684 // Handle code completion of the include file path. 685 if (curToken.is(Token::code_complete_string)) 686 return codeCompleteIncludeFilename(curToken.getStringValue()); 687 688 // Parse the file being included. 689 if (!curToken.isString()) 690 return emitError(loc, 691 "expected string file name after `include` directive"); 692 SMRange fileLoc = curToken.getLoc(); 693 std::string filenameStr = curToken.getStringValue(); 694 StringRef filename = filenameStr; 695 consumeToken(); 696 697 // Check the type of include. If ending with `.pdll`, this is another pdl file 698 // to be parsed along with the current module. 699 if (filename.endswith(".pdll")) { 700 if (failed(lexer.pushInclude(filename, fileLoc))) 701 return emitError(fileLoc, 702 "unable to open include file `" + filename + "`"); 703 704 // If we added the include successfully, parse it into the current module. 705 // Make sure to update to the next token after we finish parsing the nested 706 // file. 707 curToken = lexer.lexToken(); 708 LogicalResult result = parseModuleBody(decls); 709 curToken = lexer.lexToken(); 710 return result; 711 } 712 713 // Otherwise, this must be a `.td` include. 714 if (filename.endswith(".td")) 715 return parseTdInclude(filename, fileLoc, decls); 716 717 return emitError(fileLoc, 718 "expected include filename to end with `.pdll` or `.td`"); 719 } 720 721 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 722 SmallVectorImpl<ast::Decl *> &decls) { 723 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 724 725 // This class provides a context argument for the llvm::SourceMgr diagnostic 726 // handler. 727 struct DiagHandlerContext { 728 Parser &parser; 729 StringRef filename; 730 llvm::SMRange loc; 731 } handlerContext{*this, filename, fileLoc}; 732 733 // Set the diagnostic handler for the tablegen source manager. 734 llvm::SrcMgr.setDiagHandler( 735 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 736 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 737 (void)ctx->parser.emitError( 738 ctx->loc, 739 llvm::formatv("error while processing include file `{0}`: {1}", 740 ctx->filename, diag.getMessage())); 741 }, 742 &handlerContext); 743 744 // Use the source manager to open the file, but don't yet add it. 745 std::string includedFile; 746 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 747 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 748 if (!includeBuffer) 749 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 750 751 auto processFn = [&](llvm::RecordKeeper &records) { 752 processTdIncludeRecords(records, decls); 753 754 // After we are done processing, move all of the tablegen source buffers to 755 // the main parser source mgr. This allows for directly using source 756 // locations from the .td files without needing to remap them. 757 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr, fileLoc.End); 758 return false; 759 }; 760 if (llvm::TableGenParseFile(std::move(*includeBuffer), 761 parserSrcMgr.getIncludeDirs(), processFn)) 762 return failure(); 763 764 return success(); 765 } 766 767 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 768 SmallVectorImpl<ast::Decl *> &decls) { 769 // Return the length kind of the given value. 770 auto getLengthKind = [](const auto &value) { 771 if (value.isOptional()) 772 return ods::VariableLengthKind::Optional; 773 return value.isVariadic() ? ods::VariableLengthKind::Variadic 774 : ods::VariableLengthKind::Single; 775 }; 776 777 // Insert a type constraint into the ODS context. 778 ods::Context &odsContext = ctx.getODSContext(); 779 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 780 -> const ods::TypeConstraint & { 781 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 782 cst.constraint.getSummary(), 783 cst.constraint.getCPPClassName()); 784 }; 785 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 786 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 787 }; 788 789 // Process the parsed tablegen records to build ODS information. 790 /// Operations. 791 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 792 tblgen::Operator op(def); 793 794 bool inserted = false; 795 ods::Operation *odsOp = nullptr; 796 std::tie(odsOp, inserted) = 797 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 798 op.getDescription(), op.getLoc().front()); 799 800 // Ignore operations that have already been added. 801 if (!inserted) 802 continue; 803 804 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 805 odsOp->appendAttribute( 806 attr.name, attr.attr.isOptional(), 807 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 808 attr.attr.getSummary(), 809 attr.attr.getStorageType())); 810 } 811 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 812 odsOp->appendOperand(operand.name, getLengthKind(operand), 813 addTypeConstraint(operand)); 814 } 815 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 816 odsOp->appendResult(result.name, getLengthKind(result), 817 addTypeConstraint(result)); 818 } 819 } 820 /// Attr constraints. 821 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 822 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 823 decls.push_back( 824 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 825 tblgen::AttrConstraint(def), 826 convertLocToRange(def->getLoc().front()), attrTy)); 827 } 828 } 829 /// Type constraints. 830 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 831 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 832 decls.push_back( 833 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 834 tblgen::TypeConstraint(def), 835 convertLocToRange(def->getLoc().front()), typeTy)); 836 } 837 } 838 /// Interfaces. 839 ast::Type opTy = ast::OperationType::get(ctx); 840 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 841 StringRef name = def->getName(); 842 if (def->isAnonymous() || curDeclScope->lookup(name) || 843 def->isSubClassOf("DeclareInterfaceMethods")) 844 continue; 845 SMRange loc = convertLocToRange(def->getLoc().front()); 846 847 StringRef className = def->getValueAsString("cppClassName"); 848 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 849 std::string codeBlock = 850 llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) 851 .str(); 852 853 if (def->isSubClassOf("OpInterface")) { 854 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 855 name, codeBlock, loc, opTy)); 856 } else if (def->isSubClassOf("AttrInterface")) { 857 decls.push_back( 858 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 859 name, codeBlock, loc, attrTy)); 860 } else if (def->isSubClassOf("TypeInterface")) { 861 decls.push_back( 862 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 863 name, codeBlock, loc, typeTy)); 864 } 865 } 866 } 867 868 template <typename ConstraintT> 869 ast::Decl * 870 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 871 SMRange loc, ast::Type type) { 872 // Build the single input parameter. 873 ast::DeclScope *argScope = pushDeclScope(); 874 auto *paramVar = ast::VariableDecl::create( 875 ctx, ast::Name::create(ctx, "self", loc), type, 876 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 877 argScope->add(paramVar); 878 popDeclScope(); 879 880 // Build the native constraint. 881 auto *constraintDecl = ast::UserConstraintDecl::createNative( 882 ctx, ast::Name::create(ctx, name, loc), paramVar, 883 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 884 curDeclScope->add(constraintDecl); 885 return constraintDecl; 886 } 887 888 template <typename ConstraintT> 889 ast::Decl * 890 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 891 SMRange loc, ast::Type type) { 892 // Format the condition template. 893 tblgen::FmtContext fmtContext; 894 fmtContext.withSelf("self"); 895 std::string codeBlock = 896 tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); 897 898 return createODSNativePDLLConstraintDecl<ConstraintT>( 899 constraint.getUniqueDefName(), codeBlock, loc, type); 900 } 901 902 //===----------------------------------------------------------------------===// 903 // Decls 904 905 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 906 FailureOr<ast::Decl *> decl; 907 switch (curToken.getKind()) { 908 case Token::kw_Constraint: 909 decl = parseUserConstraintDecl(); 910 break; 911 case Token::kw_Pattern: 912 decl = parsePatternDecl(); 913 break; 914 case Token::kw_Rewrite: 915 decl = parseUserRewriteDecl(); 916 break; 917 default: 918 return emitError("expected top-level declaration, such as a `Pattern`"); 919 } 920 if (failed(decl)) 921 return failure(); 922 923 // If the decl has a name, add it to the current scope. 924 if (const ast::Name *name = (*decl)->getName()) { 925 if (failed(checkDefineNamedDecl(*name))) 926 return failure(); 927 curDeclScope->add(*decl); 928 } 929 return decl; 930 } 931 932 FailureOr<ast::NamedAttributeDecl *> 933 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 934 // Check for name code completion. 935 if (curToken.is(Token::code_complete)) 936 return codeCompleteAttributeName(parentOpName); 937 938 std::string attrNameStr; 939 if (curToken.isString()) 940 attrNameStr = curToken.getStringValue(); 941 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 942 attrNameStr = curToken.getSpelling().str(); 943 else 944 return emitError("expected identifier or string attribute name"); 945 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 946 consumeToken(); 947 948 // Check for a value of the attribute. 949 ast::Expr *attrValue = nullptr; 950 if (consumeIf(Token::equal)) { 951 FailureOr<ast::Expr *> attrExpr = parseExpr(); 952 if (failed(attrExpr)) 953 return failure(); 954 attrValue = *attrExpr; 955 } else { 956 // If there isn't a concrete value, create an expression representing a 957 // UnitAttr. 958 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 959 } 960 961 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 962 } 963 964 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 965 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 966 bool expectTerminalSemicolon) { 967 consumeToken(Token::equal_arrow); 968 969 // Parse the single statement of the lambda body. 970 SMLoc bodyStartLoc = curToken.getStartLoc(); 971 pushDeclScope(); 972 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 973 bool failedToParse = 974 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 975 popDeclScope(); 976 if (failedToParse) 977 return failure(); 978 979 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 980 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 981 } 982 983 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 984 // Ensure that the argument is named. 985 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 986 return emitError("expected identifier argument name"); 987 988 // Parse the argument similarly to a normal variable. 989 StringRef name = curToken.getSpelling(); 990 SMRange nameLoc = curToken.getLoc(); 991 consumeToken(); 992 993 if (failed( 994 parseToken(Token::colon, "expected `:` before argument constraint"))) 995 return failure(); 996 997 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 998 if (failed(cst)) 999 return failure(); 1000 1001 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1002 } 1003 1004 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1005 // Check to see if this result is named. 1006 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1007 // Check to see if this name actually refers to a Constraint. 1008 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1009 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1010 // If yes, and this is a Rewrite, give a nice error message as non-Core 1011 // constraints are not supported on Rewrite results. 1012 if (parserContext == ParserContext::Rewrite) { 1013 return emitError( 1014 "`Rewrite` results are only permitted to use core constraints, " 1015 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1016 } 1017 1018 // Otherwise, parse this as an unnamed result variable. 1019 } else { 1020 // If it wasn't a constraint, parse the result similarly to a variable. If 1021 // there is already an existing decl, we will emit an error when defining 1022 // this variable later. 1023 StringRef name = curToken.getSpelling(); 1024 SMRange nameLoc = curToken.getLoc(); 1025 consumeToken(); 1026 1027 if (failed(parseToken(Token::colon, 1028 "expected `:` before result constraint"))) 1029 return failure(); 1030 1031 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1032 if (failed(cst)) 1033 return failure(); 1034 1035 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1036 } 1037 } 1038 1039 // If it isn't named, we parse the constraint directly and create an unnamed 1040 // result variable. 1041 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1042 if (failed(cst)) 1043 return failure(); 1044 1045 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1046 } 1047 1048 FailureOr<ast::UserConstraintDecl *> 1049 Parser::parseUserConstraintDecl(bool isInline) { 1050 // Constraints and rewrites have very similar formats, dispatch to a shared 1051 // interface for parsing. 1052 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1053 [&](auto &&...args) { 1054 return this->parseUserPDLLConstraintDecl(args...); 1055 }, 1056 ParserContext::Constraint, "constraint", isInline); 1057 } 1058 1059 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1060 FailureOr<ast::UserConstraintDecl *> decl = 1061 parseUserConstraintDecl(/*isInline=*/true); 1062 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1063 return failure(); 1064 1065 curDeclScope->add(*decl); 1066 return decl; 1067 } 1068 1069 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1070 const ast::Name &name, bool isInline, 1071 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1072 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1073 // Push the argument scope back onto the list, so that the body can 1074 // reference arguments. 1075 pushDeclScope(argumentScope); 1076 1077 // Parse the body of the constraint. The body is either defined as a compound 1078 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1079 ast::CompoundStmt *body; 1080 if (curToken.is(Token::equal_arrow)) { 1081 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1082 [&](ast::Stmt *&stmt) -> LogicalResult { 1083 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1084 if (!stmtExpr) { 1085 return emitError(stmt->getLoc(), 1086 "expected `Constraint` lambda body to contain a " 1087 "single expression"); 1088 } 1089 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1090 return success(); 1091 }, 1092 /*expectTerminalSemicolon=*/!isInline); 1093 if (failed(bodyResult)) 1094 return failure(); 1095 body = *bodyResult; 1096 } else { 1097 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1098 if (failed(bodyResult)) 1099 return failure(); 1100 body = *bodyResult; 1101 1102 // Verify the structure of the body. 1103 auto bodyIt = body->begin(), bodyE = body->end(); 1104 for (; bodyIt != bodyE; ++bodyIt) 1105 if (isa<ast::ReturnStmt>(*bodyIt)) 1106 break; 1107 if (failed(validateUserConstraintOrRewriteReturn( 1108 "Constraint", body, bodyIt, bodyE, results, resultType))) 1109 return failure(); 1110 } 1111 popDeclScope(); 1112 1113 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1114 name, arguments, results, resultType, body); 1115 } 1116 1117 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1118 // Constraints and rewrites have very similar formats, dispatch to a shared 1119 // interface for parsing. 1120 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1121 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1122 ParserContext::Rewrite, "rewrite", isInline); 1123 } 1124 1125 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1126 FailureOr<ast::UserRewriteDecl *> decl = 1127 parseUserRewriteDecl(/*isInline=*/true); 1128 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1129 return failure(); 1130 1131 curDeclScope->add(*decl); 1132 return decl; 1133 } 1134 1135 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1136 const ast::Name &name, bool isInline, 1137 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1138 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1139 // Push the argument scope back onto the list, so that the body can 1140 // reference arguments. 1141 curDeclScope = argumentScope; 1142 ast::CompoundStmt *body; 1143 if (curToken.is(Token::equal_arrow)) { 1144 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1145 [&](ast::Stmt *&statement) -> LogicalResult { 1146 if (isa<ast::OpRewriteStmt>(statement)) 1147 return success(); 1148 1149 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1150 if (!statementExpr) { 1151 return emitError( 1152 statement->getLoc(), 1153 "expected `Rewrite` lambda body to contain a single expression " 1154 "or an operation rewrite statement; such as `erase`, " 1155 "`replace`, or `rewrite`"); 1156 } 1157 statement = 1158 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1159 return success(); 1160 }, 1161 /*expectTerminalSemicolon=*/!isInline); 1162 if (failed(bodyResult)) 1163 return failure(); 1164 body = *bodyResult; 1165 } else { 1166 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1167 if (failed(bodyResult)) 1168 return failure(); 1169 body = *bodyResult; 1170 } 1171 popDeclScope(); 1172 1173 // Verify the structure of the body. 1174 auto bodyIt = body->begin(), bodyE = body->end(); 1175 for (; bodyIt != bodyE; ++bodyIt) 1176 if (isa<ast::ReturnStmt>(*bodyIt)) 1177 break; 1178 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1179 bodyE, results, resultType))) 1180 return failure(); 1181 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1182 name, arguments, results, resultType, body); 1183 } 1184 1185 template <typename T, typename ParseUserPDLLDeclFnT> 1186 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1187 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1188 StringRef anonymousNamePrefix, bool isInline) { 1189 SMRange loc = curToken.getLoc(); 1190 consumeToken(); 1191 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1192 1193 // Parse the name of the decl. 1194 const ast::Name *name = nullptr; 1195 if (curToken.isNot(Token::identifier)) { 1196 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1197 // in C++, so being unnamed is fine. 1198 if (!isInline) 1199 return emitError("expected identifier name"); 1200 1201 // Create a unique anonymous name to use, as the name for this decl is not 1202 // important. 1203 std::string anonName = 1204 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1205 anonymousDeclNameCounter++) 1206 .str(); 1207 name = &ast::Name::create(ctx, anonName, loc); 1208 } else { 1209 // If a name was provided, we can use it directly. 1210 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1211 consumeToken(Token::identifier); 1212 } 1213 1214 // Parse the functional signature of the decl. 1215 SmallVector<ast::VariableDecl *> arguments, results; 1216 ast::DeclScope *argumentScope; 1217 ast::Type resultType; 1218 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1219 argumentScope, resultType))) 1220 return failure(); 1221 1222 // Check to see which type of constraint this is. If the constraint contains a 1223 // compound body, this is a PDLL decl. 1224 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1225 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1226 resultType); 1227 1228 // Otherwise, this is a native decl. 1229 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1230 results, resultType); 1231 } 1232 1233 template <typename T> 1234 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1235 const ast::Name &name, bool isInline, 1236 ArrayRef<ast::VariableDecl *> arguments, 1237 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1238 // If followed by a string, the native code body has also been specified. 1239 std::string codeStrStorage; 1240 Optional<StringRef> optCodeStr; 1241 if (curToken.isString()) { 1242 codeStrStorage = curToken.getStringValue(); 1243 optCodeStr = codeStrStorage; 1244 consumeToken(); 1245 } else if (isInline) { 1246 return emitError(name.getLoc(), 1247 "external declarations must be declared in global scope"); 1248 } 1249 if (failed(parseToken(Token::semicolon, 1250 "expected `;` after native declaration"))) 1251 return failure(); 1252 // TODO: PDL should be able to support constraint results in certain 1253 // situations, we should revise this. 1254 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1255 return emitError( 1256 "native Constraints currently do not support returning results"); 1257 } 1258 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1259 } 1260 1261 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1262 SmallVectorImpl<ast::VariableDecl *> &arguments, 1263 SmallVectorImpl<ast::VariableDecl *> &results, 1264 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1265 // Parse the argument list of the decl. 1266 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1267 return failure(); 1268 1269 argumentScope = pushDeclScope(); 1270 if (curToken.isNot(Token::r_paren)) { 1271 do { 1272 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1273 if (failed(argument)) 1274 return failure(); 1275 arguments.emplace_back(*argument); 1276 } while (consumeIf(Token::comma)); 1277 } 1278 popDeclScope(); 1279 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1280 return failure(); 1281 1282 // Parse the results of the decl. 1283 pushDeclScope(); 1284 if (consumeIf(Token::arrow)) { 1285 auto parseResultFn = [&]() -> LogicalResult { 1286 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1287 if (failed(result)) 1288 return failure(); 1289 results.emplace_back(*result); 1290 return success(); 1291 }; 1292 1293 // Check for a list of results. 1294 if (consumeIf(Token::l_paren)) { 1295 do { 1296 if (failed(parseResultFn())) 1297 return failure(); 1298 } while (consumeIf(Token::comma)); 1299 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1300 return failure(); 1301 1302 // Otherwise, there is only one result. 1303 } else if (failed(parseResultFn())) { 1304 return failure(); 1305 } 1306 } 1307 popDeclScope(); 1308 1309 // Compute the result type of the decl. 1310 resultType = createUserConstraintRewriteResultType(results); 1311 1312 // Verify that results are only named if there are more than one. 1313 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1314 return emitError( 1315 results.front()->getLoc(), 1316 "cannot create a single-element tuple with an element label"); 1317 } 1318 return success(); 1319 } 1320 1321 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1322 StringRef declType, ast::CompoundStmt *body, 1323 ArrayRef<ast::Stmt *>::iterator bodyIt, 1324 ArrayRef<ast::Stmt *>::iterator bodyE, 1325 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1326 // Handle if a `return` was provided. 1327 if (bodyIt != bodyE) { 1328 // Emit an error if we have trailing statements after the return. 1329 if (std::next(bodyIt) != bodyE) { 1330 return emitError( 1331 (*std::next(bodyIt))->getLoc(), 1332 llvm::formatv("`return` terminated the `{0}` body, but found " 1333 "trailing statements afterwards", 1334 declType)); 1335 } 1336 1337 // Otherwise if a return wasn't provided, check that no results are 1338 // expected. 1339 } else if (!results.empty()) { 1340 return emitError( 1341 {body->getLoc().End, body->getLoc().End}, 1342 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1343 declType, resultType)); 1344 } 1345 return success(); 1346 } 1347 1348 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1349 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1350 if (isa<ast::OpRewriteStmt>(statement)) 1351 return success(); 1352 return emitError( 1353 statement->getLoc(), 1354 "expected Pattern lambda body to contain a single operation " 1355 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1356 }); 1357 } 1358 1359 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1360 SMRange loc = curToken.getLoc(); 1361 consumeToken(Token::kw_Pattern); 1362 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1363 ParserContext::PatternMatch); 1364 1365 // Check for an optional identifier for the pattern name. 1366 const ast::Name *name = nullptr; 1367 if (curToken.is(Token::identifier)) { 1368 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1369 consumeToken(Token::identifier); 1370 } 1371 1372 // Parse any pattern metadata. 1373 ParsedPatternMetadata metadata; 1374 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1375 return failure(); 1376 1377 // Parse the pattern body. 1378 ast::CompoundStmt *body; 1379 1380 // Handle a lambda body. 1381 if (curToken.is(Token::equal_arrow)) { 1382 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1383 if (failed(bodyResult)) 1384 return failure(); 1385 body = *bodyResult; 1386 } else { 1387 if (curToken.isNot(Token::l_brace)) 1388 return emitError("expected `{` or `=>` to start pattern body"); 1389 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1390 if (failed(bodyResult)) 1391 return failure(); 1392 body = *bodyResult; 1393 1394 // Verify the body of the pattern. 1395 auto bodyIt = body->begin(), bodyE = body->end(); 1396 for (; bodyIt != bodyE; ++bodyIt) { 1397 if (isa<ast::ReturnStmt>(*bodyIt)) { 1398 return emitError((*bodyIt)->getLoc(), 1399 "`return` statements are only permitted within a " 1400 "`Constraint` or `Rewrite` body"); 1401 } 1402 // Break when we've found the rewrite statement. 1403 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1404 break; 1405 } 1406 if (bodyIt == bodyE) { 1407 return emitError(loc, 1408 "expected Pattern body to terminate with an operation " 1409 "rewrite statement, such as `erase`"); 1410 } 1411 if (std::next(bodyIt) != bodyE) { 1412 return emitError((*std::next(bodyIt))->getLoc(), 1413 "Pattern body was terminated by an operation " 1414 "rewrite statement, but found trailing statements"); 1415 } 1416 } 1417 1418 return createPatternDecl(loc, name, metadata, body); 1419 } 1420 1421 LogicalResult 1422 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1423 Optional<SMRange> benefitLoc; 1424 Optional<SMRange> hasBoundedRecursionLoc; 1425 1426 do { 1427 // Handle metadata code completion. 1428 if (curToken.is(Token::code_complete)) 1429 return codeCompletePatternMetadata(); 1430 1431 if (curToken.isNot(Token::identifier)) 1432 return emitError("expected pattern metadata identifier"); 1433 StringRef metadataStr = curToken.getSpelling(); 1434 SMRange metadataLoc = curToken.getLoc(); 1435 consumeToken(Token::identifier); 1436 1437 // Parse the benefit metadata: benefit(<integer-value>) 1438 if (metadataStr == "benefit") { 1439 if (benefitLoc) { 1440 return emitErrorAndNote(metadataLoc, 1441 "pattern benefit has already been specified", 1442 *benefitLoc, "see previous definition here"); 1443 } 1444 if (failed(parseToken(Token::l_paren, 1445 "expected `(` before pattern benefit"))) 1446 return failure(); 1447 1448 uint16_t benefitValue = 0; 1449 if (curToken.isNot(Token::integer)) 1450 return emitError("expected integral pattern benefit"); 1451 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1452 return emitError( 1453 "expected pattern benefit to fit within a 16-bit integer"); 1454 consumeToken(Token::integer); 1455 1456 metadata.benefit = benefitValue; 1457 benefitLoc = metadataLoc; 1458 1459 if (failed( 1460 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1461 return failure(); 1462 continue; 1463 } 1464 1465 // Parse the bounded recursion metadata: recursion 1466 if (metadataStr == "recursion") { 1467 if (hasBoundedRecursionLoc) { 1468 return emitErrorAndNote( 1469 metadataLoc, 1470 "pattern recursion metadata has already been specified", 1471 *hasBoundedRecursionLoc, "see previous definition here"); 1472 } 1473 metadata.hasBoundedRecursion = true; 1474 hasBoundedRecursionLoc = metadataLoc; 1475 continue; 1476 } 1477 1478 return emitError(metadataLoc, "unknown pattern metadata"); 1479 } while (consumeIf(Token::comma)); 1480 1481 return success(); 1482 } 1483 1484 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1485 consumeToken(Token::less); 1486 1487 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1488 if (failed(typeExpr) || 1489 failed(parseToken(Token::greater, 1490 "expected `>` after variable type constraint"))) 1491 return failure(); 1492 return typeExpr; 1493 } 1494 1495 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1496 assert(curDeclScope && "defining decl outside of a decl scope"); 1497 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1498 return emitErrorAndNote( 1499 name.getLoc(), "`" + name.getName() + "` has already been defined", 1500 lastDecl->getName()->getLoc(), "see previous definition here"); 1501 } 1502 return success(); 1503 } 1504 1505 FailureOr<ast::VariableDecl *> 1506 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1507 ast::Expr *initExpr, 1508 ArrayRef<ast::ConstraintRef> constraints) { 1509 assert(curDeclScope && "defining variable outside of decl scope"); 1510 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1511 1512 // If the name of the variable indicates a special variable, we don't add it 1513 // to the scope. This variable is local to the definition point. 1514 if (name.empty() || name == "_") { 1515 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1516 constraints); 1517 } 1518 if (failed(checkDefineNamedDecl(nameDecl))) 1519 return failure(); 1520 1521 auto *varDecl = 1522 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1523 curDeclScope->add(varDecl); 1524 return varDecl; 1525 } 1526 1527 FailureOr<ast::VariableDecl *> 1528 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1529 ArrayRef<ast::ConstraintRef> constraints) { 1530 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1531 constraints); 1532 } 1533 1534 LogicalResult Parser::parseVariableDeclConstraintList( 1535 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1536 Optional<SMRange> typeConstraint; 1537 auto parseSingleConstraint = [&] { 1538 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1539 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1540 /*allowNonCoreConstraints=*/true); 1541 if (failed(constraint)) 1542 return failure(); 1543 constraints.push_back(*constraint); 1544 return success(); 1545 }; 1546 1547 // Check to see if this is a single constraint, or a list. 1548 if (!consumeIf(Token::l_square)) 1549 return parseSingleConstraint(); 1550 1551 do { 1552 if (failed(parseSingleConstraint())) 1553 return failure(); 1554 } while (consumeIf(Token::comma)); 1555 return parseToken(Token::r_square, "expected `]` after constraint list"); 1556 } 1557 1558 FailureOr<ast::ConstraintRef> 1559 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1560 ArrayRef<ast::ConstraintRef> existingConstraints, 1561 bool allowInlineTypeConstraints, 1562 bool allowNonCoreConstraints) { 1563 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1564 if (!allowInlineTypeConstraints) { 1565 return emitError( 1566 curToken.getLoc(), 1567 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1568 "permitted on arguments or results"); 1569 } 1570 if (typeConstraint) 1571 return emitErrorAndNote( 1572 curToken.getLoc(), 1573 "the type of this variable has already been constrained", 1574 *typeConstraint, "see previous constraint location here"); 1575 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1576 if (failed(constraintExpr)) 1577 return failure(); 1578 typeExpr = *constraintExpr; 1579 typeConstraint = typeExpr->getLoc(); 1580 return success(); 1581 }; 1582 1583 SMRange loc = curToken.getLoc(); 1584 switch (curToken.getKind()) { 1585 case Token::kw_Attr: { 1586 consumeToken(Token::kw_Attr); 1587 1588 // Check for a type constraint. 1589 ast::Expr *typeExpr = nullptr; 1590 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1591 return failure(); 1592 return ast::ConstraintRef( 1593 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1594 } 1595 case Token::kw_Op: { 1596 consumeToken(Token::kw_Op); 1597 1598 // Parse an optional operation name. If the name isn't provided, this refers 1599 // to "any" operation. 1600 FailureOr<ast::OpNameDecl *> opName = 1601 parseWrappedOperationName(/*allowEmptyName=*/true); 1602 if (failed(opName)) 1603 return failure(); 1604 1605 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1606 loc); 1607 } 1608 case Token::kw_Type: 1609 consumeToken(Token::kw_Type); 1610 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1611 case Token::kw_TypeRange: 1612 consumeToken(Token::kw_TypeRange); 1613 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1614 loc); 1615 case Token::kw_Value: { 1616 consumeToken(Token::kw_Value); 1617 1618 // Check for a type constraint. 1619 ast::Expr *typeExpr = nullptr; 1620 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1621 return failure(); 1622 1623 return ast::ConstraintRef( 1624 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1625 } 1626 case Token::kw_ValueRange: { 1627 consumeToken(Token::kw_ValueRange); 1628 1629 // Check for a type constraint. 1630 ast::Expr *typeExpr = nullptr; 1631 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1632 return failure(); 1633 1634 return ast::ConstraintRef( 1635 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1636 } 1637 1638 case Token::kw_Constraint: { 1639 // Handle an inline constraint. 1640 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1641 if (failed(decl)) 1642 return failure(); 1643 return ast::ConstraintRef(*decl, loc); 1644 } 1645 case Token::identifier: { 1646 StringRef constraintName = curToken.getSpelling(); 1647 consumeToken(Token::identifier); 1648 1649 // Lookup the referenced constraint. 1650 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1651 if (!cstDecl) { 1652 return emitError(loc, "unknown reference to constraint `" + 1653 constraintName + "`"); 1654 } 1655 1656 // Handle a reference to a proper constraint. 1657 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1658 return ast::ConstraintRef(cst, loc); 1659 1660 return emitErrorAndNote( 1661 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1662 "see the definition of `" + constraintName + "` here"); 1663 } 1664 // Handle single entity constraint code completion. 1665 case Token::code_complete: { 1666 // Try to infer the current type for use by code completion. 1667 ast::Type inferredType; 1668 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1669 allowNonCoreConstraints))) 1670 return failure(); 1671 1672 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1673 allowInlineTypeConstraints); 1674 } 1675 default: 1676 break; 1677 } 1678 return emitError(loc, "expected identifier constraint"); 1679 } 1680 1681 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1682 // Constraint arguments may apply more complex constraints via the arguments. 1683 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1684 1685 Optional<SMRange> typeConstraint; 1686 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1687 /*allowInlineTypeConstraints=*/false, 1688 allowNonCoreConstraints); 1689 } 1690 1691 //===----------------------------------------------------------------------===// 1692 // Exprs 1693 1694 FailureOr<ast::Expr *> Parser::parseExpr() { 1695 if (curToken.is(Token::underscore)) 1696 return parseUnderscoreExpr(); 1697 1698 // Parse the LHS expression. 1699 FailureOr<ast::Expr *> lhsExpr; 1700 switch (curToken.getKind()) { 1701 case Token::kw_attr: 1702 lhsExpr = parseAttributeExpr(); 1703 break; 1704 case Token::kw_Constraint: 1705 lhsExpr = parseInlineConstraintLambdaExpr(); 1706 break; 1707 case Token::identifier: 1708 lhsExpr = parseIdentifierExpr(); 1709 break; 1710 case Token::kw_op: 1711 lhsExpr = parseOperationExpr(); 1712 break; 1713 case Token::kw_Rewrite: 1714 lhsExpr = parseInlineRewriteLambdaExpr(); 1715 break; 1716 case Token::kw_type: 1717 lhsExpr = parseTypeExpr(); 1718 break; 1719 case Token::l_paren: 1720 lhsExpr = parseTupleExpr(); 1721 break; 1722 default: 1723 return emitError("expected expression"); 1724 } 1725 if (failed(lhsExpr)) 1726 return failure(); 1727 1728 // Check for an operator expression. 1729 while (true) { 1730 switch (curToken.getKind()) { 1731 case Token::dot: 1732 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1733 break; 1734 case Token::l_paren: 1735 lhsExpr = parseCallExpr(*lhsExpr); 1736 break; 1737 default: 1738 return lhsExpr; 1739 } 1740 if (failed(lhsExpr)) 1741 return failure(); 1742 } 1743 } 1744 1745 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1746 SMRange loc = curToken.getLoc(); 1747 consumeToken(Token::kw_attr); 1748 1749 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1750 // identifier. 1751 if (!consumeIf(Token::less)) { 1752 resetToken(loc); 1753 return parseIdentifierExpr(); 1754 } 1755 1756 if (!curToken.isString()) 1757 return emitError("expected string literal containing MLIR attribute"); 1758 std::string attrExpr = curToken.getStringValue(); 1759 consumeToken(); 1760 1761 if (failed( 1762 parseToken(Token::greater, "expected `>` after attribute literal"))) 1763 return failure(); 1764 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1765 } 1766 1767 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1768 SMRange loc = curToken.getLoc(); 1769 consumeToken(Token::l_paren); 1770 1771 // Parse the arguments of the call. 1772 SmallVector<ast::Expr *> arguments; 1773 if (curToken.isNot(Token::r_paren)) { 1774 do { 1775 // Handle code completion for the call arguments. 1776 if (curToken.is(Token::code_complete)) { 1777 codeCompleteCallSignature(parentExpr, arguments.size()); 1778 return failure(); 1779 } 1780 1781 FailureOr<ast::Expr *> argument = parseExpr(); 1782 if (failed(argument)) 1783 return failure(); 1784 arguments.push_back(*argument); 1785 } while (consumeIf(Token::comma)); 1786 } 1787 loc.End = curToken.getEndLoc(); 1788 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1789 return failure(); 1790 1791 return createCallExpr(loc, parentExpr, arguments); 1792 } 1793 1794 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1795 ast::Decl *decl = curDeclScope->lookup(name); 1796 if (!decl) 1797 return emitError(loc, "undefined reference to `" + name + "`"); 1798 1799 return createDeclRefExpr(loc, decl); 1800 } 1801 1802 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1803 StringRef name = curToken.getSpelling(); 1804 SMRange nameLoc = curToken.getLoc(); 1805 consumeToken(); 1806 1807 // Check to see if this is a decl ref expression that defines a variable 1808 // inline. 1809 if (consumeIf(Token::colon)) { 1810 SmallVector<ast::ConstraintRef> constraints; 1811 if (failed(parseVariableDeclConstraintList(constraints))) 1812 return failure(); 1813 ast::Type type; 1814 if (failed(validateVariableConstraints(constraints, type))) 1815 return failure(); 1816 return createInlineVariableExpr(type, name, nameLoc, constraints); 1817 } 1818 1819 return parseDeclRefExpr(name, nameLoc); 1820 } 1821 1822 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1823 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1824 if (failed(decl)) 1825 return failure(); 1826 1827 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1828 ast::ConstraintType::get(ctx)); 1829 } 1830 1831 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1832 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1833 if (failed(decl)) 1834 return failure(); 1835 1836 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1837 ast::RewriteType::get(ctx)); 1838 } 1839 1840 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1841 SMRange loc = curToken.getLoc(); 1842 consumeToken(Token::dot); 1843 1844 // Check for code completion of the member name. 1845 if (curToken.is(Token::code_complete)) 1846 return codeCompleteMemberAccess(parentExpr); 1847 1848 // Parse the member name. 1849 Token memberNameTok = curToken; 1850 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1851 !memberNameTok.isKeyword()) 1852 return emitError(loc, "expected identifier or numeric member name"); 1853 StringRef memberName = memberNameTok.getSpelling(); 1854 consumeToken(); 1855 1856 return createMemberAccessExpr(parentExpr, memberName, loc); 1857 } 1858 1859 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1860 SMRange loc = curToken.getLoc(); 1861 1862 // Check for code completion for the dialect name. 1863 if (curToken.is(Token::code_complete)) 1864 return codeCompleteDialectName(); 1865 1866 // Handle the case of an no operation name. 1867 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1868 if (allowEmptyName) 1869 return ast::OpNameDecl::create(ctx, SMRange()); 1870 return emitError("expected dialect namespace"); 1871 } 1872 StringRef name = curToken.getSpelling(); 1873 consumeToken(); 1874 1875 // Otherwise, this is a literal operation name. 1876 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1877 return failure(); 1878 1879 // Check for code completion for the operation name. 1880 if (curToken.is(Token::code_complete)) 1881 return codeCompleteOperationName(name); 1882 1883 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1884 return emitError("expected operation name after dialect namespace"); 1885 1886 name = StringRef(name.data(), name.size() + 1); 1887 do { 1888 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1889 loc.End = curToken.getEndLoc(); 1890 consumeToken(); 1891 } while (curToken.isAny(Token::identifier, Token::dot) || 1892 curToken.isKeyword()); 1893 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1894 } 1895 1896 FailureOr<ast::OpNameDecl *> 1897 Parser::parseWrappedOperationName(bool allowEmptyName) { 1898 if (!consumeIf(Token::less)) 1899 return ast::OpNameDecl::create(ctx, SMRange()); 1900 1901 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1902 if (failed(opNameDecl)) 1903 return failure(); 1904 1905 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1906 return failure(); 1907 return opNameDecl; 1908 } 1909 1910 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1911 SMRange loc = curToken.getLoc(); 1912 consumeToken(Token::kw_op); 1913 1914 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1915 // identifier. 1916 if (curToken.isNot(Token::less)) { 1917 resetToken(loc); 1918 return parseIdentifierExpr(); 1919 } 1920 1921 // Parse the operation name. The name may be elided, in which case the 1922 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1923 // `Operation*`). Operation names within a rewrite context must be named. 1924 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1925 FailureOr<ast::OpNameDecl *> opNameDecl = 1926 parseWrappedOperationName(allowEmptyName); 1927 if (failed(opNameDecl)) 1928 return failure(); 1929 Optional<StringRef> opName = (*opNameDecl)->getName(); 1930 1931 // Functor used to create an implicit range variable, used for implicit "all" 1932 // operand or results variables. 1933 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1934 FailureOr<ast::VariableDecl *> rangeVar = 1935 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1936 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1937 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1938 }; 1939 1940 // Check for the optional list of operands. 1941 SmallVector<ast::Expr *> operands; 1942 if (!consumeIf(Token::l_paren)) { 1943 // If the operand list isn't specified and we are in a match context, define 1944 // an inplace unconstrained operand range corresponding to all of the 1945 // operands of the operation. This avoids treating zero operands the same 1946 // way as "unconstrained operands". 1947 if (parserContext != ParserContext::Rewrite) { 1948 operands.push_back(createImplicitRangeVar( 1949 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1950 } 1951 } else if (!consumeIf(Token::r_paren)) { 1952 // Check for operand signature code completion. 1953 if (curToken.is(Token::code_complete)) { 1954 codeCompleteOperationOperandsSignature(opName, operands.size()); 1955 return failure(); 1956 } 1957 1958 // If the operand list was specified and non-empty, parse the operands. 1959 do { 1960 FailureOr<ast::Expr *> operand = parseExpr(); 1961 if (failed(operand)) 1962 return failure(); 1963 operands.push_back(*operand); 1964 } while (consumeIf(Token::comma)); 1965 1966 if (failed(parseToken(Token::r_paren, 1967 "expected `)` after operation operand list"))) 1968 return failure(); 1969 } 1970 1971 // Check for the optional list of attributes. 1972 SmallVector<ast::NamedAttributeDecl *> attributes; 1973 if (consumeIf(Token::l_brace)) { 1974 do { 1975 FailureOr<ast::NamedAttributeDecl *> decl = 1976 parseNamedAttributeDecl(opName); 1977 if (failed(decl)) 1978 return failure(); 1979 attributes.emplace_back(*decl); 1980 } while (consumeIf(Token::comma)); 1981 1982 if (failed(parseToken(Token::r_brace, 1983 "expected `}` after operation attribute list"))) 1984 return failure(); 1985 } 1986 1987 // Check for the optional list of result types. 1988 SmallVector<ast::Expr *> resultTypes; 1989 if (consumeIf(Token::arrow)) { 1990 if (failed(parseToken(Token::l_paren, 1991 "expected `(` before operation result type list"))) 1992 return failure(); 1993 1994 // Handle the case of an empty result list. 1995 if (!consumeIf(Token::r_paren)) { 1996 do { 1997 // Check for result signature code completion. 1998 if (curToken.is(Token::code_complete)) { 1999 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2000 return failure(); 2001 } 2002 2003 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2004 if (failed(resultTypeExpr)) 2005 return failure(); 2006 resultTypes.push_back(*resultTypeExpr); 2007 } while (consumeIf(Token::comma)); 2008 2009 if (failed(parseToken(Token::r_paren, 2010 "expected `)` after operation result type list"))) 2011 return failure(); 2012 } 2013 } else if (parserContext != ParserContext::Rewrite) { 2014 // If the result list isn't specified and we are in a match context, define 2015 // an inplace unconstrained result range corresponding to all of the results 2016 // of the operation. This avoids treating zero results the same way as 2017 // "unconstrained results". 2018 resultTypes.push_back(createImplicitRangeVar( 2019 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2020 } 2021 2022 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2023 resultTypes); 2024 } 2025 2026 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2027 SMRange loc = curToken.getLoc(); 2028 consumeToken(Token::l_paren); 2029 2030 DenseMap<StringRef, SMRange> usedNames; 2031 SmallVector<StringRef> elementNames; 2032 SmallVector<ast::Expr *> elements; 2033 if (curToken.isNot(Token::r_paren)) { 2034 do { 2035 // Check for the optional element name assignment before the value. 2036 StringRef elementName; 2037 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2038 Token elementNameTok = curToken; 2039 consumeToken(); 2040 2041 // The element name is only present if followed by an `=`. 2042 if (consumeIf(Token::equal)) { 2043 elementName = elementNameTok.getSpelling(); 2044 2045 // Check to see if this name is already used. 2046 auto elementNameIt = 2047 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2048 if (!elementNameIt.second) { 2049 return emitErrorAndNote( 2050 elementNameTok.getLoc(), 2051 llvm::formatv("duplicate tuple element label `{0}`", 2052 elementName), 2053 elementNameIt.first->getSecond(), 2054 "see previous label use here"); 2055 } 2056 } else { 2057 // Otherwise, we treat this as part of an expression so reset the 2058 // lexer. 2059 resetToken(elementNameTok.getLoc()); 2060 } 2061 } 2062 elementNames.push_back(elementName); 2063 2064 // Parse the tuple element value. 2065 FailureOr<ast::Expr *> element = parseExpr(); 2066 if (failed(element)) 2067 return failure(); 2068 elements.push_back(*element); 2069 } while (consumeIf(Token::comma)); 2070 } 2071 loc.End = curToken.getEndLoc(); 2072 if (failed( 2073 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2074 return failure(); 2075 return createTupleExpr(loc, elements, elementNames); 2076 } 2077 2078 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2079 SMRange loc = curToken.getLoc(); 2080 consumeToken(Token::kw_type); 2081 2082 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2083 // identifier. 2084 if (!consumeIf(Token::less)) { 2085 resetToken(loc); 2086 return parseIdentifierExpr(); 2087 } 2088 2089 if (!curToken.isString()) 2090 return emitError("expected string literal containing MLIR type"); 2091 std::string attrExpr = curToken.getStringValue(); 2092 consumeToken(); 2093 2094 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2095 return failure(); 2096 return ast::TypeExpr::create(ctx, loc, attrExpr); 2097 } 2098 2099 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2100 StringRef name = curToken.getSpelling(); 2101 SMRange nameLoc = curToken.getLoc(); 2102 consumeToken(Token::underscore); 2103 2104 // Underscore expressions require a constraint list. 2105 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2106 return failure(); 2107 2108 // Parse the constraints for the expression. 2109 SmallVector<ast::ConstraintRef> constraints; 2110 if (failed(parseVariableDeclConstraintList(constraints))) 2111 return failure(); 2112 2113 ast::Type type; 2114 if (failed(validateVariableConstraints(constraints, type))) 2115 return failure(); 2116 return createInlineVariableExpr(type, name, nameLoc, constraints); 2117 } 2118 2119 //===----------------------------------------------------------------------===// 2120 // Stmts 2121 2122 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2123 FailureOr<ast::Stmt *> stmt; 2124 switch (curToken.getKind()) { 2125 case Token::kw_erase: 2126 stmt = parseEraseStmt(); 2127 break; 2128 case Token::kw_let: 2129 stmt = parseLetStmt(); 2130 break; 2131 case Token::kw_replace: 2132 stmt = parseReplaceStmt(); 2133 break; 2134 case Token::kw_return: 2135 stmt = parseReturnStmt(); 2136 break; 2137 case Token::kw_rewrite: 2138 stmt = parseRewriteStmt(); 2139 break; 2140 default: 2141 stmt = parseExpr(); 2142 break; 2143 } 2144 if (failed(stmt) || 2145 (expectTerminalSemicolon && 2146 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2147 return failure(); 2148 return stmt; 2149 } 2150 2151 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2152 SMLoc startLoc = curToken.getStartLoc(); 2153 consumeToken(Token::l_brace); 2154 2155 // Push a new block scope and parse any nested statements. 2156 pushDeclScope(); 2157 SmallVector<ast::Stmt *> statements; 2158 while (curToken.isNot(Token::r_brace)) { 2159 FailureOr<ast::Stmt *> statement = parseStmt(); 2160 if (failed(statement)) 2161 return popDeclScope(), failure(); 2162 statements.push_back(*statement); 2163 } 2164 popDeclScope(); 2165 2166 // Consume the end brace. 2167 SMRange location(startLoc, curToken.getEndLoc()); 2168 consumeToken(Token::r_brace); 2169 2170 return ast::CompoundStmt::create(ctx, location, statements); 2171 } 2172 2173 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2174 if (parserContext == ParserContext::Constraint) 2175 return emitError("`erase` cannot be used within a Constraint"); 2176 SMRange loc = curToken.getLoc(); 2177 consumeToken(Token::kw_erase); 2178 2179 // Parse the root operation expression. 2180 FailureOr<ast::Expr *> rootOp = parseExpr(); 2181 if (failed(rootOp)) 2182 return failure(); 2183 2184 return createEraseStmt(loc, *rootOp); 2185 } 2186 2187 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2188 SMRange loc = curToken.getLoc(); 2189 consumeToken(Token::kw_let); 2190 2191 // Parse the name of the new variable. 2192 SMRange varLoc = curToken.getLoc(); 2193 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2194 // `_` is a reserved variable name. 2195 if (curToken.is(Token::underscore)) { 2196 return emitError(varLoc, 2197 "`_` may only be used to define \"inline\" variables"); 2198 } 2199 return emitError(varLoc, 2200 "expected identifier after `let` to name a new variable"); 2201 } 2202 StringRef varName = curToken.getSpelling(); 2203 consumeToken(); 2204 2205 // Parse the optional set of constraints. 2206 SmallVector<ast::ConstraintRef> constraints; 2207 if (consumeIf(Token::colon) && 2208 failed(parseVariableDeclConstraintList(constraints))) 2209 return failure(); 2210 2211 // Parse the optional initializer expression. 2212 ast::Expr *initializer = nullptr; 2213 if (consumeIf(Token::equal)) { 2214 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2215 if (failed(initOrFailure)) 2216 return failure(); 2217 initializer = *initOrFailure; 2218 2219 // Check that the constraints are compatible with having an initializer, 2220 // e.g. type constraints cannot be used with initializers. 2221 for (ast::ConstraintRef constraint : constraints) { 2222 LogicalResult result = 2223 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2224 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2225 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2226 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2227 return this->emitError( 2228 constraint.referenceLoc, 2229 "type constraints are not permitted on variables with " 2230 "initializers"); 2231 } 2232 return success(); 2233 }) 2234 .Default(success()); 2235 if (failed(result)) 2236 return failure(); 2237 } 2238 } 2239 2240 FailureOr<ast::VariableDecl *> varDecl = 2241 createVariableDecl(varName, varLoc, initializer, constraints); 2242 if (failed(varDecl)) 2243 return failure(); 2244 return ast::LetStmt::create(ctx, loc, *varDecl); 2245 } 2246 2247 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2248 if (parserContext == ParserContext::Constraint) 2249 return emitError("`replace` cannot be used within a Constraint"); 2250 SMRange loc = curToken.getLoc(); 2251 consumeToken(Token::kw_replace); 2252 2253 // Parse the root operation expression. 2254 FailureOr<ast::Expr *> rootOp = parseExpr(); 2255 if (failed(rootOp)) 2256 return failure(); 2257 2258 if (failed( 2259 parseToken(Token::kw_with, "expected `with` after root operation"))) 2260 return failure(); 2261 2262 // The replacement portion of this statement is within a rewrite context. 2263 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2264 ParserContext::Rewrite); 2265 2266 // Parse the replacement values. 2267 SmallVector<ast::Expr *> replValues; 2268 if (consumeIf(Token::l_paren)) { 2269 if (consumeIf(Token::r_paren)) { 2270 return emitError( 2271 loc, "expected at least one replacement value, consider using " 2272 "`erase` if no replacement values are desired"); 2273 } 2274 2275 do { 2276 FailureOr<ast::Expr *> replExpr = parseExpr(); 2277 if (failed(replExpr)) 2278 return failure(); 2279 replValues.emplace_back(*replExpr); 2280 } while (consumeIf(Token::comma)); 2281 2282 if (failed(parseToken(Token::r_paren, 2283 "expected `)` after replacement values"))) 2284 return failure(); 2285 } else { 2286 FailureOr<ast::Expr *> replExpr = parseExpr(); 2287 if (failed(replExpr)) 2288 return failure(); 2289 replValues.emplace_back(*replExpr); 2290 } 2291 2292 return createReplaceStmt(loc, *rootOp, replValues); 2293 } 2294 2295 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2296 SMRange loc = curToken.getLoc(); 2297 consumeToken(Token::kw_return); 2298 2299 // Parse the result value. 2300 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2301 if (failed(resultExpr)) 2302 return failure(); 2303 2304 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2305 } 2306 2307 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2308 if (parserContext == ParserContext::Constraint) 2309 return emitError("`rewrite` cannot be used within a Constraint"); 2310 SMRange loc = curToken.getLoc(); 2311 consumeToken(Token::kw_rewrite); 2312 2313 // Parse the root operation. 2314 FailureOr<ast::Expr *> rootOp = parseExpr(); 2315 if (failed(rootOp)) 2316 return failure(); 2317 2318 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2319 return failure(); 2320 2321 if (curToken.isNot(Token::l_brace)) 2322 return emitError("expected `{` to start rewrite body"); 2323 2324 // The rewrite body of this statement is within a rewrite context. 2325 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2326 ParserContext::Rewrite); 2327 2328 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2329 if (failed(rewriteBody)) 2330 return failure(); 2331 2332 // Verify the rewrite body. 2333 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2334 if (isa<ast::ReturnStmt>(stmt)) { 2335 return emitError(stmt->getLoc(), 2336 "`return` statements are only permitted within a " 2337 "`Constraint` or `Rewrite` body"); 2338 } 2339 } 2340 2341 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2342 } 2343 2344 //===----------------------------------------------------------------------===// 2345 // Creation+Analysis 2346 //===----------------------------------------------------------------------===// 2347 2348 //===----------------------------------------------------------------------===// 2349 // Decls 2350 2351 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2352 // Unwrap reference expressions. 2353 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2354 node = init->getDecl(); 2355 return dyn_cast<ast::CallableDecl>(node); 2356 } 2357 2358 FailureOr<ast::PatternDecl *> 2359 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2360 const ParsedPatternMetadata &metadata, 2361 ast::CompoundStmt *body) { 2362 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2363 metadata.hasBoundedRecursion, body); 2364 } 2365 2366 ast::Type Parser::createUserConstraintRewriteResultType( 2367 ArrayRef<ast::VariableDecl *> results) { 2368 // Single result decls use the type of the single result. 2369 if (results.size() == 1) 2370 return results[0]->getType(); 2371 2372 // Multiple results use a tuple type, with the types and names grabbed from 2373 // the result variable decls. 2374 auto resultTypes = llvm::map_range( 2375 results, [&](const auto *result) { return result->getType(); }); 2376 auto resultNames = llvm::map_range( 2377 results, [&](const auto *result) { return result->getName().getName(); }); 2378 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2379 llvm::to_vector(resultNames)); 2380 } 2381 2382 template <typename T> 2383 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2384 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2385 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2386 ast::CompoundStmt *body) { 2387 if (!body->getChildren().empty()) { 2388 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2389 ast::Expr *resultExpr = retStmt->getResultExpr(); 2390 2391 // Process the result of the decl. If no explicit signature results 2392 // were provided, check for return type inference. Otherwise, check that 2393 // the return expression can be converted to the expected type. 2394 if (results.empty()) 2395 resultType = resultExpr->getType(); 2396 else if (failed(convertExpressionTo(resultExpr, resultType))) 2397 return failure(); 2398 else 2399 retStmt->setResultExpr(resultExpr); 2400 } 2401 } 2402 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2403 } 2404 2405 FailureOr<ast::VariableDecl *> 2406 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2407 ArrayRef<ast::ConstraintRef> constraints) { 2408 // The type of the variable, which is expected to be inferred by either a 2409 // constraint or an initializer expression. 2410 ast::Type type; 2411 if (failed(validateVariableConstraints(constraints, type))) 2412 return failure(); 2413 2414 if (initializer) { 2415 // Update the variable type based on the initializer, or try to convert the 2416 // initializer to the existing type. 2417 if (!type) 2418 type = initializer->getType(); 2419 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2420 type = mergedType; 2421 else if (failed(convertExpressionTo(initializer, type))) 2422 return failure(); 2423 2424 // Otherwise, if there is no initializer check that the type has already 2425 // been resolved from the constraint list. 2426 } else if (!type) { 2427 return emitErrorAndNote( 2428 loc, "unable to infer type for variable `" + name + "`", loc, 2429 "the type of a variable must be inferable from the constraint " 2430 "list or the initializer"); 2431 } 2432 2433 // Constraint types cannot be used when defining variables. 2434 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2435 return emitError( 2436 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2437 } 2438 2439 // Try to define a variable with the given name. 2440 FailureOr<ast::VariableDecl *> varDecl = 2441 defineVariableDecl(name, loc, type, initializer, constraints); 2442 if (failed(varDecl)) 2443 return failure(); 2444 2445 return *varDecl; 2446 } 2447 2448 FailureOr<ast::VariableDecl *> 2449 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2450 const ast::ConstraintRef &constraint) { 2451 // Constraint arguments may apply more complex constraints via the arguments. 2452 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2453 ast::Type argType; 2454 if (failed(validateVariableConstraint(constraint, argType, 2455 allowNonCoreConstraints))) 2456 return failure(); 2457 return defineVariableDecl(name, loc, argType, constraint); 2458 } 2459 2460 LogicalResult 2461 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2462 ast::Type &inferredType, 2463 bool allowNonCoreConstraints) { 2464 for (const ast::ConstraintRef &ref : constraints) 2465 if (failed(validateVariableConstraint(ref, inferredType, 2466 allowNonCoreConstraints))) 2467 return failure(); 2468 return success(); 2469 } 2470 2471 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2472 ast::Type &inferredType, 2473 bool allowNonCoreConstraints) { 2474 ast::Type constraintType; 2475 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2476 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2477 if (failed(validateTypeConstraintExpr(typeExpr))) 2478 return failure(); 2479 } 2480 constraintType = ast::AttributeType::get(ctx); 2481 } else if (const auto *cst = 2482 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2483 constraintType = ast::OperationType::get(ctx, cst->getName()); 2484 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2485 constraintType = typeTy; 2486 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2487 constraintType = typeRangeTy; 2488 } else if (const auto *cst = 2489 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2490 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2491 if (failed(validateTypeConstraintExpr(typeExpr))) 2492 return failure(); 2493 } 2494 constraintType = valueTy; 2495 } else if (const auto *cst = 2496 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2497 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2498 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2499 return failure(); 2500 } 2501 constraintType = valueRangeTy; 2502 } else if (const auto *cst = 2503 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2504 if (!allowNonCoreConstraints) { 2505 return emitError(ref.referenceLoc, 2506 "`Rewrite` arguments and results are only permitted to " 2507 "use core constraints, such as `Attr`, `Op`, `Type`, " 2508 "`TypeRange`, `Value`, `ValueRange`"); 2509 } 2510 2511 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2512 if (inputs.size() != 1) { 2513 return emitErrorAndNote(ref.referenceLoc, 2514 "`Constraint`s applied via a variable constraint " 2515 "list must take a single input, but got " + 2516 Twine(inputs.size()), 2517 cst->getLoc(), 2518 "see definition of constraint here"); 2519 } 2520 constraintType = inputs.front()->getType(); 2521 } else { 2522 llvm_unreachable("unknown constraint type"); 2523 } 2524 2525 // Check that the constraint type is compatible with the current inferred 2526 // type. 2527 if (!inferredType) { 2528 inferredType = constraintType; 2529 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2530 inferredType = mergedTy; 2531 } else { 2532 return emitError(ref.referenceLoc, 2533 llvm::formatv("constraint type `{0}` is incompatible " 2534 "with the previously inferred type `{1}`", 2535 constraintType, inferredType)); 2536 } 2537 return success(); 2538 } 2539 2540 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2541 ast::Type typeExprType = typeExpr->getType(); 2542 if (typeExprType != typeTy) { 2543 return emitError(typeExpr->getLoc(), 2544 "expected expression of `Type` in type constraint"); 2545 } 2546 return success(); 2547 } 2548 2549 LogicalResult 2550 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2551 ast::Type typeExprType = typeExpr->getType(); 2552 if (typeExprType != typeRangeTy) { 2553 return emitError(typeExpr->getLoc(), 2554 "expected expression of `TypeRange` in type constraint"); 2555 } 2556 return success(); 2557 } 2558 2559 //===----------------------------------------------------------------------===// 2560 // Exprs 2561 2562 FailureOr<ast::CallExpr *> 2563 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2564 MutableArrayRef<ast::Expr *> arguments) { 2565 ast::Type parentType = parentExpr->getType(); 2566 2567 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2568 if (!callableDecl) { 2569 return emitError(loc, 2570 llvm::formatv("expected a reference to a callable " 2571 "`Constraint` or `Rewrite`, but got: `{0}`", 2572 parentType)); 2573 } 2574 if (parserContext == ParserContext::Rewrite) { 2575 if (isa<ast::UserConstraintDecl>(callableDecl)) 2576 return emitError( 2577 loc, "unable to invoke `Constraint` within a rewrite section"); 2578 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2579 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2580 } 2581 2582 // Verify the arguments of the call. 2583 /// Handle size mismatch. 2584 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2585 if (callArgs.size() != arguments.size()) { 2586 return emitErrorAndNote( 2587 loc, 2588 llvm::formatv("invalid number of arguments for {0} call; expected " 2589 "{1}, but got {2}", 2590 callableDecl->getCallableType(), callArgs.size(), 2591 arguments.size()), 2592 callableDecl->getLoc(), 2593 llvm::formatv("see the definition of {0} here", 2594 callableDecl->getName()->getName())); 2595 } 2596 2597 /// Handle argument type mismatch. 2598 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2599 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2600 callableDecl->getName()->getName()), 2601 callableDecl->getLoc()); 2602 }; 2603 for (auto it : llvm::zip(callArgs, arguments)) { 2604 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2605 attachDiagFn))) 2606 return failure(); 2607 } 2608 2609 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2610 callableDecl->getResultType()); 2611 } 2612 2613 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2614 ast::Decl *decl) { 2615 // Check the type of decl being referenced. 2616 ast::Type declType; 2617 if (isa<ast::ConstraintDecl>(decl)) 2618 declType = ast::ConstraintType::get(ctx); 2619 else if (isa<ast::UserRewriteDecl>(decl)) 2620 declType = ast::RewriteType::get(ctx); 2621 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2622 declType = varDecl->getType(); 2623 else 2624 return emitError(loc, "invalid reference to `" + 2625 decl->getName()->getName() + "`"); 2626 2627 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2628 } 2629 2630 FailureOr<ast::DeclRefExpr *> 2631 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2632 ArrayRef<ast::ConstraintRef> constraints) { 2633 FailureOr<ast::VariableDecl *> decl = 2634 defineVariableDecl(name, loc, type, constraints); 2635 if (failed(decl)) 2636 return failure(); 2637 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2638 } 2639 2640 FailureOr<ast::MemberAccessExpr *> 2641 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2642 SMRange loc) { 2643 // Validate the member name for the given parent expression. 2644 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2645 if (failed(memberType)) 2646 return failure(); 2647 2648 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2649 } 2650 2651 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2652 StringRef name, SMRange loc) { 2653 ast::Type parentType = parentExpr->getType(); 2654 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2655 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2656 return valueRangeTy; 2657 2658 // Verify member access based on the operation type. 2659 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2660 auto results = odsOp->getResults(); 2661 2662 // Handle indexed results. 2663 unsigned index = 0; 2664 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2665 index < results.size()) { 2666 return results[index].isVariadic() ? valueRangeTy : valueTy; 2667 } 2668 2669 // Handle named results. 2670 const auto *it = llvm::find_if(results, [&](const auto &result) { 2671 return result.getName() == name; 2672 }); 2673 if (it != results.end()) 2674 return it->isVariadic() ? valueRangeTy : valueTy; 2675 } 2676 2677 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2678 // Handle indexed results. 2679 unsigned index = 0; 2680 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2681 index < tupleType.size()) { 2682 return tupleType.getElementTypes()[index]; 2683 } 2684 2685 // Handle named results. 2686 auto elementNames = tupleType.getElementNames(); 2687 const auto *it = llvm::find(elementNames, name); 2688 if (it != elementNames.end()) 2689 return tupleType.getElementTypes()[it - elementNames.begin()]; 2690 } 2691 return emitError( 2692 loc, 2693 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2694 name, parentType)); 2695 } 2696 2697 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2698 SMRange loc, const ast::OpNameDecl *name, 2699 MutableArrayRef<ast::Expr *> operands, 2700 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2701 MutableArrayRef<ast::Expr *> results) { 2702 Optional<StringRef> opNameRef = name->getName(); 2703 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2704 2705 // Verify the inputs operands. 2706 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2707 return failure(); 2708 2709 // Verify the attribute list. 2710 for (ast::NamedAttributeDecl *attr : attributes) { 2711 // Check for an attribute type, or a type awaiting resolution. 2712 ast::Type attrType = attr->getValue()->getType(); 2713 if (!attrType.isa<ast::AttributeType>()) { 2714 return emitError( 2715 attr->getValue()->getLoc(), 2716 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2717 } 2718 } 2719 2720 // Verify the result types. 2721 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2722 return failure(); 2723 2724 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2725 attributes); 2726 } 2727 2728 LogicalResult 2729 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2730 const ods::Operation *odsOp, 2731 MutableArrayRef<ast::Expr *> operands) { 2732 return validateOperationOperandsOrResults( 2733 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2734 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2735 valueRangeTy); 2736 } 2737 2738 LogicalResult 2739 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2740 const ods::Operation *odsOp, 2741 MutableArrayRef<ast::Expr *> results) { 2742 return validateOperationOperandsOrResults( 2743 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2744 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2745 } 2746 2747 LogicalResult Parser::validateOperationOperandsOrResults( 2748 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2749 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2750 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2751 ast::Type rangeTy) { 2752 // All operation types accept a single range parameter. 2753 if (values.size() == 1) { 2754 if (failed(convertExpressionTo(values[0], rangeTy))) 2755 return failure(); 2756 return success(); 2757 } 2758 2759 /// If the operation has ODS information, we can more accurately verify the 2760 /// values. 2761 if (odsOpLoc) { 2762 if (odsValues.size() != values.size()) { 2763 return emitErrorAndNote( 2764 loc, 2765 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2766 "{2}, but got {3}", 2767 groupName, *name, odsValues.size(), values.size()), 2768 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2769 } 2770 auto diagFn = [&](ast::Diagnostic &diag) { 2771 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2772 *odsOpLoc); 2773 }; 2774 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2775 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2776 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2777 return failure(); 2778 } 2779 return success(); 2780 } 2781 2782 // Otherwise, accept the value groups as they have been defined and just 2783 // ensure they are one of the expected types. 2784 for (ast::Expr *&valueExpr : values) { 2785 ast::Type valueExprType = valueExpr->getType(); 2786 2787 // Check if this is one of the expected types. 2788 if (valueExprType == rangeTy || valueExprType == singleTy) 2789 continue; 2790 2791 // If the operand is an Operation, allow converting to a Value or 2792 // ValueRange. This situations arises quite often with nested operation 2793 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2794 if (singleTy == valueTy) { 2795 if (valueExprType.isa<ast::OperationType>()) { 2796 valueExpr = convertOpToValue(valueExpr); 2797 continue; 2798 } 2799 } 2800 2801 return emitError( 2802 valueExpr->getLoc(), 2803 llvm::formatv( 2804 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2805 singleTy, rangeTy, valueExprType)); 2806 } 2807 return success(); 2808 } 2809 2810 FailureOr<ast::TupleExpr *> 2811 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2812 ArrayRef<StringRef> elementNames) { 2813 for (const ast::Expr *element : elements) { 2814 ast::Type eleTy = element->getType(); 2815 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2816 return emitError( 2817 element->getLoc(), 2818 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2819 } 2820 } 2821 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2822 } 2823 2824 //===----------------------------------------------------------------------===// 2825 // Stmts 2826 2827 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2828 ast::Expr *rootOp) { 2829 // Check that root is an Operation. 2830 ast::Type rootType = rootOp->getType(); 2831 if (!rootType.isa<ast::OperationType>()) 2832 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2833 2834 return ast::EraseStmt::create(ctx, loc, rootOp); 2835 } 2836 2837 FailureOr<ast::ReplaceStmt *> 2838 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2839 MutableArrayRef<ast::Expr *> replValues) { 2840 // Check that root is an Operation. 2841 ast::Type rootType = rootOp->getType(); 2842 if (!rootType.isa<ast::OperationType>()) { 2843 return emitError( 2844 rootOp->getLoc(), 2845 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2846 } 2847 2848 // If there are multiple replacement values, we implicitly convert any Op 2849 // expressions to the value form. 2850 bool shouldConvertOpToValues = replValues.size() > 1; 2851 for (ast::Expr *&replExpr : replValues) { 2852 ast::Type replType = replExpr->getType(); 2853 2854 // Check that replExpr is an Operation, Value, or ValueRange. 2855 if (replType.isa<ast::OperationType>()) { 2856 if (shouldConvertOpToValues) 2857 replExpr = convertOpToValue(replExpr); 2858 continue; 2859 } 2860 2861 if (replType != valueTy && replType != valueRangeTy) { 2862 return emitError(replExpr->getLoc(), 2863 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2864 "expression, but got `{0}`", 2865 replType)); 2866 } 2867 } 2868 2869 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2870 } 2871 2872 FailureOr<ast::RewriteStmt *> 2873 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2874 ast::CompoundStmt *rewriteBody) { 2875 // Check that root is an Operation. 2876 ast::Type rootType = rootOp->getType(); 2877 if (!rootType.isa<ast::OperationType>()) { 2878 return emitError( 2879 rootOp->getLoc(), 2880 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2881 } 2882 2883 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2884 } 2885 2886 //===----------------------------------------------------------------------===// 2887 // Code Completion 2888 //===----------------------------------------------------------------------===// 2889 2890 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2891 ast::Type parentType = parentExpr->getType(); 2892 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2893 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2894 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2895 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2896 return failure(); 2897 } 2898 2899 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2900 if (opName) 2901 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2902 return failure(); 2903 } 2904 2905 LogicalResult 2906 Parser::codeCompleteConstraintName(ast::Type inferredType, 2907 bool allowNonCoreConstraints, 2908 bool allowInlineTypeConstraints) { 2909 codeCompleteContext->codeCompleteConstraintName( 2910 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2911 curDeclScope); 2912 return failure(); 2913 } 2914 2915 LogicalResult Parser::codeCompleteDialectName() { 2916 codeCompleteContext->codeCompleteDialectName(); 2917 return failure(); 2918 } 2919 2920 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2921 codeCompleteContext->codeCompleteOperationName(dialectName); 2922 return failure(); 2923 } 2924 2925 LogicalResult Parser::codeCompletePatternMetadata() { 2926 codeCompleteContext->codeCompletePatternMetadata(); 2927 return failure(); 2928 } 2929 2930 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 2931 codeCompleteContext->codeCompleteIncludeFilename(curPath); 2932 return failure(); 2933 } 2934 2935 void Parser::codeCompleteCallSignature(ast::Node *parent, 2936 unsigned currentNumArgs) { 2937 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2938 if (!callableDecl) 2939 return; 2940 2941 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2942 } 2943 2944 void Parser::codeCompleteOperationOperandsSignature( 2945 Optional<StringRef> opName, unsigned currentNumOperands) { 2946 codeCompleteContext->codeCompleteOperationOperandsSignature( 2947 opName, currentNumOperands); 2948 } 2949 2950 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2951 unsigned currentNumResults) { 2952 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2953 currentNumResults); 2954 } 2955 2956 //===----------------------------------------------------------------------===// 2957 // Parser 2958 //===----------------------------------------------------------------------===// 2959 2960 FailureOr<ast::Module *> 2961 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2962 CodeCompleteContext *codeCompleteContext) { 2963 Parser parser(ctx, sourceMgr, codeCompleteContext); 2964 return parser.parseModule(); 2965 } 2966