1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 428 429 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 430 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 431 unsigned currentNumOperands); 432 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 433 unsigned currentNumResults); 434 435 //===--------------------------------------------------------------------===// 436 // Lexer Utilities 437 //===--------------------------------------------------------------------===// 438 439 /// If the current token has the specified kind, consume it and return true. 440 /// If not, return false. 441 bool consumeIf(Token::Kind kind) { 442 if (curToken.isNot(kind)) 443 return false; 444 consumeToken(kind); 445 return true; 446 } 447 448 /// Advance the current lexer onto the next token. 449 void consumeToken() { 450 assert(curToken.isNot(Token::eof, Token::error) && 451 "shouldn't advance past EOF or errors"); 452 curToken = lexer.lexToken(); 453 } 454 455 /// Advance the current lexer onto the next token, asserting what the expected 456 /// current token is. This is preferred to the above method because it leads 457 /// to more self-documenting code with better checking. 458 void consumeToken(Token::Kind kind) { 459 assert(curToken.is(kind) && "consumed an unexpected token"); 460 consumeToken(); 461 } 462 463 /// Reset the lexer to the location at the given position. 464 void resetToken(SMRange tokLoc) { 465 lexer.resetPointer(tokLoc.Start.getPointer()); 466 curToken = lexer.lexToken(); 467 } 468 469 /// Consume the specified token if present and return success. On failure, 470 /// output a diagnostic and return failure. 471 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 472 if (curToken.getKind() != kind) 473 return emitError(curToken.getLoc(), msg); 474 consumeToken(); 475 return success(); 476 } 477 LogicalResult emitError(SMRange loc, const Twine &msg) { 478 lexer.emitError(loc, msg); 479 return failure(); 480 } 481 LogicalResult emitError(const Twine &msg) { 482 return emitError(curToken.getLoc(), msg); 483 } 484 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 485 const Twine ¬e) { 486 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 487 return failure(); 488 } 489 490 //===--------------------------------------------------------------------===// 491 // Fields 492 //===--------------------------------------------------------------------===// 493 494 /// The owning AST context. 495 ast::Context &ctx; 496 497 /// The lexer of this parser. 498 Lexer lexer; 499 500 /// The current token within the lexer. 501 Token curToken; 502 503 /// The most recently defined decl scope. 504 ast::DeclScope *curDeclScope = nullptr; 505 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 506 507 /// The current context of the parser. 508 ParserContext parserContext = ParserContext::Global; 509 510 /// Cached types to simplify verification and expression creation. 511 ast::Type valueTy, valueRangeTy; 512 ast::Type typeTy, typeRangeTy; 513 ast::Type attrTy; 514 515 /// A counter used when naming anonymous constraints and rewrites. 516 unsigned anonymousDeclNameCounter = 0; 517 518 /// The optional code completion context. 519 CodeCompleteContext *codeCompleteContext; 520 }; 521 } // namespace 522 523 FailureOr<ast::Module *> Parser::parseModule() { 524 SMLoc moduleLoc = curToken.getStartLoc(); 525 pushDeclScope(); 526 527 // Parse the top-level decls of the module. 528 SmallVector<ast::Decl *> decls; 529 if (failed(parseModuleBody(decls))) 530 return popDeclScope(), failure(); 531 532 popDeclScope(); 533 return ast::Module::create(ctx, moduleLoc, decls); 534 } 535 536 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 537 while (curToken.isNot(Token::eof)) { 538 if (curToken.is(Token::directive)) { 539 if (failed(parseDirective(decls))) 540 return failure(); 541 continue; 542 } 543 544 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 545 if (failed(decl)) 546 return failure(); 547 decls.push_back(*decl); 548 } 549 return success(); 550 } 551 552 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 553 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 554 valueRangeTy); 555 } 556 557 LogicalResult Parser::convertExpressionTo( 558 ast::Expr *&expr, ast::Type type, 559 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 560 ast::Type exprType = expr->getType(); 561 if (exprType == type) 562 return success(); 563 564 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 565 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 566 expr->getLoc(), llvm::formatv("unable to convert expression of type " 567 "`{0}` to the expected type of " 568 "`{1}`", 569 exprType, type)); 570 if (noteAttachFn) 571 noteAttachFn(*diag); 572 return diag; 573 }; 574 575 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 576 // Two operation types are compatible if they have the same name, or if the 577 // expected type is more general. 578 if (auto opType = type.dyn_cast<ast::OperationType>()) { 579 if (opType.getName()) 580 return emitConvertError(); 581 return success(); 582 } 583 584 // An operation can always convert to a ValueRange. 585 if (type == valueRangeTy) { 586 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 587 valueRangeTy); 588 return success(); 589 } 590 591 // Allow conversion to a single value by constraining the result range. 592 if (type == valueTy) { 593 // If the operation is registered, we can verify if it can ever have a 594 // single result. 595 Optional<StringRef> opName = exprOpType.getName(); 596 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 597 if (odsOp->getResults().empty()) { 598 return emitConvertError()->attachNote( 599 llvm::formatv("see the definition of `{0}`, which was defined " 600 "with zero results", 601 odsOp->getName()), 602 odsOp->getLoc()); 603 } 604 605 unsigned numSingleResults = llvm::count_if( 606 odsOp->getResults(), [](const ods::OperandOrResult &result) { 607 return result.getVariableLengthKind() == 608 ods::VariableLengthKind::Single; 609 }); 610 if (numSingleResults > 1) { 611 return emitConvertError()->attachNote( 612 llvm::formatv("see the definition of `{0}`, which was defined " 613 "with at least {1} results", 614 odsOp->getName(), numSingleResults), 615 odsOp->getLoc()); 616 } 617 } 618 619 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 620 valueTy); 621 return success(); 622 } 623 return emitConvertError(); 624 } 625 626 // FIXME: Decide how to allow/support converting a single result to multiple, 627 // and multiple to a single result. For now, we just allow Single->Range, 628 // but this isn't something really supported in the PDL dialect. We should 629 // figure out some way to support both. 630 if ((exprType == valueTy || exprType == valueRangeTy) && 631 (type == valueTy || type == valueRangeTy)) 632 return success(); 633 if ((exprType == typeTy || exprType == typeRangeTy) && 634 (type == typeTy || type == typeRangeTy)) 635 return success(); 636 637 // Handle tuple types. 638 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 639 auto tupleType = type.dyn_cast<ast::TupleType>(); 640 if (!tupleType || tupleType.size() != exprTupleType.size()) 641 return emitConvertError(); 642 643 // Build a new tuple expression using each of the elements of the current 644 // tuple. 645 SmallVector<ast::Expr *> newExprs; 646 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 647 newExprs.push_back(ast::MemberAccessExpr::create( 648 ctx, expr->getLoc(), expr, llvm::to_string(i), 649 exprTupleType.getElementTypes()[i])); 650 651 auto diagFn = [&](ast::Diagnostic &diag) { 652 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 653 i, exprTupleType)); 654 if (noteAttachFn) 655 noteAttachFn(diag); 656 }; 657 if (failed(convertExpressionTo(newExprs.back(), 658 tupleType.getElementTypes()[i], diagFn))) 659 return failure(); 660 } 661 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 662 tupleType.getElementNames()); 663 return success(); 664 } 665 666 return emitConvertError(); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Directives 671 672 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 673 StringRef directive = curToken.getSpelling(); 674 if (directive == "#include") 675 return parseInclude(decls); 676 677 return emitError("unknown directive `" + directive + "`"); 678 } 679 680 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 681 SMRange loc = curToken.getLoc(); 682 consumeToken(Token::directive); 683 684 // Handle code completion of the include file path. 685 if (curToken.is(Token::code_complete_string)) 686 return codeCompleteIncludeFilename(curToken.getStringValue()); 687 688 // Parse the file being included. 689 if (!curToken.isString()) 690 return emitError(loc, 691 "expected string file name after `include` directive"); 692 SMRange fileLoc = curToken.getLoc(); 693 std::string filenameStr = curToken.getStringValue(); 694 StringRef filename = filenameStr; 695 consumeToken(); 696 697 // Check the type of include. If ending with `.pdll`, this is another pdl file 698 // to be parsed along with the current module. 699 if (filename.endswith(".pdll")) { 700 if (failed(lexer.pushInclude(filename, fileLoc))) 701 return emitError(fileLoc, 702 "unable to open include file `" + filename + "`"); 703 704 // If we added the include successfully, parse it into the current module. 705 // Make sure to update to the next token after we finish parsing the nested 706 // file. 707 curToken = lexer.lexToken(); 708 LogicalResult result = parseModuleBody(decls); 709 curToken = lexer.lexToken(); 710 return result; 711 } 712 713 // Otherwise, this must be a `.td` include. 714 if (filename.endswith(".td")) 715 return parseTdInclude(filename, fileLoc, decls); 716 717 return emitError(fileLoc, 718 "expected include filename to end with `.pdll` or `.td`"); 719 } 720 721 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 722 SmallVectorImpl<ast::Decl *> &decls) { 723 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 724 725 // Use the source manager to open the file, but don't yet add it. 726 std::string includedFile; 727 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 728 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 729 if (!includeBuffer) 730 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 731 732 // Setup the source manager for parsing the tablegen file. 733 llvm::SourceMgr tdSrcMgr; 734 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 735 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 736 737 // This class provides a context argument for the llvm::SourceMgr diagnostic 738 // handler. 739 struct DiagHandlerContext { 740 Parser &parser; 741 StringRef filename; 742 llvm::SMRange loc; 743 } handlerContext{*this, filename, fileLoc}; 744 745 // Set the diagnostic handler for the tablegen source manager. 746 tdSrcMgr.setDiagHandler( 747 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 748 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 749 (void)ctx->parser.emitError( 750 ctx->loc, 751 llvm::formatv("error while processing include file `{0}`: {1}", 752 ctx->filename, diag.getMessage())); 753 }, 754 &handlerContext); 755 756 // Parse the tablegen file. 757 llvm::RecordKeeper tdRecords; 758 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 759 return failure(); 760 761 // Process the parsed records. 762 processTdIncludeRecords(tdRecords, decls); 763 764 // After we are done processing, move all of the tablegen source buffers to 765 // the main parser source mgr. This allows for directly using source locations 766 // from the .td files without needing to remap them. 767 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 768 return success(); 769 } 770 771 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 772 SmallVectorImpl<ast::Decl *> &decls) { 773 // Return the length kind of the given value. 774 auto getLengthKind = [](const auto &value) { 775 if (value.isOptional()) 776 return ods::VariableLengthKind::Optional; 777 return value.isVariadic() ? ods::VariableLengthKind::Variadic 778 : ods::VariableLengthKind::Single; 779 }; 780 781 // Insert a type constraint into the ODS context. 782 ods::Context &odsContext = ctx.getODSContext(); 783 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 784 -> const ods::TypeConstraint & { 785 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 786 cst.constraint.getSummary(), 787 cst.constraint.getCPPClassName()); 788 }; 789 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 790 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 791 }; 792 793 // Process the parsed tablegen records to build ODS information. 794 /// Operations. 795 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 796 tblgen::Operator op(def); 797 798 bool inserted = false; 799 ods::Operation *odsOp = nullptr; 800 std::tie(odsOp, inserted) = 801 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 802 op.getDescription(), op.getLoc().front()); 803 804 // Ignore operations that have already been added. 805 if (!inserted) 806 continue; 807 808 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 809 odsOp->appendAttribute( 810 attr.name, attr.attr.isOptional(), 811 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 812 attr.attr.getSummary(), 813 attr.attr.getStorageType())); 814 } 815 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 816 odsOp->appendOperand(operand.name, getLengthKind(operand), 817 addTypeConstraint(operand)); 818 } 819 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 820 odsOp->appendResult(result.name, getLengthKind(result), 821 addTypeConstraint(result)); 822 } 823 } 824 /// Attr constraints. 825 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 826 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 827 decls.push_back( 828 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 829 tblgen::AttrConstraint(def), 830 convertLocToRange(def->getLoc().front()), attrTy)); 831 } 832 } 833 /// Type constraints. 834 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 835 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 836 decls.push_back( 837 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 838 tblgen::TypeConstraint(def), 839 convertLocToRange(def->getLoc().front()), typeTy)); 840 } 841 } 842 /// Interfaces. 843 ast::Type opTy = ast::OperationType::get(ctx); 844 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 845 StringRef name = def->getName(); 846 if (def->isAnonymous() || curDeclScope->lookup(name) || 847 def->isSubClassOf("DeclareInterfaceMethods")) 848 continue; 849 SMRange loc = convertLocToRange(def->getLoc().front()); 850 851 StringRef className = def->getValueAsString("cppClassName"); 852 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 853 std::string codeBlock = 854 llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));", 855 cppNamespace, className) 856 .str(); 857 858 if (def->isSubClassOf("OpInterface")) { 859 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 860 name, codeBlock, loc, opTy)); 861 } else if (def->isSubClassOf("AttrInterface")) { 862 decls.push_back( 863 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 864 name, codeBlock, loc, attrTy)); 865 } else if (def->isSubClassOf("TypeInterface")) { 866 decls.push_back( 867 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 868 name, codeBlock, loc, typeTy)); 869 } 870 } 871 } 872 873 template <typename ConstraintT> 874 ast::Decl * 875 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 876 SMRange loc, ast::Type type) { 877 // Build the single input parameter. 878 ast::DeclScope *argScope = pushDeclScope(); 879 auto *paramVar = ast::VariableDecl::create( 880 ctx, ast::Name::create(ctx, "self", loc), type, 881 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 882 argScope->add(paramVar); 883 popDeclScope(); 884 885 // Build the native constraint. 886 auto *constraintDecl = ast::UserConstraintDecl::createNative( 887 ctx, ast::Name::create(ctx, name, loc), paramVar, 888 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 889 curDeclScope->add(constraintDecl); 890 return constraintDecl; 891 } 892 893 template <typename ConstraintT> 894 ast::Decl * 895 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 896 SMRange loc, ast::Type type) { 897 // Format the condition template. 898 tblgen::FmtContext fmtContext; 899 fmtContext.withSelf("self"); 900 std::string codeBlock = tblgen::tgfmt( 901 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 902 &fmtContext); 903 904 return createODSNativePDLLConstraintDecl<ConstraintT>( 905 constraint.getUniqueDefName(), codeBlock, loc, type); 906 } 907 908 //===----------------------------------------------------------------------===// 909 // Decls 910 911 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 912 FailureOr<ast::Decl *> decl; 913 switch (curToken.getKind()) { 914 case Token::kw_Constraint: 915 decl = parseUserConstraintDecl(); 916 break; 917 case Token::kw_Pattern: 918 decl = parsePatternDecl(); 919 break; 920 case Token::kw_Rewrite: 921 decl = parseUserRewriteDecl(); 922 break; 923 default: 924 return emitError("expected top-level declaration, such as a `Pattern`"); 925 } 926 if (failed(decl)) 927 return failure(); 928 929 // If the decl has a name, add it to the current scope. 930 if (const ast::Name *name = (*decl)->getName()) { 931 if (failed(checkDefineNamedDecl(*name))) 932 return failure(); 933 curDeclScope->add(*decl); 934 } 935 return decl; 936 } 937 938 FailureOr<ast::NamedAttributeDecl *> 939 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 940 // Check for name code completion. 941 if (curToken.is(Token::code_complete)) 942 return codeCompleteAttributeName(parentOpName); 943 944 std::string attrNameStr; 945 if (curToken.isString()) 946 attrNameStr = curToken.getStringValue(); 947 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 948 attrNameStr = curToken.getSpelling().str(); 949 else 950 return emitError("expected identifier or string attribute name"); 951 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 952 consumeToken(); 953 954 // Check for a value of the attribute. 955 ast::Expr *attrValue = nullptr; 956 if (consumeIf(Token::equal)) { 957 FailureOr<ast::Expr *> attrExpr = parseExpr(); 958 if (failed(attrExpr)) 959 return failure(); 960 attrValue = *attrExpr; 961 } else { 962 // If there isn't a concrete value, create an expression representing a 963 // UnitAttr. 964 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 965 } 966 967 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 968 } 969 970 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 971 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 972 bool expectTerminalSemicolon) { 973 consumeToken(Token::equal_arrow); 974 975 // Parse the single statement of the lambda body. 976 SMLoc bodyStartLoc = curToken.getStartLoc(); 977 pushDeclScope(); 978 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 979 bool failedToParse = 980 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 981 popDeclScope(); 982 if (failedToParse) 983 return failure(); 984 985 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 986 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 987 } 988 989 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 990 // Ensure that the argument is named. 991 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 992 return emitError("expected identifier argument name"); 993 994 // Parse the argument similarly to a normal variable. 995 StringRef name = curToken.getSpelling(); 996 SMRange nameLoc = curToken.getLoc(); 997 consumeToken(); 998 999 if (failed( 1000 parseToken(Token::colon, "expected `:` before argument constraint"))) 1001 return failure(); 1002 1003 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1004 if (failed(cst)) 1005 return failure(); 1006 1007 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1008 } 1009 1010 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1011 // Check to see if this result is named. 1012 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1013 // Check to see if this name actually refers to a Constraint. 1014 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1015 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1016 // If yes, and this is a Rewrite, give a nice error message as non-Core 1017 // constraints are not supported on Rewrite results. 1018 if (parserContext == ParserContext::Rewrite) { 1019 return emitError( 1020 "`Rewrite` results are only permitted to use core constraints, " 1021 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1022 } 1023 1024 // Otherwise, parse this as an unnamed result variable. 1025 } else { 1026 // If it wasn't a constraint, parse the result similarly to a variable. If 1027 // there is already an existing decl, we will emit an error when defining 1028 // this variable later. 1029 StringRef name = curToken.getSpelling(); 1030 SMRange nameLoc = curToken.getLoc(); 1031 consumeToken(); 1032 1033 if (failed(parseToken(Token::colon, 1034 "expected `:` before result constraint"))) 1035 return failure(); 1036 1037 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1038 if (failed(cst)) 1039 return failure(); 1040 1041 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1042 } 1043 } 1044 1045 // If it isn't named, we parse the constraint directly and create an unnamed 1046 // result variable. 1047 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1048 if (failed(cst)) 1049 return failure(); 1050 1051 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1052 } 1053 1054 FailureOr<ast::UserConstraintDecl *> 1055 Parser::parseUserConstraintDecl(bool isInline) { 1056 // Constraints and rewrites have very similar formats, dispatch to a shared 1057 // interface for parsing. 1058 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1059 [&](auto &&...args) { 1060 return this->parseUserPDLLConstraintDecl(args...); 1061 }, 1062 ParserContext::Constraint, "constraint", isInline); 1063 } 1064 1065 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1066 FailureOr<ast::UserConstraintDecl *> decl = 1067 parseUserConstraintDecl(/*isInline=*/true); 1068 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1069 return failure(); 1070 1071 curDeclScope->add(*decl); 1072 return decl; 1073 } 1074 1075 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1076 const ast::Name &name, bool isInline, 1077 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1078 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1079 // Push the argument scope back onto the list, so that the body can 1080 // reference arguments. 1081 pushDeclScope(argumentScope); 1082 1083 // Parse the body of the constraint. The body is either defined as a compound 1084 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1085 ast::CompoundStmt *body; 1086 if (curToken.is(Token::equal_arrow)) { 1087 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1088 [&](ast::Stmt *&stmt) -> LogicalResult { 1089 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1090 if (!stmtExpr) { 1091 return emitError(stmt->getLoc(), 1092 "expected `Constraint` lambda body to contain a " 1093 "single expression"); 1094 } 1095 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1096 return success(); 1097 }, 1098 /*expectTerminalSemicolon=*/!isInline); 1099 if (failed(bodyResult)) 1100 return failure(); 1101 body = *bodyResult; 1102 } else { 1103 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1104 if (failed(bodyResult)) 1105 return failure(); 1106 body = *bodyResult; 1107 1108 // Verify the structure of the body. 1109 auto bodyIt = body->begin(), bodyE = body->end(); 1110 for (; bodyIt != bodyE; ++bodyIt) 1111 if (isa<ast::ReturnStmt>(*bodyIt)) 1112 break; 1113 if (failed(validateUserConstraintOrRewriteReturn( 1114 "Constraint", body, bodyIt, bodyE, results, resultType))) 1115 return failure(); 1116 } 1117 popDeclScope(); 1118 1119 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1120 name, arguments, results, resultType, body); 1121 } 1122 1123 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1124 // Constraints and rewrites have very similar formats, dispatch to a shared 1125 // interface for parsing. 1126 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1127 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1128 ParserContext::Rewrite, "rewrite", isInline); 1129 } 1130 1131 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1132 FailureOr<ast::UserRewriteDecl *> decl = 1133 parseUserRewriteDecl(/*isInline=*/true); 1134 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1135 return failure(); 1136 1137 curDeclScope->add(*decl); 1138 return decl; 1139 } 1140 1141 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1142 const ast::Name &name, bool isInline, 1143 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1144 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1145 // Push the argument scope back onto the list, so that the body can 1146 // reference arguments. 1147 curDeclScope = argumentScope; 1148 ast::CompoundStmt *body; 1149 if (curToken.is(Token::equal_arrow)) { 1150 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1151 [&](ast::Stmt *&statement) -> LogicalResult { 1152 if (isa<ast::OpRewriteStmt>(statement)) 1153 return success(); 1154 1155 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1156 if (!statementExpr) { 1157 return emitError( 1158 statement->getLoc(), 1159 "expected `Rewrite` lambda body to contain a single expression " 1160 "or an operation rewrite statement; such as `erase`, " 1161 "`replace`, or `rewrite`"); 1162 } 1163 statement = 1164 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1165 return success(); 1166 }, 1167 /*expectTerminalSemicolon=*/!isInline); 1168 if (failed(bodyResult)) 1169 return failure(); 1170 body = *bodyResult; 1171 } else { 1172 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1173 if (failed(bodyResult)) 1174 return failure(); 1175 body = *bodyResult; 1176 } 1177 popDeclScope(); 1178 1179 // Verify the structure of the body. 1180 auto bodyIt = body->begin(), bodyE = body->end(); 1181 for (; bodyIt != bodyE; ++bodyIt) 1182 if (isa<ast::ReturnStmt>(*bodyIt)) 1183 break; 1184 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1185 bodyE, results, resultType))) 1186 return failure(); 1187 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1188 name, arguments, results, resultType, body); 1189 } 1190 1191 template <typename T, typename ParseUserPDLLDeclFnT> 1192 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1193 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1194 StringRef anonymousNamePrefix, bool isInline) { 1195 SMRange loc = curToken.getLoc(); 1196 consumeToken(); 1197 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1198 1199 // Parse the name of the decl. 1200 const ast::Name *name = nullptr; 1201 if (curToken.isNot(Token::identifier)) { 1202 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1203 // in C++, so being unnamed is fine. 1204 if (!isInline) 1205 return emitError("expected identifier name"); 1206 1207 // Create a unique anonymous name to use, as the name for this decl is not 1208 // important. 1209 std::string anonName = 1210 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1211 anonymousDeclNameCounter++) 1212 .str(); 1213 name = &ast::Name::create(ctx, anonName, loc); 1214 } else { 1215 // If a name was provided, we can use it directly. 1216 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1217 consumeToken(Token::identifier); 1218 } 1219 1220 // Parse the functional signature of the decl. 1221 SmallVector<ast::VariableDecl *> arguments, results; 1222 ast::DeclScope *argumentScope; 1223 ast::Type resultType; 1224 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1225 argumentScope, resultType))) 1226 return failure(); 1227 1228 // Check to see which type of constraint this is. If the constraint contains a 1229 // compound body, this is a PDLL decl. 1230 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1231 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1232 resultType); 1233 1234 // Otherwise, this is a native decl. 1235 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1236 results, resultType); 1237 } 1238 1239 template <typename T> 1240 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1241 const ast::Name &name, bool isInline, 1242 ArrayRef<ast::VariableDecl *> arguments, 1243 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1244 // If followed by a string, the native code body has also been specified. 1245 std::string codeStrStorage; 1246 Optional<StringRef> optCodeStr; 1247 if (curToken.isString()) { 1248 codeStrStorage = curToken.getStringValue(); 1249 optCodeStr = codeStrStorage; 1250 consumeToken(); 1251 } else if (isInline) { 1252 return emitError(name.getLoc(), 1253 "external declarations must be declared in global scope"); 1254 } else if (curToken.is(Token::error)) { 1255 return failure(); 1256 } 1257 if (failed(parseToken(Token::semicolon, 1258 "expected `;` after native declaration"))) 1259 return failure(); 1260 // TODO: PDL should be able to support constraint results in certain 1261 // situations, we should revise this. 1262 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1263 return emitError( 1264 "native Constraints currently do not support returning results"); 1265 } 1266 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1267 } 1268 1269 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1270 SmallVectorImpl<ast::VariableDecl *> &arguments, 1271 SmallVectorImpl<ast::VariableDecl *> &results, 1272 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1273 // Parse the argument list of the decl. 1274 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1275 return failure(); 1276 1277 argumentScope = pushDeclScope(); 1278 if (curToken.isNot(Token::r_paren)) { 1279 do { 1280 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1281 if (failed(argument)) 1282 return failure(); 1283 arguments.emplace_back(*argument); 1284 } while (consumeIf(Token::comma)); 1285 } 1286 popDeclScope(); 1287 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1288 return failure(); 1289 1290 // Parse the results of the decl. 1291 pushDeclScope(); 1292 if (consumeIf(Token::arrow)) { 1293 auto parseResultFn = [&]() -> LogicalResult { 1294 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1295 if (failed(result)) 1296 return failure(); 1297 results.emplace_back(*result); 1298 return success(); 1299 }; 1300 1301 // Check for a list of results. 1302 if (consumeIf(Token::l_paren)) { 1303 do { 1304 if (failed(parseResultFn())) 1305 return failure(); 1306 } while (consumeIf(Token::comma)); 1307 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1308 return failure(); 1309 1310 // Otherwise, there is only one result. 1311 } else if (failed(parseResultFn())) { 1312 return failure(); 1313 } 1314 } 1315 popDeclScope(); 1316 1317 // Compute the result type of the decl. 1318 resultType = createUserConstraintRewriteResultType(results); 1319 1320 // Verify that results are only named if there are more than one. 1321 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1322 return emitError( 1323 results.front()->getLoc(), 1324 "cannot create a single-element tuple with an element label"); 1325 } 1326 return success(); 1327 } 1328 1329 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1330 StringRef declType, ast::CompoundStmt *body, 1331 ArrayRef<ast::Stmt *>::iterator bodyIt, 1332 ArrayRef<ast::Stmt *>::iterator bodyE, 1333 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1334 // Handle if a `return` was provided. 1335 if (bodyIt != bodyE) { 1336 // Emit an error if we have trailing statements after the return. 1337 if (std::next(bodyIt) != bodyE) { 1338 return emitError( 1339 (*std::next(bodyIt))->getLoc(), 1340 llvm::formatv("`return` terminated the `{0}` body, but found " 1341 "trailing statements afterwards", 1342 declType)); 1343 } 1344 1345 // Otherwise if a return wasn't provided, check that no results are 1346 // expected. 1347 } else if (!results.empty()) { 1348 return emitError( 1349 {body->getLoc().End, body->getLoc().End}, 1350 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1351 declType, resultType)); 1352 } 1353 return success(); 1354 } 1355 1356 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1357 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1358 if (isa<ast::OpRewriteStmt>(statement)) 1359 return success(); 1360 return emitError( 1361 statement->getLoc(), 1362 "expected Pattern lambda body to contain a single operation " 1363 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1364 }); 1365 } 1366 1367 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1368 SMRange loc = curToken.getLoc(); 1369 consumeToken(Token::kw_Pattern); 1370 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1371 ParserContext::PatternMatch); 1372 1373 // Check for an optional identifier for the pattern name. 1374 const ast::Name *name = nullptr; 1375 if (curToken.is(Token::identifier)) { 1376 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1377 consumeToken(Token::identifier); 1378 } 1379 1380 // Parse any pattern metadata. 1381 ParsedPatternMetadata metadata; 1382 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1383 return failure(); 1384 1385 // Parse the pattern body. 1386 ast::CompoundStmt *body; 1387 1388 // Handle a lambda body. 1389 if (curToken.is(Token::equal_arrow)) { 1390 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1391 if (failed(bodyResult)) 1392 return failure(); 1393 body = *bodyResult; 1394 } else { 1395 if (curToken.isNot(Token::l_brace)) 1396 return emitError("expected `{` or `=>` to start pattern body"); 1397 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1398 if (failed(bodyResult)) 1399 return failure(); 1400 body = *bodyResult; 1401 1402 // Verify the body of the pattern. 1403 auto bodyIt = body->begin(), bodyE = body->end(); 1404 for (; bodyIt != bodyE; ++bodyIt) { 1405 if (isa<ast::ReturnStmt>(*bodyIt)) { 1406 return emitError((*bodyIt)->getLoc(), 1407 "`return` statements are only permitted within a " 1408 "`Constraint` or `Rewrite` body"); 1409 } 1410 // Break when we've found the rewrite statement. 1411 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1412 break; 1413 } 1414 if (bodyIt == bodyE) { 1415 return emitError(loc, 1416 "expected Pattern body to terminate with an operation " 1417 "rewrite statement, such as `erase`"); 1418 } 1419 if (std::next(bodyIt) != bodyE) { 1420 return emitError((*std::next(bodyIt))->getLoc(), 1421 "Pattern body was terminated by an operation " 1422 "rewrite statement, but found trailing statements"); 1423 } 1424 } 1425 1426 return createPatternDecl(loc, name, metadata, body); 1427 } 1428 1429 LogicalResult 1430 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1431 Optional<SMRange> benefitLoc; 1432 Optional<SMRange> hasBoundedRecursionLoc; 1433 1434 do { 1435 // Handle metadata code completion. 1436 if (curToken.is(Token::code_complete)) 1437 return codeCompletePatternMetadata(); 1438 1439 if (curToken.isNot(Token::identifier)) 1440 return emitError("expected pattern metadata identifier"); 1441 StringRef metadataStr = curToken.getSpelling(); 1442 SMRange metadataLoc = curToken.getLoc(); 1443 consumeToken(Token::identifier); 1444 1445 // Parse the benefit metadata: benefit(<integer-value>) 1446 if (metadataStr == "benefit") { 1447 if (benefitLoc) { 1448 return emitErrorAndNote(metadataLoc, 1449 "pattern benefit has already been specified", 1450 *benefitLoc, "see previous definition here"); 1451 } 1452 if (failed(parseToken(Token::l_paren, 1453 "expected `(` before pattern benefit"))) 1454 return failure(); 1455 1456 uint16_t benefitValue = 0; 1457 if (curToken.isNot(Token::integer)) 1458 return emitError("expected integral pattern benefit"); 1459 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1460 return emitError( 1461 "expected pattern benefit to fit within a 16-bit integer"); 1462 consumeToken(Token::integer); 1463 1464 metadata.benefit = benefitValue; 1465 benefitLoc = metadataLoc; 1466 1467 if (failed( 1468 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1469 return failure(); 1470 continue; 1471 } 1472 1473 // Parse the bounded recursion metadata: recursion 1474 if (metadataStr == "recursion") { 1475 if (hasBoundedRecursionLoc) { 1476 return emitErrorAndNote( 1477 metadataLoc, 1478 "pattern recursion metadata has already been specified", 1479 *hasBoundedRecursionLoc, "see previous definition here"); 1480 } 1481 metadata.hasBoundedRecursion = true; 1482 hasBoundedRecursionLoc = metadataLoc; 1483 continue; 1484 } 1485 1486 return emitError(metadataLoc, "unknown pattern metadata"); 1487 } while (consumeIf(Token::comma)); 1488 1489 return success(); 1490 } 1491 1492 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1493 consumeToken(Token::less); 1494 1495 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1496 if (failed(typeExpr) || 1497 failed(parseToken(Token::greater, 1498 "expected `>` after variable type constraint"))) 1499 return failure(); 1500 return typeExpr; 1501 } 1502 1503 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1504 assert(curDeclScope && "defining decl outside of a decl scope"); 1505 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1506 return emitErrorAndNote( 1507 name.getLoc(), "`" + name.getName() + "` has already been defined", 1508 lastDecl->getName()->getLoc(), "see previous definition here"); 1509 } 1510 return success(); 1511 } 1512 1513 FailureOr<ast::VariableDecl *> 1514 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1515 ast::Expr *initExpr, 1516 ArrayRef<ast::ConstraintRef> constraints) { 1517 assert(curDeclScope && "defining variable outside of decl scope"); 1518 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1519 1520 // If the name of the variable indicates a special variable, we don't add it 1521 // to the scope. This variable is local to the definition point. 1522 if (name.empty() || name == "_") { 1523 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1524 constraints); 1525 } 1526 if (failed(checkDefineNamedDecl(nameDecl))) 1527 return failure(); 1528 1529 auto *varDecl = 1530 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1531 curDeclScope->add(varDecl); 1532 return varDecl; 1533 } 1534 1535 FailureOr<ast::VariableDecl *> 1536 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1537 ArrayRef<ast::ConstraintRef> constraints) { 1538 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1539 constraints); 1540 } 1541 1542 LogicalResult Parser::parseVariableDeclConstraintList( 1543 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1544 Optional<SMRange> typeConstraint; 1545 auto parseSingleConstraint = [&] { 1546 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1547 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1548 /*allowNonCoreConstraints=*/true); 1549 if (failed(constraint)) 1550 return failure(); 1551 constraints.push_back(*constraint); 1552 return success(); 1553 }; 1554 1555 // Check to see if this is a single constraint, or a list. 1556 if (!consumeIf(Token::l_square)) 1557 return parseSingleConstraint(); 1558 1559 do { 1560 if (failed(parseSingleConstraint())) 1561 return failure(); 1562 } while (consumeIf(Token::comma)); 1563 return parseToken(Token::r_square, "expected `]` after constraint list"); 1564 } 1565 1566 FailureOr<ast::ConstraintRef> 1567 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1568 ArrayRef<ast::ConstraintRef> existingConstraints, 1569 bool allowInlineTypeConstraints, 1570 bool allowNonCoreConstraints) { 1571 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1572 if (!allowInlineTypeConstraints) { 1573 return emitError( 1574 curToken.getLoc(), 1575 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1576 "permitted on arguments or results"); 1577 } 1578 if (typeConstraint) 1579 return emitErrorAndNote( 1580 curToken.getLoc(), 1581 "the type of this variable has already been constrained", 1582 *typeConstraint, "see previous constraint location here"); 1583 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1584 if (failed(constraintExpr)) 1585 return failure(); 1586 typeExpr = *constraintExpr; 1587 typeConstraint = typeExpr->getLoc(); 1588 return success(); 1589 }; 1590 1591 SMRange loc = curToken.getLoc(); 1592 switch (curToken.getKind()) { 1593 case Token::kw_Attr: { 1594 consumeToken(Token::kw_Attr); 1595 1596 // Check for a type constraint. 1597 ast::Expr *typeExpr = nullptr; 1598 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1599 return failure(); 1600 return ast::ConstraintRef( 1601 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1602 } 1603 case Token::kw_Op: { 1604 consumeToken(Token::kw_Op); 1605 1606 // Parse an optional operation name. If the name isn't provided, this refers 1607 // to "any" operation. 1608 FailureOr<ast::OpNameDecl *> opName = 1609 parseWrappedOperationName(/*allowEmptyName=*/true); 1610 if (failed(opName)) 1611 return failure(); 1612 1613 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1614 loc); 1615 } 1616 case Token::kw_Type: 1617 consumeToken(Token::kw_Type); 1618 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1619 case Token::kw_TypeRange: 1620 consumeToken(Token::kw_TypeRange); 1621 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1622 loc); 1623 case Token::kw_Value: { 1624 consumeToken(Token::kw_Value); 1625 1626 // Check for a type constraint. 1627 ast::Expr *typeExpr = nullptr; 1628 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1629 return failure(); 1630 1631 return ast::ConstraintRef( 1632 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1633 } 1634 case Token::kw_ValueRange: { 1635 consumeToken(Token::kw_ValueRange); 1636 1637 // Check for a type constraint. 1638 ast::Expr *typeExpr = nullptr; 1639 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1640 return failure(); 1641 1642 return ast::ConstraintRef( 1643 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1644 } 1645 1646 case Token::kw_Constraint: { 1647 // Handle an inline constraint. 1648 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1649 if (failed(decl)) 1650 return failure(); 1651 return ast::ConstraintRef(*decl, loc); 1652 } 1653 case Token::identifier: { 1654 StringRef constraintName = curToken.getSpelling(); 1655 consumeToken(Token::identifier); 1656 1657 // Lookup the referenced constraint. 1658 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1659 if (!cstDecl) { 1660 return emitError(loc, "unknown reference to constraint `" + 1661 constraintName + "`"); 1662 } 1663 1664 // Handle a reference to a proper constraint. 1665 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1666 return ast::ConstraintRef(cst, loc); 1667 1668 return emitErrorAndNote( 1669 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1670 "see the definition of `" + constraintName + "` here"); 1671 } 1672 // Handle single entity constraint code completion. 1673 case Token::code_complete: { 1674 // Try to infer the current type for use by code completion. 1675 ast::Type inferredType; 1676 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1677 allowNonCoreConstraints))) 1678 return failure(); 1679 1680 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1681 allowInlineTypeConstraints); 1682 } 1683 default: 1684 break; 1685 } 1686 return emitError(loc, "expected identifier constraint"); 1687 } 1688 1689 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1690 // Constraint arguments may apply more complex constraints via the arguments. 1691 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1692 1693 Optional<SMRange> typeConstraint; 1694 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1695 /*allowInlineTypeConstraints=*/false, 1696 allowNonCoreConstraints); 1697 } 1698 1699 //===----------------------------------------------------------------------===// 1700 // Exprs 1701 1702 FailureOr<ast::Expr *> Parser::parseExpr() { 1703 if (curToken.is(Token::underscore)) 1704 return parseUnderscoreExpr(); 1705 1706 // Parse the LHS expression. 1707 FailureOr<ast::Expr *> lhsExpr; 1708 switch (curToken.getKind()) { 1709 case Token::kw_attr: 1710 lhsExpr = parseAttributeExpr(); 1711 break; 1712 case Token::kw_Constraint: 1713 lhsExpr = parseInlineConstraintLambdaExpr(); 1714 break; 1715 case Token::identifier: 1716 lhsExpr = parseIdentifierExpr(); 1717 break; 1718 case Token::kw_op: 1719 lhsExpr = parseOperationExpr(); 1720 break; 1721 case Token::kw_Rewrite: 1722 lhsExpr = parseInlineRewriteLambdaExpr(); 1723 break; 1724 case Token::kw_type: 1725 lhsExpr = parseTypeExpr(); 1726 break; 1727 case Token::l_paren: 1728 lhsExpr = parseTupleExpr(); 1729 break; 1730 default: 1731 return emitError("expected expression"); 1732 } 1733 if (failed(lhsExpr)) 1734 return failure(); 1735 1736 // Check for an operator expression. 1737 while (true) { 1738 switch (curToken.getKind()) { 1739 case Token::dot: 1740 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1741 break; 1742 case Token::l_paren: 1743 lhsExpr = parseCallExpr(*lhsExpr); 1744 break; 1745 default: 1746 return lhsExpr; 1747 } 1748 if (failed(lhsExpr)) 1749 return failure(); 1750 } 1751 } 1752 1753 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1754 SMRange loc = curToken.getLoc(); 1755 consumeToken(Token::kw_attr); 1756 1757 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1758 // identifier. 1759 if (!consumeIf(Token::less)) { 1760 resetToken(loc); 1761 return parseIdentifierExpr(); 1762 } 1763 1764 if (!curToken.isString()) 1765 return emitError("expected string literal containing MLIR attribute"); 1766 std::string attrExpr = curToken.getStringValue(); 1767 consumeToken(); 1768 1769 loc.End = curToken.getEndLoc(); 1770 if (failed( 1771 parseToken(Token::greater, "expected `>` after attribute literal"))) 1772 return failure(); 1773 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1774 } 1775 1776 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1777 consumeToken(Token::l_paren); 1778 1779 // Parse the arguments of the call. 1780 SmallVector<ast::Expr *> arguments; 1781 if (curToken.isNot(Token::r_paren)) { 1782 do { 1783 // Handle code completion for the call arguments. 1784 if (curToken.is(Token::code_complete)) { 1785 codeCompleteCallSignature(parentExpr, arguments.size()); 1786 return failure(); 1787 } 1788 1789 FailureOr<ast::Expr *> argument = parseExpr(); 1790 if (failed(argument)) 1791 return failure(); 1792 arguments.push_back(*argument); 1793 } while (consumeIf(Token::comma)); 1794 } 1795 1796 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1797 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1798 return failure(); 1799 1800 return createCallExpr(loc, parentExpr, arguments); 1801 } 1802 1803 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1804 ast::Decl *decl = curDeclScope->lookup(name); 1805 if (!decl) 1806 return emitError(loc, "undefined reference to `" + name + "`"); 1807 1808 return createDeclRefExpr(loc, decl); 1809 } 1810 1811 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1812 StringRef name = curToken.getSpelling(); 1813 SMRange nameLoc = curToken.getLoc(); 1814 consumeToken(); 1815 1816 // Check to see if this is a decl ref expression that defines a variable 1817 // inline. 1818 if (consumeIf(Token::colon)) { 1819 SmallVector<ast::ConstraintRef> constraints; 1820 if (failed(parseVariableDeclConstraintList(constraints))) 1821 return failure(); 1822 ast::Type type; 1823 if (failed(validateVariableConstraints(constraints, type))) 1824 return failure(); 1825 return createInlineVariableExpr(type, name, nameLoc, constraints); 1826 } 1827 1828 return parseDeclRefExpr(name, nameLoc); 1829 } 1830 1831 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1832 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1833 if (failed(decl)) 1834 return failure(); 1835 1836 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1837 ast::ConstraintType::get(ctx)); 1838 } 1839 1840 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1841 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1842 if (failed(decl)) 1843 return failure(); 1844 1845 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1846 ast::RewriteType::get(ctx)); 1847 } 1848 1849 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1850 SMRange dotLoc = curToken.getLoc(); 1851 consumeToken(Token::dot); 1852 1853 // Check for code completion of the member name. 1854 if (curToken.is(Token::code_complete)) 1855 return codeCompleteMemberAccess(parentExpr); 1856 1857 // Parse the member name. 1858 Token memberNameTok = curToken; 1859 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1860 !memberNameTok.isKeyword()) 1861 return emitError(dotLoc, "expected identifier or numeric member name"); 1862 StringRef memberName = memberNameTok.getSpelling(); 1863 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1864 consumeToken(); 1865 1866 return createMemberAccessExpr(parentExpr, memberName, loc); 1867 } 1868 1869 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1870 SMRange loc = curToken.getLoc(); 1871 1872 // Check for code completion for the dialect name. 1873 if (curToken.is(Token::code_complete)) 1874 return codeCompleteDialectName(); 1875 1876 // Handle the case of an no operation name. 1877 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1878 if (allowEmptyName) 1879 return ast::OpNameDecl::create(ctx, SMRange()); 1880 return emitError("expected dialect namespace"); 1881 } 1882 StringRef name = curToken.getSpelling(); 1883 consumeToken(); 1884 1885 // Otherwise, this is a literal operation name. 1886 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1887 return failure(); 1888 1889 // Check for code completion for the operation name. 1890 if (curToken.is(Token::code_complete)) 1891 return codeCompleteOperationName(name); 1892 1893 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1894 return emitError("expected operation name after dialect namespace"); 1895 1896 name = StringRef(name.data(), name.size() + 1); 1897 do { 1898 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1899 loc.End = curToken.getEndLoc(); 1900 consumeToken(); 1901 } while (curToken.isAny(Token::identifier, Token::dot) || 1902 curToken.isKeyword()); 1903 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1904 } 1905 1906 FailureOr<ast::OpNameDecl *> 1907 Parser::parseWrappedOperationName(bool allowEmptyName) { 1908 if (!consumeIf(Token::less)) 1909 return ast::OpNameDecl::create(ctx, SMRange()); 1910 1911 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1912 if (failed(opNameDecl)) 1913 return failure(); 1914 1915 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1916 return failure(); 1917 return opNameDecl; 1918 } 1919 1920 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1921 SMRange loc = curToken.getLoc(); 1922 consumeToken(Token::kw_op); 1923 1924 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1925 // identifier. 1926 if (curToken.isNot(Token::less)) { 1927 resetToken(loc); 1928 return parseIdentifierExpr(); 1929 } 1930 1931 // Parse the operation name. The name may be elided, in which case the 1932 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1933 // `Operation*`). Operation names within a rewrite context must be named. 1934 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1935 FailureOr<ast::OpNameDecl *> opNameDecl = 1936 parseWrappedOperationName(allowEmptyName); 1937 if (failed(opNameDecl)) 1938 return failure(); 1939 Optional<StringRef> opName = (*opNameDecl)->getName(); 1940 1941 // Functor used to create an implicit range variable, used for implicit "all" 1942 // operand or results variables. 1943 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1944 FailureOr<ast::VariableDecl *> rangeVar = 1945 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1946 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1947 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1948 }; 1949 1950 // Check for the optional list of operands. 1951 SmallVector<ast::Expr *> operands; 1952 if (!consumeIf(Token::l_paren)) { 1953 // If the operand list isn't specified and we are in a match context, define 1954 // an inplace unconstrained operand range corresponding to all of the 1955 // operands of the operation. This avoids treating zero operands the same 1956 // way as "unconstrained operands". 1957 if (parserContext != ParserContext::Rewrite) { 1958 operands.push_back(createImplicitRangeVar( 1959 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1960 } 1961 } else if (!consumeIf(Token::r_paren)) { 1962 // Check for operand signature code completion. 1963 if (curToken.is(Token::code_complete)) { 1964 codeCompleteOperationOperandsSignature(opName, operands.size()); 1965 return failure(); 1966 } 1967 1968 // If the operand list was specified and non-empty, parse the operands. 1969 do { 1970 FailureOr<ast::Expr *> operand = parseExpr(); 1971 if (failed(operand)) 1972 return failure(); 1973 operands.push_back(*operand); 1974 } while (consumeIf(Token::comma)); 1975 1976 if (failed(parseToken(Token::r_paren, 1977 "expected `)` after operation operand list"))) 1978 return failure(); 1979 } 1980 1981 // Check for the optional list of attributes. 1982 SmallVector<ast::NamedAttributeDecl *> attributes; 1983 if (consumeIf(Token::l_brace)) { 1984 do { 1985 FailureOr<ast::NamedAttributeDecl *> decl = 1986 parseNamedAttributeDecl(opName); 1987 if (failed(decl)) 1988 return failure(); 1989 attributes.emplace_back(*decl); 1990 } while (consumeIf(Token::comma)); 1991 1992 if (failed(parseToken(Token::r_brace, 1993 "expected `}` after operation attribute list"))) 1994 return failure(); 1995 } 1996 1997 // Check for the optional list of result types. 1998 SmallVector<ast::Expr *> resultTypes; 1999 if (consumeIf(Token::arrow)) { 2000 if (failed(parseToken(Token::l_paren, 2001 "expected `(` before operation result type list"))) 2002 return failure(); 2003 2004 // Handle the case of an empty result list. 2005 if (!consumeIf(Token::r_paren)) { 2006 do { 2007 // Check for result signature code completion. 2008 if (curToken.is(Token::code_complete)) { 2009 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2010 return failure(); 2011 } 2012 2013 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2014 if (failed(resultTypeExpr)) 2015 return failure(); 2016 resultTypes.push_back(*resultTypeExpr); 2017 } while (consumeIf(Token::comma)); 2018 2019 if (failed(parseToken(Token::r_paren, 2020 "expected `)` after operation result type list"))) 2021 return failure(); 2022 } 2023 } else if (parserContext != ParserContext::Rewrite) { 2024 // If the result list isn't specified and we are in a match context, define 2025 // an inplace unconstrained result range corresponding to all of the results 2026 // of the operation. This avoids treating zero results the same way as 2027 // "unconstrained results". 2028 resultTypes.push_back(createImplicitRangeVar( 2029 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2030 } 2031 2032 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2033 resultTypes); 2034 } 2035 2036 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2037 SMRange loc = curToken.getLoc(); 2038 consumeToken(Token::l_paren); 2039 2040 DenseMap<StringRef, SMRange> usedNames; 2041 SmallVector<StringRef> elementNames; 2042 SmallVector<ast::Expr *> elements; 2043 if (curToken.isNot(Token::r_paren)) { 2044 do { 2045 // Check for the optional element name assignment before the value. 2046 StringRef elementName; 2047 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2048 Token elementNameTok = curToken; 2049 consumeToken(); 2050 2051 // The element name is only present if followed by an `=`. 2052 if (consumeIf(Token::equal)) { 2053 elementName = elementNameTok.getSpelling(); 2054 2055 // Check to see if this name is already used. 2056 auto elementNameIt = 2057 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2058 if (!elementNameIt.second) { 2059 return emitErrorAndNote( 2060 elementNameTok.getLoc(), 2061 llvm::formatv("duplicate tuple element label `{0}`", 2062 elementName), 2063 elementNameIt.first->getSecond(), 2064 "see previous label use here"); 2065 } 2066 } else { 2067 // Otherwise, we treat this as part of an expression so reset the 2068 // lexer. 2069 resetToken(elementNameTok.getLoc()); 2070 } 2071 } 2072 elementNames.push_back(elementName); 2073 2074 // Parse the tuple element value. 2075 FailureOr<ast::Expr *> element = parseExpr(); 2076 if (failed(element)) 2077 return failure(); 2078 elements.push_back(*element); 2079 } while (consumeIf(Token::comma)); 2080 } 2081 loc.End = curToken.getEndLoc(); 2082 if (failed( 2083 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2084 return failure(); 2085 return createTupleExpr(loc, elements, elementNames); 2086 } 2087 2088 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2089 SMRange loc = curToken.getLoc(); 2090 consumeToken(Token::kw_type); 2091 2092 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2093 // identifier. 2094 if (!consumeIf(Token::less)) { 2095 resetToken(loc); 2096 return parseIdentifierExpr(); 2097 } 2098 2099 if (!curToken.isString()) 2100 return emitError("expected string literal containing MLIR type"); 2101 std::string attrExpr = curToken.getStringValue(); 2102 consumeToken(); 2103 2104 loc.End = curToken.getEndLoc(); 2105 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2106 return failure(); 2107 return ast::TypeExpr::create(ctx, loc, attrExpr); 2108 } 2109 2110 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2111 StringRef name = curToken.getSpelling(); 2112 SMRange nameLoc = curToken.getLoc(); 2113 consumeToken(Token::underscore); 2114 2115 // Underscore expressions require a constraint list. 2116 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2117 return failure(); 2118 2119 // Parse the constraints for the expression. 2120 SmallVector<ast::ConstraintRef> constraints; 2121 if (failed(parseVariableDeclConstraintList(constraints))) 2122 return failure(); 2123 2124 ast::Type type; 2125 if (failed(validateVariableConstraints(constraints, type))) 2126 return failure(); 2127 return createInlineVariableExpr(type, name, nameLoc, constraints); 2128 } 2129 2130 //===----------------------------------------------------------------------===// 2131 // Stmts 2132 2133 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2134 FailureOr<ast::Stmt *> stmt; 2135 switch (curToken.getKind()) { 2136 case Token::kw_erase: 2137 stmt = parseEraseStmt(); 2138 break; 2139 case Token::kw_let: 2140 stmt = parseLetStmt(); 2141 break; 2142 case Token::kw_replace: 2143 stmt = parseReplaceStmt(); 2144 break; 2145 case Token::kw_return: 2146 stmt = parseReturnStmt(); 2147 break; 2148 case Token::kw_rewrite: 2149 stmt = parseRewriteStmt(); 2150 break; 2151 default: 2152 stmt = parseExpr(); 2153 break; 2154 } 2155 if (failed(stmt) || 2156 (expectTerminalSemicolon && 2157 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2158 return failure(); 2159 return stmt; 2160 } 2161 2162 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2163 SMLoc startLoc = curToken.getStartLoc(); 2164 consumeToken(Token::l_brace); 2165 2166 // Push a new block scope and parse any nested statements. 2167 pushDeclScope(); 2168 SmallVector<ast::Stmt *> statements; 2169 while (curToken.isNot(Token::r_brace)) { 2170 FailureOr<ast::Stmt *> statement = parseStmt(); 2171 if (failed(statement)) 2172 return popDeclScope(), failure(); 2173 statements.push_back(*statement); 2174 } 2175 popDeclScope(); 2176 2177 // Consume the end brace. 2178 SMRange location(startLoc, curToken.getEndLoc()); 2179 consumeToken(Token::r_brace); 2180 2181 return ast::CompoundStmt::create(ctx, location, statements); 2182 } 2183 2184 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2185 if (parserContext == ParserContext::Constraint) 2186 return emitError("`erase` cannot be used within a Constraint"); 2187 SMRange loc = curToken.getLoc(); 2188 consumeToken(Token::kw_erase); 2189 2190 // Parse the root operation expression. 2191 FailureOr<ast::Expr *> rootOp = parseExpr(); 2192 if (failed(rootOp)) 2193 return failure(); 2194 2195 return createEraseStmt(loc, *rootOp); 2196 } 2197 2198 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2199 SMRange loc = curToken.getLoc(); 2200 consumeToken(Token::kw_let); 2201 2202 // Parse the name of the new variable. 2203 SMRange varLoc = curToken.getLoc(); 2204 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2205 // `_` is a reserved variable name. 2206 if (curToken.is(Token::underscore)) { 2207 return emitError(varLoc, 2208 "`_` may only be used to define \"inline\" variables"); 2209 } 2210 return emitError(varLoc, 2211 "expected identifier after `let` to name a new variable"); 2212 } 2213 StringRef varName = curToken.getSpelling(); 2214 consumeToken(); 2215 2216 // Parse the optional set of constraints. 2217 SmallVector<ast::ConstraintRef> constraints; 2218 if (consumeIf(Token::colon) && 2219 failed(parseVariableDeclConstraintList(constraints))) 2220 return failure(); 2221 2222 // Parse the optional initializer expression. 2223 ast::Expr *initializer = nullptr; 2224 if (consumeIf(Token::equal)) { 2225 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2226 if (failed(initOrFailure)) 2227 return failure(); 2228 initializer = *initOrFailure; 2229 2230 // Check that the constraints are compatible with having an initializer, 2231 // e.g. type constraints cannot be used with initializers. 2232 for (ast::ConstraintRef constraint : constraints) { 2233 LogicalResult result = 2234 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2235 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2236 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2237 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2238 return this->emitError( 2239 constraint.referenceLoc, 2240 "type constraints are not permitted on variables with " 2241 "initializers"); 2242 } 2243 return success(); 2244 }) 2245 .Default(success()); 2246 if (failed(result)) 2247 return failure(); 2248 } 2249 } 2250 2251 FailureOr<ast::VariableDecl *> varDecl = 2252 createVariableDecl(varName, varLoc, initializer, constraints); 2253 if (failed(varDecl)) 2254 return failure(); 2255 return ast::LetStmt::create(ctx, loc, *varDecl); 2256 } 2257 2258 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2259 if (parserContext == ParserContext::Constraint) 2260 return emitError("`replace` cannot be used within a Constraint"); 2261 SMRange loc = curToken.getLoc(); 2262 consumeToken(Token::kw_replace); 2263 2264 // Parse the root operation expression. 2265 FailureOr<ast::Expr *> rootOp = parseExpr(); 2266 if (failed(rootOp)) 2267 return failure(); 2268 2269 if (failed( 2270 parseToken(Token::kw_with, "expected `with` after root operation"))) 2271 return failure(); 2272 2273 // The replacement portion of this statement is within a rewrite context. 2274 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2275 ParserContext::Rewrite); 2276 2277 // Parse the replacement values. 2278 SmallVector<ast::Expr *> replValues; 2279 if (consumeIf(Token::l_paren)) { 2280 if (consumeIf(Token::r_paren)) { 2281 return emitError( 2282 loc, "expected at least one replacement value, consider using " 2283 "`erase` if no replacement values are desired"); 2284 } 2285 2286 do { 2287 FailureOr<ast::Expr *> replExpr = parseExpr(); 2288 if (failed(replExpr)) 2289 return failure(); 2290 replValues.emplace_back(*replExpr); 2291 } while (consumeIf(Token::comma)); 2292 2293 if (failed(parseToken(Token::r_paren, 2294 "expected `)` after replacement values"))) 2295 return failure(); 2296 } else { 2297 FailureOr<ast::Expr *> replExpr = parseExpr(); 2298 if (failed(replExpr)) 2299 return failure(); 2300 replValues.emplace_back(*replExpr); 2301 } 2302 2303 return createReplaceStmt(loc, *rootOp, replValues); 2304 } 2305 2306 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2307 SMRange loc = curToken.getLoc(); 2308 consumeToken(Token::kw_return); 2309 2310 // Parse the result value. 2311 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2312 if (failed(resultExpr)) 2313 return failure(); 2314 2315 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2316 } 2317 2318 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2319 if (parserContext == ParserContext::Constraint) 2320 return emitError("`rewrite` cannot be used within a Constraint"); 2321 SMRange loc = curToken.getLoc(); 2322 consumeToken(Token::kw_rewrite); 2323 2324 // Parse the root operation. 2325 FailureOr<ast::Expr *> rootOp = parseExpr(); 2326 if (failed(rootOp)) 2327 return failure(); 2328 2329 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2330 return failure(); 2331 2332 if (curToken.isNot(Token::l_brace)) 2333 return emitError("expected `{` to start rewrite body"); 2334 2335 // The rewrite body of this statement is within a rewrite context. 2336 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2337 ParserContext::Rewrite); 2338 2339 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2340 if (failed(rewriteBody)) 2341 return failure(); 2342 2343 // Verify the rewrite body. 2344 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2345 if (isa<ast::ReturnStmt>(stmt)) { 2346 return emitError(stmt->getLoc(), 2347 "`return` statements are only permitted within a " 2348 "`Constraint` or `Rewrite` body"); 2349 } 2350 } 2351 2352 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2353 } 2354 2355 //===----------------------------------------------------------------------===// 2356 // Creation+Analysis 2357 //===----------------------------------------------------------------------===// 2358 2359 //===----------------------------------------------------------------------===// 2360 // Decls 2361 2362 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2363 // Unwrap reference expressions. 2364 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2365 node = init->getDecl(); 2366 return dyn_cast<ast::CallableDecl>(node); 2367 } 2368 2369 FailureOr<ast::PatternDecl *> 2370 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2371 const ParsedPatternMetadata &metadata, 2372 ast::CompoundStmt *body) { 2373 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2374 metadata.hasBoundedRecursion, body); 2375 } 2376 2377 ast::Type Parser::createUserConstraintRewriteResultType( 2378 ArrayRef<ast::VariableDecl *> results) { 2379 // Single result decls use the type of the single result. 2380 if (results.size() == 1) 2381 return results[0]->getType(); 2382 2383 // Multiple results use a tuple type, with the types and names grabbed from 2384 // the result variable decls. 2385 auto resultTypes = llvm::map_range( 2386 results, [&](const auto *result) { return result->getType(); }); 2387 auto resultNames = llvm::map_range( 2388 results, [&](const auto *result) { return result->getName().getName(); }); 2389 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2390 llvm::to_vector(resultNames)); 2391 } 2392 2393 template <typename T> 2394 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2395 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2396 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2397 ast::CompoundStmt *body) { 2398 if (!body->getChildren().empty()) { 2399 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2400 ast::Expr *resultExpr = retStmt->getResultExpr(); 2401 2402 // Process the result of the decl. If no explicit signature results 2403 // were provided, check for return type inference. Otherwise, check that 2404 // the return expression can be converted to the expected type. 2405 if (results.empty()) 2406 resultType = resultExpr->getType(); 2407 else if (failed(convertExpressionTo(resultExpr, resultType))) 2408 return failure(); 2409 else 2410 retStmt->setResultExpr(resultExpr); 2411 } 2412 } 2413 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2414 } 2415 2416 FailureOr<ast::VariableDecl *> 2417 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2418 ArrayRef<ast::ConstraintRef> constraints) { 2419 // The type of the variable, which is expected to be inferred by either a 2420 // constraint or an initializer expression. 2421 ast::Type type; 2422 if (failed(validateVariableConstraints(constraints, type))) 2423 return failure(); 2424 2425 if (initializer) { 2426 // Update the variable type based on the initializer, or try to convert the 2427 // initializer to the existing type. 2428 if (!type) 2429 type = initializer->getType(); 2430 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2431 type = mergedType; 2432 else if (failed(convertExpressionTo(initializer, type))) 2433 return failure(); 2434 2435 // Otherwise, if there is no initializer check that the type has already 2436 // been resolved from the constraint list. 2437 } else if (!type) { 2438 return emitErrorAndNote( 2439 loc, "unable to infer type for variable `" + name + "`", loc, 2440 "the type of a variable must be inferable from the constraint " 2441 "list or the initializer"); 2442 } 2443 2444 // Constraint types cannot be used when defining variables. 2445 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2446 return emitError( 2447 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2448 } 2449 2450 // Try to define a variable with the given name. 2451 FailureOr<ast::VariableDecl *> varDecl = 2452 defineVariableDecl(name, loc, type, initializer, constraints); 2453 if (failed(varDecl)) 2454 return failure(); 2455 2456 return *varDecl; 2457 } 2458 2459 FailureOr<ast::VariableDecl *> 2460 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2461 const ast::ConstraintRef &constraint) { 2462 // Constraint arguments may apply more complex constraints via the arguments. 2463 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2464 ast::Type argType; 2465 if (failed(validateVariableConstraint(constraint, argType, 2466 allowNonCoreConstraints))) 2467 return failure(); 2468 return defineVariableDecl(name, loc, argType, constraint); 2469 } 2470 2471 LogicalResult 2472 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2473 ast::Type &inferredType, 2474 bool allowNonCoreConstraints) { 2475 for (const ast::ConstraintRef &ref : constraints) 2476 if (failed(validateVariableConstraint(ref, inferredType, 2477 allowNonCoreConstraints))) 2478 return failure(); 2479 return success(); 2480 } 2481 2482 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2483 ast::Type &inferredType, 2484 bool allowNonCoreConstraints) { 2485 ast::Type constraintType; 2486 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2487 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2488 if (failed(validateTypeConstraintExpr(typeExpr))) 2489 return failure(); 2490 } 2491 constraintType = ast::AttributeType::get(ctx); 2492 } else if (const auto *cst = 2493 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2494 constraintType = ast::OperationType::get(ctx, cst->getName()); 2495 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2496 constraintType = typeTy; 2497 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2498 constraintType = typeRangeTy; 2499 } else if (const auto *cst = 2500 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2501 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2502 if (failed(validateTypeConstraintExpr(typeExpr))) 2503 return failure(); 2504 } 2505 constraintType = valueTy; 2506 } else if (const auto *cst = 2507 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2508 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2509 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2510 return failure(); 2511 } 2512 constraintType = valueRangeTy; 2513 } else if (const auto *cst = 2514 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2515 if (!allowNonCoreConstraints) { 2516 return emitError(ref.referenceLoc, 2517 "`Rewrite` arguments and results are only permitted to " 2518 "use core constraints, such as `Attr`, `Op`, `Type`, " 2519 "`TypeRange`, `Value`, `ValueRange`"); 2520 } 2521 2522 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2523 if (inputs.size() != 1) { 2524 return emitErrorAndNote(ref.referenceLoc, 2525 "`Constraint`s applied via a variable constraint " 2526 "list must take a single input, but got " + 2527 Twine(inputs.size()), 2528 cst->getLoc(), 2529 "see definition of constraint here"); 2530 } 2531 constraintType = inputs.front()->getType(); 2532 } else { 2533 llvm_unreachable("unknown constraint type"); 2534 } 2535 2536 // Check that the constraint type is compatible with the current inferred 2537 // type. 2538 if (!inferredType) { 2539 inferredType = constraintType; 2540 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2541 inferredType = mergedTy; 2542 } else { 2543 return emitError(ref.referenceLoc, 2544 llvm::formatv("constraint type `{0}` is incompatible " 2545 "with the previously inferred type `{1}`", 2546 constraintType, inferredType)); 2547 } 2548 return success(); 2549 } 2550 2551 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2552 ast::Type typeExprType = typeExpr->getType(); 2553 if (typeExprType != typeTy) { 2554 return emitError(typeExpr->getLoc(), 2555 "expected expression of `Type` in type constraint"); 2556 } 2557 return success(); 2558 } 2559 2560 LogicalResult 2561 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2562 ast::Type typeExprType = typeExpr->getType(); 2563 if (typeExprType != typeRangeTy) { 2564 return emitError(typeExpr->getLoc(), 2565 "expected expression of `TypeRange` in type constraint"); 2566 } 2567 return success(); 2568 } 2569 2570 //===----------------------------------------------------------------------===// 2571 // Exprs 2572 2573 FailureOr<ast::CallExpr *> 2574 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2575 MutableArrayRef<ast::Expr *> arguments) { 2576 ast::Type parentType = parentExpr->getType(); 2577 2578 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2579 if (!callableDecl) { 2580 return emitError(loc, 2581 llvm::formatv("expected a reference to a callable " 2582 "`Constraint` or `Rewrite`, but got: `{0}`", 2583 parentType)); 2584 } 2585 if (parserContext == ParserContext::Rewrite) { 2586 if (isa<ast::UserConstraintDecl>(callableDecl)) 2587 return emitError( 2588 loc, "unable to invoke `Constraint` within a rewrite section"); 2589 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2590 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2591 } 2592 2593 // Verify the arguments of the call. 2594 /// Handle size mismatch. 2595 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2596 if (callArgs.size() != arguments.size()) { 2597 return emitErrorAndNote( 2598 loc, 2599 llvm::formatv("invalid number of arguments for {0} call; expected " 2600 "{1}, but got {2}", 2601 callableDecl->getCallableType(), callArgs.size(), 2602 arguments.size()), 2603 callableDecl->getLoc(), 2604 llvm::formatv("see the definition of {0} here", 2605 callableDecl->getName()->getName())); 2606 } 2607 2608 /// Handle argument type mismatch. 2609 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2610 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2611 callableDecl->getName()->getName()), 2612 callableDecl->getLoc()); 2613 }; 2614 for (auto it : llvm::zip(callArgs, arguments)) { 2615 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2616 attachDiagFn))) 2617 return failure(); 2618 } 2619 2620 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2621 callableDecl->getResultType()); 2622 } 2623 2624 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2625 ast::Decl *decl) { 2626 // Check the type of decl being referenced. 2627 ast::Type declType; 2628 if (isa<ast::ConstraintDecl>(decl)) 2629 declType = ast::ConstraintType::get(ctx); 2630 else if (isa<ast::UserRewriteDecl>(decl)) 2631 declType = ast::RewriteType::get(ctx); 2632 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2633 declType = varDecl->getType(); 2634 else 2635 return emitError(loc, "invalid reference to `" + 2636 decl->getName()->getName() + "`"); 2637 2638 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2639 } 2640 2641 FailureOr<ast::DeclRefExpr *> 2642 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2643 ArrayRef<ast::ConstraintRef> constraints) { 2644 FailureOr<ast::VariableDecl *> decl = 2645 defineVariableDecl(name, loc, type, constraints); 2646 if (failed(decl)) 2647 return failure(); 2648 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2649 } 2650 2651 FailureOr<ast::MemberAccessExpr *> 2652 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2653 SMRange loc) { 2654 // Validate the member name for the given parent expression. 2655 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2656 if (failed(memberType)) 2657 return failure(); 2658 2659 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2660 } 2661 2662 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2663 StringRef name, SMRange loc) { 2664 ast::Type parentType = parentExpr->getType(); 2665 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2666 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2667 return valueRangeTy; 2668 2669 // Verify member access based on the operation type. 2670 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2671 auto results = odsOp->getResults(); 2672 2673 // Handle indexed results. 2674 unsigned index = 0; 2675 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2676 index < results.size()) { 2677 return results[index].isVariadic() ? valueRangeTy : valueTy; 2678 } 2679 2680 // Handle named results. 2681 const auto *it = llvm::find_if(results, [&](const auto &result) { 2682 return result.getName() == name; 2683 }); 2684 if (it != results.end()) 2685 return it->isVariadic() ? valueRangeTy : valueTy; 2686 } 2687 2688 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2689 // Handle indexed results. 2690 unsigned index = 0; 2691 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2692 index < tupleType.size()) { 2693 return tupleType.getElementTypes()[index]; 2694 } 2695 2696 // Handle named results. 2697 auto elementNames = tupleType.getElementNames(); 2698 const auto *it = llvm::find(elementNames, name); 2699 if (it != elementNames.end()) 2700 return tupleType.getElementTypes()[it - elementNames.begin()]; 2701 } 2702 return emitError( 2703 loc, 2704 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2705 name, parentType)); 2706 } 2707 2708 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2709 SMRange loc, const ast::OpNameDecl *name, 2710 MutableArrayRef<ast::Expr *> operands, 2711 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2712 MutableArrayRef<ast::Expr *> results) { 2713 Optional<StringRef> opNameRef = name->getName(); 2714 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2715 2716 // Verify the inputs operands. 2717 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2718 return failure(); 2719 2720 // Verify the attribute list. 2721 for (ast::NamedAttributeDecl *attr : attributes) { 2722 // Check for an attribute type, or a type awaiting resolution. 2723 ast::Type attrType = attr->getValue()->getType(); 2724 if (!attrType.isa<ast::AttributeType>()) { 2725 return emitError( 2726 attr->getValue()->getLoc(), 2727 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2728 } 2729 } 2730 2731 // Verify the result types. 2732 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2733 return failure(); 2734 2735 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2736 attributes); 2737 } 2738 2739 LogicalResult 2740 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2741 const ods::Operation *odsOp, 2742 MutableArrayRef<ast::Expr *> operands) { 2743 return validateOperationOperandsOrResults( 2744 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2745 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2746 valueRangeTy); 2747 } 2748 2749 LogicalResult 2750 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2751 const ods::Operation *odsOp, 2752 MutableArrayRef<ast::Expr *> results) { 2753 return validateOperationOperandsOrResults( 2754 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2755 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2756 } 2757 2758 LogicalResult Parser::validateOperationOperandsOrResults( 2759 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2760 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2761 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2762 ast::Type rangeTy) { 2763 // All operation types accept a single range parameter. 2764 if (values.size() == 1) { 2765 if (failed(convertExpressionTo(values[0], rangeTy))) 2766 return failure(); 2767 return success(); 2768 } 2769 2770 /// If the operation has ODS information, we can more accurately verify the 2771 /// values. 2772 if (odsOpLoc) { 2773 if (odsValues.size() != values.size()) { 2774 return emitErrorAndNote( 2775 loc, 2776 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2777 "{2}, but got {3}", 2778 groupName, *name, odsValues.size(), values.size()), 2779 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2780 } 2781 auto diagFn = [&](ast::Diagnostic &diag) { 2782 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2783 *odsOpLoc); 2784 }; 2785 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2786 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2787 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2788 return failure(); 2789 } 2790 return success(); 2791 } 2792 2793 // Otherwise, accept the value groups as they have been defined and just 2794 // ensure they are one of the expected types. 2795 for (ast::Expr *&valueExpr : values) { 2796 ast::Type valueExprType = valueExpr->getType(); 2797 2798 // Check if this is one of the expected types. 2799 if (valueExprType == rangeTy || valueExprType == singleTy) 2800 continue; 2801 2802 // If the operand is an Operation, allow converting to a Value or 2803 // ValueRange. This situations arises quite often with nested operation 2804 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2805 if (singleTy == valueTy) { 2806 if (valueExprType.isa<ast::OperationType>()) { 2807 valueExpr = convertOpToValue(valueExpr); 2808 continue; 2809 } 2810 } 2811 2812 return emitError( 2813 valueExpr->getLoc(), 2814 llvm::formatv( 2815 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2816 singleTy, rangeTy, valueExprType)); 2817 } 2818 return success(); 2819 } 2820 2821 FailureOr<ast::TupleExpr *> 2822 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2823 ArrayRef<StringRef> elementNames) { 2824 for (const ast::Expr *element : elements) { 2825 ast::Type eleTy = element->getType(); 2826 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2827 return emitError( 2828 element->getLoc(), 2829 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2830 } 2831 } 2832 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2833 } 2834 2835 //===----------------------------------------------------------------------===// 2836 // Stmts 2837 2838 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2839 ast::Expr *rootOp) { 2840 // Check that root is an Operation. 2841 ast::Type rootType = rootOp->getType(); 2842 if (!rootType.isa<ast::OperationType>()) 2843 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2844 2845 return ast::EraseStmt::create(ctx, loc, rootOp); 2846 } 2847 2848 FailureOr<ast::ReplaceStmt *> 2849 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2850 MutableArrayRef<ast::Expr *> replValues) { 2851 // Check that root is an Operation. 2852 ast::Type rootType = rootOp->getType(); 2853 if (!rootType.isa<ast::OperationType>()) { 2854 return emitError( 2855 rootOp->getLoc(), 2856 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2857 } 2858 2859 // If there are multiple replacement values, we implicitly convert any Op 2860 // expressions to the value form. 2861 bool shouldConvertOpToValues = replValues.size() > 1; 2862 for (ast::Expr *&replExpr : replValues) { 2863 ast::Type replType = replExpr->getType(); 2864 2865 // Check that replExpr is an Operation, Value, or ValueRange. 2866 if (replType.isa<ast::OperationType>()) { 2867 if (shouldConvertOpToValues) 2868 replExpr = convertOpToValue(replExpr); 2869 continue; 2870 } 2871 2872 if (replType != valueTy && replType != valueRangeTy) { 2873 return emitError(replExpr->getLoc(), 2874 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2875 "expression, but got `{0}`", 2876 replType)); 2877 } 2878 } 2879 2880 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2881 } 2882 2883 FailureOr<ast::RewriteStmt *> 2884 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2885 ast::CompoundStmt *rewriteBody) { 2886 // Check that root is an Operation. 2887 ast::Type rootType = rootOp->getType(); 2888 if (!rootType.isa<ast::OperationType>()) { 2889 return emitError( 2890 rootOp->getLoc(), 2891 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2892 } 2893 2894 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2895 } 2896 2897 //===----------------------------------------------------------------------===// 2898 // Code Completion 2899 //===----------------------------------------------------------------------===// 2900 2901 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2902 ast::Type parentType = parentExpr->getType(); 2903 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2904 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2905 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2906 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2907 return failure(); 2908 } 2909 2910 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2911 if (opName) 2912 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2913 return failure(); 2914 } 2915 2916 LogicalResult 2917 Parser::codeCompleteConstraintName(ast::Type inferredType, 2918 bool allowNonCoreConstraints, 2919 bool allowInlineTypeConstraints) { 2920 codeCompleteContext->codeCompleteConstraintName( 2921 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2922 curDeclScope); 2923 return failure(); 2924 } 2925 2926 LogicalResult Parser::codeCompleteDialectName() { 2927 codeCompleteContext->codeCompleteDialectName(); 2928 return failure(); 2929 } 2930 2931 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2932 codeCompleteContext->codeCompleteOperationName(dialectName); 2933 return failure(); 2934 } 2935 2936 LogicalResult Parser::codeCompletePatternMetadata() { 2937 codeCompleteContext->codeCompletePatternMetadata(); 2938 return failure(); 2939 } 2940 2941 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 2942 codeCompleteContext->codeCompleteIncludeFilename(curPath); 2943 return failure(); 2944 } 2945 2946 void Parser::codeCompleteCallSignature(ast::Node *parent, 2947 unsigned currentNumArgs) { 2948 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2949 if (!callableDecl) 2950 return; 2951 2952 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2953 } 2954 2955 void Parser::codeCompleteOperationOperandsSignature( 2956 Optional<StringRef> opName, unsigned currentNumOperands) { 2957 codeCompleteContext->codeCompleteOperationOperandsSignature( 2958 opName, currentNumOperands); 2959 } 2960 2961 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2962 unsigned currentNumResults) { 2963 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2964 currentNumResults); 2965 } 2966 2967 //===----------------------------------------------------------------------===// 2968 // Parser 2969 //===----------------------------------------------------------------------===// 2970 2971 FailureOr<ast::Module *> 2972 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2973 CodeCompleteContext *codeCompleteContext) { 2974 Parser parser(ctx, sourceMgr, codeCompleteContext); 2975 return parser.parseModule(); 2976 } 2977