1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 428 429 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 430 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 431 unsigned currentNumOperands); 432 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 433 unsigned currentNumResults); 434 435 //===--------------------------------------------------------------------===// 436 // Lexer Utilities 437 //===--------------------------------------------------------------------===// 438 439 /// If the current token has the specified kind, consume it and return true. 440 /// If not, return false. 441 bool consumeIf(Token::Kind kind) { 442 if (curToken.isNot(kind)) 443 return false; 444 consumeToken(kind); 445 return true; 446 } 447 448 /// Advance the current lexer onto the next token. 449 void consumeToken() { 450 assert(curToken.isNot(Token::eof, Token::error) && 451 "shouldn't advance past EOF or errors"); 452 curToken = lexer.lexToken(); 453 } 454 455 /// Advance the current lexer onto the next token, asserting what the expected 456 /// current token is. This is preferred to the above method because it leads 457 /// to more self-documenting code with better checking. 458 void consumeToken(Token::Kind kind) { 459 assert(curToken.is(kind) && "consumed an unexpected token"); 460 consumeToken(); 461 } 462 463 /// Reset the lexer to the location at the given position. 464 void resetToken(SMRange tokLoc) { 465 lexer.resetPointer(tokLoc.Start.getPointer()); 466 curToken = lexer.lexToken(); 467 } 468 469 /// Consume the specified token if present and return success. On failure, 470 /// output a diagnostic and return failure. 471 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 472 if (curToken.getKind() != kind) 473 return emitError(curToken.getLoc(), msg); 474 consumeToken(); 475 return success(); 476 } 477 LogicalResult emitError(SMRange loc, const Twine &msg) { 478 lexer.emitError(loc, msg); 479 return failure(); 480 } 481 LogicalResult emitError(const Twine &msg) { 482 return emitError(curToken.getLoc(), msg); 483 } 484 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 485 const Twine ¬e) { 486 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 487 return failure(); 488 } 489 490 //===--------------------------------------------------------------------===// 491 // Fields 492 //===--------------------------------------------------------------------===// 493 494 /// The owning AST context. 495 ast::Context &ctx; 496 497 /// The lexer of this parser. 498 Lexer lexer; 499 500 /// The current token within the lexer. 501 Token curToken; 502 503 /// The most recently defined decl scope. 504 ast::DeclScope *curDeclScope = nullptr; 505 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 506 507 /// The current context of the parser. 508 ParserContext parserContext = ParserContext::Global; 509 510 /// Cached types to simplify verification and expression creation. 511 ast::Type valueTy, valueRangeTy; 512 ast::Type typeTy, typeRangeTy; 513 ast::Type attrTy; 514 515 /// A counter used when naming anonymous constraints and rewrites. 516 unsigned anonymousDeclNameCounter = 0; 517 518 /// The optional code completion context. 519 CodeCompleteContext *codeCompleteContext; 520 }; 521 } // namespace 522 523 FailureOr<ast::Module *> Parser::parseModule() { 524 SMLoc moduleLoc = curToken.getStartLoc(); 525 pushDeclScope(); 526 527 // Parse the top-level decls of the module. 528 SmallVector<ast::Decl *> decls; 529 if (failed(parseModuleBody(decls))) 530 return popDeclScope(), failure(); 531 532 popDeclScope(); 533 return ast::Module::create(ctx, moduleLoc, decls); 534 } 535 536 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 537 while (curToken.isNot(Token::eof)) { 538 if (curToken.is(Token::directive)) { 539 if (failed(parseDirective(decls))) 540 return failure(); 541 continue; 542 } 543 544 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 545 if (failed(decl)) 546 return failure(); 547 decls.push_back(*decl); 548 } 549 return success(); 550 } 551 552 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 553 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 554 valueRangeTy); 555 } 556 557 LogicalResult Parser::convertExpressionTo( 558 ast::Expr *&expr, ast::Type type, 559 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 560 ast::Type exprType = expr->getType(); 561 if (exprType == type) 562 return success(); 563 564 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 565 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 566 expr->getLoc(), llvm::formatv("unable to convert expression of type " 567 "`{0}` to the expected type of " 568 "`{1}`", 569 exprType, type)); 570 if (noteAttachFn) 571 noteAttachFn(*diag); 572 return diag; 573 }; 574 575 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 576 // Two operation types are compatible if they have the same name, or if the 577 // expected type is more general. 578 if (auto opType = type.dyn_cast<ast::OperationType>()) { 579 if (opType.getName()) 580 return emitConvertError(); 581 return success(); 582 } 583 584 // An operation can always convert to a ValueRange. 585 if (type == valueRangeTy) { 586 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 587 valueRangeTy); 588 return success(); 589 } 590 591 // Allow conversion to a single value by constraining the result range. 592 if (type == valueTy) { 593 // If the operation is registered, we can verify if it can ever have a 594 // single result. 595 Optional<StringRef> opName = exprOpType.getName(); 596 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 597 if (odsOp->getResults().empty()) { 598 return emitConvertError()->attachNote( 599 llvm::formatv("see the definition of `{0}`, which was defined " 600 "with zero results", 601 odsOp->getName()), 602 odsOp->getLoc()); 603 } 604 605 unsigned numSingleResults = llvm::count_if( 606 odsOp->getResults(), [](const ods::OperandOrResult &result) { 607 return result.getVariableLengthKind() == 608 ods::VariableLengthKind::Single; 609 }); 610 if (numSingleResults > 1) { 611 return emitConvertError()->attachNote( 612 llvm::formatv("see the definition of `{0}`, which was defined " 613 "with at least {1} results", 614 odsOp->getName(), numSingleResults), 615 odsOp->getLoc()); 616 } 617 } 618 619 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 620 valueTy); 621 return success(); 622 } 623 return emitConvertError(); 624 } 625 626 // FIXME: Decide how to allow/support converting a single result to multiple, 627 // and multiple to a single result. For now, we just allow Single->Range, 628 // but this isn't something really supported in the PDL dialect. We should 629 // figure out some way to support both. 630 if ((exprType == valueTy || exprType == valueRangeTy) && 631 (type == valueTy || type == valueRangeTy)) 632 return success(); 633 if ((exprType == typeTy || exprType == typeRangeTy) && 634 (type == typeTy || type == typeRangeTy)) 635 return success(); 636 637 // Handle tuple types. 638 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 639 auto tupleType = type.dyn_cast<ast::TupleType>(); 640 if (!tupleType || tupleType.size() != exprTupleType.size()) 641 return emitConvertError(); 642 643 // Build a new tuple expression using each of the elements of the current 644 // tuple. 645 SmallVector<ast::Expr *> newExprs; 646 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 647 newExprs.push_back(ast::MemberAccessExpr::create( 648 ctx, expr->getLoc(), expr, llvm::to_string(i), 649 exprTupleType.getElementTypes()[i])); 650 651 auto diagFn = [&](ast::Diagnostic &diag) { 652 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 653 i, exprTupleType)); 654 if (noteAttachFn) 655 noteAttachFn(diag); 656 }; 657 if (failed(convertExpressionTo(newExprs.back(), 658 tupleType.getElementTypes()[i], diagFn))) 659 return failure(); 660 } 661 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 662 tupleType.getElementNames()); 663 return success(); 664 } 665 666 return emitConvertError(); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Directives 671 672 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 673 StringRef directive = curToken.getSpelling(); 674 if (directive == "#include") 675 return parseInclude(decls); 676 677 return emitError("unknown directive `" + directive + "`"); 678 } 679 680 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 681 SMRange loc = curToken.getLoc(); 682 consumeToken(Token::directive); 683 684 // Handle code completion of the include file path. 685 if (curToken.is(Token::code_complete_string)) 686 return codeCompleteIncludeFilename(curToken.getStringValue()); 687 688 // Parse the file being included. 689 if (!curToken.isString()) 690 return emitError(loc, 691 "expected string file name after `include` directive"); 692 SMRange fileLoc = curToken.getLoc(); 693 std::string filenameStr = curToken.getStringValue(); 694 StringRef filename = filenameStr; 695 consumeToken(); 696 697 // Check the type of include. If ending with `.pdll`, this is another pdl file 698 // to be parsed along with the current module. 699 if (filename.endswith(".pdll")) { 700 if (failed(lexer.pushInclude(filename, fileLoc))) 701 return emitError(fileLoc, 702 "unable to open include file `" + filename + "`"); 703 704 // If we added the include successfully, parse it into the current module. 705 // Make sure to update to the next token after we finish parsing the nested 706 // file. 707 curToken = lexer.lexToken(); 708 LogicalResult result = parseModuleBody(decls); 709 curToken = lexer.lexToken(); 710 return result; 711 } 712 713 // Otherwise, this must be a `.td` include. 714 if (filename.endswith(".td")) 715 return parseTdInclude(filename, fileLoc, decls); 716 717 return emitError(fileLoc, 718 "expected include filename to end with `.pdll` or `.td`"); 719 } 720 721 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 722 SmallVectorImpl<ast::Decl *> &decls) { 723 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 724 725 // Use the source manager to open the file, but don't yet add it. 726 std::string includedFile; 727 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 728 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 729 if (!includeBuffer) 730 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 731 732 // Setup the source manager for parsing the tablegen file. 733 llvm::SourceMgr tdSrcMgr; 734 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 735 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 736 737 // This class provides a context argument for the llvm::SourceMgr diagnostic 738 // handler. 739 struct DiagHandlerContext { 740 Parser &parser; 741 StringRef filename; 742 llvm::SMRange loc; 743 } handlerContext{*this, filename, fileLoc}; 744 745 // Set the diagnostic handler for the tablegen source manager. 746 tdSrcMgr.setDiagHandler( 747 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 748 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 749 (void)ctx->parser.emitError( 750 ctx->loc, 751 llvm::formatv("error while processing include file `{0}`: {1}", 752 ctx->filename, diag.getMessage())); 753 }, 754 &handlerContext); 755 756 // Parse the tablegen file. 757 llvm::RecordKeeper tdRecords; 758 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 759 return failure(); 760 761 // Process the parsed records. 762 processTdIncludeRecords(tdRecords, decls); 763 764 // After we are done processing, move all of the tablegen source buffers to 765 // the main parser source mgr. This allows for directly using source locations 766 // from the .td files without needing to remap them. 767 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 768 return success(); 769 } 770 771 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 772 SmallVectorImpl<ast::Decl *> &decls) { 773 // Return the length kind of the given value. 774 auto getLengthKind = [](const auto &value) { 775 if (value.isOptional()) 776 return ods::VariableLengthKind::Optional; 777 return value.isVariadic() ? ods::VariableLengthKind::Variadic 778 : ods::VariableLengthKind::Single; 779 }; 780 781 // Insert a type constraint into the ODS context. 782 ods::Context &odsContext = ctx.getODSContext(); 783 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 784 -> const ods::TypeConstraint & { 785 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 786 cst.constraint.getSummary(), 787 cst.constraint.getCPPClassName()); 788 }; 789 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 790 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 791 }; 792 793 // Process the parsed tablegen records to build ODS information. 794 /// Operations. 795 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 796 tblgen::Operator op(def); 797 798 bool inserted = false; 799 ods::Operation *odsOp = nullptr; 800 std::tie(odsOp, inserted) = 801 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 802 op.getDescription(), op.getLoc().front()); 803 804 // Ignore operations that have already been added. 805 if (!inserted) 806 continue; 807 808 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 809 odsOp->appendAttribute( 810 attr.name, attr.attr.isOptional(), 811 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 812 attr.attr.getSummary(), 813 attr.attr.getStorageType())); 814 } 815 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 816 odsOp->appendOperand(operand.name, getLengthKind(operand), 817 addTypeConstraint(operand)); 818 } 819 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 820 odsOp->appendResult(result.name, getLengthKind(result), 821 addTypeConstraint(result)); 822 } 823 } 824 /// Attr constraints. 825 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 826 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 827 decls.push_back( 828 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 829 tblgen::AttrConstraint(def), 830 convertLocToRange(def->getLoc().front()), attrTy)); 831 } 832 } 833 /// Type constraints. 834 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 835 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 836 decls.push_back( 837 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 838 tblgen::TypeConstraint(def), 839 convertLocToRange(def->getLoc().front()), typeTy)); 840 } 841 } 842 /// Interfaces. 843 ast::Type opTy = ast::OperationType::get(ctx); 844 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 845 StringRef name = def->getName(); 846 if (def->isAnonymous() || curDeclScope->lookup(name) || 847 def->isSubClassOf("DeclareInterfaceMethods")) 848 continue; 849 SMRange loc = convertLocToRange(def->getLoc().front()); 850 851 StringRef className = def->getValueAsString("cppClassName"); 852 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 853 std::string codeBlock = 854 llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));", 855 cppNamespace, className) 856 .str(); 857 858 if (def->isSubClassOf("OpInterface")) { 859 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 860 name, codeBlock, loc, opTy)); 861 } else if (def->isSubClassOf("AttrInterface")) { 862 decls.push_back( 863 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 864 name, codeBlock, loc, attrTy)); 865 } else if (def->isSubClassOf("TypeInterface")) { 866 decls.push_back( 867 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 868 name, codeBlock, loc, typeTy)); 869 } 870 } 871 } 872 873 template <typename ConstraintT> 874 ast::Decl * 875 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 876 SMRange loc, ast::Type type) { 877 // Build the single input parameter. 878 ast::DeclScope *argScope = pushDeclScope(); 879 auto *paramVar = ast::VariableDecl::create( 880 ctx, ast::Name::create(ctx, "self", loc), type, 881 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 882 argScope->add(paramVar); 883 popDeclScope(); 884 885 // Build the native constraint. 886 auto *constraintDecl = ast::UserConstraintDecl::createNative( 887 ctx, ast::Name::create(ctx, name, loc), paramVar, 888 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 889 curDeclScope->add(constraintDecl); 890 return constraintDecl; 891 } 892 893 template <typename ConstraintT> 894 ast::Decl * 895 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 896 SMRange loc, ast::Type type) { 897 // Format the condition template. 898 tblgen::FmtContext fmtContext; 899 fmtContext.withSelf("self"); 900 std::string codeBlock = tblgen::tgfmt( 901 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 902 &fmtContext); 903 904 return createODSNativePDLLConstraintDecl<ConstraintT>( 905 constraint.getUniqueDefName(), codeBlock, loc, type); 906 } 907 908 //===----------------------------------------------------------------------===// 909 // Decls 910 911 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 912 FailureOr<ast::Decl *> decl; 913 switch (curToken.getKind()) { 914 case Token::kw_Constraint: 915 decl = parseUserConstraintDecl(); 916 break; 917 case Token::kw_Pattern: 918 decl = parsePatternDecl(); 919 break; 920 case Token::kw_Rewrite: 921 decl = parseUserRewriteDecl(); 922 break; 923 default: 924 return emitError("expected top-level declaration, such as a `Pattern`"); 925 } 926 if (failed(decl)) 927 return failure(); 928 929 // If the decl has a name, add it to the current scope. 930 if (const ast::Name *name = (*decl)->getName()) { 931 if (failed(checkDefineNamedDecl(*name))) 932 return failure(); 933 curDeclScope->add(*decl); 934 } 935 return decl; 936 } 937 938 FailureOr<ast::NamedAttributeDecl *> 939 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 940 // Check for name code completion. 941 if (curToken.is(Token::code_complete)) 942 return codeCompleteAttributeName(parentOpName); 943 944 std::string attrNameStr; 945 if (curToken.isString()) 946 attrNameStr = curToken.getStringValue(); 947 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 948 attrNameStr = curToken.getSpelling().str(); 949 else 950 return emitError("expected identifier or string attribute name"); 951 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 952 consumeToken(); 953 954 // Check for a value of the attribute. 955 ast::Expr *attrValue = nullptr; 956 if (consumeIf(Token::equal)) { 957 FailureOr<ast::Expr *> attrExpr = parseExpr(); 958 if (failed(attrExpr)) 959 return failure(); 960 attrValue = *attrExpr; 961 } else { 962 // If there isn't a concrete value, create an expression representing a 963 // UnitAttr. 964 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 965 } 966 967 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 968 } 969 970 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 971 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 972 bool expectTerminalSemicolon) { 973 consumeToken(Token::equal_arrow); 974 975 // Parse the single statement of the lambda body. 976 SMLoc bodyStartLoc = curToken.getStartLoc(); 977 pushDeclScope(); 978 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 979 bool failedToParse = 980 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 981 popDeclScope(); 982 if (failedToParse) 983 return failure(); 984 985 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 986 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 987 } 988 989 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 990 // Ensure that the argument is named. 991 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 992 return emitError("expected identifier argument name"); 993 994 // Parse the argument similarly to a normal variable. 995 StringRef name = curToken.getSpelling(); 996 SMRange nameLoc = curToken.getLoc(); 997 consumeToken(); 998 999 if (failed( 1000 parseToken(Token::colon, "expected `:` before argument constraint"))) 1001 return failure(); 1002 1003 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1004 if (failed(cst)) 1005 return failure(); 1006 1007 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1008 } 1009 1010 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1011 // Check to see if this result is named. 1012 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1013 // Check to see if this name actually refers to a Constraint. 1014 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1015 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1016 // If yes, and this is a Rewrite, give a nice error message as non-Core 1017 // constraints are not supported on Rewrite results. 1018 if (parserContext == ParserContext::Rewrite) { 1019 return emitError( 1020 "`Rewrite` results are only permitted to use core constraints, " 1021 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1022 } 1023 1024 // Otherwise, parse this as an unnamed result variable. 1025 } else { 1026 // If it wasn't a constraint, parse the result similarly to a variable. If 1027 // there is already an existing decl, we will emit an error when defining 1028 // this variable later. 1029 StringRef name = curToken.getSpelling(); 1030 SMRange nameLoc = curToken.getLoc(); 1031 consumeToken(); 1032 1033 if (failed(parseToken(Token::colon, 1034 "expected `:` before result constraint"))) 1035 return failure(); 1036 1037 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1038 if (failed(cst)) 1039 return failure(); 1040 1041 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1042 } 1043 } 1044 1045 // If it isn't named, we parse the constraint directly and create an unnamed 1046 // result variable. 1047 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1048 if (failed(cst)) 1049 return failure(); 1050 1051 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1052 } 1053 1054 FailureOr<ast::UserConstraintDecl *> 1055 Parser::parseUserConstraintDecl(bool isInline) { 1056 // Constraints and rewrites have very similar formats, dispatch to a shared 1057 // interface for parsing. 1058 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1059 [&](auto &&...args) { 1060 return this->parseUserPDLLConstraintDecl(args...); 1061 }, 1062 ParserContext::Constraint, "constraint", isInline); 1063 } 1064 1065 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1066 FailureOr<ast::UserConstraintDecl *> decl = 1067 parseUserConstraintDecl(/*isInline=*/true); 1068 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1069 return failure(); 1070 1071 curDeclScope->add(*decl); 1072 return decl; 1073 } 1074 1075 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1076 const ast::Name &name, bool isInline, 1077 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1078 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1079 // Push the argument scope back onto the list, so that the body can 1080 // reference arguments. 1081 pushDeclScope(argumentScope); 1082 1083 // Parse the body of the constraint. The body is either defined as a compound 1084 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1085 ast::CompoundStmt *body; 1086 if (curToken.is(Token::equal_arrow)) { 1087 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1088 [&](ast::Stmt *&stmt) -> LogicalResult { 1089 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1090 if (!stmtExpr) { 1091 return emitError(stmt->getLoc(), 1092 "expected `Constraint` lambda body to contain a " 1093 "single expression"); 1094 } 1095 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1096 return success(); 1097 }, 1098 /*expectTerminalSemicolon=*/!isInline); 1099 if (failed(bodyResult)) 1100 return failure(); 1101 body = *bodyResult; 1102 } else { 1103 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1104 if (failed(bodyResult)) 1105 return failure(); 1106 body = *bodyResult; 1107 1108 // Verify the structure of the body. 1109 auto bodyIt = body->begin(), bodyE = body->end(); 1110 for (; bodyIt != bodyE; ++bodyIt) 1111 if (isa<ast::ReturnStmt>(*bodyIt)) 1112 break; 1113 if (failed(validateUserConstraintOrRewriteReturn( 1114 "Constraint", body, bodyIt, bodyE, results, resultType))) 1115 return failure(); 1116 } 1117 popDeclScope(); 1118 1119 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1120 name, arguments, results, resultType, body); 1121 } 1122 1123 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1124 // Constraints and rewrites have very similar formats, dispatch to a shared 1125 // interface for parsing. 1126 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1127 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1128 ParserContext::Rewrite, "rewrite", isInline); 1129 } 1130 1131 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1132 FailureOr<ast::UserRewriteDecl *> decl = 1133 parseUserRewriteDecl(/*isInline=*/true); 1134 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1135 return failure(); 1136 1137 curDeclScope->add(*decl); 1138 return decl; 1139 } 1140 1141 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1142 const ast::Name &name, bool isInline, 1143 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1144 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1145 // Push the argument scope back onto the list, so that the body can 1146 // reference arguments. 1147 curDeclScope = argumentScope; 1148 ast::CompoundStmt *body; 1149 if (curToken.is(Token::equal_arrow)) { 1150 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1151 [&](ast::Stmt *&statement) -> LogicalResult { 1152 if (isa<ast::OpRewriteStmt>(statement)) 1153 return success(); 1154 1155 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1156 if (!statementExpr) { 1157 return emitError( 1158 statement->getLoc(), 1159 "expected `Rewrite` lambda body to contain a single expression " 1160 "or an operation rewrite statement; such as `erase`, " 1161 "`replace`, or `rewrite`"); 1162 } 1163 statement = 1164 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1165 return success(); 1166 }, 1167 /*expectTerminalSemicolon=*/!isInline); 1168 if (failed(bodyResult)) 1169 return failure(); 1170 body = *bodyResult; 1171 } else { 1172 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1173 if (failed(bodyResult)) 1174 return failure(); 1175 body = *bodyResult; 1176 } 1177 popDeclScope(); 1178 1179 // Verify the structure of the body. 1180 auto bodyIt = body->begin(), bodyE = body->end(); 1181 for (; bodyIt != bodyE; ++bodyIt) 1182 if (isa<ast::ReturnStmt>(*bodyIt)) 1183 break; 1184 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1185 bodyE, results, resultType))) 1186 return failure(); 1187 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1188 name, arguments, results, resultType, body); 1189 } 1190 1191 template <typename T, typename ParseUserPDLLDeclFnT> 1192 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1193 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1194 StringRef anonymousNamePrefix, bool isInline) { 1195 SMRange loc = curToken.getLoc(); 1196 consumeToken(); 1197 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1198 1199 // Parse the name of the decl. 1200 const ast::Name *name = nullptr; 1201 if (curToken.isNot(Token::identifier)) { 1202 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1203 // in C++, so being unnamed is fine. 1204 if (!isInline) 1205 return emitError("expected identifier name"); 1206 1207 // Create a unique anonymous name to use, as the name for this decl is not 1208 // important. 1209 std::string anonName = 1210 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1211 anonymousDeclNameCounter++) 1212 .str(); 1213 name = &ast::Name::create(ctx, anonName, loc); 1214 } else { 1215 // If a name was provided, we can use it directly. 1216 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1217 consumeToken(Token::identifier); 1218 } 1219 1220 // Parse the functional signature of the decl. 1221 SmallVector<ast::VariableDecl *> arguments, results; 1222 ast::DeclScope *argumentScope; 1223 ast::Type resultType; 1224 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1225 argumentScope, resultType))) 1226 return failure(); 1227 1228 // Check to see which type of constraint this is. If the constraint contains a 1229 // compound body, this is a PDLL decl. 1230 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1231 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1232 resultType); 1233 1234 // Otherwise, this is a native decl. 1235 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1236 results, resultType); 1237 } 1238 1239 template <typename T> 1240 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1241 const ast::Name &name, bool isInline, 1242 ArrayRef<ast::VariableDecl *> arguments, 1243 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1244 // If followed by a string, the native code body has also been specified. 1245 std::string codeStrStorage; 1246 Optional<StringRef> optCodeStr; 1247 if (curToken.isString()) { 1248 codeStrStorage = curToken.getStringValue(); 1249 optCodeStr = codeStrStorage; 1250 consumeToken(); 1251 } else if (isInline) { 1252 return emitError(name.getLoc(), 1253 "external declarations must be declared in global scope"); 1254 } else if (curToken.is(Token::error)) { 1255 return failure(); 1256 } 1257 if (failed(parseToken(Token::semicolon, 1258 "expected `;` after native declaration"))) 1259 return failure(); 1260 // TODO: PDL should be able to support constraint results in certain 1261 // situations, we should revise this. 1262 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1263 return emitError( 1264 "native Constraints currently do not support returning results"); 1265 } 1266 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1267 } 1268 1269 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1270 SmallVectorImpl<ast::VariableDecl *> &arguments, 1271 SmallVectorImpl<ast::VariableDecl *> &results, 1272 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1273 // Parse the argument list of the decl. 1274 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1275 return failure(); 1276 1277 argumentScope = pushDeclScope(); 1278 if (curToken.isNot(Token::r_paren)) { 1279 do { 1280 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1281 if (failed(argument)) 1282 return failure(); 1283 arguments.emplace_back(*argument); 1284 } while (consumeIf(Token::comma)); 1285 } 1286 popDeclScope(); 1287 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1288 return failure(); 1289 1290 // Parse the results of the decl. 1291 pushDeclScope(); 1292 if (consumeIf(Token::arrow)) { 1293 auto parseResultFn = [&]() -> LogicalResult { 1294 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1295 if (failed(result)) 1296 return failure(); 1297 results.emplace_back(*result); 1298 return success(); 1299 }; 1300 1301 // Check for a list of results. 1302 if (consumeIf(Token::l_paren)) { 1303 do { 1304 if (failed(parseResultFn())) 1305 return failure(); 1306 } while (consumeIf(Token::comma)); 1307 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1308 return failure(); 1309 1310 // Otherwise, there is only one result. 1311 } else if (failed(parseResultFn())) { 1312 return failure(); 1313 } 1314 } 1315 popDeclScope(); 1316 1317 // Compute the result type of the decl. 1318 resultType = createUserConstraintRewriteResultType(results); 1319 1320 // Verify that results are only named if there are more than one. 1321 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1322 return emitError( 1323 results.front()->getLoc(), 1324 "cannot create a single-element tuple with an element label"); 1325 } 1326 return success(); 1327 } 1328 1329 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1330 StringRef declType, ast::CompoundStmt *body, 1331 ArrayRef<ast::Stmt *>::iterator bodyIt, 1332 ArrayRef<ast::Stmt *>::iterator bodyE, 1333 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1334 // Handle if a `return` was provided. 1335 if (bodyIt != bodyE) { 1336 // Emit an error if we have trailing statements after the return. 1337 if (std::next(bodyIt) != bodyE) { 1338 return emitError( 1339 (*std::next(bodyIt))->getLoc(), 1340 llvm::formatv("`return` terminated the `{0}` body, but found " 1341 "trailing statements afterwards", 1342 declType)); 1343 } 1344 1345 // Otherwise if a return wasn't provided, check that no results are 1346 // expected. 1347 } else if (!results.empty()) { 1348 return emitError( 1349 {body->getLoc().End, body->getLoc().End}, 1350 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1351 declType, resultType)); 1352 } 1353 return success(); 1354 } 1355 1356 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1357 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1358 if (isa<ast::OpRewriteStmt>(statement)) 1359 return success(); 1360 return emitError( 1361 statement->getLoc(), 1362 "expected Pattern lambda body to contain a single operation " 1363 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1364 }); 1365 } 1366 1367 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1368 SMRange loc = curToken.getLoc(); 1369 consumeToken(Token::kw_Pattern); 1370 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1371 ParserContext::PatternMatch); 1372 1373 // Check for an optional identifier for the pattern name. 1374 const ast::Name *name = nullptr; 1375 if (curToken.is(Token::identifier)) { 1376 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1377 consumeToken(Token::identifier); 1378 } 1379 1380 // Parse any pattern metadata. 1381 ParsedPatternMetadata metadata; 1382 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1383 return failure(); 1384 1385 // Parse the pattern body. 1386 ast::CompoundStmt *body; 1387 1388 // Handle a lambda body. 1389 if (curToken.is(Token::equal_arrow)) { 1390 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1391 if (failed(bodyResult)) 1392 return failure(); 1393 body = *bodyResult; 1394 } else { 1395 if (curToken.isNot(Token::l_brace)) 1396 return emitError("expected `{` or `=>` to start pattern body"); 1397 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1398 if (failed(bodyResult)) 1399 return failure(); 1400 body = *bodyResult; 1401 1402 // Verify the body of the pattern. 1403 auto bodyIt = body->begin(), bodyE = body->end(); 1404 for (; bodyIt != bodyE; ++bodyIt) { 1405 if (isa<ast::ReturnStmt>(*bodyIt)) { 1406 return emitError((*bodyIt)->getLoc(), 1407 "`return` statements are only permitted within a " 1408 "`Constraint` or `Rewrite` body"); 1409 } 1410 // Break when we've found the rewrite statement. 1411 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1412 break; 1413 } 1414 if (bodyIt == bodyE) { 1415 return emitError(loc, 1416 "expected Pattern body to terminate with an operation " 1417 "rewrite statement, such as `erase`"); 1418 } 1419 if (std::next(bodyIt) != bodyE) { 1420 return emitError((*std::next(bodyIt))->getLoc(), 1421 "Pattern body was terminated by an operation " 1422 "rewrite statement, but found trailing statements"); 1423 } 1424 } 1425 1426 return createPatternDecl(loc, name, metadata, body); 1427 } 1428 1429 LogicalResult 1430 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1431 Optional<SMRange> benefitLoc; 1432 Optional<SMRange> hasBoundedRecursionLoc; 1433 1434 do { 1435 // Handle metadata code completion. 1436 if (curToken.is(Token::code_complete)) 1437 return codeCompletePatternMetadata(); 1438 1439 if (curToken.isNot(Token::identifier)) 1440 return emitError("expected pattern metadata identifier"); 1441 StringRef metadataStr = curToken.getSpelling(); 1442 SMRange metadataLoc = curToken.getLoc(); 1443 consumeToken(Token::identifier); 1444 1445 // Parse the benefit metadata: benefit(<integer-value>) 1446 if (metadataStr == "benefit") { 1447 if (benefitLoc) { 1448 return emitErrorAndNote(metadataLoc, 1449 "pattern benefit has already been specified", 1450 *benefitLoc, "see previous definition here"); 1451 } 1452 if (failed(parseToken(Token::l_paren, 1453 "expected `(` before pattern benefit"))) 1454 return failure(); 1455 1456 uint16_t benefitValue = 0; 1457 if (curToken.isNot(Token::integer)) 1458 return emitError("expected integral pattern benefit"); 1459 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1460 return emitError( 1461 "expected pattern benefit to fit within a 16-bit integer"); 1462 consumeToken(Token::integer); 1463 1464 metadata.benefit = benefitValue; 1465 benefitLoc = metadataLoc; 1466 1467 if (failed( 1468 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1469 return failure(); 1470 continue; 1471 } 1472 1473 // Parse the bounded recursion metadata: recursion 1474 if (metadataStr == "recursion") { 1475 if (hasBoundedRecursionLoc) { 1476 return emitErrorAndNote( 1477 metadataLoc, 1478 "pattern recursion metadata has already been specified", 1479 *hasBoundedRecursionLoc, "see previous definition here"); 1480 } 1481 metadata.hasBoundedRecursion = true; 1482 hasBoundedRecursionLoc = metadataLoc; 1483 continue; 1484 } 1485 1486 return emitError(metadataLoc, "unknown pattern metadata"); 1487 } while (consumeIf(Token::comma)); 1488 1489 return success(); 1490 } 1491 1492 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1493 consumeToken(Token::less); 1494 1495 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1496 if (failed(typeExpr) || 1497 failed(parseToken(Token::greater, 1498 "expected `>` after variable type constraint"))) 1499 return failure(); 1500 return typeExpr; 1501 } 1502 1503 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1504 assert(curDeclScope && "defining decl outside of a decl scope"); 1505 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1506 return emitErrorAndNote( 1507 name.getLoc(), "`" + name.getName() + "` has already been defined", 1508 lastDecl->getName()->getLoc(), "see previous definition here"); 1509 } 1510 return success(); 1511 } 1512 1513 FailureOr<ast::VariableDecl *> 1514 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1515 ast::Expr *initExpr, 1516 ArrayRef<ast::ConstraintRef> constraints) { 1517 assert(curDeclScope && "defining variable outside of decl scope"); 1518 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1519 1520 // If the name of the variable indicates a special variable, we don't add it 1521 // to the scope. This variable is local to the definition point. 1522 if (name.empty() || name == "_") { 1523 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1524 constraints); 1525 } 1526 if (failed(checkDefineNamedDecl(nameDecl))) 1527 return failure(); 1528 1529 auto *varDecl = 1530 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1531 curDeclScope->add(varDecl); 1532 return varDecl; 1533 } 1534 1535 FailureOr<ast::VariableDecl *> 1536 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1537 ArrayRef<ast::ConstraintRef> constraints) { 1538 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1539 constraints); 1540 } 1541 1542 LogicalResult Parser::parseVariableDeclConstraintList( 1543 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1544 Optional<SMRange> typeConstraint; 1545 auto parseSingleConstraint = [&] { 1546 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1547 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1548 /*allowNonCoreConstraints=*/true); 1549 if (failed(constraint)) 1550 return failure(); 1551 constraints.push_back(*constraint); 1552 return success(); 1553 }; 1554 1555 // Check to see if this is a single constraint, or a list. 1556 if (!consumeIf(Token::l_square)) 1557 return parseSingleConstraint(); 1558 1559 do { 1560 if (failed(parseSingleConstraint())) 1561 return failure(); 1562 } while (consumeIf(Token::comma)); 1563 return parseToken(Token::r_square, "expected `]` after constraint list"); 1564 } 1565 1566 FailureOr<ast::ConstraintRef> 1567 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1568 ArrayRef<ast::ConstraintRef> existingConstraints, 1569 bool allowInlineTypeConstraints, 1570 bool allowNonCoreConstraints) { 1571 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1572 if (!allowInlineTypeConstraints) { 1573 return emitError( 1574 curToken.getLoc(), 1575 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1576 "permitted on arguments or results"); 1577 } 1578 if (typeConstraint) 1579 return emitErrorAndNote( 1580 curToken.getLoc(), 1581 "the type of this variable has already been constrained", 1582 *typeConstraint, "see previous constraint location here"); 1583 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1584 if (failed(constraintExpr)) 1585 return failure(); 1586 typeExpr = *constraintExpr; 1587 typeConstraint = typeExpr->getLoc(); 1588 return success(); 1589 }; 1590 1591 SMRange loc = curToken.getLoc(); 1592 switch (curToken.getKind()) { 1593 case Token::kw_Attr: { 1594 consumeToken(Token::kw_Attr); 1595 1596 // Check for a type constraint. 1597 ast::Expr *typeExpr = nullptr; 1598 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1599 return failure(); 1600 return ast::ConstraintRef( 1601 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1602 } 1603 case Token::kw_Op: { 1604 consumeToken(Token::kw_Op); 1605 1606 // Parse an optional operation name. If the name isn't provided, this refers 1607 // to "any" operation. 1608 FailureOr<ast::OpNameDecl *> opName = 1609 parseWrappedOperationName(/*allowEmptyName=*/true); 1610 if (failed(opName)) 1611 return failure(); 1612 1613 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1614 loc); 1615 } 1616 case Token::kw_Type: 1617 consumeToken(Token::kw_Type); 1618 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1619 case Token::kw_TypeRange: 1620 consumeToken(Token::kw_TypeRange); 1621 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1622 loc); 1623 case Token::kw_Value: { 1624 consumeToken(Token::kw_Value); 1625 1626 // Check for a type constraint. 1627 ast::Expr *typeExpr = nullptr; 1628 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1629 return failure(); 1630 1631 return ast::ConstraintRef( 1632 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1633 } 1634 case Token::kw_ValueRange: { 1635 consumeToken(Token::kw_ValueRange); 1636 1637 // Check for a type constraint. 1638 ast::Expr *typeExpr = nullptr; 1639 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1640 return failure(); 1641 1642 return ast::ConstraintRef( 1643 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1644 } 1645 1646 case Token::kw_Constraint: { 1647 // Handle an inline constraint. 1648 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1649 if (failed(decl)) 1650 return failure(); 1651 return ast::ConstraintRef(*decl, loc); 1652 } 1653 case Token::identifier: { 1654 StringRef constraintName = curToken.getSpelling(); 1655 consumeToken(Token::identifier); 1656 1657 // Lookup the referenced constraint. 1658 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1659 if (!cstDecl) { 1660 return emitError(loc, "unknown reference to constraint `" + 1661 constraintName + "`"); 1662 } 1663 1664 // Handle a reference to a proper constraint. 1665 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1666 return ast::ConstraintRef(cst, loc); 1667 1668 return emitErrorAndNote( 1669 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1670 "see the definition of `" + constraintName + "` here"); 1671 } 1672 // Handle single entity constraint code completion. 1673 case Token::code_complete: { 1674 // Try to infer the current type for use by code completion. 1675 ast::Type inferredType; 1676 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1677 allowNonCoreConstraints))) 1678 return failure(); 1679 1680 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1681 allowInlineTypeConstraints); 1682 } 1683 default: 1684 break; 1685 } 1686 return emitError(loc, "expected identifier constraint"); 1687 } 1688 1689 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1690 // Constraint arguments may apply more complex constraints via the arguments. 1691 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1692 1693 Optional<SMRange> typeConstraint; 1694 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1695 /*allowInlineTypeConstraints=*/false, 1696 allowNonCoreConstraints); 1697 } 1698 1699 //===----------------------------------------------------------------------===// 1700 // Exprs 1701 1702 FailureOr<ast::Expr *> Parser::parseExpr() { 1703 if (curToken.is(Token::underscore)) 1704 return parseUnderscoreExpr(); 1705 1706 // Parse the LHS expression. 1707 FailureOr<ast::Expr *> lhsExpr; 1708 switch (curToken.getKind()) { 1709 case Token::kw_attr: 1710 lhsExpr = parseAttributeExpr(); 1711 break; 1712 case Token::kw_Constraint: 1713 lhsExpr = parseInlineConstraintLambdaExpr(); 1714 break; 1715 case Token::identifier: 1716 lhsExpr = parseIdentifierExpr(); 1717 break; 1718 case Token::kw_op: 1719 lhsExpr = parseOperationExpr(); 1720 break; 1721 case Token::kw_Rewrite: 1722 lhsExpr = parseInlineRewriteLambdaExpr(); 1723 break; 1724 case Token::kw_type: 1725 lhsExpr = parseTypeExpr(); 1726 break; 1727 case Token::l_paren: 1728 lhsExpr = parseTupleExpr(); 1729 break; 1730 default: 1731 return emitError("expected expression"); 1732 } 1733 if (failed(lhsExpr)) 1734 return failure(); 1735 1736 // Check for an operator expression. 1737 while (true) { 1738 switch (curToken.getKind()) { 1739 case Token::dot: 1740 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1741 break; 1742 case Token::l_paren: 1743 lhsExpr = parseCallExpr(*lhsExpr); 1744 break; 1745 default: 1746 return lhsExpr; 1747 } 1748 if (failed(lhsExpr)) 1749 return failure(); 1750 } 1751 } 1752 1753 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1754 SMRange loc = curToken.getLoc(); 1755 consumeToken(Token::kw_attr); 1756 1757 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1758 // identifier. 1759 if (!consumeIf(Token::less)) { 1760 resetToken(loc); 1761 return parseIdentifierExpr(); 1762 } 1763 1764 if (!curToken.isString()) 1765 return emitError("expected string literal containing MLIR attribute"); 1766 std::string attrExpr = curToken.getStringValue(); 1767 consumeToken(); 1768 1769 if (failed( 1770 parseToken(Token::greater, "expected `>` after attribute literal"))) 1771 return failure(); 1772 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1773 } 1774 1775 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1776 SMRange loc = curToken.getLoc(); 1777 consumeToken(Token::l_paren); 1778 1779 // Parse the arguments of the call. 1780 SmallVector<ast::Expr *> arguments; 1781 if (curToken.isNot(Token::r_paren)) { 1782 do { 1783 // Handle code completion for the call arguments. 1784 if (curToken.is(Token::code_complete)) { 1785 codeCompleteCallSignature(parentExpr, arguments.size()); 1786 return failure(); 1787 } 1788 1789 FailureOr<ast::Expr *> argument = parseExpr(); 1790 if (failed(argument)) 1791 return failure(); 1792 arguments.push_back(*argument); 1793 } while (consumeIf(Token::comma)); 1794 } 1795 loc.End = curToken.getEndLoc(); 1796 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1797 return failure(); 1798 1799 return createCallExpr(loc, parentExpr, arguments); 1800 } 1801 1802 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1803 ast::Decl *decl = curDeclScope->lookup(name); 1804 if (!decl) 1805 return emitError(loc, "undefined reference to `" + name + "`"); 1806 1807 return createDeclRefExpr(loc, decl); 1808 } 1809 1810 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1811 StringRef name = curToken.getSpelling(); 1812 SMRange nameLoc = curToken.getLoc(); 1813 consumeToken(); 1814 1815 // Check to see if this is a decl ref expression that defines a variable 1816 // inline. 1817 if (consumeIf(Token::colon)) { 1818 SmallVector<ast::ConstraintRef> constraints; 1819 if (failed(parseVariableDeclConstraintList(constraints))) 1820 return failure(); 1821 ast::Type type; 1822 if (failed(validateVariableConstraints(constraints, type))) 1823 return failure(); 1824 return createInlineVariableExpr(type, name, nameLoc, constraints); 1825 } 1826 1827 return parseDeclRefExpr(name, nameLoc); 1828 } 1829 1830 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1831 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1832 if (failed(decl)) 1833 return failure(); 1834 1835 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1836 ast::ConstraintType::get(ctx)); 1837 } 1838 1839 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1840 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1841 if (failed(decl)) 1842 return failure(); 1843 1844 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1845 ast::RewriteType::get(ctx)); 1846 } 1847 1848 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1849 SMRange loc = curToken.getLoc(); 1850 consumeToken(Token::dot); 1851 1852 // Check for code completion of the member name. 1853 if (curToken.is(Token::code_complete)) 1854 return codeCompleteMemberAccess(parentExpr); 1855 1856 // Parse the member name. 1857 Token memberNameTok = curToken; 1858 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1859 !memberNameTok.isKeyword()) 1860 return emitError(loc, "expected identifier or numeric member name"); 1861 StringRef memberName = memberNameTok.getSpelling(); 1862 consumeToken(); 1863 1864 return createMemberAccessExpr(parentExpr, memberName, loc); 1865 } 1866 1867 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1868 SMRange loc = curToken.getLoc(); 1869 1870 // Check for code completion for the dialect name. 1871 if (curToken.is(Token::code_complete)) 1872 return codeCompleteDialectName(); 1873 1874 // Handle the case of an no operation name. 1875 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1876 if (allowEmptyName) 1877 return ast::OpNameDecl::create(ctx, SMRange()); 1878 return emitError("expected dialect namespace"); 1879 } 1880 StringRef name = curToken.getSpelling(); 1881 consumeToken(); 1882 1883 // Otherwise, this is a literal operation name. 1884 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1885 return failure(); 1886 1887 // Check for code completion for the operation name. 1888 if (curToken.is(Token::code_complete)) 1889 return codeCompleteOperationName(name); 1890 1891 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1892 return emitError("expected operation name after dialect namespace"); 1893 1894 name = StringRef(name.data(), name.size() + 1); 1895 do { 1896 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1897 loc.End = curToken.getEndLoc(); 1898 consumeToken(); 1899 } while (curToken.isAny(Token::identifier, Token::dot) || 1900 curToken.isKeyword()); 1901 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1902 } 1903 1904 FailureOr<ast::OpNameDecl *> 1905 Parser::parseWrappedOperationName(bool allowEmptyName) { 1906 if (!consumeIf(Token::less)) 1907 return ast::OpNameDecl::create(ctx, SMRange()); 1908 1909 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1910 if (failed(opNameDecl)) 1911 return failure(); 1912 1913 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1914 return failure(); 1915 return opNameDecl; 1916 } 1917 1918 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1919 SMRange loc = curToken.getLoc(); 1920 consumeToken(Token::kw_op); 1921 1922 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1923 // identifier. 1924 if (curToken.isNot(Token::less)) { 1925 resetToken(loc); 1926 return parseIdentifierExpr(); 1927 } 1928 1929 // Parse the operation name. The name may be elided, in which case the 1930 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1931 // `Operation*`). Operation names within a rewrite context must be named. 1932 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1933 FailureOr<ast::OpNameDecl *> opNameDecl = 1934 parseWrappedOperationName(allowEmptyName); 1935 if (failed(opNameDecl)) 1936 return failure(); 1937 Optional<StringRef> opName = (*opNameDecl)->getName(); 1938 1939 // Functor used to create an implicit range variable, used for implicit "all" 1940 // operand or results variables. 1941 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1942 FailureOr<ast::VariableDecl *> rangeVar = 1943 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1944 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1945 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1946 }; 1947 1948 // Check for the optional list of operands. 1949 SmallVector<ast::Expr *> operands; 1950 if (!consumeIf(Token::l_paren)) { 1951 // If the operand list isn't specified and we are in a match context, define 1952 // an inplace unconstrained operand range corresponding to all of the 1953 // operands of the operation. This avoids treating zero operands the same 1954 // way as "unconstrained operands". 1955 if (parserContext != ParserContext::Rewrite) { 1956 operands.push_back(createImplicitRangeVar( 1957 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1958 } 1959 } else if (!consumeIf(Token::r_paren)) { 1960 // Check for operand signature code completion. 1961 if (curToken.is(Token::code_complete)) { 1962 codeCompleteOperationOperandsSignature(opName, operands.size()); 1963 return failure(); 1964 } 1965 1966 // If the operand list was specified and non-empty, parse the operands. 1967 do { 1968 FailureOr<ast::Expr *> operand = parseExpr(); 1969 if (failed(operand)) 1970 return failure(); 1971 operands.push_back(*operand); 1972 } while (consumeIf(Token::comma)); 1973 1974 if (failed(parseToken(Token::r_paren, 1975 "expected `)` after operation operand list"))) 1976 return failure(); 1977 } 1978 1979 // Check for the optional list of attributes. 1980 SmallVector<ast::NamedAttributeDecl *> attributes; 1981 if (consumeIf(Token::l_brace)) { 1982 do { 1983 FailureOr<ast::NamedAttributeDecl *> decl = 1984 parseNamedAttributeDecl(opName); 1985 if (failed(decl)) 1986 return failure(); 1987 attributes.emplace_back(*decl); 1988 } while (consumeIf(Token::comma)); 1989 1990 if (failed(parseToken(Token::r_brace, 1991 "expected `}` after operation attribute list"))) 1992 return failure(); 1993 } 1994 1995 // Check for the optional list of result types. 1996 SmallVector<ast::Expr *> resultTypes; 1997 if (consumeIf(Token::arrow)) { 1998 if (failed(parseToken(Token::l_paren, 1999 "expected `(` before operation result type list"))) 2000 return failure(); 2001 2002 // Handle the case of an empty result list. 2003 if (!consumeIf(Token::r_paren)) { 2004 do { 2005 // Check for result signature code completion. 2006 if (curToken.is(Token::code_complete)) { 2007 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2008 return failure(); 2009 } 2010 2011 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2012 if (failed(resultTypeExpr)) 2013 return failure(); 2014 resultTypes.push_back(*resultTypeExpr); 2015 } while (consumeIf(Token::comma)); 2016 2017 if (failed(parseToken(Token::r_paren, 2018 "expected `)` after operation result type list"))) 2019 return failure(); 2020 } 2021 } else if (parserContext != ParserContext::Rewrite) { 2022 // If the result list isn't specified and we are in a match context, define 2023 // an inplace unconstrained result range corresponding to all of the results 2024 // of the operation. This avoids treating zero results the same way as 2025 // "unconstrained results". 2026 resultTypes.push_back(createImplicitRangeVar( 2027 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2028 } 2029 2030 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2031 resultTypes); 2032 } 2033 2034 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2035 SMRange loc = curToken.getLoc(); 2036 consumeToken(Token::l_paren); 2037 2038 DenseMap<StringRef, SMRange> usedNames; 2039 SmallVector<StringRef> elementNames; 2040 SmallVector<ast::Expr *> elements; 2041 if (curToken.isNot(Token::r_paren)) { 2042 do { 2043 // Check for the optional element name assignment before the value. 2044 StringRef elementName; 2045 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2046 Token elementNameTok = curToken; 2047 consumeToken(); 2048 2049 // The element name is only present if followed by an `=`. 2050 if (consumeIf(Token::equal)) { 2051 elementName = elementNameTok.getSpelling(); 2052 2053 // Check to see if this name is already used. 2054 auto elementNameIt = 2055 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2056 if (!elementNameIt.second) { 2057 return emitErrorAndNote( 2058 elementNameTok.getLoc(), 2059 llvm::formatv("duplicate tuple element label `{0}`", 2060 elementName), 2061 elementNameIt.first->getSecond(), 2062 "see previous label use here"); 2063 } 2064 } else { 2065 // Otherwise, we treat this as part of an expression so reset the 2066 // lexer. 2067 resetToken(elementNameTok.getLoc()); 2068 } 2069 } 2070 elementNames.push_back(elementName); 2071 2072 // Parse the tuple element value. 2073 FailureOr<ast::Expr *> element = parseExpr(); 2074 if (failed(element)) 2075 return failure(); 2076 elements.push_back(*element); 2077 } while (consumeIf(Token::comma)); 2078 } 2079 loc.End = curToken.getEndLoc(); 2080 if (failed( 2081 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2082 return failure(); 2083 return createTupleExpr(loc, elements, elementNames); 2084 } 2085 2086 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2087 SMRange loc = curToken.getLoc(); 2088 consumeToken(Token::kw_type); 2089 2090 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2091 // identifier. 2092 if (!consumeIf(Token::less)) { 2093 resetToken(loc); 2094 return parseIdentifierExpr(); 2095 } 2096 2097 if (!curToken.isString()) 2098 return emitError("expected string literal containing MLIR type"); 2099 std::string attrExpr = curToken.getStringValue(); 2100 consumeToken(); 2101 2102 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2103 return failure(); 2104 return ast::TypeExpr::create(ctx, loc, attrExpr); 2105 } 2106 2107 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2108 StringRef name = curToken.getSpelling(); 2109 SMRange nameLoc = curToken.getLoc(); 2110 consumeToken(Token::underscore); 2111 2112 // Underscore expressions require a constraint list. 2113 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2114 return failure(); 2115 2116 // Parse the constraints for the expression. 2117 SmallVector<ast::ConstraintRef> constraints; 2118 if (failed(parseVariableDeclConstraintList(constraints))) 2119 return failure(); 2120 2121 ast::Type type; 2122 if (failed(validateVariableConstraints(constraints, type))) 2123 return failure(); 2124 return createInlineVariableExpr(type, name, nameLoc, constraints); 2125 } 2126 2127 //===----------------------------------------------------------------------===// 2128 // Stmts 2129 2130 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2131 FailureOr<ast::Stmt *> stmt; 2132 switch (curToken.getKind()) { 2133 case Token::kw_erase: 2134 stmt = parseEraseStmt(); 2135 break; 2136 case Token::kw_let: 2137 stmt = parseLetStmt(); 2138 break; 2139 case Token::kw_replace: 2140 stmt = parseReplaceStmt(); 2141 break; 2142 case Token::kw_return: 2143 stmt = parseReturnStmt(); 2144 break; 2145 case Token::kw_rewrite: 2146 stmt = parseRewriteStmt(); 2147 break; 2148 default: 2149 stmt = parseExpr(); 2150 break; 2151 } 2152 if (failed(stmt) || 2153 (expectTerminalSemicolon && 2154 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2155 return failure(); 2156 return stmt; 2157 } 2158 2159 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2160 SMLoc startLoc = curToken.getStartLoc(); 2161 consumeToken(Token::l_brace); 2162 2163 // Push a new block scope and parse any nested statements. 2164 pushDeclScope(); 2165 SmallVector<ast::Stmt *> statements; 2166 while (curToken.isNot(Token::r_brace)) { 2167 FailureOr<ast::Stmt *> statement = parseStmt(); 2168 if (failed(statement)) 2169 return popDeclScope(), failure(); 2170 statements.push_back(*statement); 2171 } 2172 popDeclScope(); 2173 2174 // Consume the end brace. 2175 SMRange location(startLoc, curToken.getEndLoc()); 2176 consumeToken(Token::r_brace); 2177 2178 return ast::CompoundStmt::create(ctx, location, statements); 2179 } 2180 2181 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2182 if (parserContext == ParserContext::Constraint) 2183 return emitError("`erase` cannot be used within a Constraint"); 2184 SMRange loc = curToken.getLoc(); 2185 consumeToken(Token::kw_erase); 2186 2187 // Parse the root operation expression. 2188 FailureOr<ast::Expr *> rootOp = parseExpr(); 2189 if (failed(rootOp)) 2190 return failure(); 2191 2192 return createEraseStmt(loc, *rootOp); 2193 } 2194 2195 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2196 SMRange loc = curToken.getLoc(); 2197 consumeToken(Token::kw_let); 2198 2199 // Parse the name of the new variable. 2200 SMRange varLoc = curToken.getLoc(); 2201 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2202 // `_` is a reserved variable name. 2203 if (curToken.is(Token::underscore)) { 2204 return emitError(varLoc, 2205 "`_` may only be used to define \"inline\" variables"); 2206 } 2207 return emitError(varLoc, 2208 "expected identifier after `let` to name a new variable"); 2209 } 2210 StringRef varName = curToken.getSpelling(); 2211 consumeToken(); 2212 2213 // Parse the optional set of constraints. 2214 SmallVector<ast::ConstraintRef> constraints; 2215 if (consumeIf(Token::colon) && 2216 failed(parseVariableDeclConstraintList(constraints))) 2217 return failure(); 2218 2219 // Parse the optional initializer expression. 2220 ast::Expr *initializer = nullptr; 2221 if (consumeIf(Token::equal)) { 2222 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2223 if (failed(initOrFailure)) 2224 return failure(); 2225 initializer = *initOrFailure; 2226 2227 // Check that the constraints are compatible with having an initializer, 2228 // e.g. type constraints cannot be used with initializers. 2229 for (ast::ConstraintRef constraint : constraints) { 2230 LogicalResult result = 2231 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2232 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2233 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2234 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2235 return this->emitError( 2236 constraint.referenceLoc, 2237 "type constraints are not permitted on variables with " 2238 "initializers"); 2239 } 2240 return success(); 2241 }) 2242 .Default(success()); 2243 if (failed(result)) 2244 return failure(); 2245 } 2246 } 2247 2248 FailureOr<ast::VariableDecl *> varDecl = 2249 createVariableDecl(varName, varLoc, initializer, constraints); 2250 if (failed(varDecl)) 2251 return failure(); 2252 return ast::LetStmt::create(ctx, loc, *varDecl); 2253 } 2254 2255 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2256 if (parserContext == ParserContext::Constraint) 2257 return emitError("`replace` cannot be used within a Constraint"); 2258 SMRange loc = curToken.getLoc(); 2259 consumeToken(Token::kw_replace); 2260 2261 // Parse the root operation expression. 2262 FailureOr<ast::Expr *> rootOp = parseExpr(); 2263 if (failed(rootOp)) 2264 return failure(); 2265 2266 if (failed( 2267 parseToken(Token::kw_with, "expected `with` after root operation"))) 2268 return failure(); 2269 2270 // The replacement portion of this statement is within a rewrite context. 2271 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2272 ParserContext::Rewrite); 2273 2274 // Parse the replacement values. 2275 SmallVector<ast::Expr *> replValues; 2276 if (consumeIf(Token::l_paren)) { 2277 if (consumeIf(Token::r_paren)) { 2278 return emitError( 2279 loc, "expected at least one replacement value, consider using " 2280 "`erase` if no replacement values are desired"); 2281 } 2282 2283 do { 2284 FailureOr<ast::Expr *> replExpr = parseExpr(); 2285 if (failed(replExpr)) 2286 return failure(); 2287 replValues.emplace_back(*replExpr); 2288 } while (consumeIf(Token::comma)); 2289 2290 if (failed(parseToken(Token::r_paren, 2291 "expected `)` after replacement values"))) 2292 return failure(); 2293 } else { 2294 FailureOr<ast::Expr *> replExpr = parseExpr(); 2295 if (failed(replExpr)) 2296 return failure(); 2297 replValues.emplace_back(*replExpr); 2298 } 2299 2300 return createReplaceStmt(loc, *rootOp, replValues); 2301 } 2302 2303 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2304 SMRange loc = curToken.getLoc(); 2305 consumeToken(Token::kw_return); 2306 2307 // Parse the result value. 2308 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2309 if (failed(resultExpr)) 2310 return failure(); 2311 2312 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2313 } 2314 2315 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2316 if (parserContext == ParserContext::Constraint) 2317 return emitError("`rewrite` cannot be used within a Constraint"); 2318 SMRange loc = curToken.getLoc(); 2319 consumeToken(Token::kw_rewrite); 2320 2321 // Parse the root operation. 2322 FailureOr<ast::Expr *> rootOp = parseExpr(); 2323 if (failed(rootOp)) 2324 return failure(); 2325 2326 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2327 return failure(); 2328 2329 if (curToken.isNot(Token::l_brace)) 2330 return emitError("expected `{` to start rewrite body"); 2331 2332 // The rewrite body of this statement is within a rewrite context. 2333 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2334 ParserContext::Rewrite); 2335 2336 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2337 if (failed(rewriteBody)) 2338 return failure(); 2339 2340 // Verify the rewrite body. 2341 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2342 if (isa<ast::ReturnStmt>(stmt)) { 2343 return emitError(stmt->getLoc(), 2344 "`return` statements are only permitted within a " 2345 "`Constraint` or `Rewrite` body"); 2346 } 2347 } 2348 2349 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2350 } 2351 2352 //===----------------------------------------------------------------------===// 2353 // Creation+Analysis 2354 //===----------------------------------------------------------------------===// 2355 2356 //===----------------------------------------------------------------------===// 2357 // Decls 2358 2359 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2360 // Unwrap reference expressions. 2361 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2362 node = init->getDecl(); 2363 return dyn_cast<ast::CallableDecl>(node); 2364 } 2365 2366 FailureOr<ast::PatternDecl *> 2367 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2368 const ParsedPatternMetadata &metadata, 2369 ast::CompoundStmt *body) { 2370 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2371 metadata.hasBoundedRecursion, body); 2372 } 2373 2374 ast::Type Parser::createUserConstraintRewriteResultType( 2375 ArrayRef<ast::VariableDecl *> results) { 2376 // Single result decls use the type of the single result. 2377 if (results.size() == 1) 2378 return results[0]->getType(); 2379 2380 // Multiple results use a tuple type, with the types and names grabbed from 2381 // the result variable decls. 2382 auto resultTypes = llvm::map_range( 2383 results, [&](const auto *result) { return result->getType(); }); 2384 auto resultNames = llvm::map_range( 2385 results, [&](const auto *result) { return result->getName().getName(); }); 2386 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2387 llvm::to_vector(resultNames)); 2388 } 2389 2390 template <typename T> 2391 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2392 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2393 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2394 ast::CompoundStmt *body) { 2395 if (!body->getChildren().empty()) { 2396 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2397 ast::Expr *resultExpr = retStmt->getResultExpr(); 2398 2399 // Process the result of the decl. If no explicit signature results 2400 // were provided, check for return type inference. Otherwise, check that 2401 // the return expression can be converted to the expected type. 2402 if (results.empty()) 2403 resultType = resultExpr->getType(); 2404 else if (failed(convertExpressionTo(resultExpr, resultType))) 2405 return failure(); 2406 else 2407 retStmt->setResultExpr(resultExpr); 2408 } 2409 } 2410 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2411 } 2412 2413 FailureOr<ast::VariableDecl *> 2414 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2415 ArrayRef<ast::ConstraintRef> constraints) { 2416 // The type of the variable, which is expected to be inferred by either a 2417 // constraint or an initializer expression. 2418 ast::Type type; 2419 if (failed(validateVariableConstraints(constraints, type))) 2420 return failure(); 2421 2422 if (initializer) { 2423 // Update the variable type based on the initializer, or try to convert the 2424 // initializer to the existing type. 2425 if (!type) 2426 type = initializer->getType(); 2427 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2428 type = mergedType; 2429 else if (failed(convertExpressionTo(initializer, type))) 2430 return failure(); 2431 2432 // Otherwise, if there is no initializer check that the type has already 2433 // been resolved from the constraint list. 2434 } else if (!type) { 2435 return emitErrorAndNote( 2436 loc, "unable to infer type for variable `" + name + "`", loc, 2437 "the type of a variable must be inferable from the constraint " 2438 "list or the initializer"); 2439 } 2440 2441 // Constraint types cannot be used when defining variables. 2442 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2443 return emitError( 2444 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2445 } 2446 2447 // Try to define a variable with the given name. 2448 FailureOr<ast::VariableDecl *> varDecl = 2449 defineVariableDecl(name, loc, type, initializer, constraints); 2450 if (failed(varDecl)) 2451 return failure(); 2452 2453 return *varDecl; 2454 } 2455 2456 FailureOr<ast::VariableDecl *> 2457 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2458 const ast::ConstraintRef &constraint) { 2459 // Constraint arguments may apply more complex constraints via the arguments. 2460 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2461 ast::Type argType; 2462 if (failed(validateVariableConstraint(constraint, argType, 2463 allowNonCoreConstraints))) 2464 return failure(); 2465 return defineVariableDecl(name, loc, argType, constraint); 2466 } 2467 2468 LogicalResult 2469 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2470 ast::Type &inferredType, 2471 bool allowNonCoreConstraints) { 2472 for (const ast::ConstraintRef &ref : constraints) 2473 if (failed(validateVariableConstraint(ref, inferredType, 2474 allowNonCoreConstraints))) 2475 return failure(); 2476 return success(); 2477 } 2478 2479 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2480 ast::Type &inferredType, 2481 bool allowNonCoreConstraints) { 2482 ast::Type constraintType; 2483 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2484 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2485 if (failed(validateTypeConstraintExpr(typeExpr))) 2486 return failure(); 2487 } 2488 constraintType = ast::AttributeType::get(ctx); 2489 } else if (const auto *cst = 2490 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2491 constraintType = ast::OperationType::get(ctx, cst->getName()); 2492 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2493 constraintType = typeTy; 2494 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2495 constraintType = typeRangeTy; 2496 } else if (const auto *cst = 2497 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2498 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2499 if (failed(validateTypeConstraintExpr(typeExpr))) 2500 return failure(); 2501 } 2502 constraintType = valueTy; 2503 } else if (const auto *cst = 2504 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2505 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2506 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2507 return failure(); 2508 } 2509 constraintType = valueRangeTy; 2510 } else if (const auto *cst = 2511 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2512 if (!allowNonCoreConstraints) { 2513 return emitError(ref.referenceLoc, 2514 "`Rewrite` arguments and results are only permitted to " 2515 "use core constraints, such as `Attr`, `Op`, `Type`, " 2516 "`TypeRange`, `Value`, `ValueRange`"); 2517 } 2518 2519 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2520 if (inputs.size() != 1) { 2521 return emitErrorAndNote(ref.referenceLoc, 2522 "`Constraint`s applied via a variable constraint " 2523 "list must take a single input, but got " + 2524 Twine(inputs.size()), 2525 cst->getLoc(), 2526 "see definition of constraint here"); 2527 } 2528 constraintType = inputs.front()->getType(); 2529 } else { 2530 llvm_unreachable("unknown constraint type"); 2531 } 2532 2533 // Check that the constraint type is compatible with the current inferred 2534 // type. 2535 if (!inferredType) { 2536 inferredType = constraintType; 2537 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2538 inferredType = mergedTy; 2539 } else { 2540 return emitError(ref.referenceLoc, 2541 llvm::formatv("constraint type `{0}` is incompatible " 2542 "with the previously inferred type `{1}`", 2543 constraintType, inferredType)); 2544 } 2545 return success(); 2546 } 2547 2548 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2549 ast::Type typeExprType = typeExpr->getType(); 2550 if (typeExprType != typeTy) { 2551 return emitError(typeExpr->getLoc(), 2552 "expected expression of `Type` in type constraint"); 2553 } 2554 return success(); 2555 } 2556 2557 LogicalResult 2558 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2559 ast::Type typeExprType = typeExpr->getType(); 2560 if (typeExprType != typeRangeTy) { 2561 return emitError(typeExpr->getLoc(), 2562 "expected expression of `TypeRange` in type constraint"); 2563 } 2564 return success(); 2565 } 2566 2567 //===----------------------------------------------------------------------===// 2568 // Exprs 2569 2570 FailureOr<ast::CallExpr *> 2571 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2572 MutableArrayRef<ast::Expr *> arguments) { 2573 ast::Type parentType = parentExpr->getType(); 2574 2575 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2576 if (!callableDecl) { 2577 return emitError(loc, 2578 llvm::formatv("expected a reference to a callable " 2579 "`Constraint` or `Rewrite`, but got: `{0}`", 2580 parentType)); 2581 } 2582 if (parserContext == ParserContext::Rewrite) { 2583 if (isa<ast::UserConstraintDecl>(callableDecl)) 2584 return emitError( 2585 loc, "unable to invoke `Constraint` within a rewrite section"); 2586 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2587 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2588 } 2589 2590 // Verify the arguments of the call. 2591 /// Handle size mismatch. 2592 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2593 if (callArgs.size() != arguments.size()) { 2594 return emitErrorAndNote( 2595 loc, 2596 llvm::formatv("invalid number of arguments for {0} call; expected " 2597 "{1}, but got {2}", 2598 callableDecl->getCallableType(), callArgs.size(), 2599 arguments.size()), 2600 callableDecl->getLoc(), 2601 llvm::formatv("see the definition of {0} here", 2602 callableDecl->getName()->getName())); 2603 } 2604 2605 /// Handle argument type mismatch. 2606 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2607 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2608 callableDecl->getName()->getName()), 2609 callableDecl->getLoc()); 2610 }; 2611 for (auto it : llvm::zip(callArgs, arguments)) { 2612 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2613 attachDiagFn))) 2614 return failure(); 2615 } 2616 2617 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2618 callableDecl->getResultType()); 2619 } 2620 2621 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2622 ast::Decl *decl) { 2623 // Check the type of decl being referenced. 2624 ast::Type declType; 2625 if (isa<ast::ConstraintDecl>(decl)) 2626 declType = ast::ConstraintType::get(ctx); 2627 else if (isa<ast::UserRewriteDecl>(decl)) 2628 declType = ast::RewriteType::get(ctx); 2629 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2630 declType = varDecl->getType(); 2631 else 2632 return emitError(loc, "invalid reference to `" + 2633 decl->getName()->getName() + "`"); 2634 2635 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2636 } 2637 2638 FailureOr<ast::DeclRefExpr *> 2639 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2640 ArrayRef<ast::ConstraintRef> constraints) { 2641 FailureOr<ast::VariableDecl *> decl = 2642 defineVariableDecl(name, loc, type, constraints); 2643 if (failed(decl)) 2644 return failure(); 2645 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2646 } 2647 2648 FailureOr<ast::MemberAccessExpr *> 2649 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2650 SMRange loc) { 2651 // Validate the member name for the given parent expression. 2652 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2653 if (failed(memberType)) 2654 return failure(); 2655 2656 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2657 } 2658 2659 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2660 StringRef name, SMRange loc) { 2661 ast::Type parentType = parentExpr->getType(); 2662 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2663 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2664 return valueRangeTy; 2665 2666 // Verify member access based on the operation type. 2667 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2668 auto results = odsOp->getResults(); 2669 2670 // Handle indexed results. 2671 unsigned index = 0; 2672 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2673 index < results.size()) { 2674 return results[index].isVariadic() ? valueRangeTy : valueTy; 2675 } 2676 2677 // Handle named results. 2678 const auto *it = llvm::find_if(results, [&](const auto &result) { 2679 return result.getName() == name; 2680 }); 2681 if (it != results.end()) 2682 return it->isVariadic() ? valueRangeTy : valueTy; 2683 } 2684 2685 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2686 // Handle indexed results. 2687 unsigned index = 0; 2688 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2689 index < tupleType.size()) { 2690 return tupleType.getElementTypes()[index]; 2691 } 2692 2693 // Handle named results. 2694 auto elementNames = tupleType.getElementNames(); 2695 const auto *it = llvm::find(elementNames, name); 2696 if (it != elementNames.end()) 2697 return tupleType.getElementTypes()[it - elementNames.begin()]; 2698 } 2699 return emitError( 2700 loc, 2701 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2702 name, parentType)); 2703 } 2704 2705 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2706 SMRange loc, const ast::OpNameDecl *name, 2707 MutableArrayRef<ast::Expr *> operands, 2708 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2709 MutableArrayRef<ast::Expr *> results) { 2710 Optional<StringRef> opNameRef = name->getName(); 2711 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2712 2713 // Verify the inputs operands. 2714 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2715 return failure(); 2716 2717 // Verify the attribute list. 2718 for (ast::NamedAttributeDecl *attr : attributes) { 2719 // Check for an attribute type, or a type awaiting resolution. 2720 ast::Type attrType = attr->getValue()->getType(); 2721 if (!attrType.isa<ast::AttributeType>()) { 2722 return emitError( 2723 attr->getValue()->getLoc(), 2724 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2725 } 2726 } 2727 2728 // Verify the result types. 2729 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2730 return failure(); 2731 2732 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2733 attributes); 2734 } 2735 2736 LogicalResult 2737 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2738 const ods::Operation *odsOp, 2739 MutableArrayRef<ast::Expr *> operands) { 2740 return validateOperationOperandsOrResults( 2741 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2742 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2743 valueRangeTy); 2744 } 2745 2746 LogicalResult 2747 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2748 const ods::Operation *odsOp, 2749 MutableArrayRef<ast::Expr *> results) { 2750 return validateOperationOperandsOrResults( 2751 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2752 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2753 } 2754 2755 LogicalResult Parser::validateOperationOperandsOrResults( 2756 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2757 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2758 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2759 ast::Type rangeTy) { 2760 // All operation types accept a single range parameter. 2761 if (values.size() == 1) { 2762 if (failed(convertExpressionTo(values[0], rangeTy))) 2763 return failure(); 2764 return success(); 2765 } 2766 2767 /// If the operation has ODS information, we can more accurately verify the 2768 /// values. 2769 if (odsOpLoc) { 2770 if (odsValues.size() != values.size()) { 2771 return emitErrorAndNote( 2772 loc, 2773 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2774 "{2}, but got {3}", 2775 groupName, *name, odsValues.size(), values.size()), 2776 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2777 } 2778 auto diagFn = [&](ast::Diagnostic &diag) { 2779 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2780 *odsOpLoc); 2781 }; 2782 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2783 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2784 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2785 return failure(); 2786 } 2787 return success(); 2788 } 2789 2790 // Otherwise, accept the value groups as they have been defined and just 2791 // ensure they are one of the expected types. 2792 for (ast::Expr *&valueExpr : values) { 2793 ast::Type valueExprType = valueExpr->getType(); 2794 2795 // Check if this is one of the expected types. 2796 if (valueExprType == rangeTy || valueExprType == singleTy) 2797 continue; 2798 2799 // If the operand is an Operation, allow converting to a Value or 2800 // ValueRange. This situations arises quite often with nested operation 2801 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2802 if (singleTy == valueTy) { 2803 if (valueExprType.isa<ast::OperationType>()) { 2804 valueExpr = convertOpToValue(valueExpr); 2805 continue; 2806 } 2807 } 2808 2809 return emitError( 2810 valueExpr->getLoc(), 2811 llvm::formatv( 2812 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2813 singleTy, rangeTy, valueExprType)); 2814 } 2815 return success(); 2816 } 2817 2818 FailureOr<ast::TupleExpr *> 2819 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2820 ArrayRef<StringRef> elementNames) { 2821 for (const ast::Expr *element : elements) { 2822 ast::Type eleTy = element->getType(); 2823 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2824 return emitError( 2825 element->getLoc(), 2826 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2827 } 2828 } 2829 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2830 } 2831 2832 //===----------------------------------------------------------------------===// 2833 // Stmts 2834 2835 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2836 ast::Expr *rootOp) { 2837 // Check that root is an Operation. 2838 ast::Type rootType = rootOp->getType(); 2839 if (!rootType.isa<ast::OperationType>()) 2840 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2841 2842 return ast::EraseStmt::create(ctx, loc, rootOp); 2843 } 2844 2845 FailureOr<ast::ReplaceStmt *> 2846 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2847 MutableArrayRef<ast::Expr *> replValues) { 2848 // Check that root is an Operation. 2849 ast::Type rootType = rootOp->getType(); 2850 if (!rootType.isa<ast::OperationType>()) { 2851 return emitError( 2852 rootOp->getLoc(), 2853 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2854 } 2855 2856 // If there are multiple replacement values, we implicitly convert any Op 2857 // expressions to the value form. 2858 bool shouldConvertOpToValues = replValues.size() > 1; 2859 for (ast::Expr *&replExpr : replValues) { 2860 ast::Type replType = replExpr->getType(); 2861 2862 // Check that replExpr is an Operation, Value, or ValueRange. 2863 if (replType.isa<ast::OperationType>()) { 2864 if (shouldConvertOpToValues) 2865 replExpr = convertOpToValue(replExpr); 2866 continue; 2867 } 2868 2869 if (replType != valueTy && replType != valueRangeTy) { 2870 return emitError(replExpr->getLoc(), 2871 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2872 "expression, but got `{0}`", 2873 replType)); 2874 } 2875 } 2876 2877 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2878 } 2879 2880 FailureOr<ast::RewriteStmt *> 2881 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2882 ast::CompoundStmt *rewriteBody) { 2883 // Check that root is an Operation. 2884 ast::Type rootType = rootOp->getType(); 2885 if (!rootType.isa<ast::OperationType>()) { 2886 return emitError( 2887 rootOp->getLoc(), 2888 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2889 } 2890 2891 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2892 } 2893 2894 //===----------------------------------------------------------------------===// 2895 // Code Completion 2896 //===----------------------------------------------------------------------===// 2897 2898 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2899 ast::Type parentType = parentExpr->getType(); 2900 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2901 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2902 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2903 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2904 return failure(); 2905 } 2906 2907 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2908 if (opName) 2909 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2910 return failure(); 2911 } 2912 2913 LogicalResult 2914 Parser::codeCompleteConstraintName(ast::Type inferredType, 2915 bool allowNonCoreConstraints, 2916 bool allowInlineTypeConstraints) { 2917 codeCompleteContext->codeCompleteConstraintName( 2918 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2919 curDeclScope); 2920 return failure(); 2921 } 2922 2923 LogicalResult Parser::codeCompleteDialectName() { 2924 codeCompleteContext->codeCompleteDialectName(); 2925 return failure(); 2926 } 2927 2928 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2929 codeCompleteContext->codeCompleteOperationName(dialectName); 2930 return failure(); 2931 } 2932 2933 LogicalResult Parser::codeCompletePatternMetadata() { 2934 codeCompleteContext->codeCompletePatternMetadata(); 2935 return failure(); 2936 } 2937 2938 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 2939 codeCompleteContext->codeCompleteIncludeFilename(curPath); 2940 return failure(); 2941 } 2942 2943 void Parser::codeCompleteCallSignature(ast::Node *parent, 2944 unsigned currentNumArgs) { 2945 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2946 if (!callableDecl) 2947 return; 2948 2949 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2950 } 2951 2952 void Parser::codeCompleteOperationOperandsSignature( 2953 Optional<StringRef> opName, unsigned currentNumOperands) { 2954 codeCompleteContext->codeCompleteOperationOperandsSignature( 2955 opName, currentNumOperands); 2956 } 2957 2958 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2959 unsigned currentNumResults) { 2960 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2961 currentNumResults); 2962 } 2963 2964 //===----------------------------------------------------------------------===// 2965 // Parser 2966 //===----------------------------------------------------------------------===// 2967 2968 FailureOr<ast::Module *> 2969 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2970 CodeCompleteContext *codeCompleteContext) { 2971 Parser parser(ctx, sourceMgr, codeCompleteContext); 2972 return parser.parseModule(); 2973 } 2974