1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/IndentedOstream.h" 12 #include "mlir/Support/LogicalResult.h" 13 #include "mlir/TableGen/Argument.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Constraint.h" 16 #include "mlir/TableGen/Format.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/Tools/PDLL/AST/Context.h" 19 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 20 #include "mlir/Tools/PDLL/AST/Nodes.h" 21 #include "mlir/Tools/PDLL/AST/Types.h" 22 #include "mlir/Tools/PDLL/ODS/Constraint.h" 23 #include "mlir/Tools/PDLL/ODS/Context.h" 24 #include "mlir/Tools/PDLL/ODS/Operation.h" 25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/ManagedStatic.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/ScopedPrinter.h" 32 #include "llvm/TableGen/Error.h" 33 #include "llvm/TableGen/Parser.h" 34 #include <string> 35 36 using namespace mlir; 37 using namespace mlir::pdll; 38 39 //===----------------------------------------------------------------------===// 40 // Parser 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 class Parser { 45 public: 46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext) 48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation), 50 valueTy(ast::ValueType::get(ctx)), 51 valueRangeTy(ast::ValueRangeType::get(ctx)), 52 typeTy(ast::TypeType::get(ctx)), 53 typeRangeTy(ast::TypeRangeType::get(ctx)), 54 attrTy(ast::AttributeType::get(ctx)), 55 codeCompleteContext(codeCompleteContext) {} 56 57 /// Try to parse a new module. Returns nullptr in the case of failure. 58 FailureOr<ast::Module *> parseModule(); 59 60 private: 61 /// The current context of the parser. It allows for the parser to know a bit 62 /// about the construct it is nested within during parsing. This is used 63 /// specifically to provide additional verification during parsing, e.g. to 64 /// prevent using rewrites within a match context, matcher constraints within 65 /// a rewrite section, etc. 66 enum class ParserContext { 67 /// The parser is in the global context. 68 Global, 69 /// The parser is currently within a Constraint, which disallows all types 70 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 71 Constraint, 72 /// The parser is currently within the matcher portion of a Pattern, which 73 /// is allows a terminal operation rewrite statement but no other rewrite 74 /// transformations. 75 PatternMatch, 76 /// The parser is currently within a Rewrite, which disallows calls to 77 /// constraints, requires operation expressions to have names, etc. 78 Rewrite, 79 }; 80 81 /// The current specification context of an operations result type. This 82 /// indicates how the result types of an operation may be inferred. 83 enum class OpResultTypeContext { 84 /// The result types of the operation are not known to be inferred. 85 Explicit, 86 /// The result types of the operation are inferred from the root input of a 87 /// `replace` statement. 88 Replacement, 89 /// The result types of the operation are inferred by using the 90 /// `InferTypeOpInterface` interface provided by the operation. 91 Interface, 92 }; 93 94 //===--------------------------------------------------------------------===// 95 // Parsing 96 //===--------------------------------------------------------------------===// 97 98 /// Push a new decl scope onto the lexer. 99 ast::DeclScope *pushDeclScope() { 100 ast::DeclScope *newScope = 101 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 102 return (curDeclScope = newScope); 103 } 104 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 105 106 /// Pop the last decl scope from the lexer. 107 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 108 109 /// Parse the body of an AST module. 110 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 111 112 /// Try to convert the given expression to `type`. Returns failure and emits 113 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 114 /// invoked to attach notes to the emitted error diagnostic. On success, 115 /// `expr` is updated to the expression used to convert to `type`. 116 LogicalResult convertExpressionTo( 117 ast::Expr *&expr, ast::Type type, 118 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 119 120 /// Given an operation expression, convert it to a Value or ValueRange 121 /// typed expression. 122 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 123 124 /// Lookup ODS information for the given operation, returns nullptr if no 125 /// information is found. 126 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 127 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 128 } 129 130 /// Process the given documentation string, or return an empty string if 131 /// documentation isn't enabled. 132 StringRef processDoc(StringRef doc) { 133 return enableDocumentation ? doc : StringRef(); 134 } 135 136 /// Process the given documentation string and format it, or return an empty 137 /// string if documentation isn't enabled. 138 std::string processAndFormatDoc(const Twine &doc) { 139 if (!enableDocumentation) 140 return ""; 141 std::string docStr; 142 { 143 llvm::raw_string_ostream docOS(docStr); 144 std::string tmpDocStr = doc.str(); 145 raw_indented_ostream(docOS).printReindented( 146 StringRef(tmpDocStr).rtrim(" \t")); 147 } 148 return docStr; 149 } 150 151 //===--------------------------------------------------------------------===// 152 // Directives 153 154 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 155 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 156 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 157 SmallVectorImpl<ast::Decl *> &decls); 158 159 /// Process the records of a parsed tablegen include file. 160 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 161 SmallVectorImpl<ast::Decl *> &decls); 162 163 /// Create a user defined native constraint for a constraint imported from 164 /// ODS. 165 template <typename ConstraintT> 166 ast::Decl * 167 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 168 SMRange loc, ast::Type type, 169 StringRef nativeType, StringRef docString); 170 template <typename ConstraintT> 171 ast::Decl * 172 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 173 SMRange loc, ast::Type type, 174 StringRef nativeType); 175 176 //===--------------------------------------------------------------------===// 177 // Decls 178 179 /// This structure contains the set of pattern metadata that may be parsed. 180 struct ParsedPatternMetadata { 181 Optional<uint16_t> benefit; 182 bool hasBoundedRecursion = false; 183 }; 184 185 FailureOr<ast::Decl *> parseTopLevelDecl(); 186 FailureOr<ast::NamedAttributeDecl *> 187 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 188 189 /// Parse an argument variable as part of the signature of a 190 /// UserConstraintDecl or UserRewriteDecl. 191 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 192 193 /// Parse a result variable as part of the signature of a UserConstraintDecl 194 /// or UserRewriteDecl. 195 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 196 197 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 198 /// defined in a non-global context. 199 FailureOr<ast::UserConstraintDecl *> 200 parseUserConstraintDecl(bool isInline = false); 201 202 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 203 /// non-global context, such as within a Pattern/Constraint/etc. 204 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 205 206 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 207 /// PDLL constructs. 208 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 209 const ast::Name &name, bool isInline, 210 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 211 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 212 213 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 214 /// defined in a non-global context. 215 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 216 217 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 218 /// non-global context, such as within a Pattern/Rewrite/etc. 219 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 220 221 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 222 /// PDLL constructs. 223 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 224 const ast::Name &name, bool isInline, 225 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 226 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 227 228 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 229 /// effectively the same syntax, and only differ on slight semantics (given 230 /// the different parsing contexts). 231 template <typename T, typename ParseUserPDLLDeclFnT> 232 FailureOr<T *> parseUserConstraintOrRewriteDecl( 233 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 234 StringRef anonymousNamePrefix, bool isInline); 235 236 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 237 /// These decls have effectively the same syntax. 238 template <typename T> 239 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 240 const ast::Name &name, bool isInline, 241 ArrayRef<ast::VariableDecl *> arguments, 242 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 243 244 /// Parse the functional signature (i.e. the arguments and results) of a 245 /// UserConstraintDecl or UserRewriteDecl. 246 LogicalResult parseUserConstraintOrRewriteSignature( 247 SmallVectorImpl<ast::VariableDecl *> &arguments, 248 SmallVectorImpl<ast::VariableDecl *> &results, 249 ast::DeclScope *&argumentScope, ast::Type &resultType); 250 251 /// Validate the return (which if present is specified by bodyIt) of a 252 /// UserConstraintDecl or UserRewriteDecl. 253 LogicalResult validateUserConstraintOrRewriteReturn( 254 StringRef declType, ast::CompoundStmt *body, 255 ArrayRef<ast::Stmt *>::iterator bodyIt, 256 ArrayRef<ast::Stmt *>::iterator bodyE, 257 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 258 259 FailureOr<ast::CompoundStmt *> 260 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 261 bool expectTerminalSemicolon = true); 262 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 263 FailureOr<ast::Decl *> parsePatternDecl(); 264 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 265 266 /// Check to see if a decl has already been defined with the given name, if 267 /// one has emit and error and return failure. Returns success otherwise. 268 LogicalResult checkDefineNamedDecl(const ast::Name &name); 269 270 /// Try to define a variable decl with the given components, returns the 271 /// variable on success. 272 FailureOr<ast::VariableDecl *> 273 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 274 ast::Expr *initExpr, 275 ArrayRef<ast::ConstraintRef> constraints); 276 FailureOr<ast::VariableDecl *> 277 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 278 ArrayRef<ast::ConstraintRef> constraints); 279 280 /// Parse the constraint reference list for a variable decl. 281 LogicalResult parseVariableDeclConstraintList( 282 SmallVectorImpl<ast::ConstraintRef> &constraints); 283 284 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 285 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 286 287 /// Try to parse a single reference to a constraint. `typeConstraint` is the 288 /// location of a previously parsed type constraint for the entity that will 289 /// be constrained by the parsed constraint. `existingConstraints` are any 290 /// existing constraints that have already been parsed for the same entity 291 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 292 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 293 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 294 /// constraints) may be used with the variable. 295 FailureOr<ast::ConstraintRef> 296 parseConstraint(Optional<SMRange> &typeConstraint, 297 ArrayRef<ast::ConstraintRef> existingConstraints, 298 bool allowInlineTypeConstraints, 299 bool allowNonCoreConstraints); 300 301 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 302 /// argument or result variable. The constraints for these variables do not 303 /// allow inline type constraints, and only permit a single constraint. 304 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 305 306 //===--------------------------------------------------------------------===// 307 // Exprs 308 309 FailureOr<ast::Expr *> parseExpr(); 310 311 /// Identifier expressions. 312 FailureOr<ast::Expr *> parseAttributeExpr(); 313 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 314 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 315 FailureOr<ast::Expr *> parseIdentifierExpr(); 316 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 317 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 318 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 319 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 320 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 321 FailureOr<ast::Expr *> 322 parseOperationExpr(OpResultTypeContext inputResultTypeContext = 323 OpResultTypeContext::Explicit); 324 FailureOr<ast::Expr *> parseTupleExpr(); 325 FailureOr<ast::Expr *> parseTypeExpr(); 326 FailureOr<ast::Expr *> parseUnderscoreExpr(); 327 328 //===--------------------------------------------------------------------===// 329 // Stmts 330 331 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 332 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 333 FailureOr<ast::EraseStmt *> parseEraseStmt(); 334 FailureOr<ast::LetStmt *> parseLetStmt(); 335 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 336 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 337 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 338 339 //===--------------------------------------------------------------------===// 340 // Creation+Analysis 341 //===--------------------------------------------------------------------===// 342 343 //===--------------------------------------------------------------------===// 344 // Decls 345 346 /// Try to extract a callable from the given AST node. Returns nullptr on 347 /// failure. 348 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 349 350 /// Try to create a pattern decl with the given components, returning the 351 /// Pattern on success. 352 FailureOr<ast::PatternDecl *> 353 createPatternDecl(SMRange loc, const ast::Name *name, 354 const ParsedPatternMetadata &metadata, 355 ast::CompoundStmt *body); 356 357 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 358 /// of results, defined as part of the signature. 359 ast::Type 360 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 361 362 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 363 template <typename T> 364 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 365 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 366 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 367 ast::CompoundStmt *body); 368 369 /// Try to create a variable decl with the given components, returning the 370 /// Variable on success. 371 FailureOr<ast::VariableDecl *> 372 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 373 ArrayRef<ast::ConstraintRef> constraints); 374 375 /// Create a variable for an argument or result defined as part of the 376 /// signature of a UserConstraintDecl/UserRewriteDecl. 377 FailureOr<ast::VariableDecl *> 378 createArgOrResultVariableDecl(StringRef name, SMRange loc, 379 const ast::ConstraintRef &constraint); 380 381 /// Validate the constraints used to constraint a variable decl. 382 /// `inferredType` is the type of the variable inferred by the constraints 383 /// within the list, and is updated to the most refined type as determined by 384 /// the constraints. Returns success if the constraint list is valid, failure 385 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 386 /// defined constraints) may be used with the variable. 387 LogicalResult 388 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 389 ast::Type &inferredType, 390 bool allowNonCoreConstraints = true); 391 /// Validate a single reference to a constraint. `inferredType` contains the 392 /// currently inferred variabled type and is refined within the type defined 393 /// by the constraint. Returns success if the constraint is valid, failure 394 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 395 /// defined constraints) may be used with the variable. 396 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 397 ast::Type &inferredType, 398 bool allowNonCoreConstraints = true); 399 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 400 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 401 402 //===--------------------------------------------------------------------===// 403 // Exprs 404 405 FailureOr<ast::CallExpr *> 406 createCallExpr(SMRange loc, ast::Expr *parentExpr, 407 MutableArrayRef<ast::Expr *> arguments); 408 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 409 FailureOr<ast::DeclRefExpr *> 410 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 411 ArrayRef<ast::ConstraintRef> constraints); 412 FailureOr<ast::MemberAccessExpr *> 413 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 414 415 /// Validate the member access `name` into the given parent expression. On 416 /// success, this also returns the type of the member accessed. 417 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 418 StringRef name, SMRange loc); 419 FailureOr<ast::OperationExpr *> 420 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 421 OpResultTypeContext resultTypeContext, 422 MutableArrayRef<ast::Expr *> operands, 423 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 424 MutableArrayRef<ast::Expr *> results); 425 LogicalResult 426 validateOperationOperands(SMRange loc, Optional<StringRef> name, 427 const ods::Operation *odsOp, 428 MutableArrayRef<ast::Expr *> operands); 429 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 430 const ods::Operation *odsOp, 431 MutableArrayRef<ast::Expr *> results); 432 void checkOperationResultTypeInferrence(SMRange loc, StringRef name, 433 const ods::Operation *odsOp); 434 LogicalResult validateOperationOperandsOrResults( 435 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 436 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 437 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 438 ast::Type rangeTy); 439 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 440 ArrayRef<ast::Expr *> elements, 441 ArrayRef<StringRef> elementNames); 442 443 //===--------------------------------------------------------------------===// 444 // Stmts 445 446 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 447 FailureOr<ast::ReplaceStmt *> 448 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 449 MutableArrayRef<ast::Expr *> replValues); 450 FailureOr<ast::RewriteStmt *> 451 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 452 ast::CompoundStmt *rewriteBody); 453 454 //===--------------------------------------------------------------------===// 455 // Code Completion 456 //===--------------------------------------------------------------------===// 457 458 /// The set of various code completion methods. Every completion method 459 /// returns `failure` to stop the parsing process after providing completion 460 /// results. 461 462 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 463 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 464 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 465 bool allowNonCoreConstraints, 466 bool allowInlineTypeConstraints); 467 LogicalResult codeCompleteDialectName(); 468 LogicalResult codeCompleteOperationName(StringRef dialectName); 469 LogicalResult codeCompletePatternMetadata(); 470 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 471 472 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 473 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 474 unsigned currentNumOperands); 475 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 476 unsigned currentNumResults); 477 478 //===--------------------------------------------------------------------===// 479 // Lexer Utilities 480 //===--------------------------------------------------------------------===// 481 482 /// If the current token has the specified kind, consume it and return true. 483 /// If not, return false. 484 bool consumeIf(Token::Kind kind) { 485 if (curToken.isNot(kind)) 486 return false; 487 consumeToken(kind); 488 return true; 489 } 490 491 /// Advance the current lexer onto the next token. 492 void consumeToken() { 493 assert(curToken.isNot(Token::eof, Token::error) && 494 "shouldn't advance past EOF or errors"); 495 curToken = lexer.lexToken(); 496 } 497 498 /// Advance the current lexer onto the next token, asserting what the expected 499 /// current token is. This is preferred to the above method because it leads 500 /// to more self-documenting code with better checking. 501 void consumeToken(Token::Kind kind) { 502 assert(curToken.is(kind) && "consumed an unexpected token"); 503 consumeToken(); 504 } 505 506 /// Reset the lexer to the location at the given position. 507 void resetToken(SMRange tokLoc) { 508 lexer.resetPointer(tokLoc.Start.getPointer()); 509 curToken = lexer.lexToken(); 510 } 511 512 /// Consume the specified token if present and return success. On failure, 513 /// output a diagnostic and return failure. 514 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 515 if (curToken.getKind() != kind) 516 return emitError(curToken.getLoc(), msg); 517 consumeToken(); 518 return success(); 519 } 520 LogicalResult emitError(SMRange loc, const Twine &msg) { 521 lexer.emitError(loc, msg); 522 return failure(); 523 } 524 LogicalResult emitError(const Twine &msg) { 525 return emitError(curToken.getLoc(), msg); 526 } 527 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 528 const Twine ¬e) { 529 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 530 return failure(); 531 } 532 533 //===--------------------------------------------------------------------===// 534 // Fields 535 //===--------------------------------------------------------------------===// 536 537 /// The owning AST context. 538 ast::Context &ctx; 539 540 /// The lexer of this parser. 541 Lexer lexer; 542 543 /// The current token within the lexer. 544 Token curToken; 545 546 /// A flag indicating if the parser should add documentation to AST nodes when 547 /// viable. 548 bool enableDocumentation; 549 550 /// The most recently defined decl scope. 551 ast::DeclScope *curDeclScope = nullptr; 552 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 553 554 /// The current context of the parser. 555 ParserContext parserContext = ParserContext::Global; 556 557 /// Cached types to simplify verification and expression creation. 558 ast::Type valueTy, valueRangeTy; 559 ast::Type typeTy, typeRangeTy; 560 ast::Type attrTy; 561 562 /// A counter used when naming anonymous constraints and rewrites. 563 unsigned anonymousDeclNameCounter = 0; 564 565 /// The optional code completion context. 566 CodeCompleteContext *codeCompleteContext; 567 }; 568 } // namespace 569 570 FailureOr<ast::Module *> Parser::parseModule() { 571 SMLoc moduleLoc = curToken.getStartLoc(); 572 pushDeclScope(); 573 574 // Parse the top-level decls of the module. 575 SmallVector<ast::Decl *> decls; 576 if (failed(parseModuleBody(decls))) 577 return popDeclScope(), failure(); 578 579 popDeclScope(); 580 return ast::Module::create(ctx, moduleLoc, decls); 581 } 582 583 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 584 while (curToken.isNot(Token::eof)) { 585 if (curToken.is(Token::directive)) { 586 if (failed(parseDirective(decls))) 587 return failure(); 588 continue; 589 } 590 591 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 592 if (failed(decl)) 593 return failure(); 594 decls.push_back(*decl); 595 } 596 return success(); 597 } 598 599 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 600 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 601 valueRangeTy); 602 } 603 604 LogicalResult Parser::convertExpressionTo( 605 ast::Expr *&expr, ast::Type type, 606 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 607 ast::Type exprType = expr->getType(); 608 if (exprType == type) 609 return success(); 610 611 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 612 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 613 expr->getLoc(), llvm::formatv("unable to convert expression of type " 614 "`{0}` to the expected type of " 615 "`{1}`", 616 exprType, type)); 617 if (noteAttachFn) 618 noteAttachFn(*diag); 619 return diag; 620 }; 621 622 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 623 // Two operation types are compatible if they have the same name, or if the 624 // expected type is more general. 625 if (auto opType = type.dyn_cast<ast::OperationType>()) { 626 if (opType.getName()) 627 return emitConvertError(); 628 return success(); 629 } 630 631 // An operation can always convert to a ValueRange. 632 if (type == valueRangeTy) { 633 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 634 valueRangeTy); 635 return success(); 636 } 637 638 // Allow conversion to a single value by constraining the result range. 639 if (type == valueTy) { 640 // If the operation is registered, we can verify if it can ever have a 641 // single result. 642 if (const ods::Operation *odsOp = exprOpType.getODSOperation()) { 643 if (odsOp->getResults().empty()) { 644 return emitConvertError()->attachNote( 645 llvm::formatv("see the definition of `{0}`, which was defined " 646 "with zero results", 647 odsOp->getName()), 648 odsOp->getLoc()); 649 } 650 651 unsigned numSingleResults = llvm::count_if( 652 odsOp->getResults(), [](const ods::OperandOrResult &result) { 653 return result.getVariableLengthKind() == 654 ods::VariableLengthKind::Single; 655 }); 656 if (numSingleResults > 1) { 657 return emitConvertError()->attachNote( 658 llvm::formatv("see the definition of `{0}`, which was defined " 659 "with at least {1} results", 660 odsOp->getName(), numSingleResults), 661 odsOp->getLoc()); 662 } 663 } 664 665 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 666 valueTy); 667 return success(); 668 } 669 return emitConvertError(); 670 } 671 672 // FIXME: Decide how to allow/support converting a single result to multiple, 673 // and multiple to a single result. For now, we just allow Single->Range, 674 // but this isn't something really supported in the PDL dialect. We should 675 // figure out some way to support both. 676 if ((exprType == valueTy || exprType == valueRangeTy) && 677 (type == valueTy || type == valueRangeTy)) 678 return success(); 679 if ((exprType == typeTy || exprType == typeRangeTy) && 680 (type == typeTy || type == typeRangeTy)) 681 return success(); 682 683 // Handle tuple types. 684 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 685 auto tupleType = type.dyn_cast<ast::TupleType>(); 686 if (!tupleType || tupleType.size() != exprTupleType.size()) 687 return emitConvertError(); 688 689 // Build a new tuple expression using each of the elements of the current 690 // tuple. 691 SmallVector<ast::Expr *> newExprs; 692 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 693 newExprs.push_back(ast::MemberAccessExpr::create( 694 ctx, expr->getLoc(), expr, llvm::to_string(i), 695 exprTupleType.getElementTypes()[i])); 696 697 auto diagFn = [&](ast::Diagnostic &diag) { 698 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 699 i, exprTupleType)); 700 if (noteAttachFn) 701 noteAttachFn(diag); 702 }; 703 if (failed(convertExpressionTo(newExprs.back(), 704 tupleType.getElementTypes()[i], diagFn))) 705 return failure(); 706 } 707 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 708 tupleType.getElementNames()); 709 return success(); 710 } 711 712 return emitConvertError(); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Directives 717 718 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 719 StringRef directive = curToken.getSpelling(); 720 if (directive == "#include") 721 return parseInclude(decls); 722 723 return emitError("unknown directive `" + directive + "`"); 724 } 725 726 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 727 SMRange loc = curToken.getLoc(); 728 consumeToken(Token::directive); 729 730 // Handle code completion of the include file path. 731 if (curToken.is(Token::code_complete_string)) 732 return codeCompleteIncludeFilename(curToken.getStringValue()); 733 734 // Parse the file being included. 735 if (!curToken.isString()) 736 return emitError(loc, 737 "expected string file name after `include` directive"); 738 SMRange fileLoc = curToken.getLoc(); 739 std::string filenameStr = curToken.getStringValue(); 740 StringRef filename = filenameStr; 741 consumeToken(); 742 743 // Check the type of include. If ending with `.pdll`, this is another pdl file 744 // to be parsed along with the current module. 745 if (filename.endswith(".pdll")) { 746 if (failed(lexer.pushInclude(filename, fileLoc))) 747 return emitError(fileLoc, 748 "unable to open include file `" + filename + "`"); 749 750 // If we added the include successfully, parse it into the current module. 751 // Make sure to update to the next token after we finish parsing the nested 752 // file. 753 curToken = lexer.lexToken(); 754 LogicalResult result = parseModuleBody(decls); 755 curToken = lexer.lexToken(); 756 return result; 757 } 758 759 // Otherwise, this must be a `.td` include. 760 if (filename.endswith(".td")) 761 return parseTdInclude(filename, fileLoc, decls); 762 763 return emitError(fileLoc, 764 "expected include filename to end with `.pdll` or `.td`"); 765 } 766 767 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 768 SmallVectorImpl<ast::Decl *> &decls) { 769 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 770 771 // Use the source manager to open the file, but don't yet add it. 772 std::string includedFile; 773 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 774 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 775 if (!includeBuffer) 776 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 777 778 // Setup the source manager for parsing the tablegen file. 779 llvm::SourceMgr tdSrcMgr; 780 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 781 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 782 783 // This class provides a context argument for the llvm::SourceMgr diagnostic 784 // handler. 785 struct DiagHandlerContext { 786 Parser &parser; 787 StringRef filename; 788 llvm::SMRange loc; 789 } handlerContext{*this, filename, fileLoc}; 790 791 // Set the diagnostic handler for the tablegen source manager. 792 tdSrcMgr.setDiagHandler( 793 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 794 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 795 (void)ctx->parser.emitError( 796 ctx->loc, 797 llvm::formatv("error while processing include file `{0}`: {1}", 798 ctx->filename, diag.getMessage())); 799 }, 800 &handlerContext); 801 802 // Parse the tablegen file. 803 llvm::RecordKeeper tdRecords; 804 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 805 return failure(); 806 807 // Process the parsed records. 808 processTdIncludeRecords(tdRecords, decls); 809 810 // After we are done processing, move all of the tablegen source buffers to 811 // the main parser source mgr. This allows for directly using source locations 812 // from the .td files without needing to remap them. 813 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 814 return success(); 815 } 816 817 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 818 SmallVectorImpl<ast::Decl *> &decls) { 819 // Return the length kind of the given value. 820 auto getLengthKind = [](const auto &value) { 821 if (value.isOptional()) 822 return ods::VariableLengthKind::Optional; 823 return value.isVariadic() ? ods::VariableLengthKind::Variadic 824 : ods::VariableLengthKind::Single; 825 }; 826 827 // Insert a type constraint into the ODS context. 828 ods::Context &odsContext = ctx.getODSContext(); 829 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 830 -> const ods::TypeConstraint & { 831 return odsContext.insertTypeConstraint( 832 cst.constraint.getUniqueDefName(), 833 processDoc(cst.constraint.getSummary()), 834 cst.constraint.getCPPClassName()); 835 }; 836 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 837 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 838 }; 839 840 // Process the parsed tablegen records to build ODS information. 841 /// Operations. 842 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 843 tblgen::Operator op(def); 844 845 // Check to see if this operation is known to support type inferrence. 846 bool supportsResultTypeInferrence = 847 op.getTrait("::mlir::InferTypeOpInterface::Trait"); 848 849 bool inserted = false; 850 ods::Operation *odsOp = nullptr; 851 std::tie(odsOp, inserted) = odsContext.insertOperation( 852 op.getOperationName(), processDoc(op.getSummary()), 853 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(), 854 supportsResultTypeInferrence, op.getLoc().front()); 855 856 // Ignore operations that have already been added. 857 if (!inserted) 858 continue; 859 860 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 861 odsOp->appendAttribute(attr.name, attr.attr.isOptional(), 862 odsContext.insertAttributeConstraint( 863 attr.attr.getUniqueDefName(), 864 processDoc(attr.attr.getSummary()), 865 attr.attr.getStorageType())); 866 } 867 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 868 odsOp->appendOperand(operand.name, getLengthKind(operand), 869 addTypeConstraint(operand)); 870 } 871 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 872 odsOp->appendResult(result.name, getLengthKind(result), 873 addTypeConstraint(result)); 874 } 875 } 876 /// Attr constraints. 877 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 878 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 879 tblgen::Attribute constraint(def); 880 decls.push_back( 881 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 882 constraint, convertLocToRange(def->getLoc().front()), attrTy, 883 constraint.getStorageType())); 884 } 885 } 886 /// Type constraints. 887 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 888 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 889 tblgen::TypeConstraint constraint(def); 890 decls.push_back( 891 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 892 constraint, convertLocToRange(def->getLoc().front()), typeTy, 893 constraint.getCPPClassName())); 894 } 895 } 896 /// Interfaces. 897 ast::Type opTy = ast::OperationType::get(ctx); 898 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 899 StringRef name = def->getName(); 900 if (def->isAnonymous() || curDeclScope->lookup(name) || 901 def->isSubClassOf("DeclareInterfaceMethods")) 902 continue; 903 SMRange loc = convertLocToRange(def->getLoc().front()); 904 905 std::string cppClassName = 906 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), 907 def->getValueAsString("cppClassName")) 908 .str(); 909 std::string codeBlock = 910 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", 911 cppClassName) 912 .str(); 913 914 std::string desc = 915 processAndFormatDoc(def->getValueAsString("description")); 916 if (def->isSubClassOf("OpInterface")) { 917 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 918 name, codeBlock, loc, opTy, cppClassName, desc)); 919 } else if (def->isSubClassOf("AttrInterface")) { 920 decls.push_back( 921 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 922 name, codeBlock, loc, attrTy, cppClassName, desc)); 923 } else if (def->isSubClassOf("TypeInterface")) { 924 decls.push_back( 925 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 926 name, codeBlock, loc, typeTy, cppClassName, desc)); 927 } 928 } 929 } 930 931 template <typename ConstraintT> 932 ast::Decl *Parser::createODSNativePDLLConstraintDecl( 933 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type, 934 StringRef nativeType, StringRef docString) { 935 // Build the single input parameter. 936 ast::DeclScope *argScope = pushDeclScope(); 937 auto *paramVar = ast::VariableDecl::create( 938 ctx, ast::Name::create(ctx, "self", loc), type, 939 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 940 argScope->add(paramVar); 941 popDeclScope(); 942 943 // Build the native constraint. 944 auto *constraintDecl = ast::UserConstraintDecl::createNative( 945 ctx, ast::Name::create(ctx, name, loc), paramVar, 946 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType); 947 constraintDecl->setDocComment(ctx, docString); 948 curDeclScope->add(constraintDecl); 949 return constraintDecl; 950 } 951 952 template <typename ConstraintT> 953 ast::Decl * 954 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 955 SMRange loc, ast::Type type, 956 StringRef nativeType) { 957 // Format the condition template. 958 tblgen::FmtContext fmtContext; 959 fmtContext.withSelf("self"); 960 std::string codeBlock = tblgen::tgfmt( 961 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 962 &fmtContext); 963 964 // If documentation was enabled, build the doc string for the generated 965 // constraint. It would be nice to do this lazily, but TableGen information is 966 // destroyed after we finish parsing the file. 967 std::string docString; 968 if (enableDocumentation) { 969 StringRef desc = constraint.getDescription(); 970 docString = processAndFormatDoc( 971 constraint.getSummary() + 972 (desc.empty() ? "" : ("\n\n" + constraint.getDescription()))); 973 } 974 975 return createODSNativePDLLConstraintDecl<ConstraintT>( 976 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType, 977 docString); 978 } 979 980 //===----------------------------------------------------------------------===// 981 // Decls 982 983 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 984 FailureOr<ast::Decl *> decl; 985 switch (curToken.getKind()) { 986 case Token::kw_Constraint: 987 decl = parseUserConstraintDecl(); 988 break; 989 case Token::kw_Pattern: 990 decl = parsePatternDecl(); 991 break; 992 case Token::kw_Rewrite: 993 decl = parseUserRewriteDecl(); 994 break; 995 default: 996 return emitError("expected top-level declaration, such as a `Pattern`"); 997 } 998 if (failed(decl)) 999 return failure(); 1000 1001 // If the decl has a name, add it to the current scope. 1002 if (const ast::Name *name = (*decl)->getName()) { 1003 if (failed(checkDefineNamedDecl(*name))) 1004 return failure(); 1005 curDeclScope->add(*decl); 1006 } 1007 return decl; 1008 } 1009 1010 FailureOr<ast::NamedAttributeDecl *> 1011 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 1012 // Check for name code completion. 1013 if (curToken.is(Token::code_complete)) 1014 return codeCompleteAttributeName(parentOpName); 1015 1016 std::string attrNameStr; 1017 if (curToken.isString()) 1018 attrNameStr = curToken.getStringValue(); 1019 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 1020 attrNameStr = curToken.getSpelling().str(); 1021 else 1022 return emitError("expected identifier or string attribute name"); 1023 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 1024 consumeToken(); 1025 1026 // Check for a value of the attribute. 1027 ast::Expr *attrValue = nullptr; 1028 if (consumeIf(Token::equal)) { 1029 FailureOr<ast::Expr *> attrExpr = parseExpr(); 1030 if (failed(attrExpr)) 1031 return failure(); 1032 attrValue = *attrExpr; 1033 } else { 1034 // If there isn't a concrete value, create an expression representing a 1035 // UnitAttr. 1036 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 1037 } 1038 1039 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 1040 } 1041 1042 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 1043 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 1044 bool expectTerminalSemicolon) { 1045 consumeToken(Token::equal_arrow); 1046 1047 // Parse the single statement of the lambda body. 1048 SMLoc bodyStartLoc = curToken.getStartLoc(); 1049 pushDeclScope(); 1050 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 1051 bool failedToParse = 1052 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 1053 popDeclScope(); 1054 if (failedToParse) 1055 return failure(); 1056 1057 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 1058 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 1059 } 1060 1061 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 1062 // Ensure that the argument is named. 1063 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 1064 return emitError("expected identifier argument name"); 1065 1066 // Parse the argument similarly to a normal variable. 1067 StringRef name = curToken.getSpelling(); 1068 SMRange nameLoc = curToken.getLoc(); 1069 consumeToken(); 1070 1071 if (failed( 1072 parseToken(Token::colon, "expected `:` before argument constraint"))) 1073 return failure(); 1074 1075 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1076 if (failed(cst)) 1077 return failure(); 1078 1079 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1080 } 1081 1082 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1083 // Check to see if this result is named. 1084 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1085 // Check to see if this name actually refers to a Constraint. 1086 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1087 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1088 // If yes, and this is a Rewrite, give a nice error message as non-Core 1089 // constraints are not supported on Rewrite results. 1090 if (parserContext == ParserContext::Rewrite) { 1091 return emitError( 1092 "`Rewrite` results are only permitted to use core constraints, " 1093 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1094 } 1095 1096 // Otherwise, parse this as an unnamed result variable. 1097 } else { 1098 // If it wasn't a constraint, parse the result similarly to a variable. If 1099 // there is already an existing decl, we will emit an error when defining 1100 // this variable later. 1101 StringRef name = curToken.getSpelling(); 1102 SMRange nameLoc = curToken.getLoc(); 1103 consumeToken(); 1104 1105 if (failed(parseToken(Token::colon, 1106 "expected `:` before result constraint"))) 1107 return failure(); 1108 1109 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1110 if (failed(cst)) 1111 return failure(); 1112 1113 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1114 } 1115 } 1116 1117 // If it isn't named, we parse the constraint directly and create an unnamed 1118 // result variable. 1119 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1120 if (failed(cst)) 1121 return failure(); 1122 1123 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1124 } 1125 1126 FailureOr<ast::UserConstraintDecl *> 1127 Parser::parseUserConstraintDecl(bool isInline) { 1128 // Constraints and rewrites have very similar formats, dispatch to a shared 1129 // interface for parsing. 1130 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1131 [&](auto &&...args) { 1132 return this->parseUserPDLLConstraintDecl(args...); 1133 }, 1134 ParserContext::Constraint, "constraint", isInline); 1135 } 1136 1137 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1138 FailureOr<ast::UserConstraintDecl *> decl = 1139 parseUserConstraintDecl(/*isInline=*/true); 1140 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1141 return failure(); 1142 1143 curDeclScope->add(*decl); 1144 return decl; 1145 } 1146 1147 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1148 const ast::Name &name, bool isInline, 1149 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1150 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1151 // Push the argument scope back onto the list, so that the body can 1152 // reference arguments. 1153 pushDeclScope(argumentScope); 1154 1155 // Parse the body of the constraint. The body is either defined as a compound 1156 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1157 ast::CompoundStmt *body; 1158 if (curToken.is(Token::equal_arrow)) { 1159 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1160 [&](ast::Stmt *&stmt) -> LogicalResult { 1161 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1162 if (!stmtExpr) { 1163 return emitError(stmt->getLoc(), 1164 "expected `Constraint` lambda body to contain a " 1165 "single expression"); 1166 } 1167 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1168 return success(); 1169 }, 1170 /*expectTerminalSemicolon=*/!isInline); 1171 if (failed(bodyResult)) 1172 return failure(); 1173 body = *bodyResult; 1174 } else { 1175 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1176 if (failed(bodyResult)) 1177 return failure(); 1178 body = *bodyResult; 1179 1180 // Verify the structure of the body. 1181 auto bodyIt = body->begin(), bodyE = body->end(); 1182 for (; bodyIt != bodyE; ++bodyIt) 1183 if (isa<ast::ReturnStmt>(*bodyIt)) 1184 break; 1185 if (failed(validateUserConstraintOrRewriteReturn( 1186 "Constraint", body, bodyIt, bodyE, results, resultType))) 1187 return failure(); 1188 } 1189 popDeclScope(); 1190 1191 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1192 name, arguments, results, resultType, body); 1193 } 1194 1195 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1196 // Constraints and rewrites have very similar formats, dispatch to a shared 1197 // interface for parsing. 1198 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1199 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1200 ParserContext::Rewrite, "rewrite", isInline); 1201 } 1202 1203 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1204 FailureOr<ast::UserRewriteDecl *> decl = 1205 parseUserRewriteDecl(/*isInline=*/true); 1206 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1207 return failure(); 1208 1209 curDeclScope->add(*decl); 1210 return decl; 1211 } 1212 1213 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1214 const ast::Name &name, bool isInline, 1215 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1216 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1217 // Push the argument scope back onto the list, so that the body can 1218 // reference arguments. 1219 curDeclScope = argumentScope; 1220 ast::CompoundStmt *body; 1221 if (curToken.is(Token::equal_arrow)) { 1222 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1223 [&](ast::Stmt *&statement) -> LogicalResult { 1224 if (isa<ast::OpRewriteStmt>(statement)) 1225 return success(); 1226 1227 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1228 if (!statementExpr) { 1229 return emitError( 1230 statement->getLoc(), 1231 "expected `Rewrite` lambda body to contain a single expression " 1232 "or an operation rewrite statement; such as `erase`, " 1233 "`replace`, or `rewrite`"); 1234 } 1235 statement = 1236 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1237 return success(); 1238 }, 1239 /*expectTerminalSemicolon=*/!isInline); 1240 if (failed(bodyResult)) 1241 return failure(); 1242 body = *bodyResult; 1243 } else { 1244 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1245 if (failed(bodyResult)) 1246 return failure(); 1247 body = *bodyResult; 1248 } 1249 popDeclScope(); 1250 1251 // Verify the structure of the body. 1252 auto bodyIt = body->begin(), bodyE = body->end(); 1253 for (; bodyIt != bodyE; ++bodyIt) 1254 if (isa<ast::ReturnStmt>(*bodyIt)) 1255 break; 1256 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1257 bodyE, results, resultType))) 1258 return failure(); 1259 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1260 name, arguments, results, resultType, body); 1261 } 1262 1263 template <typename T, typename ParseUserPDLLDeclFnT> 1264 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1265 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1266 StringRef anonymousNamePrefix, bool isInline) { 1267 SMRange loc = curToken.getLoc(); 1268 consumeToken(); 1269 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1270 1271 // Parse the name of the decl. 1272 const ast::Name *name = nullptr; 1273 if (curToken.isNot(Token::identifier)) { 1274 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1275 // in C++, so being unnamed is fine. 1276 if (!isInline) 1277 return emitError("expected identifier name"); 1278 1279 // Create a unique anonymous name to use, as the name for this decl is not 1280 // important. 1281 std::string anonName = 1282 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1283 anonymousDeclNameCounter++) 1284 .str(); 1285 name = &ast::Name::create(ctx, anonName, loc); 1286 } else { 1287 // If a name was provided, we can use it directly. 1288 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1289 consumeToken(Token::identifier); 1290 } 1291 1292 // Parse the functional signature of the decl. 1293 SmallVector<ast::VariableDecl *> arguments, results; 1294 ast::DeclScope *argumentScope; 1295 ast::Type resultType; 1296 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1297 argumentScope, resultType))) 1298 return failure(); 1299 1300 // Check to see which type of constraint this is. If the constraint contains a 1301 // compound body, this is a PDLL decl. 1302 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1303 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1304 resultType); 1305 1306 // Otherwise, this is a native decl. 1307 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1308 results, resultType); 1309 } 1310 1311 template <typename T> 1312 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1313 const ast::Name &name, bool isInline, 1314 ArrayRef<ast::VariableDecl *> arguments, 1315 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1316 // If followed by a string, the native code body has also been specified. 1317 std::string codeStrStorage; 1318 Optional<StringRef> optCodeStr; 1319 if (curToken.isString()) { 1320 codeStrStorage = curToken.getStringValue(); 1321 optCodeStr = codeStrStorage; 1322 consumeToken(); 1323 } else if (isInline) { 1324 return emitError(name.getLoc(), 1325 "external declarations must be declared in global scope"); 1326 } else if (curToken.is(Token::error)) { 1327 return failure(); 1328 } 1329 if (failed(parseToken(Token::semicolon, 1330 "expected `;` after native declaration"))) 1331 return failure(); 1332 // TODO: PDL should be able to support constraint results in certain 1333 // situations, we should revise this. 1334 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1335 return emitError( 1336 "native Constraints currently do not support returning results"); 1337 } 1338 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1339 } 1340 1341 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1342 SmallVectorImpl<ast::VariableDecl *> &arguments, 1343 SmallVectorImpl<ast::VariableDecl *> &results, 1344 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1345 // Parse the argument list of the decl. 1346 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1347 return failure(); 1348 1349 argumentScope = pushDeclScope(); 1350 if (curToken.isNot(Token::r_paren)) { 1351 do { 1352 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1353 if (failed(argument)) 1354 return failure(); 1355 arguments.emplace_back(*argument); 1356 } while (consumeIf(Token::comma)); 1357 } 1358 popDeclScope(); 1359 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1360 return failure(); 1361 1362 // Parse the results of the decl. 1363 pushDeclScope(); 1364 if (consumeIf(Token::arrow)) { 1365 auto parseResultFn = [&]() -> LogicalResult { 1366 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1367 if (failed(result)) 1368 return failure(); 1369 results.emplace_back(*result); 1370 return success(); 1371 }; 1372 1373 // Check for a list of results. 1374 if (consumeIf(Token::l_paren)) { 1375 do { 1376 if (failed(parseResultFn())) 1377 return failure(); 1378 } while (consumeIf(Token::comma)); 1379 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1380 return failure(); 1381 1382 // Otherwise, there is only one result. 1383 } else if (failed(parseResultFn())) { 1384 return failure(); 1385 } 1386 } 1387 popDeclScope(); 1388 1389 // Compute the result type of the decl. 1390 resultType = createUserConstraintRewriteResultType(results); 1391 1392 // Verify that results are only named if there are more than one. 1393 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1394 return emitError( 1395 results.front()->getLoc(), 1396 "cannot create a single-element tuple with an element label"); 1397 } 1398 return success(); 1399 } 1400 1401 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1402 StringRef declType, ast::CompoundStmt *body, 1403 ArrayRef<ast::Stmt *>::iterator bodyIt, 1404 ArrayRef<ast::Stmt *>::iterator bodyE, 1405 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1406 // Handle if a `return` was provided. 1407 if (bodyIt != bodyE) { 1408 // Emit an error if we have trailing statements after the return. 1409 if (std::next(bodyIt) != bodyE) { 1410 return emitError( 1411 (*std::next(bodyIt))->getLoc(), 1412 llvm::formatv("`return` terminated the `{0}` body, but found " 1413 "trailing statements afterwards", 1414 declType)); 1415 } 1416 1417 // Otherwise if a return wasn't provided, check that no results are 1418 // expected. 1419 } else if (!results.empty()) { 1420 return emitError( 1421 {body->getLoc().End, body->getLoc().End}, 1422 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1423 declType, resultType)); 1424 } 1425 return success(); 1426 } 1427 1428 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1429 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1430 if (isa<ast::OpRewriteStmt>(statement)) 1431 return success(); 1432 return emitError( 1433 statement->getLoc(), 1434 "expected Pattern lambda body to contain a single operation " 1435 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1436 }); 1437 } 1438 1439 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1440 SMRange loc = curToken.getLoc(); 1441 consumeToken(Token::kw_Pattern); 1442 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1443 ParserContext::PatternMatch); 1444 1445 // Check for an optional identifier for the pattern name. 1446 const ast::Name *name = nullptr; 1447 if (curToken.is(Token::identifier)) { 1448 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1449 consumeToken(Token::identifier); 1450 } 1451 1452 // Parse any pattern metadata. 1453 ParsedPatternMetadata metadata; 1454 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1455 return failure(); 1456 1457 // Parse the pattern body. 1458 ast::CompoundStmt *body; 1459 1460 // Handle a lambda body. 1461 if (curToken.is(Token::equal_arrow)) { 1462 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1463 if (failed(bodyResult)) 1464 return failure(); 1465 body = *bodyResult; 1466 } else { 1467 if (curToken.isNot(Token::l_brace)) 1468 return emitError("expected `{` or `=>` to start pattern body"); 1469 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1470 if (failed(bodyResult)) 1471 return failure(); 1472 body = *bodyResult; 1473 1474 // Verify the body of the pattern. 1475 auto bodyIt = body->begin(), bodyE = body->end(); 1476 for (; bodyIt != bodyE; ++bodyIt) { 1477 if (isa<ast::ReturnStmt>(*bodyIt)) { 1478 return emitError((*bodyIt)->getLoc(), 1479 "`return` statements are only permitted within a " 1480 "`Constraint` or `Rewrite` body"); 1481 } 1482 // Break when we've found the rewrite statement. 1483 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1484 break; 1485 } 1486 if (bodyIt == bodyE) { 1487 return emitError(loc, 1488 "expected Pattern body to terminate with an operation " 1489 "rewrite statement, such as `erase`"); 1490 } 1491 if (std::next(bodyIt) != bodyE) { 1492 return emitError((*std::next(bodyIt))->getLoc(), 1493 "Pattern body was terminated by an operation " 1494 "rewrite statement, but found trailing statements"); 1495 } 1496 } 1497 1498 return createPatternDecl(loc, name, metadata, body); 1499 } 1500 1501 LogicalResult 1502 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1503 Optional<SMRange> benefitLoc; 1504 Optional<SMRange> hasBoundedRecursionLoc; 1505 1506 do { 1507 // Handle metadata code completion. 1508 if (curToken.is(Token::code_complete)) 1509 return codeCompletePatternMetadata(); 1510 1511 if (curToken.isNot(Token::identifier)) 1512 return emitError("expected pattern metadata identifier"); 1513 StringRef metadataStr = curToken.getSpelling(); 1514 SMRange metadataLoc = curToken.getLoc(); 1515 consumeToken(Token::identifier); 1516 1517 // Parse the benefit metadata: benefit(<integer-value>) 1518 if (metadataStr == "benefit") { 1519 if (benefitLoc) { 1520 return emitErrorAndNote(metadataLoc, 1521 "pattern benefit has already been specified", 1522 *benefitLoc, "see previous definition here"); 1523 } 1524 if (failed(parseToken(Token::l_paren, 1525 "expected `(` before pattern benefit"))) 1526 return failure(); 1527 1528 uint16_t benefitValue = 0; 1529 if (curToken.isNot(Token::integer)) 1530 return emitError("expected integral pattern benefit"); 1531 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1532 return emitError( 1533 "expected pattern benefit to fit within a 16-bit integer"); 1534 consumeToken(Token::integer); 1535 1536 metadata.benefit = benefitValue; 1537 benefitLoc = metadataLoc; 1538 1539 if (failed( 1540 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1541 return failure(); 1542 continue; 1543 } 1544 1545 // Parse the bounded recursion metadata: recursion 1546 if (metadataStr == "recursion") { 1547 if (hasBoundedRecursionLoc) { 1548 return emitErrorAndNote( 1549 metadataLoc, 1550 "pattern recursion metadata has already been specified", 1551 *hasBoundedRecursionLoc, "see previous definition here"); 1552 } 1553 metadata.hasBoundedRecursion = true; 1554 hasBoundedRecursionLoc = metadataLoc; 1555 continue; 1556 } 1557 1558 return emitError(metadataLoc, "unknown pattern metadata"); 1559 } while (consumeIf(Token::comma)); 1560 1561 return success(); 1562 } 1563 1564 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1565 consumeToken(Token::less); 1566 1567 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1568 if (failed(typeExpr) || 1569 failed(parseToken(Token::greater, 1570 "expected `>` after variable type constraint"))) 1571 return failure(); 1572 return typeExpr; 1573 } 1574 1575 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1576 assert(curDeclScope && "defining decl outside of a decl scope"); 1577 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1578 return emitErrorAndNote( 1579 name.getLoc(), "`" + name.getName() + "` has already been defined", 1580 lastDecl->getName()->getLoc(), "see previous definition here"); 1581 } 1582 return success(); 1583 } 1584 1585 FailureOr<ast::VariableDecl *> 1586 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1587 ast::Expr *initExpr, 1588 ArrayRef<ast::ConstraintRef> constraints) { 1589 assert(curDeclScope && "defining variable outside of decl scope"); 1590 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1591 1592 // If the name of the variable indicates a special variable, we don't add it 1593 // to the scope. This variable is local to the definition point. 1594 if (name.empty() || name == "_") { 1595 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1596 constraints); 1597 } 1598 if (failed(checkDefineNamedDecl(nameDecl))) 1599 return failure(); 1600 1601 auto *varDecl = 1602 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1603 curDeclScope->add(varDecl); 1604 return varDecl; 1605 } 1606 1607 FailureOr<ast::VariableDecl *> 1608 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1609 ArrayRef<ast::ConstraintRef> constraints) { 1610 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1611 constraints); 1612 } 1613 1614 LogicalResult Parser::parseVariableDeclConstraintList( 1615 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1616 Optional<SMRange> typeConstraint; 1617 auto parseSingleConstraint = [&] { 1618 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1619 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1620 /*allowNonCoreConstraints=*/true); 1621 if (failed(constraint)) 1622 return failure(); 1623 constraints.push_back(*constraint); 1624 return success(); 1625 }; 1626 1627 // Check to see if this is a single constraint, or a list. 1628 if (!consumeIf(Token::l_square)) 1629 return parseSingleConstraint(); 1630 1631 do { 1632 if (failed(parseSingleConstraint())) 1633 return failure(); 1634 } while (consumeIf(Token::comma)); 1635 return parseToken(Token::r_square, "expected `]` after constraint list"); 1636 } 1637 1638 FailureOr<ast::ConstraintRef> 1639 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1640 ArrayRef<ast::ConstraintRef> existingConstraints, 1641 bool allowInlineTypeConstraints, 1642 bool allowNonCoreConstraints) { 1643 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1644 if (!allowInlineTypeConstraints) { 1645 return emitError( 1646 curToken.getLoc(), 1647 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1648 "permitted on arguments or results"); 1649 } 1650 if (typeConstraint) 1651 return emitErrorAndNote( 1652 curToken.getLoc(), 1653 "the type of this variable has already been constrained", 1654 *typeConstraint, "see previous constraint location here"); 1655 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1656 if (failed(constraintExpr)) 1657 return failure(); 1658 typeExpr = *constraintExpr; 1659 typeConstraint = typeExpr->getLoc(); 1660 return success(); 1661 }; 1662 1663 SMRange loc = curToken.getLoc(); 1664 switch (curToken.getKind()) { 1665 case Token::kw_Attr: { 1666 consumeToken(Token::kw_Attr); 1667 1668 // Check for a type constraint. 1669 ast::Expr *typeExpr = nullptr; 1670 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1671 return failure(); 1672 return ast::ConstraintRef( 1673 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1674 } 1675 case Token::kw_Op: { 1676 consumeToken(Token::kw_Op); 1677 1678 // Parse an optional operation name. If the name isn't provided, this refers 1679 // to "any" operation. 1680 FailureOr<ast::OpNameDecl *> opName = 1681 parseWrappedOperationName(/*allowEmptyName=*/true); 1682 if (failed(opName)) 1683 return failure(); 1684 1685 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1686 loc); 1687 } 1688 case Token::kw_Type: 1689 consumeToken(Token::kw_Type); 1690 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1691 case Token::kw_TypeRange: 1692 consumeToken(Token::kw_TypeRange); 1693 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1694 loc); 1695 case Token::kw_Value: { 1696 consumeToken(Token::kw_Value); 1697 1698 // Check for a type constraint. 1699 ast::Expr *typeExpr = nullptr; 1700 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1701 return failure(); 1702 1703 return ast::ConstraintRef( 1704 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1705 } 1706 case Token::kw_ValueRange: { 1707 consumeToken(Token::kw_ValueRange); 1708 1709 // Check for a type constraint. 1710 ast::Expr *typeExpr = nullptr; 1711 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1712 return failure(); 1713 1714 return ast::ConstraintRef( 1715 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1716 } 1717 1718 case Token::kw_Constraint: { 1719 // Handle an inline constraint. 1720 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1721 if (failed(decl)) 1722 return failure(); 1723 return ast::ConstraintRef(*decl, loc); 1724 } 1725 case Token::identifier: { 1726 StringRef constraintName = curToken.getSpelling(); 1727 consumeToken(Token::identifier); 1728 1729 // Lookup the referenced constraint. 1730 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1731 if (!cstDecl) { 1732 return emitError(loc, "unknown reference to constraint `" + 1733 constraintName + "`"); 1734 } 1735 1736 // Handle a reference to a proper constraint. 1737 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1738 return ast::ConstraintRef(cst, loc); 1739 1740 return emitErrorAndNote( 1741 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1742 "see the definition of `" + constraintName + "` here"); 1743 } 1744 // Handle single entity constraint code completion. 1745 case Token::code_complete: { 1746 // Try to infer the current type for use by code completion. 1747 ast::Type inferredType; 1748 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1749 allowNonCoreConstraints))) 1750 return failure(); 1751 1752 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1753 allowInlineTypeConstraints); 1754 } 1755 default: 1756 break; 1757 } 1758 return emitError(loc, "expected identifier constraint"); 1759 } 1760 1761 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1762 // Constraint arguments may apply more complex constraints via the arguments. 1763 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1764 1765 Optional<SMRange> typeConstraint; 1766 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1767 /*allowInlineTypeConstraints=*/false, 1768 allowNonCoreConstraints); 1769 } 1770 1771 //===----------------------------------------------------------------------===// 1772 // Exprs 1773 1774 FailureOr<ast::Expr *> Parser::parseExpr() { 1775 if (curToken.is(Token::underscore)) 1776 return parseUnderscoreExpr(); 1777 1778 // Parse the LHS expression. 1779 FailureOr<ast::Expr *> lhsExpr; 1780 switch (curToken.getKind()) { 1781 case Token::kw_attr: 1782 lhsExpr = parseAttributeExpr(); 1783 break; 1784 case Token::kw_Constraint: 1785 lhsExpr = parseInlineConstraintLambdaExpr(); 1786 break; 1787 case Token::identifier: 1788 lhsExpr = parseIdentifierExpr(); 1789 break; 1790 case Token::kw_op: 1791 lhsExpr = parseOperationExpr(); 1792 break; 1793 case Token::kw_Rewrite: 1794 lhsExpr = parseInlineRewriteLambdaExpr(); 1795 break; 1796 case Token::kw_type: 1797 lhsExpr = parseTypeExpr(); 1798 break; 1799 case Token::l_paren: 1800 lhsExpr = parseTupleExpr(); 1801 break; 1802 default: 1803 return emitError("expected expression"); 1804 } 1805 if (failed(lhsExpr)) 1806 return failure(); 1807 1808 // Check for an operator expression. 1809 while (true) { 1810 switch (curToken.getKind()) { 1811 case Token::dot: 1812 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1813 break; 1814 case Token::l_paren: 1815 lhsExpr = parseCallExpr(*lhsExpr); 1816 break; 1817 default: 1818 return lhsExpr; 1819 } 1820 if (failed(lhsExpr)) 1821 return failure(); 1822 } 1823 } 1824 1825 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1826 SMRange loc = curToken.getLoc(); 1827 consumeToken(Token::kw_attr); 1828 1829 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1830 // identifier. 1831 if (!consumeIf(Token::less)) { 1832 resetToken(loc); 1833 return parseIdentifierExpr(); 1834 } 1835 1836 if (!curToken.isString()) 1837 return emitError("expected string literal containing MLIR attribute"); 1838 std::string attrExpr = curToken.getStringValue(); 1839 consumeToken(); 1840 1841 loc.End = curToken.getEndLoc(); 1842 if (failed( 1843 parseToken(Token::greater, "expected `>` after attribute literal"))) 1844 return failure(); 1845 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1846 } 1847 1848 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1849 consumeToken(Token::l_paren); 1850 1851 // Parse the arguments of the call. 1852 SmallVector<ast::Expr *> arguments; 1853 if (curToken.isNot(Token::r_paren)) { 1854 do { 1855 // Handle code completion for the call arguments. 1856 if (curToken.is(Token::code_complete)) { 1857 codeCompleteCallSignature(parentExpr, arguments.size()); 1858 return failure(); 1859 } 1860 1861 FailureOr<ast::Expr *> argument = parseExpr(); 1862 if (failed(argument)) 1863 return failure(); 1864 arguments.push_back(*argument); 1865 } while (consumeIf(Token::comma)); 1866 } 1867 1868 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1869 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1870 return failure(); 1871 1872 return createCallExpr(loc, parentExpr, arguments); 1873 } 1874 1875 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1876 ast::Decl *decl = curDeclScope->lookup(name); 1877 if (!decl) 1878 return emitError(loc, "undefined reference to `" + name + "`"); 1879 1880 return createDeclRefExpr(loc, decl); 1881 } 1882 1883 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1884 StringRef name = curToken.getSpelling(); 1885 SMRange nameLoc = curToken.getLoc(); 1886 consumeToken(); 1887 1888 // Check to see if this is a decl ref expression that defines a variable 1889 // inline. 1890 if (consumeIf(Token::colon)) { 1891 SmallVector<ast::ConstraintRef> constraints; 1892 if (failed(parseVariableDeclConstraintList(constraints))) 1893 return failure(); 1894 ast::Type type; 1895 if (failed(validateVariableConstraints(constraints, type))) 1896 return failure(); 1897 return createInlineVariableExpr(type, name, nameLoc, constraints); 1898 } 1899 1900 return parseDeclRefExpr(name, nameLoc); 1901 } 1902 1903 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1904 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1905 if (failed(decl)) 1906 return failure(); 1907 1908 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1909 ast::ConstraintType::get(ctx)); 1910 } 1911 1912 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1913 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1914 if (failed(decl)) 1915 return failure(); 1916 1917 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1918 ast::RewriteType::get(ctx)); 1919 } 1920 1921 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1922 SMRange dotLoc = curToken.getLoc(); 1923 consumeToken(Token::dot); 1924 1925 // Check for code completion of the member name. 1926 if (curToken.is(Token::code_complete)) 1927 return codeCompleteMemberAccess(parentExpr); 1928 1929 // Parse the member name. 1930 Token memberNameTok = curToken; 1931 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1932 !memberNameTok.isKeyword()) 1933 return emitError(dotLoc, "expected identifier or numeric member name"); 1934 StringRef memberName = memberNameTok.getSpelling(); 1935 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1936 consumeToken(); 1937 1938 return createMemberAccessExpr(parentExpr, memberName, loc); 1939 } 1940 1941 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1942 SMRange loc = curToken.getLoc(); 1943 1944 // Check for code completion for the dialect name. 1945 if (curToken.is(Token::code_complete)) 1946 return codeCompleteDialectName(); 1947 1948 // Handle the case of an no operation name. 1949 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1950 if (allowEmptyName) 1951 return ast::OpNameDecl::create(ctx, SMRange()); 1952 return emitError("expected dialect namespace"); 1953 } 1954 StringRef name = curToken.getSpelling(); 1955 consumeToken(); 1956 1957 // Otherwise, this is a literal operation name. 1958 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1959 return failure(); 1960 1961 // Check for code completion for the operation name. 1962 if (curToken.is(Token::code_complete)) 1963 return codeCompleteOperationName(name); 1964 1965 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1966 return emitError("expected operation name after dialect namespace"); 1967 1968 name = StringRef(name.data(), name.size() + 1); 1969 do { 1970 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1971 loc.End = curToken.getEndLoc(); 1972 consumeToken(); 1973 } while (curToken.isAny(Token::identifier, Token::dot) || 1974 curToken.isKeyword()); 1975 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1976 } 1977 1978 FailureOr<ast::OpNameDecl *> 1979 Parser::parseWrappedOperationName(bool allowEmptyName) { 1980 if (!consumeIf(Token::less)) 1981 return ast::OpNameDecl::create(ctx, SMRange()); 1982 1983 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1984 if (failed(opNameDecl)) 1985 return failure(); 1986 1987 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1988 return failure(); 1989 return opNameDecl; 1990 } 1991 1992 FailureOr<ast::Expr *> 1993 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { 1994 SMRange loc = curToken.getLoc(); 1995 consumeToken(Token::kw_op); 1996 1997 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1998 // identifier. 1999 if (curToken.isNot(Token::less)) { 2000 resetToken(loc); 2001 return parseIdentifierExpr(); 2002 } 2003 2004 // Parse the operation name. The name may be elided, in which case the 2005 // operation refers to "any" operation(i.e. a difference between `MyOp` and 2006 // `Operation*`). Operation names within a rewrite context must be named. 2007 bool allowEmptyName = parserContext != ParserContext::Rewrite; 2008 FailureOr<ast::OpNameDecl *> opNameDecl = 2009 parseWrappedOperationName(allowEmptyName); 2010 if (failed(opNameDecl)) 2011 return failure(); 2012 Optional<StringRef> opName = (*opNameDecl)->getName(); 2013 2014 // Functor used to create an implicit range variable, used for implicit "all" 2015 // operand or results variables. 2016 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 2017 FailureOr<ast::VariableDecl *> rangeVar = 2018 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 2019 assert(succeeded(rangeVar) && "expected range variable to be valid"); 2020 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 2021 }; 2022 2023 // Check for the optional list of operands. 2024 SmallVector<ast::Expr *> operands; 2025 if (!consumeIf(Token::l_paren)) { 2026 // If the operand list isn't specified and we are in a match context, define 2027 // an inplace unconstrained operand range corresponding to all of the 2028 // operands of the operation. This avoids treating zero operands the same 2029 // way as "unconstrained operands". 2030 if (parserContext != ParserContext::Rewrite) { 2031 operands.push_back(createImplicitRangeVar( 2032 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 2033 } 2034 } else if (!consumeIf(Token::r_paren)) { 2035 // If the operand list was specified and non-empty, parse the operands. 2036 do { 2037 // Check for operand signature code completion. 2038 if (curToken.is(Token::code_complete)) { 2039 codeCompleteOperationOperandsSignature(opName, operands.size()); 2040 return failure(); 2041 } 2042 2043 FailureOr<ast::Expr *> operand = parseExpr(); 2044 if (failed(operand)) 2045 return failure(); 2046 operands.push_back(*operand); 2047 } while (consumeIf(Token::comma)); 2048 2049 if (failed(parseToken(Token::r_paren, 2050 "expected `)` after operation operand list"))) 2051 return failure(); 2052 } 2053 2054 // Check for the optional list of attributes. 2055 SmallVector<ast::NamedAttributeDecl *> attributes; 2056 if (consumeIf(Token::l_brace)) { 2057 do { 2058 FailureOr<ast::NamedAttributeDecl *> decl = 2059 parseNamedAttributeDecl(opName); 2060 if (failed(decl)) 2061 return failure(); 2062 attributes.emplace_back(*decl); 2063 } while (consumeIf(Token::comma)); 2064 2065 if (failed(parseToken(Token::r_brace, 2066 "expected `}` after operation attribute list"))) 2067 return failure(); 2068 } 2069 2070 // Handle the result types of the operation. 2071 SmallVector<ast::Expr *> resultTypes; 2072 OpResultTypeContext resultTypeContext = inputResultTypeContext; 2073 2074 // Check for an explicit list of result types. 2075 if (consumeIf(Token::arrow)) { 2076 if (failed(parseToken(Token::l_paren, 2077 "expected `(` before operation result type list"))) 2078 return failure(); 2079 2080 // If result types are provided, initially assume that the operation does 2081 // not rely on type inferrence. We don't assert that it isn't, because we 2082 // may be inferring the value of some type/type range variables, but given 2083 // that these variables may be defined in calls we can't always discern when 2084 // this is the case. 2085 resultTypeContext = OpResultTypeContext::Explicit; 2086 2087 // Handle the case of an empty result list. 2088 if (!consumeIf(Token::r_paren)) { 2089 do { 2090 // Check for result signature code completion. 2091 if (curToken.is(Token::code_complete)) { 2092 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2093 return failure(); 2094 } 2095 2096 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2097 if (failed(resultTypeExpr)) 2098 return failure(); 2099 resultTypes.push_back(*resultTypeExpr); 2100 } while (consumeIf(Token::comma)); 2101 2102 if (failed(parseToken(Token::r_paren, 2103 "expected `)` after operation result type list"))) 2104 return failure(); 2105 } 2106 } else if (parserContext != ParserContext::Rewrite) { 2107 // If the result list isn't specified and we are in a match context, define 2108 // an inplace unconstrained result range corresponding to all of the results 2109 // of the operation. This avoids treating zero results the same way as 2110 // "unconstrained results". 2111 resultTypes.push_back(createImplicitRangeVar( 2112 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2113 } else if (resultTypeContext == OpResultTypeContext::Explicit) { 2114 // If the result list isn't specified and we are in a rewrite, try to infer 2115 // them at runtime instead. 2116 resultTypeContext = OpResultTypeContext::Interface; 2117 } 2118 2119 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands, 2120 attributes, resultTypes); 2121 } 2122 2123 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2124 SMRange loc = curToken.getLoc(); 2125 consumeToken(Token::l_paren); 2126 2127 DenseMap<StringRef, SMRange> usedNames; 2128 SmallVector<StringRef> elementNames; 2129 SmallVector<ast::Expr *> elements; 2130 if (curToken.isNot(Token::r_paren)) { 2131 do { 2132 // Check for the optional element name assignment before the value. 2133 StringRef elementName; 2134 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2135 Token elementNameTok = curToken; 2136 consumeToken(); 2137 2138 // The element name is only present if followed by an `=`. 2139 if (consumeIf(Token::equal)) { 2140 elementName = elementNameTok.getSpelling(); 2141 2142 // Check to see if this name is already used. 2143 auto elementNameIt = 2144 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2145 if (!elementNameIt.second) { 2146 return emitErrorAndNote( 2147 elementNameTok.getLoc(), 2148 llvm::formatv("duplicate tuple element label `{0}`", 2149 elementName), 2150 elementNameIt.first->getSecond(), 2151 "see previous label use here"); 2152 } 2153 } else { 2154 // Otherwise, we treat this as part of an expression so reset the 2155 // lexer. 2156 resetToken(elementNameTok.getLoc()); 2157 } 2158 } 2159 elementNames.push_back(elementName); 2160 2161 // Parse the tuple element value. 2162 FailureOr<ast::Expr *> element = parseExpr(); 2163 if (failed(element)) 2164 return failure(); 2165 elements.push_back(*element); 2166 } while (consumeIf(Token::comma)); 2167 } 2168 loc.End = curToken.getEndLoc(); 2169 if (failed( 2170 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2171 return failure(); 2172 return createTupleExpr(loc, elements, elementNames); 2173 } 2174 2175 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2176 SMRange loc = curToken.getLoc(); 2177 consumeToken(Token::kw_type); 2178 2179 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2180 // identifier. 2181 if (!consumeIf(Token::less)) { 2182 resetToken(loc); 2183 return parseIdentifierExpr(); 2184 } 2185 2186 if (!curToken.isString()) 2187 return emitError("expected string literal containing MLIR type"); 2188 std::string attrExpr = curToken.getStringValue(); 2189 consumeToken(); 2190 2191 loc.End = curToken.getEndLoc(); 2192 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2193 return failure(); 2194 return ast::TypeExpr::create(ctx, loc, attrExpr); 2195 } 2196 2197 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2198 StringRef name = curToken.getSpelling(); 2199 SMRange nameLoc = curToken.getLoc(); 2200 consumeToken(Token::underscore); 2201 2202 // Underscore expressions require a constraint list. 2203 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2204 return failure(); 2205 2206 // Parse the constraints for the expression. 2207 SmallVector<ast::ConstraintRef> constraints; 2208 if (failed(parseVariableDeclConstraintList(constraints))) 2209 return failure(); 2210 2211 ast::Type type; 2212 if (failed(validateVariableConstraints(constraints, type))) 2213 return failure(); 2214 return createInlineVariableExpr(type, name, nameLoc, constraints); 2215 } 2216 2217 //===----------------------------------------------------------------------===// 2218 // Stmts 2219 2220 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2221 FailureOr<ast::Stmt *> stmt; 2222 switch (curToken.getKind()) { 2223 case Token::kw_erase: 2224 stmt = parseEraseStmt(); 2225 break; 2226 case Token::kw_let: 2227 stmt = parseLetStmt(); 2228 break; 2229 case Token::kw_replace: 2230 stmt = parseReplaceStmt(); 2231 break; 2232 case Token::kw_return: 2233 stmt = parseReturnStmt(); 2234 break; 2235 case Token::kw_rewrite: 2236 stmt = parseRewriteStmt(); 2237 break; 2238 default: 2239 stmt = parseExpr(); 2240 break; 2241 } 2242 if (failed(stmt) || 2243 (expectTerminalSemicolon && 2244 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2245 return failure(); 2246 return stmt; 2247 } 2248 2249 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2250 SMLoc startLoc = curToken.getStartLoc(); 2251 consumeToken(Token::l_brace); 2252 2253 // Push a new block scope and parse any nested statements. 2254 pushDeclScope(); 2255 SmallVector<ast::Stmt *> statements; 2256 while (curToken.isNot(Token::r_brace)) { 2257 FailureOr<ast::Stmt *> statement = parseStmt(); 2258 if (failed(statement)) 2259 return popDeclScope(), failure(); 2260 statements.push_back(*statement); 2261 } 2262 popDeclScope(); 2263 2264 // Consume the end brace. 2265 SMRange location(startLoc, curToken.getEndLoc()); 2266 consumeToken(Token::r_brace); 2267 2268 return ast::CompoundStmt::create(ctx, location, statements); 2269 } 2270 2271 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2272 if (parserContext == ParserContext::Constraint) 2273 return emitError("`erase` cannot be used within a Constraint"); 2274 SMRange loc = curToken.getLoc(); 2275 consumeToken(Token::kw_erase); 2276 2277 // Parse the root operation expression. 2278 FailureOr<ast::Expr *> rootOp = parseExpr(); 2279 if (failed(rootOp)) 2280 return failure(); 2281 2282 return createEraseStmt(loc, *rootOp); 2283 } 2284 2285 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2286 SMRange loc = curToken.getLoc(); 2287 consumeToken(Token::kw_let); 2288 2289 // Parse the name of the new variable. 2290 SMRange varLoc = curToken.getLoc(); 2291 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2292 // `_` is a reserved variable name. 2293 if (curToken.is(Token::underscore)) { 2294 return emitError(varLoc, 2295 "`_` may only be used to define \"inline\" variables"); 2296 } 2297 return emitError(varLoc, 2298 "expected identifier after `let` to name a new variable"); 2299 } 2300 StringRef varName = curToken.getSpelling(); 2301 consumeToken(); 2302 2303 // Parse the optional set of constraints. 2304 SmallVector<ast::ConstraintRef> constraints; 2305 if (consumeIf(Token::colon) && 2306 failed(parseVariableDeclConstraintList(constraints))) 2307 return failure(); 2308 2309 // Parse the optional initializer expression. 2310 ast::Expr *initializer = nullptr; 2311 if (consumeIf(Token::equal)) { 2312 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2313 if (failed(initOrFailure)) 2314 return failure(); 2315 initializer = *initOrFailure; 2316 2317 // Check that the constraints are compatible with having an initializer, 2318 // e.g. type constraints cannot be used with initializers. 2319 for (ast::ConstraintRef constraint : constraints) { 2320 LogicalResult result = 2321 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2322 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2323 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2324 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2325 return this->emitError( 2326 constraint.referenceLoc, 2327 "type constraints are not permitted on variables with " 2328 "initializers"); 2329 } 2330 return success(); 2331 }) 2332 .Default(success()); 2333 if (failed(result)) 2334 return failure(); 2335 } 2336 } 2337 2338 FailureOr<ast::VariableDecl *> varDecl = 2339 createVariableDecl(varName, varLoc, initializer, constraints); 2340 if (failed(varDecl)) 2341 return failure(); 2342 return ast::LetStmt::create(ctx, loc, *varDecl); 2343 } 2344 2345 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2346 if (parserContext == ParserContext::Constraint) 2347 return emitError("`replace` cannot be used within a Constraint"); 2348 SMRange loc = curToken.getLoc(); 2349 consumeToken(Token::kw_replace); 2350 2351 // Parse the root operation expression. 2352 FailureOr<ast::Expr *> rootOp = parseExpr(); 2353 if (failed(rootOp)) 2354 return failure(); 2355 2356 if (failed( 2357 parseToken(Token::kw_with, "expected `with` after root operation"))) 2358 return failure(); 2359 2360 // The replacement portion of this statement is within a rewrite context. 2361 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2362 ParserContext::Rewrite); 2363 2364 // Parse the replacement values. 2365 SmallVector<ast::Expr *> replValues; 2366 if (consumeIf(Token::l_paren)) { 2367 if (consumeIf(Token::r_paren)) { 2368 return emitError( 2369 loc, "expected at least one replacement value, consider using " 2370 "`erase` if no replacement values are desired"); 2371 } 2372 2373 do { 2374 FailureOr<ast::Expr *> replExpr = parseExpr(); 2375 if (failed(replExpr)) 2376 return failure(); 2377 replValues.emplace_back(*replExpr); 2378 } while (consumeIf(Token::comma)); 2379 2380 if (failed(parseToken(Token::r_paren, 2381 "expected `)` after replacement values"))) 2382 return failure(); 2383 } else { 2384 // Handle replacement with an operation uniquely, as the replacement 2385 // operation supports type inferrence from the root operation. 2386 FailureOr<ast::Expr *> replExpr; 2387 if (curToken.is(Token::kw_op)) 2388 replExpr = parseOperationExpr(OpResultTypeContext::Replacement); 2389 else 2390 replExpr = parseExpr(); 2391 if (failed(replExpr)) 2392 return failure(); 2393 replValues.emplace_back(*replExpr); 2394 } 2395 2396 return createReplaceStmt(loc, *rootOp, replValues); 2397 } 2398 2399 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2400 SMRange loc = curToken.getLoc(); 2401 consumeToken(Token::kw_return); 2402 2403 // Parse the result value. 2404 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2405 if (failed(resultExpr)) 2406 return failure(); 2407 2408 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2409 } 2410 2411 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2412 if (parserContext == ParserContext::Constraint) 2413 return emitError("`rewrite` cannot be used within a Constraint"); 2414 SMRange loc = curToken.getLoc(); 2415 consumeToken(Token::kw_rewrite); 2416 2417 // Parse the root operation. 2418 FailureOr<ast::Expr *> rootOp = parseExpr(); 2419 if (failed(rootOp)) 2420 return failure(); 2421 2422 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2423 return failure(); 2424 2425 if (curToken.isNot(Token::l_brace)) 2426 return emitError("expected `{` to start rewrite body"); 2427 2428 // The rewrite body of this statement is within a rewrite context. 2429 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2430 ParserContext::Rewrite); 2431 2432 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2433 if (failed(rewriteBody)) 2434 return failure(); 2435 2436 // Verify the rewrite body. 2437 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2438 if (isa<ast::ReturnStmt>(stmt)) { 2439 return emitError(stmt->getLoc(), 2440 "`return` statements are only permitted within a " 2441 "`Constraint` or `Rewrite` body"); 2442 } 2443 } 2444 2445 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2446 } 2447 2448 //===----------------------------------------------------------------------===// 2449 // Creation+Analysis 2450 //===----------------------------------------------------------------------===// 2451 2452 //===----------------------------------------------------------------------===// 2453 // Decls 2454 2455 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2456 // Unwrap reference expressions. 2457 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2458 node = init->getDecl(); 2459 return dyn_cast<ast::CallableDecl>(node); 2460 } 2461 2462 FailureOr<ast::PatternDecl *> 2463 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2464 const ParsedPatternMetadata &metadata, 2465 ast::CompoundStmt *body) { 2466 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2467 metadata.hasBoundedRecursion, body); 2468 } 2469 2470 ast::Type Parser::createUserConstraintRewriteResultType( 2471 ArrayRef<ast::VariableDecl *> results) { 2472 // Single result decls use the type of the single result. 2473 if (results.size() == 1) 2474 return results[0]->getType(); 2475 2476 // Multiple results use a tuple type, with the types and names grabbed from 2477 // the result variable decls. 2478 auto resultTypes = llvm::map_range( 2479 results, [&](const auto *result) { return result->getType(); }); 2480 auto resultNames = llvm::map_range( 2481 results, [&](const auto *result) { return result->getName().getName(); }); 2482 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2483 llvm::to_vector(resultNames)); 2484 } 2485 2486 template <typename T> 2487 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2488 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2489 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2490 ast::CompoundStmt *body) { 2491 if (!body->getChildren().empty()) { 2492 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2493 ast::Expr *resultExpr = retStmt->getResultExpr(); 2494 2495 // Process the result of the decl. If no explicit signature results 2496 // were provided, check for return type inference. Otherwise, check that 2497 // the return expression can be converted to the expected type. 2498 if (results.empty()) 2499 resultType = resultExpr->getType(); 2500 else if (failed(convertExpressionTo(resultExpr, resultType))) 2501 return failure(); 2502 else 2503 retStmt->setResultExpr(resultExpr); 2504 } 2505 } 2506 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2507 } 2508 2509 FailureOr<ast::VariableDecl *> 2510 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2511 ArrayRef<ast::ConstraintRef> constraints) { 2512 // The type of the variable, which is expected to be inferred by either a 2513 // constraint or an initializer expression. 2514 ast::Type type; 2515 if (failed(validateVariableConstraints(constraints, type))) 2516 return failure(); 2517 2518 if (initializer) { 2519 // Update the variable type based on the initializer, or try to convert the 2520 // initializer to the existing type. 2521 if (!type) 2522 type = initializer->getType(); 2523 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2524 type = mergedType; 2525 else if (failed(convertExpressionTo(initializer, type))) 2526 return failure(); 2527 2528 // Otherwise, if there is no initializer check that the type has already 2529 // been resolved from the constraint list. 2530 } else if (!type) { 2531 return emitErrorAndNote( 2532 loc, "unable to infer type for variable `" + name + "`", loc, 2533 "the type of a variable must be inferable from the constraint " 2534 "list or the initializer"); 2535 } 2536 2537 // Constraint types cannot be used when defining variables. 2538 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2539 return emitError( 2540 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2541 } 2542 2543 // Try to define a variable with the given name. 2544 FailureOr<ast::VariableDecl *> varDecl = 2545 defineVariableDecl(name, loc, type, initializer, constraints); 2546 if (failed(varDecl)) 2547 return failure(); 2548 2549 return *varDecl; 2550 } 2551 2552 FailureOr<ast::VariableDecl *> 2553 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2554 const ast::ConstraintRef &constraint) { 2555 // Constraint arguments may apply more complex constraints via the arguments. 2556 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2557 ast::Type argType; 2558 if (failed(validateVariableConstraint(constraint, argType, 2559 allowNonCoreConstraints))) 2560 return failure(); 2561 return defineVariableDecl(name, loc, argType, constraint); 2562 } 2563 2564 LogicalResult 2565 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2566 ast::Type &inferredType, 2567 bool allowNonCoreConstraints) { 2568 for (const ast::ConstraintRef &ref : constraints) 2569 if (failed(validateVariableConstraint(ref, inferredType, 2570 allowNonCoreConstraints))) 2571 return failure(); 2572 return success(); 2573 } 2574 2575 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2576 ast::Type &inferredType, 2577 bool allowNonCoreConstraints) { 2578 ast::Type constraintType; 2579 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2580 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2581 if (failed(validateTypeConstraintExpr(typeExpr))) 2582 return failure(); 2583 } 2584 constraintType = ast::AttributeType::get(ctx); 2585 } else if (const auto *cst = 2586 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2587 constraintType = ast::OperationType::get( 2588 ctx, cst->getName(), lookupODSOperation(cst->getName())); 2589 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2590 constraintType = typeTy; 2591 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2592 constraintType = typeRangeTy; 2593 } else if (const auto *cst = 2594 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2595 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2596 if (failed(validateTypeConstraintExpr(typeExpr))) 2597 return failure(); 2598 } 2599 constraintType = valueTy; 2600 } else if (const auto *cst = 2601 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2602 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2603 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2604 return failure(); 2605 } 2606 constraintType = valueRangeTy; 2607 } else if (const auto *cst = 2608 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2609 if (!allowNonCoreConstraints) { 2610 return emitError(ref.referenceLoc, 2611 "`Rewrite` arguments and results are only permitted to " 2612 "use core constraints, such as `Attr`, `Op`, `Type`, " 2613 "`TypeRange`, `Value`, `ValueRange`"); 2614 } 2615 2616 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2617 if (inputs.size() != 1) { 2618 return emitErrorAndNote(ref.referenceLoc, 2619 "`Constraint`s applied via a variable constraint " 2620 "list must take a single input, but got " + 2621 Twine(inputs.size()), 2622 cst->getLoc(), 2623 "see definition of constraint here"); 2624 } 2625 constraintType = inputs.front()->getType(); 2626 } else { 2627 llvm_unreachable("unknown constraint type"); 2628 } 2629 2630 // Check that the constraint type is compatible with the current inferred 2631 // type. 2632 if (!inferredType) { 2633 inferredType = constraintType; 2634 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2635 inferredType = mergedTy; 2636 } else { 2637 return emitError(ref.referenceLoc, 2638 llvm::formatv("constraint type `{0}` is incompatible " 2639 "with the previously inferred type `{1}`", 2640 constraintType, inferredType)); 2641 } 2642 return success(); 2643 } 2644 2645 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2646 ast::Type typeExprType = typeExpr->getType(); 2647 if (typeExprType != typeTy) { 2648 return emitError(typeExpr->getLoc(), 2649 "expected expression of `Type` in type constraint"); 2650 } 2651 return success(); 2652 } 2653 2654 LogicalResult 2655 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2656 ast::Type typeExprType = typeExpr->getType(); 2657 if (typeExprType != typeRangeTy) { 2658 return emitError(typeExpr->getLoc(), 2659 "expected expression of `TypeRange` in type constraint"); 2660 } 2661 return success(); 2662 } 2663 2664 //===----------------------------------------------------------------------===// 2665 // Exprs 2666 2667 FailureOr<ast::CallExpr *> 2668 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2669 MutableArrayRef<ast::Expr *> arguments) { 2670 ast::Type parentType = parentExpr->getType(); 2671 2672 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2673 if (!callableDecl) { 2674 return emitError(loc, 2675 llvm::formatv("expected a reference to a callable " 2676 "`Constraint` or `Rewrite`, but got: `{0}`", 2677 parentType)); 2678 } 2679 if (parserContext == ParserContext::Rewrite) { 2680 if (isa<ast::UserConstraintDecl>(callableDecl)) 2681 return emitError( 2682 loc, "unable to invoke `Constraint` within a rewrite section"); 2683 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2684 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2685 } 2686 2687 // Verify the arguments of the call. 2688 /// Handle size mismatch. 2689 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2690 if (callArgs.size() != arguments.size()) { 2691 return emitErrorAndNote( 2692 loc, 2693 llvm::formatv("invalid number of arguments for {0} call; expected " 2694 "{1}, but got {2}", 2695 callableDecl->getCallableType(), callArgs.size(), 2696 arguments.size()), 2697 callableDecl->getLoc(), 2698 llvm::formatv("see the definition of {0} here", 2699 callableDecl->getName()->getName())); 2700 } 2701 2702 /// Handle argument type mismatch. 2703 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2704 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2705 callableDecl->getName()->getName()), 2706 callableDecl->getLoc()); 2707 }; 2708 for (auto it : llvm::zip(callArgs, arguments)) { 2709 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2710 attachDiagFn))) 2711 return failure(); 2712 } 2713 2714 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2715 callableDecl->getResultType()); 2716 } 2717 2718 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2719 ast::Decl *decl) { 2720 // Check the type of decl being referenced. 2721 ast::Type declType; 2722 if (isa<ast::ConstraintDecl>(decl)) 2723 declType = ast::ConstraintType::get(ctx); 2724 else if (isa<ast::UserRewriteDecl>(decl)) 2725 declType = ast::RewriteType::get(ctx); 2726 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2727 declType = varDecl->getType(); 2728 else 2729 return emitError(loc, "invalid reference to `" + 2730 decl->getName()->getName() + "`"); 2731 2732 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2733 } 2734 2735 FailureOr<ast::DeclRefExpr *> 2736 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2737 ArrayRef<ast::ConstraintRef> constraints) { 2738 FailureOr<ast::VariableDecl *> decl = 2739 defineVariableDecl(name, loc, type, constraints); 2740 if (failed(decl)) 2741 return failure(); 2742 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2743 } 2744 2745 FailureOr<ast::MemberAccessExpr *> 2746 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2747 SMRange loc) { 2748 // Validate the member name for the given parent expression. 2749 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2750 if (failed(memberType)) 2751 return failure(); 2752 2753 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2754 } 2755 2756 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2757 StringRef name, SMRange loc) { 2758 ast::Type parentType = parentExpr->getType(); 2759 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2760 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2761 return valueRangeTy; 2762 2763 // Verify member access based on the operation type. 2764 if (const ods::Operation *odsOp = opType.getODSOperation()) { 2765 auto results = odsOp->getResults(); 2766 2767 // Handle indexed results. 2768 unsigned index = 0; 2769 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2770 index < results.size()) { 2771 return results[index].isVariadic() ? valueRangeTy : valueTy; 2772 } 2773 2774 // Handle named results. 2775 const auto *it = llvm::find_if(results, [&](const auto &result) { 2776 return result.getName() == name; 2777 }); 2778 if (it != results.end()) 2779 return it->isVariadic() ? valueRangeTy : valueTy; 2780 } else if (llvm::isDigit(name[0])) { 2781 // Allow unchecked numeric indexing of the results of unregistered 2782 // operations. It returns a single value. 2783 return valueTy; 2784 } 2785 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2786 // Handle indexed results. 2787 unsigned index = 0; 2788 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2789 index < tupleType.size()) { 2790 return tupleType.getElementTypes()[index]; 2791 } 2792 2793 // Handle named results. 2794 auto elementNames = tupleType.getElementNames(); 2795 const auto *it = llvm::find(elementNames, name); 2796 if (it != elementNames.end()) 2797 return tupleType.getElementTypes()[it - elementNames.begin()]; 2798 } 2799 return emitError( 2800 loc, 2801 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2802 name, parentType)); 2803 } 2804 2805 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2806 SMRange loc, const ast::OpNameDecl *name, 2807 OpResultTypeContext resultTypeContext, 2808 MutableArrayRef<ast::Expr *> operands, 2809 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2810 MutableArrayRef<ast::Expr *> results) { 2811 Optional<StringRef> opNameRef = name->getName(); 2812 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2813 2814 // Verify the inputs operands. 2815 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2816 return failure(); 2817 2818 // Verify the attribute list. 2819 for (ast::NamedAttributeDecl *attr : attributes) { 2820 // Check for an attribute type, or a type awaiting resolution. 2821 ast::Type attrType = attr->getValue()->getType(); 2822 if (!attrType.isa<ast::AttributeType>()) { 2823 return emitError( 2824 attr->getValue()->getLoc(), 2825 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2826 } 2827 } 2828 2829 assert( 2830 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && 2831 "unexpected inferrence when results were explicitly specified"); 2832 2833 // If we aren't relying on type inferrence, or explicit results were provided, 2834 // validate them. 2835 if (resultTypeContext == OpResultTypeContext::Explicit) { 2836 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2837 return failure(); 2838 2839 // Validate the use of interface based type inferrence for this operation. 2840 } else if (resultTypeContext == OpResultTypeContext::Interface) { 2841 assert(opNameRef && 2842 "expected valid operation name when inferring operation results"); 2843 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); 2844 } 2845 2846 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results, 2847 attributes); 2848 } 2849 2850 LogicalResult 2851 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2852 const ods::Operation *odsOp, 2853 MutableArrayRef<ast::Expr *> operands) { 2854 return validateOperationOperandsOrResults( 2855 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2856 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2857 valueRangeTy); 2858 } 2859 2860 LogicalResult 2861 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2862 const ods::Operation *odsOp, 2863 MutableArrayRef<ast::Expr *> results) { 2864 return validateOperationOperandsOrResults( 2865 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2866 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2867 } 2868 2869 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, 2870 const ods::Operation *odsOp) { 2871 // If the operation might not have inferrence support, emit a warning to the 2872 // user. We don't emit an error because the interface might be added to the 2873 // operation at runtime. It's rare, but it could still happen. We emit a 2874 // warning here instead. 2875 2876 // Handle inferrence warnings for unknown operations. 2877 if (!odsOp) { 2878 ctx.getDiagEngine().emitWarning( 2879 loc, llvm::formatv( 2880 "operation result types are marked to be inferred, but " 2881 "`{0}` is unknown. Ensure that `{0}` supports zero " 2882 "results or implements `InferTypeOpInterface`. Include " 2883 "the ODS definition of this operation to remove this warning.", 2884 opName)); 2885 return; 2886 } 2887 2888 // Handle inferrence warnings for known operations that expected at least one 2889 // result, but don't have inference support. An elided results list can mean 2890 // "zero-results", and we don't want to warn when that is the expected 2891 // behavior. 2892 bool requiresInferrence = 2893 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) { 2894 return !result.isVariableLength(); 2895 }); 2896 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { 2897 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( 2898 loc, 2899 llvm::formatv("operation result types are marked to be inferred, but " 2900 "`{0}` does not provide an implementation of " 2901 "`InferTypeOpInterface`. Ensure that `{0}` attaches " 2902 "`InferTypeOpInterface` at runtime, or add support to " 2903 "the ODS definition to remove this warning.", 2904 opName)); 2905 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName), 2906 odsOp->getLoc()); 2907 return; 2908 } 2909 } 2910 2911 LogicalResult Parser::validateOperationOperandsOrResults( 2912 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2913 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2914 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2915 ast::Type rangeTy) { 2916 // All operation types accept a single range parameter. 2917 if (values.size() == 1) { 2918 if (failed(convertExpressionTo(values[0], rangeTy))) 2919 return failure(); 2920 return success(); 2921 } 2922 2923 /// If the operation has ODS information, we can more accurately verify the 2924 /// values. 2925 if (odsOpLoc) { 2926 if (odsValues.size() != values.size()) { 2927 return emitErrorAndNote( 2928 loc, 2929 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2930 "{2}, but got {3}", 2931 groupName, *name, odsValues.size(), values.size()), 2932 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2933 } 2934 auto diagFn = [&](ast::Diagnostic &diag) { 2935 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2936 *odsOpLoc); 2937 }; 2938 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2939 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2940 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2941 return failure(); 2942 } 2943 return success(); 2944 } 2945 2946 // Otherwise, accept the value groups as they have been defined and just 2947 // ensure they are one of the expected types. 2948 for (ast::Expr *&valueExpr : values) { 2949 ast::Type valueExprType = valueExpr->getType(); 2950 2951 // Check if this is one of the expected types. 2952 if (valueExprType == rangeTy || valueExprType == singleTy) 2953 continue; 2954 2955 // If the operand is an Operation, allow converting to a Value or 2956 // ValueRange. This situations arises quite often with nested operation 2957 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2958 if (singleTy == valueTy) { 2959 if (valueExprType.isa<ast::OperationType>()) { 2960 valueExpr = convertOpToValue(valueExpr); 2961 continue; 2962 } 2963 } 2964 2965 return emitError( 2966 valueExpr->getLoc(), 2967 llvm::formatv( 2968 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2969 singleTy, rangeTy, valueExprType)); 2970 } 2971 return success(); 2972 } 2973 2974 FailureOr<ast::TupleExpr *> 2975 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2976 ArrayRef<StringRef> elementNames) { 2977 for (const ast::Expr *element : elements) { 2978 ast::Type eleTy = element->getType(); 2979 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2980 return emitError( 2981 element->getLoc(), 2982 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2983 } 2984 } 2985 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2986 } 2987 2988 //===----------------------------------------------------------------------===// 2989 // Stmts 2990 2991 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2992 ast::Expr *rootOp) { 2993 // Check that root is an Operation. 2994 ast::Type rootType = rootOp->getType(); 2995 if (!rootType.isa<ast::OperationType>()) 2996 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2997 2998 return ast::EraseStmt::create(ctx, loc, rootOp); 2999 } 3000 3001 FailureOr<ast::ReplaceStmt *> 3002 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 3003 MutableArrayRef<ast::Expr *> replValues) { 3004 // Check that root is an Operation. 3005 ast::Type rootType = rootOp->getType(); 3006 if (!rootType.isa<ast::OperationType>()) { 3007 return emitError( 3008 rootOp->getLoc(), 3009 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3010 } 3011 3012 // If there are multiple replacement values, we implicitly convert any Op 3013 // expressions to the value form. 3014 bool shouldConvertOpToValues = replValues.size() > 1; 3015 for (ast::Expr *&replExpr : replValues) { 3016 ast::Type replType = replExpr->getType(); 3017 3018 // Check that replExpr is an Operation, Value, or ValueRange. 3019 if (replType.isa<ast::OperationType>()) { 3020 if (shouldConvertOpToValues) 3021 replExpr = convertOpToValue(replExpr); 3022 continue; 3023 } 3024 3025 if (replType != valueTy && replType != valueRangeTy) { 3026 return emitError(replExpr->getLoc(), 3027 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 3028 "expression, but got `{0}`", 3029 replType)); 3030 } 3031 } 3032 3033 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 3034 } 3035 3036 FailureOr<ast::RewriteStmt *> 3037 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 3038 ast::CompoundStmt *rewriteBody) { 3039 // Check that root is an Operation. 3040 ast::Type rootType = rootOp->getType(); 3041 if (!rootType.isa<ast::OperationType>()) { 3042 return emitError( 3043 rootOp->getLoc(), 3044 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3045 } 3046 3047 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 3048 } 3049 3050 //===----------------------------------------------------------------------===// 3051 // Code Completion 3052 //===----------------------------------------------------------------------===// 3053 3054 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 3055 ast::Type parentType = parentExpr->getType(); 3056 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 3057 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 3058 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 3059 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 3060 return failure(); 3061 } 3062 3063 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 3064 if (opName) 3065 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 3066 return failure(); 3067 } 3068 3069 LogicalResult 3070 Parser::codeCompleteConstraintName(ast::Type inferredType, 3071 bool allowNonCoreConstraints, 3072 bool allowInlineTypeConstraints) { 3073 codeCompleteContext->codeCompleteConstraintName( 3074 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 3075 curDeclScope); 3076 return failure(); 3077 } 3078 3079 LogicalResult Parser::codeCompleteDialectName() { 3080 codeCompleteContext->codeCompleteDialectName(); 3081 return failure(); 3082 } 3083 3084 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 3085 codeCompleteContext->codeCompleteOperationName(dialectName); 3086 return failure(); 3087 } 3088 3089 LogicalResult Parser::codeCompletePatternMetadata() { 3090 codeCompleteContext->codeCompletePatternMetadata(); 3091 return failure(); 3092 } 3093 3094 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 3095 codeCompleteContext->codeCompleteIncludeFilename(curPath); 3096 return failure(); 3097 } 3098 3099 void Parser::codeCompleteCallSignature(ast::Node *parent, 3100 unsigned currentNumArgs) { 3101 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 3102 if (!callableDecl) 3103 return; 3104 3105 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 3106 } 3107 3108 void Parser::codeCompleteOperationOperandsSignature( 3109 Optional<StringRef> opName, unsigned currentNumOperands) { 3110 codeCompleteContext->codeCompleteOperationOperandsSignature( 3111 opName, currentNumOperands); 3112 } 3113 3114 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 3115 unsigned currentNumResults) { 3116 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 3117 currentNumResults); 3118 } 3119 3120 //===----------------------------------------------------------------------===// 3121 // Parser 3122 //===----------------------------------------------------------------------===// 3123 3124 FailureOr<ast::Module *> 3125 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 3126 bool enableDocumentation, 3127 CodeCompleteContext *codeCompleteContext) { 3128 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext); 3129 return parser.parseModule(); 3130 } 3131