1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/ManagedStatic.h" 29 #include "llvm/Support/SaveAndRestore.h" 30 #include "llvm/Support/ScopedPrinter.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Parser.h" 33 #include <string> 34 35 using namespace mlir; 36 using namespace mlir::pdll; 37 38 //===----------------------------------------------------------------------===// 39 // Parser 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 class Parser { 44 public: 45 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 46 CodeCompleteContext *codeCompleteContext) 47 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 48 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 49 valueRangeTy(ast::ValueRangeType::get(ctx)), 50 typeTy(ast::TypeType::get(ctx)), 51 typeRangeTy(ast::TypeRangeType::get(ctx)), 52 attrTy(ast::AttributeType::get(ctx)), 53 codeCompleteContext(codeCompleteContext) {} 54 55 /// Try to parse a new module. Returns nullptr in the case of failure. 56 FailureOr<ast::Module *> parseModule(); 57 58 private: 59 /// The current context of the parser. It allows for the parser to know a bit 60 /// about the construct it is nested within during parsing. This is used 61 /// specifically to provide additional verification during parsing, e.g. to 62 /// prevent using rewrites within a match context, matcher constraints within 63 /// a rewrite section, etc. 64 enum class ParserContext { 65 /// The parser is in the global context. 66 Global, 67 /// The parser is currently within a Constraint, which disallows all types 68 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 69 Constraint, 70 /// The parser is currently within the matcher portion of a Pattern, which 71 /// is allows a terminal operation rewrite statement but no other rewrite 72 /// transformations. 73 PatternMatch, 74 /// The parser is currently within a Rewrite, which disallows calls to 75 /// constraints, requires operation expressions to have names, etc. 76 Rewrite, 77 }; 78 79 //===--------------------------------------------------------------------===// 80 // Parsing 81 //===--------------------------------------------------------------------===// 82 83 /// Push a new decl scope onto the lexer. 84 ast::DeclScope *pushDeclScope() { 85 ast::DeclScope *newScope = 86 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 87 return (curDeclScope = newScope); 88 } 89 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 90 91 /// Pop the last decl scope from the lexer. 92 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 93 94 /// Parse the body of an AST module. 95 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 96 97 /// Try to convert the given expression to `type`. Returns failure and emits 98 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 99 /// invoked to attach notes to the emitted error diagnostic. On success, 100 /// `expr` is updated to the expression used to convert to `type`. 101 LogicalResult convertExpressionTo( 102 ast::Expr *&expr, ast::Type type, 103 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 104 105 /// Given an operation expression, convert it to a Value or ValueRange 106 /// typed expression. 107 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 108 109 /// Lookup ODS information for the given operation, returns nullptr if no 110 /// information is found. 111 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 112 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 113 } 114 115 //===--------------------------------------------------------------------===// 116 // Directives 117 118 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 119 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 120 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 121 SmallVectorImpl<ast::Decl *> &decls); 122 123 /// Process the records of a parsed tablegen include file. 124 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 125 SmallVectorImpl<ast::Decl *> &decls); 126 127 /// Create a user defined native constraint for a constraint imported from 128 /// ODS. 129 template <typename ConstraintT> 130 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 131 StringRef codeBlock, SMRange loc, 132 ast::Type type); 133 template <typename ConstraintT> 134 ast::Decl * 135 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 136 SMRange loc, ast::Type type); 137 138 //===--------------------------------------------------------------------===// 139 // Decls 140 141 /// This structure contains the set of pattern metadata that may be parsed. 142 struct ParsedPatternMetadata { 143 Optional<uint16_t> benefit; 144 bool hasBoundedRecursion = false; 145 }; 146 147 FailureOr<ast::Decl *> parseTopLevelDecl(); 148 FailureOr<ast::NamedAttributeDecl *> 149 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 150 151 /// Parse an argument variable as part of the signature of a 152 /// UserConstraintDecl or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 154 155 /// Parse a result variable as part of the signature of a UserConstraintDecl 156 /// or UserRewriteDecl. 157 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 158 159 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 160 /// defined in a non-global context. 161 FailureOr<ast::UserConstraintDecl *> 162 parseUserConstraintDecl(bool isInline = false); 163 164 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 165 /// non-global context, such as within a Pattern/Constraint/etc. 166 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 167 168 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 169 /// PDLL constructs. 170 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 171 const ast::Name &name, bool isInline, 172 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 173 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 174 175 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 176 /// defined in a non-global context. 177 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 178 179 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 180 /// non-global context, such as within a Pattern/Rewrite/etc. 181 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 182 183 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 184 /// PDLL constructs. 185 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 186 const ast::Name &name, bool isInline, 187 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 188 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 189 190 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 191 /// effectively the same syntax, and only differ on slight semantics (given 192 /// the different parsing contexts). 193 template <typename T, typename ParseUserPDLLDeclFnT> 194 FailureOr<T *> parseUserConstraintOrRewriteDecl( 195 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 196 StringRef anonymousNamePrefix, bool isInline); 197 198 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 199 /// These decls have effectively the same syntax. 200 template <typename T> 201 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 202 const ast::Name &name, bool isInline, 203 ArrayRef<ast::VariableDecl *> arguments, 204 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 205 206 /// Parse the functional signature (i.e. the arguments and results) of a 207 /// UserConstraintDecl or UserRewriteDecl. 208 LogicalResult parseUserConstraintOrRewriteSignature( 209 SmallVectorImpl<ast::VariableDecl *> &arguments, 210 SmallVectorImpl<ast::VariableDecl *> &results, 211 ast::DeclScope *&argumentScope, ast::Type &resultType); 212 213 /// Validate the return (which if present is specified by bodyIt) of a 214 /// UserConstraintDecl or UserRewriteDecl. 215 LogicalResult validateUserConstraintOrRewriteReturn( 216 StringRef declType, ast::CompoundStmt *body, 217 ArrayRef<ast::Stmt *>::iterator bodyIt, 218 ArrayRef<ast::Stmt *>::iterator bodyE, 219 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 220 221 FailureOr<ast::CompoundStmt *> 222 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 223 bool expectTerminalSemicolon = true); 224 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 225 FailureOr<ast::Decl *> parsePatternDecl(); 226 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 227 228 /// Check to see if a decl has already been defined with the given name, if 229 /// one has emit and error and return failure. Returns success otherwise. 230 LogicalResult checkDefineNamedDecl(const ast::Name &name); 231 232 /// Try to define a variable decl with the given components, returns the 233 /// variable on success. 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ast::Expr *initExpr, 237 ArrayRef<ast::ConstraintRef> constraints); 238 FailureOr<ast::VariableDecl *> 239 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 240 ArrayRef<ast::ConstraintRef> constraints); 241 242 /// Parse the constraint reference list for a variable decl. 243 LogicalResult parseVariableDeclConstraintList( 244 SmallVectorImpl<ast::ConstraintRef> &constraints); 245 246 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 247 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 248 249 /// Try to parse a single reference to a constraint. `typeConstraint` is the 250 /// location of a previously parsed type constraint for the entity that will 251 /// be constrained by the parsed constraint. `existingConstraints` are any 252 /// existing constraints that have already been parsed for the same entity 253 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 254 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 255 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 256 /// constraints) may be used with the variable. 257 FailureOr<ast::ConstraintRef> 258 parseConstraint(Optional<SMRange> &typeConstraint, 259 ArrayRef<ast::ConstraintRef> existingConstraints, 260 bool allowInlineTypeConstraints, 261 bool allowNonCoreConstraints); 262 263 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 264 /// argument or result variable. The constraints for these variables do not 265 /// allow inline type constraints, and only permit a single constraint. 266 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 267 268 //===--------------------------------------------------------------------===// 269 // Exprs 270 271 FailureOr<ast::Expr *> parseExpr(); 272 273 /// Identifier expressions. 274 FailureOr<ast::Expr *> parseAttributeExpr(); 275 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 276 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 277 FailureOr<ast::Expr *> parseIdentifierExpr(); 278 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 279 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 280 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 281 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 282 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 283 FailureOr<ast::Expr *> parseOperationExpr(); 284 FailureOr<ast::Expr *> parseTupleExpr(); 285 FailureOr<ast::Expr *> parseTypeExpr(); 286 FailureOr<ast::Expr *> parseUnderscoreExpr(); 287 288 //===--------------------------------------------------------------------===// 289 // Stmts 290 291 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 292 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 293 FailureOr<ast::EraseStmt *> parseEraseStmt(); 294 FailureOr<ast::LetStmt *> parseLetStmt(); 295 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 296 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 297 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 298 299 //===--------------------------------------------------------------------===// 300 // Creation+Analysis 301 //===--------------------------------------------------------------------===// 302 303 //===--------------------------------------------------------------------===// 304 // Decls 305 306 /// Try to extract a callable from the given AST node. Returns nullptr on 307 /// failure. 308 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 309 310 /// Try to create a pattern decl with the given components, returning the 311 /// Pattern on success. 312 FailureOr<ast::PatternDecl *> 313 createPatternDecl(SMRange loc, const ast::Name *name, 314 const ParsedPatternMetadata &metadata, 315 ast::CompoundStmt *body); 316 317 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 318 /// of results, defined as part of the signature. 319 ast::Type 320 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 321 322 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 323 template <typename T> 324 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 325 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 326 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 327 ast::CompoundStmt *body); 328 329 /// Try to create a variable decl with the given components, returning the 330 /// Variable on success. 331 FailureOr<ast::VariableDecl *> 332 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 333 ArrayRef<ast::ConstraintRef> constraints); 334 335 /// Create a variable for an argument or result defined as part of the 336 /// signature of a UserConstraintDecl/UserRewriteDecl. 337 FailureOr<ast::VariableDecl *> 338 createArgOrResultVariableDecl(StringRef name, SMRange loc, 339 const ast::ConstraintRef &constraint); 340 341 /// Validate the constraints used to constraint a variable decl. 342 /// `inferredType` is the type of the variable inferred by the constraints 343 /// within the list, and is updated to the most refined type as determined by 344 /// the constraints. Returns success if the constraint list is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult 348 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 349 ast::Type &inferredType, 350 bool allowNonCoreConstraints = true); 351 /// Validate a single reference to a constraint. `inferredType` contains the 352 /// currently inferred variabled type and is refined within the type defined 353 /// by the constraint. Returns success if the constraint is valid, failure 354 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 355 /// defined constraints) may be used with the variable. 356 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 357 ast::Type &inferredType, 358 bool allowNonCoreConstraints = true); 359 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 360 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 361 362 //===--------------------------------------------------------------------===// 363 // Exprs 364 365 FailureOr<ast::CallExpr *> 366 createCallExpr(SMRange loc, ast::Expr *parentExpr, 367 MutableArrayRef<ast::Expr *> arguments); 368 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 369 FailureOr<ast::DeclRefExpr *> 370 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 371 ArrayRef<ast::ConstraintRef> constraints); 372 FailureOr<ast::MemberAccessExpr *> 373 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 374 375 /// Validate the member access `name` into the given parent expression. On 376 /// success, this also returns the type of the member accessed. 377 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 378 StringRef name, SMRange loc); 379 FailureOr<ast::OperationExpr *> 380 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 381 MutableArrayRef<ast::Expr *> operands, 382 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 383 MutableArrayRef<ast::Expr *> results); 384 LogicalResult 385 validateOperationOperands(SMRange loc, Optional<StringRef> name, 386 const ods::Operation *odsOp, 387 MutableArrayRef<ast::Expr *> operands); 388 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 389 const ods::Operation *odsOp, 390 MutableArrayRef<ast::Expr *> results); 391 LogicalResult validateOperationOperandsOrResults( 392 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 393 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 394 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 395 ast::Type rangeTy); 396 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 397 ArrayRef<ast::Expr *> elements, 398 ArrayRef<StringRef> elementNames); 399 400 //===--------------------------------------------------------------------===// 401 // Stmts 402 403 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 404 FailureOr<ast::ReplaceStmt *> 405 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 406 MutableArrayRef<ast::Expr *> replValues); 407 FailureOr<ast::RewriteStmt *> 408 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 409 ast::CompoundStmt *rewriteBody); 410 411 //===--------------------------------------------------------------------===// 412 // Code Completion 413 //===--------------------------------------------------------------------===// 414 415 /// The set of various code completion methods. Every completion method 416 /// returns `failure` to stop the parsing process after providing completion 417 /// results. 418 419 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 420 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 421 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 422 bool allowNonCoreConstraints, 423 bool allowInlineTypeConstraints); 424 LogicalResult codeCompleteDialectName(); 425 LogicalResult codeCompleteOperationName(StringRef dialectName); 426 LogicalResult codeCompletePatternMetadata(); 427 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 428 429 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 430 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 431 unsigned currentNumOperands); 432 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 433 unsigned currentNumResults); 434 435 //===--------------------------------------------------------------------===// 436 // Lexer Utilities 437 //===--------------------------------------------------------------------===// 438 439 /// If the current token has the specified kind, consume it and return true. 440 /// If not, return false. 441 bool consumeIf(Token::Kind kind) { 442 if (curToken.isNot(kind)) 443 return false; 444 consumeToken(kind); 445 return true; 446 } 447 448 /// Advance the current lexer onto the next token. 449 void consumeToken() { 450 assert(curToken.isNot(Token::eof, Token::error) && 451 "shouldn't advance past EOF or errors"); 452 curToken = lexer.lexToken(); 453 } 454 455 /// Advance the current lexer onto the next token, asserting what the expected 456 /// current token is. This is preferred to the above method because it leads 457 /// to more self-documenting code with better checking. 458 void consumeToken(Token::Kind kind) { 459 assert(curToken.is(kind) && "consumed an unexpected token"); 460 consumeToken(); 461 } 462 463 /// Reset the lexer to the location at the given position. 464 void resetToken(SMRange tokLoc) { 465 lexer.resetPointer(tokLoc.Start.getPointer()); 466 curToken = lexer.lexToken(); 467 } 468 469 /// Consume the specified token if present and return success. On failure, 470 /// output a diagnostic and return failure. 471 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 472 if (curToken.getKind() != kind) 473 return emitError(curToken.getLoc(), msg); 474 consumeToken(); 475 return success(); 476 } 477 LogicalResult emitError(SMRange loc, const Twine &msg) { 478 lexer.emitError(loc, msg); 479 return failure(); 480 } 481 LogicalResult emitError(const Twine &msg) { 482 return emitError(curToken.getLoc(), msg); 483 } 484 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 485 const Twine ¬e) { 486 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 487 return failure(); 488 } 489 490 //===--------------------------------------------------------------------===// 491 // Fields 492 //===--------------------------------------------------------------------===// 493 494 /// The owning AST context. 495 ast::Context &ctx; 496 497 /// The lexer of this parser. 498 Lexer lexer; 499 500 /// The current token within the lexer. 501 Token curToken; 502 503 /// The most recently defined decl scope. 504 ast::DeclScope *curDeclScope = nullptr; 505 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 506 507 /// The current context of the parser. 508 ParserContext parserContext = ParserContext::Global; 509 510 /// Cached types to simplify verification and expression creation. 511 ast::Type valueTy, valueRangeTy; 512 ast::Type typeTy, typeRangeTy; 513 ast::Type attrTy; 514 515 /// A counter used when naming anonymous constraints and rewrites. 516 unsigned anonymousDeclNameCounter = 0; 517 518 /// The optional code completion context. 519 CodeCompleteContext *codeCompleteContext; 520 }; 521 } // namespace 522 523 FailureOr<ast::Module *> Parser::parseModule() { 524 SMLoc moduleLoc = curToken.getStartLoc(); 525 pushDeclScope(); 526 527 // Parse the top-level decls of the module. 528 SmallVector<ast::Decl *> decls; 529 if (failed(parseModuleBody(decls))) 530 return popDeclScope(), failure(); 531 532 popDeclScope(); 533 return ast::Module::create(ctx, moduleLoc, decls); 534 } 535 536 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 537 while (curToken.isNot(Token::eof)) { 538 if (curToken.is(Token::directive)) { 539 if (failed(parseDirective(decls))) 540 return failure(); 541 continue; 542 } 543 544 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 545 if (failed(decl)) 546 return failure(); 547 decls.push_back(*decl); 548 } 549 return success(); 550 } 551 552 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 553 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 554 valueRangeTy); 555 } 556 557 LogicalResult Parser::convertExpressionTo( 558 ast::Expr *&expr, ast::Type type, 559 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 560 ast::Type exprType = expr->getType(); 561 if (exprType == type) 562 return success(); 563 564 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 565 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 566 expr->getLoc(), llvm::formatv("unable to convert expression of type " 567 "`{0}` to the expected type of " 568 "`{1}`", 569 exprType, type)); 570 if (noteAttachFn) 571 noteAttachFn(*diag); 572 return diag; 573 }; 574 575 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 576 // Two operation types are compatible if they have the same name, or if the 577 // expected type is more general. 578 if (auto opType = type.dyn_cast<ast::OperationType>()) { 579 if (opType.getName()) 580 return emitConvertError(); 581 return success(); 582 } 583 584 // An operation can always convert to a ValueRange. 585 if (type == valueRangeTy) { 586 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 587 valueRangeTy); 588 return success(); 589 } 590 591 // Allow conversion to a single value by constraining the result range. 592 if (type == valueTy) { 593 // If the operation is registered, we can verify if it can ever have a 594 // single result. 595 Optional<StringRef> opName = exprOpType.getName(); 596 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 597 if (odsOp->getResults().empty()) { 598 return emitConvertError()->attachNote( 599 llvm::formatv("see the definition of `{0}`, which was defined " 600 "with zero results", 601 odsOp->getName()), 602 odsOp->getLoc()); 603 } 604 605 unsigned numSingleResults = llvm::count_if( 606 odsOp->getResults(), [](const ods::OperandOrResult &result) { 607 return result.getVariableLengthKind() == 608 ods::VariableLengthKind::Single; 609 }); 610 if (numSingleResults > 1) { 611 return emitConvertError()->attachNote( 612 llvm::formatv("see the definition of `{0}`, which was defined " 613 "with at least {1} results", 614 odsOp->getName(), numSingleResults), 615 odsOp->getLoc()); 616 } 617 } 618 619 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 620 valueTy); 621 return success(); 622 } 623 return emitConvertError(); 624 } 625 626 // FIXME: Decide how to allow/support converting a single result to multiple, 627 // and multiple to a single result. For now, we just allow Single->Range, 628 // but this isn't something really supported in the PDL dialect. We should 629 // figure out some way to support both. 630 if ((exprType == valueTy || exprType == valueRangeTy) && 631 (type == valueTy || type == valueRangeTy)) 632 return success(); 633 if ((exprType == typeTy || exprType == typeRangeTy) && 634 (type == typeTy || type == typeRangeTy)) 635 return success(); 636 637 // Handle tuple types. 638 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 639 auto tupleType = type.dyn_cast<ast::TupleType>(); 640 if (!tupleType || tupleType.size() != exprTupleType.size()) 641 return emitConvertError(); 642 643 // Build a new tuple expression using each of the elements of the current 644 // tuple. 645 SmallVector<ast::Expr *> newExprs; 646 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 647 newExprs.push_back(ast::MemberAccessExpr::create( 648 ctx, expr->getLoc(), expr, llvm::to_string(i), 649 exprTupleType.getElementTypes()[i])); 650 651 auto diagFn = [&](ast::Diagnostic &diag) { 652 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 653 i, exprTupleType)); 654 if (noteAttachFn) 655 noteAttachFn(diag); 656 }; 657 if (failed(convertExpressionTo(newExprs.back(), 658 tupleType.getElementTypes()[i], diagFn))) 659 return failure(); 660 } 661 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 662 tupleType.getElementNames()); 663 return success(); 664 } 665 666 return emitConvertError(); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Directives 671 672 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 673 StringRef directive = curToken.getSpelling(); 674 if (directive == "#include") 675 return parseInclude(decls); 676 677 return emitError("unknown directive `" + directive + "`"); 678 } 679 680 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 681 SMRange loc = curToken.getLoc(); 682 consumeToken(Token::directive); 683 684 // Handle code completion of the include file path. 685 if (curToken.is(Token::code_complete_string)) 686 return codeCompleteIncludeFilename(curToken.getStringValue()); 687 688 // Parse the file being included. 689 if (!curToken.isString()) 690 return emitError(loc, 691 "expected string file name after `include` directive"); 692 SMRange fileLoc = curToken.getLoc(); 693 std::string filenameStr = curToken.getStringValue(); 694 StringRef filename = filenameStr; 695 consumeToken(); 696 697 // Check the type of include. If ending with `.pdll`, this is another pdl file 698 // to be parsed along with the current module. 699 if (filename.endswith(".pdll")) { 700 if (failed(lexer.pushInclude(filename, fileLoc))) 701 return emitError(fileLoc, 702 "unable to open include file `" + filename + "`"); 703 704 // If we added the include successfully, parse it into the current module. 705 // Make sure to update to the next token after we finish parsing the nested 706 // file. 707 curToken = lexer.lexToken(); 708 LogicalResult result = parseModuleBody(decls); 709 curToken = lexer.lexToken(); 710 return result; 711 } 712 713 // Otherwise, this must be a `.td` include. 714 if (filename.endswith(".td")) 715 return parseTdInclude(filename, fileLoc, decls); 716 717 return emitError(fileLoc, 718 "expected include filename to end with `.pdll` or `.td`"); 719 } 720 721 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 722 SmallVectorImpl<ast::Decl *> &decls) { 723 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 724 725 // This class provides a context argument for the llvm::SourceMgr diagnostic 726 // handler. 727 struct DiagHandlerContext { 728 Parser &parser; 729 StringRef filename; 730 llvm::SMRange loc; 731 } handlerContext{*this, filename, fileLoc}; 732 733 // Set the diagnostic handler for the tablegen source manager. 734 llvm::SrcMgr.setDiagHandler( 735 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 736 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 737 (void)ctx->parser.emitError( 738 ctx->loc, 739 llvm::formatv("error while processing include file `{0}`: {1}", 740 ctx->filename, diag.getMessage())); 741 }, 742 &handlerContext); 743 744 // Use the source manager to open the file, but don't yet add it. 745 std::string includedFile; 746 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 747 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 748 if (!includeBuffer) 749 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 750 751 auto processFn = [&](llvm::RecordKeeper &records) { 752 processTdIncludeRecords(records, decls); 753 754 // After we are done processing, move all of the tablegen source buffers to 755 // the main parser source mgr. This allows for directly using source 756 // locations from the .td files without needing to remap them. 757 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr, fileLoc.End); 758 return false; 759 }; 760 if (llvm::TableGenParseFile(std::move(*includeBuffer), 761 parserSrcMgr.getIncludeDirs(), processFn)) 762 return failure(); 763 764 return success(); 765 } 766 767 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 768 SmallVectorImpl<ast::Decl *> &decls) { 769 // Return the length kind of the given value. 770 auto getLengthKind = [](const auto &value) { 771 if (value.isOptional()) 772 return ods::VariableLengthKind::Optional; 773 return value.isVariadic() ? ods::VariableLengthKind::Variadic 774 : ods::VariableLengthKind::Single; 775 }; 776 777 // Insert a type constraint into the ODS context. 778 ods::Context &odsContext = ctx.getODSContext(); 779 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 780 -> const ods::TypeConstraint & { 781 return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(), 782 cst.constraint.getSummary(), 783 cst.constraint.getCPPClassName()); 784 }; 785 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 786 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 787 }; 788 789 // Process the parsed tablegen records to build ODS information. 790 /// Operations. 791 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 792 tblgen::Operator op(def); 793 794 bool inserted = false; 795 ods::Operation *odsOp = nullptr; 796 std::tie(odsOp, inserted) = 797 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 798 op.getDescription(), op.getLoc().front()); 799 800 // Ignore operations that have already been added. 801 if (!inserted) 802 continue; 803 804 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 805 odsOp->appendAttribute( 806 attr.name, attr.attr.isOptional(), 807 odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(), 808 attr.attr.getSummary(), 809 attr.attr.getStorageType())); 810 } 811 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 812 odsOp->appendOperand(operand.name, getLengthKind(operand), 813 addTypeConstraint(operand)); 814 } 815 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 816 odsOp->appendResult(result.name, getLengthKind(result), 817 addTypeConstraint(result)); 818 } 819 } 820 /// Attr constraints. 821 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 822 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 823 decls.push_back( 824 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 825 tblgen::AttrConstraint(def), 826 convertLocToRange(def->getLoc().front()), attrTy)); 827 } 828 } 829 /// Type constraints. 830 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 831 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 832 decls.push_back( 833 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 834 tblgen::TypeConstraint(def), 835 convertLocToRange(def->getLoc().front()), typeTy)); 836 } 837 } 838 /// Interfaces. 839 ast::Type opTy = ast::OperationType::get(ctx); 840 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 841 StringRef name = def->getName(); 842 if (def->isAnonymous() || curDeclScope->lookup(name) || 843 def->isSubClassOf("DeclareInterfaceMethods")) 844 continue; 845 SMRange loc = convertLocToRange(def->getLoc().front()); 846 847 StringRef className = def->getValueAsString("cppClassName"); 848 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 849 std::string codeBlock = 850 llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) 851 .str(); 852 853 if (def->isSubClassOf("OpInterface")) { 854 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 855 name, codeBlock, loc, opTy)); 856 } else if (def->isSubClassOf("AttrInterface")) { 857 decls.push_back( 858 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 859 name, codeBlock, loc, attrTy)); 860 } else if (def->isSubClassOf("TypeInterface")) { 861 decls.push_back( 862 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 863 name, codeBlock, loc, typeTy)); 864 } 865 } 866 } 867 868 template <typename ConstraintT> 869 ast::Decl * 870 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 871 SMRange loc, ast::Type type) { 872 // Build the single input parameter. 873 ast::DeclScope *argScope = pushDeclScope(); 874 auto *paramVar = ast::VariableDecl::create( 875 ctx, ast::Name::create(ctx, "self", loc), type, 876 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 877 argScope->add(paramVar); 878 popDeclScope(); 879 880 // Build the native constraint. 881 auto *constraintDecl = ast::UserConstraintDecl::createNative( 882 ctx, ast::Name::create(ctx, name, loc), paramVar, 883 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 884 curDeclScope->add(constraintDecl); 885 return constraintDecl; 886 } 887 888 template <typename ConstraintT> 889 ast::Decl * 890 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 891 SMRange loc, ast::Type type) { 892 // Format the condition template. 893 tblgen::FmtContext fmtContext; 894 fmtContext.withSelf("self"); 895 std::string codeBlock = 896 tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); 897 898 return createODSNativePDLLConstraintDecl<ConstraintT>( 899 constraint.getUniqueDefName(), codeBlock, loc, type); 900 } 901 902 //===----------------------------------------------------------------------===// 903 // Decls 904 905 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 906 FailureOr<ast::Decl *> decl; 907 switch (curToken.getKind()) { 908 case Token::kw_Constraint: 909 decl = parseUserConstraintDecl(); 910 break; 911 case Token::kw_Pattern: 912 decl = parsePatternDecl(); 913 break; 914 case Token::kw_Rewrite: 915 decl = parseUserRewriteDecl(); 916 break; 917 default: 918 return emitError("expected top-level declaration, such as a `Pattern`"); 919 } 920 if (failed(decl)) 921 return failure(); 922 923 // If the decl has a name, add it to the current scope. 924 if (const ast::Name *name = (*decl)->getName()) { 925 if (failed(checkDefineNamedDecl(*name))) 926 return failure(); 927 curDeclScope->add(*decl); 928 } 929 return decl; 930 } 931 932 FailureOr<ast::NamedAttributeDecl *> 933 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 934 // Check for name code completion. 935 if (curToken.is(Token::code_complete)) 936 return codeCompleteAttributeName(parentOpName); 937 938 std::string attrNameStr; 939 if (curToken.isString()) 940 attrNameStr = curToken.getStringValue(); 941 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 942 attrNameStr = curToken.getSpelling().str(); 943 else 944 return emitError("expected identifier or string attribute name"); 945 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 946 consumeToken(); 947 948 // Check for a value of the attribute. 949 ast::Expr *attrValue = nullptr; 950 if (consumeIf(Token::equal)) { 951 FailureOr<ast::Expr *> attrExpr = parseExpr(); 952 if (failed(attrExpr)) 953 return failure(); 954 attrValue = *attrExpr; 955 } else { 956 // If there isn't a concrete value, create an expression representing a 957 // UnitAttr. 958 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 959 } 960 961 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 962 } 963 964 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 965 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 966 bool expectTerminalSemicolon) { 967 consumeToken(Token::equal_arrow); 968 969 // Parse the single statement of the lambda body. 970 SMLoc bodyStartLoc = curToken.getStartLoc(); 971 pushDeclScope(); 972 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 973 bool failedToParse = 974 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 975 popDeclScope(); 976 if (failedToParse) 977 return failure(); 978 979 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 980 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 981 } 982 983 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 984 // Ensure that the argument is named. 985 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 986 return emitError("expected identifier argument name"); 987 988 // Parse the argument similarly to a normal variable. 989 StringRef name = curToken.getSpelling(); 990 SMRange nameLoc = curToken.getLoc(); 991 consumeToken(); 992 993 if (failed( 994 parseToken(Token::colon, "expected `:` before argument constraint"))) 995 return failure(); 996 997 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 998 if (failed(cst)) 999 return failure(); 1000 1001 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1002 } 1003 1004 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1005 // Check to see if this result is named. 1006 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1007 // Check to see if this name actually refers to a Constraint. 1008 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1009 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1010 // If yes, and this is a Rewrite, give a nice error message as non-Core 1011 // constraints are not supported on Rewrite results. 1012 if (parserContext == ParserContext::Rewrite) { 1013 return emitError( 1014 "`Rewrite` results are only permitted to use core constraints, " 1015 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1016 } 1017 1018 // Otherwise, parse this as an unnamed result variable. 1019 } else { 1020 // If it wasn't a constraint, parse the result similarly to a variable. If 1021 // there is already an existing decl, we will emit an error when defining 1022 // this variable later. 1023 StringRef name = curToken.getSpelling(); 1024 SMRange nameLoc = curToken.getLoc(); 1025 consumeToken(); 1026 1027 if (failed(parseToken(Token::colon, 1028 "expected `:` before result constraint"))) 1029 return failure(); 1030 1031 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1032 if (failed(cst)) 1033 return failure(); 1034 1035 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1036 } 1037 } 1038 1039 // If it isn't named, we parse the constraint directly and create an unnamed 1040 // result variable. 1041 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1042 if (failed(cst)) 1043 return failure(); 1044 1045 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1046 } 1047 1048 FailureOr<ast::UserConstraintDecl *> 1049 Parser::parseUserConstraintDecl(bool isInline) { 1050 // Constraints and rewrites have very similar formats, dispatch to a shared 1051 // interface for parsing. 1052 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1053 [&](auto &&...args) { 1054 return this->parseUserPDLLConstraintDecl(args...); 1055 }, 1056 ParserContext::Constraint, "constraint", isInline); 1057 } 1058 1059 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1060 FailureOr<ast::UserConstraintDecl *> decl = 1061 parseUserConstraintDecl(/*isInline=*/true); 1062 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1063 return failure(); 1064 1065 curDeclScope->add(*decl); 1066 return decl; 1067 } 1068 1069 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1070 const ast::Name &name, bool isInline, 1071 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1072 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1073 // Push the argument scope back onto the list, so that the body can 1074 // reference arguments. 1075 pushDeclScope(argumentScope); 1076 1077 // Parse the body of the constraint. The body is either defined as a compound 1078 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1079 ast::CompoundStmt *body; 1080 if (curToken.is(Token::equal_arrow)) { 1081 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1082 [&](ast::Stmt *&stmt) -> LogicalResult { 1083 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1084 if (!stmtExpr) { 1085 return emitError(stmt->getLoc(), 1086 "expected `Constraint` lambda body to contain a " 1087 "single expression"); 1088 } 1089 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1090 return success(); 1091 }, 1092 /*expectTerminalSemicolon=*/!isInline); 1093 if (failed(bodyResult)) 1094 return failure(); 1095 body = *bodyResult; 1096 } else { 1097 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1098 if (failed(bodyResult)) 1099 return failure(); 1100 body = *bodyResult; 1101 1102 // Verify the structure of the body. 1103 auto bodyIt = body->begin(), bodyE = body->end(); 1104 for (; bodyIt != bodyE; ++bodyIt) 1105 if (isa<ast::ReturnStmt>(*bodyIt)) 1106 break; 1107 if (failed(validateUserConstraintOrRewriteReturn( 1108 "Constraint", body, bodyIt, bodyE, results, resultType))) 1109 return failure(); 1110 } 1111 popDeclScope(); 1112 1113 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1114 name, arguments, results, resultType, body); 1115 } 1116 1117 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1118 // Constraints and rewrites have very similar formats, dispatch to a shared 1119 // interface for parsing. 1120 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1121 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1122 ParserContext::Rewrite, "rewrite", isInline); 1123 } 1124 1125 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1126 FailureOr<ast::UserRewriteDecl *> decl = 1127 parseUserRewriteDecl(/*isInline=*/true); 1128 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1129 return failure(); 1130 1131 curDeclScope->add(*decl); 1132 return decl; 1133 } 1134 1135 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1136 const ast::Name &name, bool isInline, 1137 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1138 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1139 // Push the argument scope back onto the list, so that the body can 1140 // reference arguments. 1141 curDeclScope = argumentScope; 1142 ast::CompoundStmt *body; 1143 if (curToken.is(Token::equal_arrow)) { 1144 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1145 [&](ast::Stmt *&statement) -> LogicalResult { 1146 if (isa<ast::OpRewriteStmt>(statement)) 1147 return success(); 1148 1149 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1150 if (!statementExpr) { 1151 return emitError( 1152 statement->getLoc(), 1153 "expected `Rewrite` lambda body to contain a single expression " 1154 "or an operation rewrite statement; such as `erase`, " 1155 "`replace`, or `rewrite`"); 1156 } 1157 statement = 1158 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1159 return success(); 1160 }, 1161 /*expectTerminalSemicolon=*/!isInline); 1162 if (failed(bodyResult)) 1163 return failure(); 1164 body = *bodyResult; 1165 } else { 1166 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1167 if (failed(bodyResult)) 1168 return failure(); 1169 body = *bodyResult; 1170 } 1171 popDeclScope(); 1172 1173 // Verify the structure of the body. 1174 auto bodyIt = body->begin(), bodyE = body->end(); 1175 for (; bodyIt != bodyE; ++bodyIt) 1176 if (isa<ast::ReturnStmt>(*bodyIt)) 1177 break; 1178 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1179 bodyE, results, resultType))) 1180 return failure(); 1181 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1182 name, arguments, results, resultType, body); 1183 } 1184 1185 template <typename T, typename ParseUserPDLLDeclFnT> 1186 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1187 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1188 StringRef anonymousNamePrefix, bool isInline) { 1189 SMRange loc = curToken.getLoc(); 1190 consumeToken(); 1191 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1192 1193 // Parse the name of the decl. 1194 const ast::Name *name = nullptr; 1195 if (curToken.isNot(Token::identifier)) { 1196 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1197 // in C++, so being unnamed is fine. 1198 if (!isInline) 1199 return emitError("expected identifier name"); 1200 1201 // Create a unique anonymous name to use, as the name for this decl is not 1202 // important. 1203 std::string anonName = 1204 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1205 anonymousDeclNameCounter++) 1206 .str(); 1207 name = &ast::Name::create(ctx, anonName, loc); 1208 } else { 1209 // If a name was provided, we can use it directly. 1210 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1211 consumeToken(Token::identifier); 1212 } 1213 1214 // Parse the functional signature of the decl. 1215 SmallVector<ast::VariableDecl *> arguments, results; 1216 ast::DeclScope *argumentScope; 1217 ast::Type resultType; 1218 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1219 argumentScope, resultType))) 1220 return failure(); 1221 1222 // Check to see which type of constraint this is. If the constraint contains a 1223 // compound body, this is a PDLL decl. 1224 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1225 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1226 resultType); 1227 1228 // Otherwise, this is a native decl. 1229 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1230 results, resultType); 1231 } 1232 1233 template <typename T> 1234 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1235 const ast::Name &name, bool isInline, 1236 ArrayRef<ast::VariableDecl *> arguments, 1237 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1238 // If followed by a string, the native code body has also been specified. 1239 std::string codeStrStorage; 1240 Optional<StringRef> optCodeStr; 1241 if (curToken.isString()) { 1242 codeStrStorage = curToken.getStringValue(); 1243 optCodeStr = codeStrStorage; 1244 consumeToken(); 1245 } else if (isInline) { 1246 return emitError(name.getLoc(), 1247 "external declarations must be declared in global scope"); 1248 } else if (curToken.is(Token::error)) { 1249 return failure(); 1250 } 1251 if (failed(parseToken(Token::semicolon, 1252 "expected `;` after native declaration"))) 1253 return failure(); 1254 // TODO: PDL should be able to support constraint results in certain 1255 // situations, we should revise this. 1256 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1257 return emitError( 1258 "native Constraints currently do not support returning results"); 1259 } 1260 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1261 } 1262 1263 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1264 SmallVectorImpl<ast::VariableDecl *> &arguments, 1265 SmallVectorImpl<ast::VariableDecl *> &results, 1266 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1267 // Parse the argument list of the decl. 1268 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1269 return failure(); 1270 1271 argumentScope = pushDeclScope(); 1272 if (curToken.isNot(Token::r_paren)) { 1273 do { 1274 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1275 if (failed(argument)) 1276 return failure(); 1277 arguments.emplace_back(*argument); 1278 } while (consumeIf(Token::comma)); 1279 } 1280 popDeclScope(); 1281 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1282 return failure(); 1283 1284 // Parse the results of the decl. 1285 pushDeclScope(); 1286 if (consumeIf(Token::arrow)) { 1287 auto parseResultFn = [&]() -> LogicalResult { 1288 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1289 if (failed(result)) 1290 return failure(); 1291 results.emplace_back(*result); 1292 return success(); 1293 }; 1294 1295 // Check for a list of results. 1296 if (consumeIf(Token::l_paren)) { 1297 do { 1298 if (failed(parseResultFn())) 1299 return failure(); 1300 } while (consumeIf(Token::comma)); 1301 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1302 return failure(); 1303 1304 // Otherwise, there is only one result. 1305 } else if (failed(parseResultFn())) { 1306 return failure(); 1307 } 1308 } 1309 popDeclScope(); 1310 1311 // Compute the result type of the decl. 1312 resultType = createUserConstraintRewriteResultType(results); 1313 1314 // Verify that results are only named if there are more than one. 1315 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1316 return emitError( 1317 results.front()->getLoc(), 1318 "cannot create a single-element tuple with an element label"); 1319 } 1320 return success(); 1321 } 1322 1323 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1324 StringRef declType, ast::CompoundStmt *body, 1325 ArrayRef<ast::Stmt *>::iterator bodyIt, 1326 ArrayRef<ast::Stmt *>::iterator bodyE, 1327 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1328 // Handle if a `return` was provided. 1329 if (bodyIt != bodyE) { 1330 // Emit an error if we have trailing statements after the return. 1331 if (std::next(bodyIt) != bodyE) { 1332 return emitError( 1333 (*std::next(bodyIt))->getLoc(), 1334 llvm::formatv("`return` terminated the `{0}` body, but found " 1335 "trailing statements afterwards", 1336 declType)); 1337 } 1338 1339 // Otherwise if a return wasn't provided, check that no results are 1340 // expected. 1341 } else if (!results.empty()) { 1342 return emitError( 1343 {body->getLoc().End, body->getLoc().End}, 1344 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1345 declType, resultType)); 1346 } 1347 return success(); 1348 } 1349 1350 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1351 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1352 if (isa<ast::OpRewriteStmt>(statement)) 1353 return success(); 1354 return emitError( 1355 statement->getLoc(), 1356 "expected Pattern lambda body to contain a single operation " 1357 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1358 }); 1359 } 1360 1361 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1362 SMRange loc = curToken.getLoc(); 1363 consumeToken(Token::kw_Pattern); 1364 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1365 ParserContext::PatternMatch); 1366 1367 // Check for an optional identifier for the pattern name. 1368 const ast::Name *name = nullptr; 1369 if (curToken.is(Token::identifier)) { 1370 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1371 consumeToken(Token::identifier); 1372 } 1373 1374 // Parse any pattern metadata. 1375 ParsedPatternMetadata metadata; 1376 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1377 return failure(); 1378 1379 // Parse the pattern body. 1380 ast::CompoundStmt *body; 1381 1382 // Handle a lambda body. 1383 if (curToken.is(Token::equal_arrow)) { 1384 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1385 if (failed(bodyResult)) 1386 return failure(); 1387 body = *bodyResult; 1388 } else { 1389 if (curToken.isNot(Token::l_brace)) 1390 return emitError("expected `{` or `=>` to start pattern body"); 1391 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1392 if (failed(bodyResult)) 1393 return failure(); 1394 body = *bodyResult; 1395 1396 // Verify the body of the pattern. 1397 auto bodyIt = body->begin(), bodyE = body->end(); 1398 for (; bodyIt != bodyE; ++bodyIt) { 1399 if (isa<ast::ReturnStmt>(*bodyIt)) { 1400 return emitError((*bodyIt)->getLoc(), 1401 "`return` statements are only permitted within a " 1402 "`Constraint` or `Rewrite` body"); 1403 } 1404 // Break when we've found the rewrite statement. 1405 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1406 break; 1407 } 1408 if (bodyIt == bodyE) { 1409 return emitError(loc, 1410 "expected Pattern body to terminate with an operation " 1411 "rewrite statement, such as `erase`"); 1412 } 1413 if (std::next(bodyIt) != bodyE) { 1414 return emitError((*std::next(bodyIt))->getLoc(), 1415 "Pattern body was terminated by an operation " 1416 "rewrite statement, but found trailing statements"); 1417 } 1418 } 1419 1420 return createPatternDecl(loc, name, metadata, body); 1421 } 1422 1423 LogicalResult 1424 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1425 Optional<SMRange> benefitLoc; 1426 Optional<SMRange> hasBoundedRecursionLoc; 1427 1428 do { 1429 // Handle metadata code completion. 1430 if (curToken.is(Token::code_complete)) 1431 return codeCompletePatternMetadata(); 1432 1433 if (curToken.isNot(Token::identifier)) 1434 return emitError("expected pattern metadata identifier"); 1435 StringRef metadataStr = curToken.getSpelling(); 1436 SMRange metadataLoc = curToken.getLoc(); 1437 consumeToken(Token::identifier); 1438 1439 // Parse the benefit metadata: benefit(<integer-value>) 1440 if (metadataStr == "benefit") { 1441 if (benefitLoc) { 1442 return emitErrorAndNote(metadataLoc, 1443 "pattern benefit has already been specified", 1444 *benefitLoc, "see previous definition here"); 1445 } 1446 if (failed(parseToken(Token::l_paren, 1447 "expected `(` before pattern benefit"))) 1448 return failure(); 1449 1450 uint16_t benefitValue = 0; 1451 if (curToken.isNot(Token::integer)) 1452 return emitError("expected integral pattern benefit"); 1453 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1454 return emitError( 1455 "expected pattern benefit to fit within a 16-bit integer"); 1456 consumeToken(Token::integer); 1457 1458 metadata.benefit = benefitValue; 1459 benefitLoc = metadataLoc; 1460 1461 if (failed( 1462 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1463 return failure(); 1464 continue; 1465 } 1466 1467 // Parse the bounded recursion metadata: recursion 1468 if (metadataStr == "recursion") { 1469 if (hasBoundedRecursionLoc) { 1470 return emitErrorAndNote( 1471 metadataLoc, 1472 "pattern recursion metadata has already been specified", 1473 *hasBoundedRecursionLoc, "see previous definition here"); 1474 } 1475 metadata.hasBoundedRecursion = true; 1476 hasBoundedRecursionLoc = metadataLoc; 1477 continue; 1478 } 1479 1480 return emitError(metadataLoc, "unknown pattern metadata"); 1481 } while (consumeIf(Token::comma)); 1482 1483 return success(); 1484 } 1485 1486 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1487 consumeToken(Token::less); 1488 1489 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1490 if (failed(typeExpr) || 1491 failed(parseToken(Token::greater, 1492 "expected `>` after variable type constraint"))) 1493 return failure(); 1494 return typeExpr; 1495 } 1496 1497 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1498 assert(curDeclScope && "defining decl outside of a decl scope"); 1499 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1500 return emitErrorAndNote( 1501 name.getLoc(), "`" + name.getName() + "` has already been defined", 1502 lastDecl->getName()->getLoc(), "see previous definition here"); 1503 } 1504 return success(); 1505 } 1506 1507 FailureOr<ast::VariableDecl *> 1508 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1509 ast::Expr *initExpr, 1510 ArrayRef<ast::ConstraintRef> constraints) { 1511 assert(curDeclScope && "defining variable outside of decl scope"); 1512 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1513 1514 // If the name of the variable indicates a special variable, we don't add it 1515 // to the scope. This variable is local to the definition point. 1516 if (name.empty() || name == "_") { 1517 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1518 constraints); 1519 } 1520 if (failed(checkDefineNamedDecl(nameDecl))) 1521 return failure(); 1522 1523 auto *varDecl = 1524 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1525 curDeclScope->add(varDecl); 1526 return varDecl; 1527 } 1528 1529 FailureOr<ast::VariableDecl *> 1530 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1531 ArrayRef<ast::ConstraintRef> constraints) { 1532 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1533 constraints); 1534 } 1535 1536 LogicalResult Parser::parseVariableDeclConstraintList( 1537 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1538 Optional<SMRange> typeConstraint; 1539 auto parseSingleConstraint = [&] { 1540 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1541 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1542 /*allowNonCoreConstraints=*/true); 1543 if (failed(constraint)) 1544 return failure(); 1545 constraints.push_back(*constraint); 1546 return success(); 1547 }; 1548 1549 // Check to see if this is a single constraint, or a list. 1550 if (!consumeIf(Token::l_square)) 1551 return parseSingleConstraint(); 1552 1553 do { 1554 if (failed(parseSingleConstraint())) 1555 return failure(); 1556 } while (consumeIf(Token::comma)); 1557 return parseToken(Token::r_square, "expected `]` after constraint list"); 1558 } 1559 1560 FailureOr<ast::ConstraintRef> 1561 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1562 ArrayRef<ast::ConstraintRef> existingConstraints, 1563 bool allowInlineTypeConstraints, 1564 bool allowNonCoreConstraints) { 1565 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1566 if (!allowInlineTypeConstraints) { 1567 return emitError( 1568 curToken.getLoc(), 1569 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1570 "permitted on arguments or results"); 1571 } 1572 if (typeConstraint) 1573 return emitErrorAndNote( 1574 curToken.getLoc(), 1575 "the type of this variable has already been constrained", 1576 *typeConstraint, "see previous constraint location here"); 1577 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1578 if (failed(constraintExpr)) 1579 return failure(); 1580 typeExpr = *constraintExpr; 1581 typeConstraint = typeExpr->getLoc(); 1582 return success(); 1583 }; 1584 1585 SMRange loc = curToken.getLoc(); 1586 switch (curToken.getKind()) { 1587 case Token::kw_Attr: { 1588 consumeToken(Token::kw_Attr); 1589 1590 // Check for a type constraint. 1591 ast::Expr *typeExpr = nullptr; 1592 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1593 return failure(); 1594 return ast::ConstraintRef( 1595 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1596 } 1597 case Token::kw_Op: { 1598 consumeToken(Token::kw_Op); 1599 1600 // Parse an optional operation name. If the name isn't provided, this refers 1601 // to "any" operation. 1602 FailureOr<ast::OpNameDecl *> opName = 1603 parseWrappedOperationName(/*allowEmptyName=*/true); 1604 if (failed(opName)) 1605 return failure(); 1606 1607 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1608 loc); 1609 } 1610 case Token::kw_Type: 1611 consumeToken(Token::kw_Type); 1612 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1613 case Token::kw_TypeRange: 1614 consumeToken(Token::kw_TypeRange); 1615 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1616 loc); 1617 case Token::kw_Value: { 1618 consumeToken(Token::kw_Value); 1619 1620 // Check for a type constraint. 1621 ast::Expr *typeExpr = nullptr; 1622 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1623 return failure(); 1624 1625 return ast::ConstraintRef( 1626 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1627 } 1628 case Token::kw_ValueRange: { 1629 consumeToken(Token::kw_ValueRange); 1630 1631 // Check for a type constraint. 1632 ast::Expr *typeExpr = nullptr; 1633 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1634 return failure(); 1635 1636 return ast::ConstraintRef( 1637 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1638 } 1639 1640 case Token::kw_Constraint: { 1641 // Handle an inline constraint. 1642 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1643 if (failed(decl)) 1644 return failure(); 1645 return ast::ConstraintRef(*decl, loc); 1646 } 1647 case Token::identifier: { 1648 StringRef constraintName = curToken.getSpelling(); 1649 consumeToken(Token::identifier); 1650 1651 // Lookup the referenced constraint. 1652 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1653 if (!cstDecl) { 1654 return emitError(loc, "unknown reference to constraint `" + 1655 constraintName + "`"); 1656 } 1657 1658 // Handle a reference to a proper constraint. 1659 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1660 return ast::ConstraintRef(cst, loc); 1661 1662 return emitErrorAndNote( 1663 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1664 "see the definition of `" + constraintName + "` here"); 1665 } 1666 // Handle single entity constraint code completion. 1667 case Token::code_complete: { 1668 // Try to infer the current type for use by code completion. 1669 ast::Type inferredType; 1670 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1671 allowNonCoreConstraints))) 1672 return failure(); 1673 1674 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1675 allowInlineTypeConstraints); 1676 } 1677 default: 1678 break; 1679 } 1680 return emitError(loc, "expected identifier constraint"); 1681 } 1682 1683 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1684 // Constraint arguments may apply more complex constraints via the arguments. 1685 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1686 1687 Optional<SMRange> typeConstraint; 1688 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1689 /*allowInlineTypeConstraints=*/false, 1690 allowNonCoreConstraints); 1691 } 1692 1693 //===----------------------------------------------------------------------===// 1694 // Exprs 1695 1696 FailureOr<ast::Expr *> Parser::parseExpr() { 1697 if (curToken.is(Token::underscore)) 1698 return parseUnderscoreExpr(); 1699 1700 // Parse the LHS expression. 1701 FailureOr<ast::Expr *> lhsExpr; 1702 switch (curToken.getKind()) { 1703 case Token::kw_attr: 1704 lhsExpr = parseAttributeExpr(); 1705 break; 1706 case Token::kw_Constraint: 1707 lhsExpr = parseInlineConstraintLambdaExpr(); 1708 break; 1709 case Token::identifier: 1710 lhsExpr = parseIdentifierExpr(); 1711 break; 1712 case Token::kw_op: 1713 lhsExpr = parseOperationExpr(); 1714 break; 1715 case Token::kw_Rewrite: 1716 lhsExpr = parseInlineRewriteLambdaExpr(); 1717 break; 1718 case Token::kw_type: 1719 lhsExpr = parseTypeExpr(); 1720 break; 1721 case Token::l_paren: 1722 lhsExpr = parseTupleExpr(); 1723 break; 1724 default: 1725 return emitError("expected expression"); 1726 } 1727 if (failed(lhsExpr)) 1728 return failure(); 1729 1730 // Check for an operator expression. 1731 while (true) { 1732 switch (curToken.getKind()) { 1733 case Token::dot: 1734 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1735 break; 1736 case Token::l_paren: 1737 lhsExpr = parseCallExpr(*lhsExpr); 1738 break; 1739 default: 1740 return lhsExpr; 1741 } 1742 if (failed(lhsExpr)) 1743 return failure(); 1744 } 1745 } 1746 1747 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1748 SMRange loc = curToken.getLoc(); 1749 consumeToken(Token::kw_attr); 1750 1751 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1752 // identifier. 1753 if (!consumeIf(Token::less)) { 1754 resetToken(loc); 1755 return parseIdentifierExpr(); 1756 } 1757 1758 if (!curToken.isString()) 1759 return emitError("expected string literal containing MLIR attribute"); 1760 std::string attrExpr = curToken.getStringValue(); 1761 consumeToken(); 1762 1763 if (failed( 1764 parseToken(Token::greater, "expected `>` after attribute literal"))) 1765 return failure(); 1766 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1767 } 1768 1769 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1770 SMRange loc = curToken.getLoc(); 1771 consumeToken(Token::l_paren); 1772 1773 // Parse the arguments of the call. 1774 SmallVector<ast::Expr *> arguments; 1775 if (curToken.isNot(Token::r_paren)) { 1776 do { 1777 // Handle code completion for the call arguments. 1778 if (curToken.is(Token::code_complete)) { 1779 codeCompleteCallSignature(parentExpr, arguments.size()); 1780 return failure(); 1781 } 1782 1783 FailureOr<ast::Expr *> argument = parseExpr(); 1784 if (failed(argument)) 1785 return failure(); 1786 arguments.push_back(*argument); 1787 } while (consumeIf(Token::comma)); 1788 } 1789 loc.End = curToken.getEndLoc(); 1790 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1791 return failure(); 1792 1793 return createCallExpr(loc, parentExpr, arguments); 1794 } 1795 1796 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1797 ast::Decl *decl = curDeclScope->lookup(name); 1798 if (!decl) 1799 return emitError(loc, "undefined reference to `" + name + "`"); 1800 1801 return createDeclRefExpr(loc, decl); 1802 } 1803 1804 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1805 StringRef name = curToken.getSpelling(); 1806 SMRange nameLoc = curToken.getLoc(); 1807 consumeToken(); 1808 1809 // Check to see if this is a decl ref expression that defines a variable 1810 // inline. 1811 if (consumeIf(Token::colon)) { 1812 SmallVector<ast::ConstraintRef> constraints; 1813 if (failed(parseVariableDeclConstraintList(constraints))) 1814 return failure(); 1815 ast::Type type; 1816 if (failed(validateVariableConstraints(constraints, type))) 1817 return failure(); 1818 return createInlineVariableExpr(type, name, nameLoc, constraints); 1819 } 1820 1821 return parseDeclRefExpr(name, nameLoc); 1822 } 1823 1824 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1825 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1826 if (failed(decl)) 1827 return failure(); 1828 1829 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1830 ast::ConstraintType::get(ctx)); 1831 } 1832 1833 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1834 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1835 if (failed(decl)) 1836 return failure(); 1837 1838 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1839 ast::RewriteType::get(ctx)); 1840 } 1841 1842 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1843 SMRange loc = curToken.getLoc(); 1844 consumeToken(Token::dot); 1845 1846 // Check for code completion of the member name. 1847 if (curToken.is(Token::code_complete)) 1848 return codeCompleteMemberAccess(parentExpr); 1849 1850 // Parse the member name. 1851 Token memberNameTok = curToken; 1852 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1853 !memberNameTok.isKeyword()) 1854 return emitError(loc, "expected identifier or numeric member name"); 1855 StringRef memberName = memberNameTok.getSpelling(); 1856 consumeToken(); 1857 1858 return createMemberAccessExpr(parentExpr, memberName, loc); 1859 } 1860 1861 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1862 SMRange loc = curToken.getLoc(); 1863 1864 // Check for code completion for the dialect name. 1865 if (curToken.is(Token::code_complete)) 1866 return codeCompleteDialectName(); 1867 1868 // Handle the case of an no operation name. 1869 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1870 if (allowEmptyName) 1871 return ast::OpNameDecl::create(ctx, SMRange()); 1872 return emitError("expected dialect namespace"); 1873 } 1874 StringRef name = curToken.getSpelling(); 1875 consumeToken(); 1876 1877 // Otherwise, this is a literal operation name. 1878 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1879 return failure(); 1880 1881 // Check for code completion for the operation name. 1882 if (curToken.is(Token::code_complete)) 1883 return codeCompleteOperationName(name); 1884 1885 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1886 return emitError("expected operation name after dialect namespace"); 1887 1888 name = StringRef(name.data(), name.size() + 1); 1889 do { 1890 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1891 loc.End = curToken.getEndLoc(); 1892 consumeToken(); 1893 } while (curToken.isAny(Token::identifier, Token::dot) || 1894 curToken.isKeyword()); 1895 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1896 } 1897 1898 FailureOr<ast::OpNameDecl *> 1899 Parser::parseWrappedOperationName(bool allowEmptyName) { 1900 if (!consumeIf(Token::less)) 1901 return ast::OpNameDecl::create(ctx, SMRange()); 1902 1903 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1904 if (failed(opNameDecl)) 1905 return failure(); 1906 1907 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1908 return failure(); 1909 return opNameDecl; 1910 } 1911 1912 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1913 SMRange loc = curToken.getLoc(); 1914 consumeToken(Token::kw_op); 1915 1916 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1917 // identifier. 1918 if (curToken.isNot(Token::less)) { 1919 resetToken(loc); 1920 return parseIdentifierExpr(); 1921 } 1922 1923 // Parse the operation name. The name may be elided, in which case the 1924 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1925 // `Operation*`). Operation names within a rewrite context must be named. 1926 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1927 FailureOr<ast::OpNameDecl *> opNameDecl = 1928 parseWrappedOperationName(allowEmptyName); 1929 if (failed(opNameDecl)) 1930 return failure(); 1931 Optional<StringRef> opName = (*opNameDecl)->getName(); 1932 1933 // Functor used to create an implicit range variable, used for implicit "all" 1934 // operand or results variables. 1935 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1936 FailureOr<ast::VariableDecl *> rangeVar = 1937 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1938 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1939 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1940 }; 1941 1942 // Check for the optional list of operands. 1943 SmallVector<ast::Expr *> operands; 1944 if (!consumeIf(Token::l_paren)) { 1945 // If the operand list isn't specified and we are in a match context, define 1946 // an inplace unconstrained operand range corresponding to all of the 1947 // operands of the operation. This avoids treating zero operands the same 1948 // way as "unconstrained operands". 1949 if (parserContext != ParserContext::Rewrite) { 1950 operands.push_back(createImplicitRangeVar( 1951 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1952 } 1953 } else if (!consumeIf(Token::r_paren)) { 1954 // Check for operand signature code completion. 1955 if (curToken.is(Token::code_complete)) { 1956 codeCompleteOperationOperandsSignature(opName, operands.size()); 1957 return failure(); 1958 } 1959 1960 // If the operand list was specified and non-empty, parse the operands. 1961 do { 1962 FailureOr<ast::Expr *> operand = parseExpr(); 1963 if (failed(operand)) 1964 return failure(); 1965 operands.push_back(*operand); 1966 } while (consumeIf(Token::comma)); 1967 1968 if (failed(parseToken(Token::r_paren, 1969 "expected `)` after operation operand list"))) 1970 return failure(); 1971 } 1972 1973 // Check for the optional list of attributes. 1974 SmallVector<ast::NamedAttributeDecl *> attributes; 1975 if (consumeIf(Token::l_brace)) { 1976 do { 1977 FailureOr<ast::NamedAttributeDecl *> decl = 1978 parseNamedAttributeDecl(opName); 1979 if (failed(decl)) 1980 return failure(); 1981 attributes.emplace_back(*decl); 1982 } while (consumeIf(Token::comma)); 1983 1984 if (failed(parseToken(Token::r_brace, 1985 "expected `}` after operation attribute list"))) 1986 return failure(); 1987 } 1988 1989 // Check for the optional list of result types. 1990 SmallVector<ast::Expr *> resultTypes; 1991 if (consumeIf(Token::arrow)) { 1992 if (failed(parseToken(Token::l_paren, 1993 "expected `(` before operation result type list"))) 1994 return failure(); 1995 1996 // Handle the case of an empty result list. 1997 if (!consumeIf(Token::r_paren)) { 1998 do { 1999 // Check for result signature code completion. 2000 if (curToken.is(Token::code_complete)) { 2001 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2002 return failure(); 2003 } 2004 2005 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2006 if (failed(resultTypeExpr)) 2007 return failure(); 2008 resultTypes.push_back(*resultTypeExpr); 2009 } while (consumeIf(Token::comma)); 2010 2011 if (failed(parseToken(Token::r_paren, 2012 "expected `)` after operation result type list"))) 2013 return failure(); 2014 } 2015 } else if (parserContext != ParserContext::Rewrite) { 2016 // If the result list isn't specified and we are in a match context, define 2017 // an inplace unconstrained result range corresponding to all of the results 2018 // of the operation. This avoids treating zero results the same way as 2019 // "unconstrained results". 2020 resultTypes.push_back(createImplicitRangeVar( 2021 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2022 } 2023 2024 return createOperationExpr(loc, *opNameDecl, operands, attributes, 2025 resultTypes); 2026 } 2027 2028 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2029 SMRange loc = curToken.getLoc(); 2030 consumeToken(Token::l_paren); 2031 2032 DenseMap<StringRef, SMRange> usedNames; 2033 SmallVector<StringRef> elementNames; 2034 SmallVector<ast::Expr *> elements; 2035 if (curToken.isNot(Token::r_paren)) { 2036 do { 2037 // Check for the optional element name assignment before the value. 2038 StringRef elementName; 2039 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2040 Token elementNameTok = curToken; 2041 consumeToken(); 2042 2043 // The element name is only present if followed by an `=`. 2044 if (consumeIf(Token::equal)) { 2045 elementName = elementNameTok.getSpelling(); 2046 2047 // Check to see if this name is already used. 2048 auto elementNameIt = 2049 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2050 if (!elementNameIt.second) { 2051 return emitErrorAndNote( 2052 elementNameTok.getLoc(), 2053 llvm::formatv("duplicate tuple element label `{0}`", 2054 elementName), 2055 elementNameIt.first->getSecond(), 2056 "see previous label use here"); 2057 } 2058 } else { 2059 // Otherwise, we treat this as part of an expression so reset the 2060 // lexer. 2061 resetToken(elementNameTok.getLoc()); 2062 } 2063 } 2064 elementNames.push_back(elementName); 2065 2066 // Parse the tuple element value. 2067 FailureOr<ast::Expr *> element = parseExpr(); 2068 if (failed(element)) 2069 return failure(); 2070 elements.push_back(*element); 2071 } while (consumeIf(Token::comma)); 2072 } 2073 loc.End = curToken.getEndLoc(); 2074 if (failed( 2075 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2076 return failure(); 2077 return createTupleExpr(loc, elements, elementNames); 2078 } 2079 2080 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2081 SMRange loc = curToken.getLoc(); 2082 consumeToken(Token::kw_type); 2083 2084 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2085 // identifier. 2086 if (!consumeIf(Token::less)) { 2087 resetToken(loc); 2088 return parseIdentifierExpr(); 2089 } 2090 2091 if (!curToken.isString()) 2092 return emitError("expected string literal containing MLIR type"); 2093 std::string attrExpr = curToken.getStringValue(); 2094 consumeToken(); 2095 2096 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2097 return failure(); 2098 return ast::TypeExpr::create(ctx, loc, attrExpr); 2099 } 2100 2101 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2102 StringRef name = curToken.getSpelling(); 2103 SMRange nameLoc = curToken.getLoc(); 2104 consumeToken(Token::underscore); 2105 2106 // Underscore expressions require a constraint list. 2107 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2108 return failure(); 2109 2110 // Parse the constraints for the expression. 2111 SmallVector<ast::ConstraintRef> constraints; 2112 if (failed(parseVariableDeclConstraintList(constraints))) 2113 return failure(); 2114 2115 ast::Type type; 2116 if (failed(validateVariableConstraints(constraints, type))) 2117 return failure(); 2118 return createInlineVariableExpr(type, name, nameLoc, constraints); 2119 } 2120 2121 //===----------------------------------------------------------------------===// 2122 // Stmts 2123 2124 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2125 FailureOr<ast::Stmt *> stmt; 2126 switch (curToken.getKind()) { 2127 case Token::kw_erase: 2128 stmt = parseEraseStmt(); 2129 break; 2130 case Token::kw_let: 2131 stmt = parseLetStmt(); 2132 break; 2133 case Token::kw_replace: 2134 stmt = parseReplaceStmt(); 2135 break; 2136 case Token::kw_return: 2137 stmt = parseReturnStmt(); 2138 break; 2139 case Token::kw_rewrite: 2140 stmt = parseRewriteStmt(); 2141 break; 2142 default: 2143 stmt = parseExpr(); 2144 break; 2145 } 2146 if (failed(stmt) || 2147 (expectTerminalSemicolon && 2148 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2149 return failure(); 2150 return stmt; 2151 } 2152 2153 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2154 SMLoc startLoc = curToken.getStartLoc(); 2155 consumeToken(Token::l_brace); 2156 2157 // Push a new block scope and parse any nested statements. 2158 pushDeclScope(); 2159 SmallVector<ast::Stmt *> statements; 2160 while (curToken.isNot(Token::r_brace)) { 2161 FailureOr<ast::Stmt *> statement = parseStmt(); 2162 if (failed(statement)) 2163 return popDeclScope(), failure(); 2164 statements.push_back(*statement); 2165 } 2166 popDeclScope(); 2167 2168 // Consume the end brace. 2169 SMRange location(startLoc, curToken.getEndLoc()); 2170 consumeToken(Token::r_brace); 2171 2172 return ast::CompoundStmt::create(ctx, location, statements); 2173 } 2174 2175 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2176 if (parserContext == ParserContext::Constraint) 2177 return emitError("`erase` cannot be used within a Constraint"); 2178 SMRange loc = curToken.getLoc(); 2179 consumeToken(Token::kw_erase); 2180 2181 // Parse the root operation expression. 2182 FailureOr<ast::Expr *> rootOp = parseExpr(); 2183 if (failed(rootOp)) 2184 return failure(); 2185 2186 return createEraseStmt(loc, *rootOp); 2187 } 2188 2189 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2190 SMRange loc = curToken.getLoc(); 2191 consumeToken(Token::kw_let); 2192 2193 // Parse the name of the new variable. 2194 SMRange varLoc = curToken.getLoc(); 2195 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2196 // `_` is a reserved variable name. 2197 if (curToken.is(Token::underscore)) { 2198 return emitError(varLoc, 2199 "`_` may only be used to define \"inline\" variables"); 2200 } 2201 return emitError(varLoc, 2202 "expected identifier after `let` to name a new variable"); 2203 } 2204 StringRef varName = curToken.getSpelling(); 2205 consumeToken(); 2206 2207 // Parse the optional set of constraints. 2208 SmallVector<ast::ConstraintRef> constraints; 2209 if (consumeIf(Token::colon) && 2210 failed(parseVariableDeclConstraintList(constraints))) 2211 return failure(); 2212 2213 // Parse the optional initializer expression. 2214 ast::Expr *initializer = nullptr; 2215 if (consumeIf(Token::equal)) { 2216 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2217 if (failed(initOrFailure)) 2218 return failure(); 2219 initializer = *initOrFailure; 2220 2221 // Check that the constraints are compatible with having an initializer, 2222 // e.g. type constraints cannot be used with initializers. 2223 for (ast::ConstraintRef constraint : constraints) { 2224 LogicalResult result = 2225 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2226 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2227 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2228 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2229 return this->emitError( 2230 constraint.referenceLoc, 2231 "type constraints are not permitted on variables with " 2232 "initializers"); 2233 } 2234 return success(); 2235 }) 2236 .Default(success()); 2237 if (failed(result)) 2238 return failure(); 2239 } 2240 } 2241 2242 FailureOr<ast::VariableDecl *> varDecl = 2243 createVariableDecl(varName, varLoc, initializer, constraints); 2244 if (failed(varDecl)) 2245 return failure(); 2246 return ast::LetStmt::create(ctx, loc, *varDecl); 2247 } 2248 2249 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2250 if (parserContext == ParserContext::Constraint) 2251 return emitError("`replace` cannot be used within a Constraint"); 2252 SMRange loc = curToken.getLoc(); 2253 consumeToken(Token::kw_replace); 2254 2255 // Parse the root operation expression. 2256 FailureOr<ast::Expr *> rootOp = parseExpr(); 2257 if (failed(rootOp)) 2258 return failure(); 2259 2260 if (failed( 2261 parseToken(Token::kw_with, "expected `with` after root operation"))) 2262 return failure(); 2263 2264 // The replacement portion of this statement is within a rewrite context. 2265 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2266 ParserContext::Rewrite); 2267 2268 // Parse the replacement values. 2269 SmallVector<ast::Expr *> replValues; 2270 if (consumeIf(Token::l_paren)) { 2271 if (consumeIf(Token::r_paren)) { 2272 return emitError( 2273 loc, "expected at least one replacement value, consider using " 2274 "`erase` if no replacement values are desired"); 2275 } 2276 2277 do { 2278 FailureOr<ast::Expr *> replExpr = parseExpr(); 2279 if (failed(replExpr)) 2280 return failure(); 2281 replValues.emplace_back(*replExpr); 2282 } while (consumeIf(Token::comma)); 2283 2284 if (failed(parseToken(Token::r_paren, 2285 "expected `)` after replacement values"))) 2286 return failure(); 2287 } else { 2288 FailureOr<ast::Expr *> replExpr = parseExpr(); 2289 if (failed(replExpr)) 2290 return failure(); 2291 replValues.emplace_back(*replExpr); 2292 } 2293 2294 return createReplaceStmt(loc, *rootOp, replValues); 2295 } 2296 2297 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2298 SMRange loc = curToken.getLoc(); 2299 consumeToken(Token::kw_return); 2300 2301 // Parse the result value. 2302 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2303 if (failed(resultExpr)) 2304 return failure(); 2305 2306 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2307 } 2308 2309 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2310 if (parserContext == ParserContext::Constraint) 2311 return emitError("`rewrite` cannot be used within a Constraint"); 2312 SMRange loc = curToken.getLoc(); 2313 consumeToken(Token::kw_rewrite); 2314 2315 // Parse the root operation. 2316 FailureOr<ast::Expr *> rootOp = parseExpr(); 2317 if (failed(rootOp)) 2318 return failure(); 2319 2320 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2321 return failure(); 2322 2323 if (curToken.isNot(Token::l_brace)) 2324 return emitError("expected `{` to start rewrite body"); 2325 2326 // The rewrite body of this statement is within a rewrite context. 2327 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2328 ParserContext::Rewrite); 2329 2330 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2331 if (failed(rewriteBody)) 2332 return failure(); 2333 2334 // Verify the rewrite body. 2335 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2336 if (isa<ast::ReturnStmt>(stmt)) { 2337 return emitError(stmt->getLoc(), 2338 "`return` statements are only permitted within a " 2339 "`Constraint` or `Rewrite` body"); 2340 } 2341 } 2342 2343 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2344 } 2345 2346 //===----------------------------------------------------------------------===// 2347 // Creation+Analysis 2348 //===----------------------------------------------------------------------===// 2349 2350 //===----------------------------------------------------------------------===// 2351 // Decls 2352 2353 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2354 // Unwrap reference expressions. 2355 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2356 node = init->getDecl(); 2357 return dyn_cast<ast::CallableDecl>(node); 2358 } 2359 2360 FailureOr<ast::PatternDecl *> 2361 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2362 const ParsedPatternMetadata &metadata, 2363 ast::CompoundStmt *body) { 2364 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2365 metadata.hasBoundedRecursion, body); 2366 } 2367 2368 ast::Type Parser::createUserConstraintRewriteResultType( 2369 ArrayRef<ast::VariableDecl *> results) { 2370 // Single result decls use the type of the single result. 2371 if (results.size() == 1) 2372 return results[0]->getType(); 2373 2374 // Multiple results use a tuple type, with the types and names grabbed from 2375 // the result variable decls. 2376 auto resultTypes = llvm::map_range( 2377 results, [&](const auto *result) { return result->getType(); }); 2378 auto resultNames = llvm::map_range( 2379 results, [&](const auto *result) { return result->getName().getName(); }); 2380 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2381 llvm::to_vector(resultNames)); 2382 } 2383 2384 template <typename T> 2385 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2386 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2387 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2388 ast::CompoundStmt *body) { 2389 if (!body->getChildren().empty()) { 2390 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2391 ast::Expr *resultExpr = retStmt->getResultExpr(); 2392 2393 // Process the result of the decl. If no explicit signature results 2394 // were provided, check for return type inference. Otherwise, check that 2395 // the return expression can be converted to the expected type. 2396 if (results.empty()) 2397 resultType = resultExpr->getType(); 2398 else if (failed(convertExpressionTo(resultExpr, resultType))) 2399 return failure(); 2400 else 2401 retStmt->setResultExpr(resultExpr); 2402 } 2403 } 2404 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2405 } 2406 2407 FailureOr<ast::VariableDecl *> 2408 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2409 ArrayRef<ast::ConstraintRef> constraints) { 2410 // The type of the variable, which is expected to be inferred by either a 2411 // constraint or an initializer expression. 2412 ast::Type type; 2413 if (failed(validateVariableConstraints(constraints, type))) 2414 return failure(); 2415 2416 if (initializer) { 2417 // Update the variable type based on the initializer, or try to convert the 2418 // initializer to the existing type. 2419 if (!type) 2420 type = initializer->getType(); 2421 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2422 type = mergedType; 2423 else if (failed(convertExpressionTo(initializer, type))) 2424 return failure(); 2425 2426 // Otherwise, if there is no initializer check that the type has already 2427 // been resolved from the constraint list. 2428 } else if (!type) { 2429 return emitErrorAndNote( 2430 loc, "unable to infer type for variable `" + name + "`", loc, 2431 "the type of a variable must be inferable from the constraint " 2432 "list or the initializer"); 2433 } 2434 2435 // Constraint types cannot be used when defining variables. 2436 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2437 return emitError( 2438 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2439 } 2440 2441 // Try to define a variable with the given name. 2442 FailureOr<ast::VariableDecl *> varDecl = 2443 defineVariableDecl(name, loc, type, initializer, constraints); 2444 if (failed(varDecl)) 2445 return failure(); 2446 2447 return *varDecl; 2448 } 2449 2450 FailureOr<ast::VariableDecl *> 2451 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2452 const ast::ConstraintRef &constraint) { 2453 // Constraint arguments may apply more complex constraints via the arguments. 2454 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2455 ast::Type argType; 2456 if (failed(validateVariableConstraint(constraint, argType, 2457 allowNonCoreConstraints))) 2458 return failure(); 2459 return defineVariableDecl(name, loc, argType, constraint); 2460 } 2461 2462 LogicalResult 2463 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2464 ast::Type &inferredType, 2465 bool allowNonCoreConstraints) { 2466 for (const ast::ConstraintRef &ref : constraints) 2467 if (failed(validateVariableConstraint(ref, inferredType, 2468 allowNonCoreConstraints))) 2469 return failure(); 2470 return success(); 2471 } 2472 2473 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2474 ast::Type &inferredType, 2475 bool allowNonCoreConstraints) { 2476 ast::Type constraintType; 2477 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2478 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2479 if (failed(validateTypeConstraintExpr(typeExpr))) 2480 return failure(); 2481 } 2482 constraintType = ast::AttributeType::get(ctx); 2483 } else if (const auto *cst = 2484 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2485 constraintType = ast::OperationType::get(ctx, cst->getName()); 2486 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2487 constraintType = typeTy; 2488 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2489 constraintType = typeRangeTy; 2490 } else if (const auto *cst = 2491 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2492 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2493 if (failed(validateTypeConstraintExpr(typeExpr))) 2494 return failure(); 2495 } 2496 constraintType = valueTy; 2497 } else if (const auto *cst = 2498 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2499 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2500 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2501 return failure(); 2502 } 2503 constraintType = valueRangeTy; 2504 } else if (const auto *cst = 2505 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2506 if (!allowNonCoreConstraints) { 2507 return emitError(ref.referenceLoc, 2508 "`Rewrite` arguments and results are only permitted to " 2509 "use core constraints, such as `Attr`, `Op`, `Type`, " 2510 "`TypeRange`, `Value`, `ValueRange`"); 2511 } 2512 2513 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2514 if (inputs.size() != 1) { 2515 return emitErrorAndNote(ref.referenceLoc, 2516 "`Constraint`s applied via a variable constraint " 2517 "list must take a single input, but got " + 2518 Twine(inputs.size()), 2519 cst->getLoc(), 2520 "see definition of constraint here"); 2521 } 2522 constraintType = inputs.front()->getType(); 2523 } else { 2524 llvm_unreachable("unknown constraint type"); 2525 } 2526 2527 // Check that the constraint type is compatible with the current inferred 2528 // type. 2529 if (!inferredType) { 2530 inferredType = constraintType; 2531 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2532 inferredType = mergedTy; 2533 } else { 2534 return emitError(ref.referenceLoc, 2535 llvm::formatv("constraint type `{0}` is incompatible " 2536 "with the previously inferred type `{1}`", 2537 constraintType, inferredType)); 2538 } 2539 return success(); 2540 } 2541 2542 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2543 ast::Type typeExprType = typeExpr->getType(); 2544 if (typeExprType != typeTy) { 2545 return emitError(typeExpr->getLoc(), 2546 "expected expression of `Type` in type constraint"); 2547 } 2548 return success(); 2549 } 2550 2551 LogicalResult 2552 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2553 ast::Type typeExprType = typeExpr->getType(); 2554 if (typeExprType != typeRangeTy) { 2555 return emitError(typeExpr->getLoc(), 2556 "expected expression of `TypeRange` in type constraint"); 2557 } 2558 return success(); 2559 } 2560 2561 //===----------------------------------------------------------------------===// 2562 // Exprs 2563 2564 FailureOr<ast::CallExpr *> 2565 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2566 MutableArrayRef<ast::Expr *> arguments) { 2567 ast::Type parentType = parentExpr->getType(); 2568 2569 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2570 if (!callableDecl) { 2571 return emitError(loc, 2572 llvm::formatv("expected a reference to a callable " 2573 "`Constraint` or `Rewrite`, but got: `{0}`", 2574 parentType)); 2575 } 2576 if (parserContext == ParserContext::Rewrite) { 2577 if (isa<ast::UserConstraintDecl>(callableDecl)) 2578 return emitError( 2579 loc, "unable to invoke `Constraint` within a rewrite section"); 2580 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2581 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2582 } 2583 2584 // Verify the arguments of the call. 2585 /// Handle size mismatch. 2586 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2587 if (callArgs.size() != arguments.size()) { 2588 return emitErrorAndNote( 2589 loc, 2590 llvm::formatv("invalid number of arguments for {0} call; expected " 2591 "{1}, but got {2}", 2592 callableDecl->getCallableType(), callArgs.size(), 2593 arguments.size()), 2594 callableDecl->getLoc(), 2595 llvm::formatv("see the definition of {0} here", 2596 callableDecl->getName()->getName())); 2597 } 2598 2599 /// Handle argument type mismatch. 2600 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2601 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2602 callableDecl->getName()->getName()), 2603 callableDecl->getLoc()); 2604 }; 2605 for (auto it : llvm::zip(callArgs, arguments)) { 2606 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2607 attachDiagFn))) 2608 return failure(); 2609 } 2610 2611 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2612 callableDecl->getResultType()); 2613 } 2614 2615 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2616 ast::Decl *decl) { 2617 // Check the type of decl being referenced. 2618 ast::Type declType; 2619 if (isa<ast::ConstraintDecl>(decl)) 2620 declType = ast::ConstraintType::get(ctx); 2621 else if (isa<ast::UserRewriteDecl>(decl)) 2622 declType = ast::RewriteType::get(ctx); 2623 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2624 declType = varDecl->getType(); 2625 else 2626 return emitError(loc, "invalid reference to `" + 2627 decl->getName()->getName() + "`"); 2628 2629 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2630 } 2631 2632 FailureOr<ast::DeclRefExpr *> 2633 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2634 ArrayRef<ast::ConstraintRef> constraints) { 2635 FailureOr<ast::VariableDecl *> decl = 2636 defineVariableDecl(name, loc, type, constraints); 2637 if (failed(decl)) 2638 return failure(); 2639 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2640 } 2641 2642 FailureOr<ast::MemberAccessExpr *> 2643 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2644 SMRange loc) { 2645 // Validate the member name for the given parent expression. 2646 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2647 if (failed(memberType)) 2648 return failure(); 2649 2650 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2651 } 2652 2653 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2654 StringRef name, SMRange loc) { 2655 ast::Type parentType = parentExpr->getType(); 2656 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2657 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2658 return valueRangeTy; 2659 2660 // Verify member access based on the operation type. 2661 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2662 auto results = odsOp->getResults(); 2663 2664 // Handle indexed results. 2665 unsigned index = 0; 2666 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2667 index < results.size()) { 2668 return results[index].isVariadic() ? valueRangeTy : valueTy; 2669 } 2670 2671 // Handle named results. 2672 const auto *it = llvm::find_if(results, [&](const auto &result) { 2673 return result.getName() == name; 2674 }); 2675 if (it != results.end()) 2676 return it->isVariadic() ? valueRangeTy : valueTy; 2677 } 2678 2679 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2680 // Handle indexed results. 2681 unsigned index = 0; 2682 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2683 index < tupleType.size()) { 2684 return tupleType.getElementTypes()[index]; 2685 } 2686 2687 // Handle named results. 2688 auto elementNames = tupleType.getElementNames(); 2689 const auto *it = llvm::find(elementNames, name); 2690 if (it != elementNames.end()) 2691 return tupleType.getElementTypes()[it - elementNames.begin()]; 2692 } 2693 return emitError( 2694 loc, 2695 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2696 name, parentType)); 2697 } 2698 2699 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2700 SMRange loc, const ast::OpNameDecl *name, 2701 MutableArrayRef<ast::Expr *> operands, 2702 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2703 MutableArrayRef<ast::Expr *> results) { 2704 Optional<StringRef> opNameRef = name->getName(); 2705 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2706 2707 // Verify the inputs operands. 2708 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2709 return failure(); 2710 2711 // Verify the attribute list. 2712 for (ast::NamedAttributeDecl *attr : attributes) { 2713 // Check for an attribute type, or a type awaiting resolution. 2714 ast::Type attrType = attr->getValue()->getType(); 2715 if (!attrType.isa<ast::AttributeType>()) { 2716 return emitError( 2717 attr->getValue()->getLoc(), 2718 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2719 } 2720 } 2721 2722 // Verify the result types. 2723 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2724 return failure(); 2725 2726 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2727 attributes); 2728 } 2729 2730 LogicalResult 2731 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2732 const ods::Operation *odsOp, 2733 MutableArrayRef<ast::Expr *> operands) { 2734 return validateOperationOperandsOrResults( 2735 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2736 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2737 valueRangeTy); 2738 } 2739 2740 LogicalResult 2741 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2742 const ods::Operation *odsOp, 2743 MutableArrayRef<ast::Expr *> results) { 2744 return validateOperationOperandsOrResults( 2745 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2746 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2747 } 2748 2749 LogicalResult Parser::validateOperationOperandsOrResults( 2750 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2751 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2752 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2753 ast::Type rangeTy) { 2754 // All operation types accept a single range parameter. 2755 if (values.size() == 1) { 2756 if (failed(convertExpressionTo(values[0], rangeTy))) 2757 return failure(); 2758 return success(); 2759 } 2760 2761 /// If the operation has ODS information, we can more accurately verify the 2762 /// values. 2763 if (odsOpLoc) { 2764 if (odsValues.size() != values.size()) { 2765 return emitErrorAndNote( 2766 loc, 2767 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2768 "{2}, but got {3}", 2769 groupName, *name, odsValues.size(), values.size()), 2770 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2771 } 2772 auto diagFn = [&](ast::Diagnostic &diag) { 2773 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2774 *odsOpLoc); 2775 }; 2776 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2777 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2778 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2779 return failure(); 2780 } 2781 return success(); 2782 } 2783 2784 // Otherwise, accept the value groups as they have been defined and just 2785 // ensure they are one of the expected types. 2786 for (ast::Expr *&valueExpr : values) { 2787 ast::Type valueExprType = valueExpr->getType(); 2788 2789 // Check if this is one of the expected types. 2790 if (valueExprType == rangeTy || valueExprType == singleTy) 2791 continue; 2792 2793 // If the operand is an Operation, allow converting to a Value or 2794 // ValueRange. This situations arises quite often with nested operation 2795 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2796 if (singleTy == valueTy) { 2797 if (valueExprType.isa<ast::OperationType>()) { 2798 valueExpr = convertOpToValue(valueExpr); 2799 continue; 2800 } 2801 } 2802 2803 return emitError( 2804 valueExpr->getLoc(), 2805 llvm::formatv( 2806 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2807 singleTy, rangeTy, valueExprType)); 2808 } 2809 return success(); 2810 } 2811 2812 FailureOr<ast::TupleExpr *> 2813 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2814 ArrayRef<StringRef> elementNames) { 2815 for (const ast::Expr *element : elements) { 2816 ast::Type eleTy = element->getType(); 2817 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2818 return emitError( 2819 element->getLoc(), 2820 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2821 } 2822 } 2823 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2824 } 2825 2826 //===----------------------------------------------------------------------===// 2827 // Stmts 2828 2829 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2830 ast::Expr *rootOp) { 2831 // Check that root is an Operation. 2832 ast::Type rootType = rootOp->getType(); 2833 if (!rootType.isa<ast::OperationType>()) 2834 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2835 2836 return ast::EraseStmt::create(ctx, loc, rootOp); 2837 } 2838 2839 FailureOr<ast::ReplaceStmt *> 2840 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2841 MutableArrayRef<ast::Expr *> replValues) { 2842 // Check that root is an Operation. 2843 ast::Type rootType = rootOp->getType(); 2844 if (!rootType.isa<ast::OperationType>()) { 2845 return emitError( 2846 rootOp->getLoc(), 2847 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2848 } 2849 2850 // If there are multiple replacement values, we implicitly convert any Op 2851 // expressions to the value form. 2852 bool shouldConvertOpToValues = replValues.size() > 1; 2853 for (ast::Expr *&replExpr : replValues) { 2854 ast::Type replType = replExpr->getType(); 2855 2856 // Check that replExpr is an Operation, Value, or ValueRange. 2857 if (replType.isa<ast::OperationType>()) { 2858 if (shouldConvertOpToValues) 2859 replExpr = convertOpToValue(replExpr); 2860 continue; 2861 } 2862 2863 if (replType != valueTy && replType != valueRangeTy) { 2864 return emitError(replExpr->getLoc(), 2865 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2866 "expression, but got `{0}`", 2867 replType)); 2868 } 2869 } 2870 2871 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2872 } 2873 2874 FailureOr<ast::RewriteStmt *> 2875 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2876 ast::CompoundStmt *rewriteBody) { 2877 // Check that root is an Operation. 2878 ast::Type rootType = rootOp->getType(); 2879 if (!rootType.isa<ast::OperationType>()) { 2880 return emitError( 2881 rootOp->getLoc(), 2882 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2883 } 2884 2885 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2886 } 2887 2888 //===----------------------------------------------------------------------===// 2889 // Code Completion 2890 //===----------------------------------------------------------------------===// 2891 2892 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 2893 ast::Type parentType = parentExpr->getType(); 2894 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 2895 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 2896 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 2897 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 2898 return failure(); 2899 } 2900 2901 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 2902 if (opName) 2903 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 2904 return failure(); 2905 } 2906 2907 LogicalResult 2908 Parser::codeCompleteConstraintName(ast::Type inferredType, 2909 bool allowNonCoreConstraints, 2910 bool allowInlineTypeConstraints) { 2911 codeCompleteContext->codeCompleteConstraintName( 2912 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 2913 curDeclScope); 2914 return failure(); 2915 } 2916 2917 LogicalResult Parser::codeCompleteDialectName() { 2918 codeCompleteContext->codeCompleteDialectName(); 2919 return failure(); 2920 } 2921 2922 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 2923 codeCompleteContext->codeCompleteOperationName(dialectName); 2924 return failure(); 2925 } 2926 2927 LogicalResult Parser::codeCompletePatternMetadata() { 2928 codeCompleteContext->codeCompletePatternMetadata(); 2929 return failure(); 2930 } 2931 2932 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 2933 codeCompleteContext->codeCompleteIncludeFilename(curPath); 2934 return failure(); 2935 } 2936 2937 void Parser::codeCompleteCallSignature(ast::Node *parent, 2938 unsigned currentNumArgs) { 2939 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 2940 if (!callableDecl) 2941 return; 2942 2943 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 2944 } 2945 2946 void Parser::codeCompleteOperationOperandsSignature( 2947 Optional<StringRef> opName, unsigned currentNumOperands) { 2948 codeCompleteContext->codeCompleteOperationOperandsSignature( 2949 opName, currentNumOperands); 2950 } 2951 2952 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 2953 unsigned currentNumResults) { 2954 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 2955 currentNumResults); 2956 } 2957 2958 //===----------------------------------------------------------------------===// 2959 // Parser 2960 //===----------------------------------------------------------------------===// 2961 2962 FailureOr<ast::Module *> 2963 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 2964 CodeCompleteContext *codeCompleteContext) { 2965 Parser parser(ctx, sourceMgr, codeCompleteContext); 2966 return parser.parseModule(); 2967 } 2968