1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/IndentedOstream.h" 12 #include "mlir/Support/LogicalResult.h" 13 #include "mlir/TableGen/Argument.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Constraint.h" 16 #include "mlir/TableGen/Format.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/Tools/PDLL/AST/Context.h" 19 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 20 #include "mlir/Tools/PDLL/AST/Nodes.h" 21 #include "mlir/Tools/PDLL/AST/Types.h" 22 #include "mlir/Tools/PDLL/ODS/Constraint.h" 23 #include "mlir/Tools/PDLL/ODS/Context.h" 24 #include "mlir/Tools/PDLL/ODS/Operation.h" 25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/ManagedStatic.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/ScopedPrinter.h" 32 #include "llvm/TableGen/Error.h" 33 #include "llvm/TableGen/Parser.h" 34 #include <string> 35 36 using namespace mlir; 37 using namespace mlir::pdll; 38 39 //===----------------------------------------------------------------------===// 40 // Parser 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 class Parser { 45 public: 46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext) 48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), 49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation), 50 valueTy(ast::ValueType::get(ctx)), 51 valueRangeTy(ast::ValueRangeType::get(ctx)), 52 typeTy(ast::TypeType::get(ctx)), 53 typeRangeTy(ast::TypeRangeType::get(ctx)), 54 attrTy(ast::AttributeType::get(ctx)), 55 codeCompleteContext(codeCompleteContext) {} 56 57 /// Try to parse a new module. Returns nullptr in the case of failure. 58 FailureOr<ast::Module *> parseModule(); 59 60 private: 61 /// The current context of the parser. It allows for the parser to know a bit 62 /// about the construct it is nested within during parsing. This is used 63 /// specifically to provide additional verification during parsing, e.g. to 64 /// prevent using rewrites within a match context, matcher constraints within 65 /// a rewrite section, etc. 66 enum class ParserContext { 67 /// The parser is in the global context. 68 Global, 69 /// The parser is currently within a Constraint, which disallows all types 70 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 71 Constraint, 72 /// The parser is currently within the matcher portion of a Pattern, which 73 /// is allows a terminal operation rewrite statement but no other rewrite 74 /// transformations. 75 PatternMatch, 76 /// The parser is currently within a Rewrite, which disallows calls to 77 /// constraints, requires operation expressions to have names, etc. 78 Rewrite, 79 }; 80 81 /// The current specification context of an operations result type. This 82 /// indicates how the result types of an operation may be inferred. 83 enum class OpResultTypeContext { 84 /// The result types of the operation are not known to be inferred. 85 Explicit, 86 /// The result types of the operation are inferred from the root input of a 87 /// `replace` statement. 88 Replacement, 89 /// The result types of the operation are inferred by using the 90 /// `InferTypeOpInterface` interface provided by the operation. 91 Interface, 92 }; 93 94 //===--------------------------------------------------------------------===// 95 // Parsing 96 //===--------------------------------------------------------------------===// 97 98 /// Push a new decl scope onto the lexer. 99 ast::DeclScope *pushDeclScope() { 100 ast::DeclScope *newScope = 101 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 102 return (curDeclScope = newScope); 103 } 104 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 105 106 /// Pop the last decl scope from the lexer. 107 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 108 109 /// Parse the body of an AST module. 110 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 111 112 /// Try to convert the given expression to `type`. Returns failure and emits 113 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 114 /// invoked to attach notes to the emitted error diagnostic. On success, 115 /// `expr` is updated to the expression used to convert to `type`. 116 LogicalResult convertExpressionTo( 117 ast::Expr *&expr, ast::Type type, 118 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 119 120 /// Given an operation expression, convert it to a Value or ValueRange 121 /// typed expression. 122 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 123 124 /// Lookup ODS information for the given operation, returns nullptr if no 125 /// information is found. 126 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 127 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 128 } 129 130 /// Process the given documentation string, or return an empty string if 131 /// documentation isn't enabled. 132 StringRef processDoc(StringRef doc) { 133 return enableDocumentation ? doc : StringRef(); 134 } 135 136 /// Process the given documentation string and format it, or return an empty 137 /// string if documentation isn't enabled. 138 std::string processAndFormatDoc(const Twine &doc) { 139 if (!enableDocumentation) 140 return ""; 141 std::string docStr; 142 { 143 llvm::raw_string_ostream docOS(docStr); 144 std::string tmpDocStr = doc.str(); 145 raw_indented_ostream(docOS).printReindented( 146 StringRef(tmpDocStr).rtrim(" \t")); 147 } 148 return docStr; 149 } 150 151 //===--------------------------------------------------------------------===// 152 // Directives 153 154 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 155 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 156 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 157 SmallVectorImpl<ast::Decl *> &decls); 158 159 /// Process the records of a parsed tablegen include file. 160 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 161 SmallVectorImpl<ast::Decl *> &decls); 162 163 /// Create a user defined native constraint for a constraint imported from 164 /// ODS. 165 template <typename ConstraintT> 166 ast::Decl * 167 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 168 SMRange loc, ast::Type type, 169 StringRef nativeType, StringRef docString); 170 template <typename ConstraintT> 171 ast::Decl * 172 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 173 SMRange loc, ast::Type type, 174 StringRef nativeType); 175 176 //===--------------------------------------------------------------------===// 177 // Decls 178 179 /// This structure contains the set of pattern metadata that may be parsed. 180 struct ParsedPatternMetadata { 181 Optional<uint16_t> benefit; 182 bool hasBoundedRecursion = false; 183 }; 184 185 FailureOr<ast::Decl *> parseTopLevelDecl(); 186 FailureOr<ast::NamedAttributeDecl *> 187 parseNamedAttributeDecl(Optional<StringRef> parentOpName); 188 189 /// Parse an argument variable as part of the signature of a 190 /// UserConstraintDecl or UserRewriteDecl. 191 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 192 193 /// Parse a result variable as part of the signature of a UserConstraintDecl 194 /// or UserRewriteDecl. 195 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 196 197 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 198 /// defined in a non-global context. 199 FailureOr<ast::UserConstraintDecl *> 200 parseUserConstraintDecl(bool isInline = false); 201 202 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 203 /// non-global context, such as within a Pattern/Constraint/etc. 204 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 205 206 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 207 /// PDLL constructs. 208 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 209 const ast::Name &name, bool isInline, 210 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 211 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 212 213 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 214 /// defined in a non-global context. 215 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 216 217 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 218 /// non-global context, such as within a Pattern/Rewrite/etc. 219 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 220 221 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 222 /// PDLL constructs. 223 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 224 const ast::Name &name, bool isInline, 225 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 226 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 227 228 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 229 /// effectively the same syntax, and only differ on slight semantics (given 230 /// the different parsing contexts). 231 template <typename T, typename ParseUserPDLLDeclFnT> 232 FailureOr<T *> parseUserConstraintOrRewriteDecl( 233 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 234 StringRef anonymousNamePrefix, bool isInline); 235 236 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 237 /// These decls have effectively the same syntax. 238 template <typename T> 239 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 240 const ast::Name &name, bool isInline, 241 ArrayRef<ast::VariableDecl *> arguments, 242 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 243 244 /// Parse the functional signature (i.e. the arguments and results) of a 245 /// UserConstraintDecl or UserRewriteDecl. 246 LogicalResult parseUserConstraintOrRewriteSignature( 247 SmallVectorImpl<ast::VariableDecl *> &arguments, 248 SmallVectorImpl<ast::VariableDecl *> &results, 249 ast::DeclScope *&argumentScope, ast::Type &resultType); 250 251 /// Validate the return (which if present is specified by bodyIt) of a 252 /// UserConstraintDecl or UserRewriteDecl. 253 LogicalResult validateUserConstraintOrRewriteReturn( 254 StringRef declType, ast::CompoundStmt *body, 255 ArrayRef<ast::Stmt *>::iterator bodyIt, 256 ArrayRef<ast::Stmt *>::iterator bodyE, 257 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 258 259 FailureOr<ast::CompoundStmt *> 260 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 261 bool expectTerminalSemicolon = true); 262 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 263 FailureOr<ast::Decl *> parsePatternDecl(); 264 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 265 266 /// Check to see if a decl has already been defined with the given name, if 267 /// one has emit and error and return failure. Returns success otherwise. 268 LogicalResult checkDefineNamedDecl(const ast::Name &name); 269 270 /// Try to define a variable decl with the given components, returns the 271 /// variable on success. 272 FailureOr<ast::VariableDecl *> 273 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 274 ast::Expr *initExpr, 275 ArrayRef<ast::ConstraintRef> constraints); 276 FailureOr<ast::VariableDecl *> 277 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 278 ArrayRef<ast::ConstraintRef> constraints); 279 280 /// Parse the constraint reference list for a variable decl. 281 LogicalResult parseVariableDeclConstraintList( 282 SmallVectorImpl<ast::ConstraintRef> &constraints); 283 284 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 285 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 286 287 /// Try to parse a single reference to a constraint. `typeConstraint` is the 288 /// location of a previously parsed type constraint for the entity that will 289 /// be constrained by the parsed constraint. `existingConstraints` are any 290 /// existing constraints that have already been parsed for the same entity 291 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 292 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 293 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined 294 /// constraints) may be used with the variable. 295 FailureOr<ast::ConstraintRef> 296 parseConstraint(Optional<SMRange> &typeConstraint, 297 ArrayRef<ast::ConstraintRef> existingConstraints, 298 bool allowInlineTypeConstraints, 299 bool allowNonCoreConstraints); 300 301 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 302 /// argument or result variable. The constraints for these variables do not 303 /// allow inline type constraints, and only permit a single constraint. 304 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 305 306 //===--------------------------------------------------------------------===// 307 // Exprs 308 309 FailureOr<ast::Expr *> parseExpr(); 310 311 /// Identifier expressions. 312 FailureOr<ast::Expr *> parseAttributeExpr(); 313 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 314 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 315 FailureOr<ast::Expr *> parseIdentifierExpr(); 316 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 317 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 318 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 319 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 320 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 321 FailureOr<ast::Expr *> 322 parseOperationExpr(OpResultTypeContext inputResultTypeContext = 323 OpResultTypeContext::Explicit); 324 FailureOr<ast::Expr *> parseTupleExpr(); 325 FailureOr<ast::Expr *> parseTypeExpr(); 326 FailureOr<ast::Expr *> parseUnderscoreExpr(); 327 328 //===--------------------------------------------------------------------===// 329 // Stmts 330 331 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 332 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 333 FailureOr<ast::EraseStmt *> parseEraseStmt(); 334 FailureOr<ast::LetStmt *> parseLetStmt(); 335 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 336 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 337 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 338 339 //===--------------------------------------------------------------------===// 340 // Creation+Analysis 341 //===--------------------------------------------------------------------===// 342 343 //===--------------------------------------------------------------------===// 344 // Decls 345 346 /// Try to extract a callable from the given AST node. Returns nullptr on 347 /// failure. 348 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 349 350 /// Try to create a pattern decl with the given components, returning the 351 /// Pattern on success. 352 FailureOr<ast::PatternDecl *> 353 createPatternDecl(SMRange loc, const ast::Name *name, 354 const ParsedPatternMetadata &metadata, 355 ast::CompoundStmt *body); 356 357 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 358 /// of results, defined as part of the signature. 359 ast::Type 360 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 361 362 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 363 template <typename T> 364 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 365 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 366 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 367 ast::CompoundStmt *body); 368 369 /// Try to create a variable decl with the given components, returning the 370 /// Variable on success. 371 FailureOr<ast::VariableDecl *> 372 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 373 ArrayRef<ast::ConstraintRef> constraints); 374 375 /// Create a variable for an argument or result defined as part of the 376 /// signature of a UserConstraintDecl/UserRewriteDecl. 377 FailureOr<ast::VariableDecl *> 378 createArgOrResultVariableDecl(StringRef name, SMRange loc, 379 const ast::ConstraintRef &constraint); 380 381 /// Validate the constraints used to constraint a variable decl. 382 /// `inferredType` is the type of the variable inferred by the constraints 383 /// within the list, and is updated to the most refined type as determined by 384 /// the constraints. Returns success if the constraint list is valid, failure 385 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 386 /// defined constraints) may be used with the variable. 387 LogicalResult 388 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 389 ast::Type &inferredType, 390 bool allowNonCoreConstraints = true); 391 /// Validate a single reference to a constraint. `inferredType` contains the 392 /// currently inferred variabled type and is refined within the type defined 393 /// by the constraint. Returns success if the constraint is valid, failure 394 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 395 /// defined constraints) may be used with the variable. 396 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 397 ast::Type &inferredType, 398 bool allowNonCoreConstraints = true); 399 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 400 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 401 402 //===--------------------------------------------------------------------===// 403 // Exprs 404 405 FailureOr<ast::CallExpr *> 406 createCallExpr(SMRange loc, ast::Expr *parentExpr, 407 MutableArrayRef<ast::Expr *> arguments); 408 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 409 FailureOr<ast::DeclRefExpr *> 410 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 411 ArrayRef<ast::ConstraintRef> constraints); 412 FailureOr<ast::MemberAccessExpr *> 413 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 414 415 /// Validate the member access `name` into the given parent expression. On 416 /// success, this also returns the type of the member accessed. 417 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 418 StringRef name, SMRange loc); 419 FailureOr<ast::OperationExpr *> 420 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 421 OpResultTypeContext resultTypeContext, 422 MutableArrayRef<ast::Expr *> operands, 423 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 424 MutableArrayRef<ast::Expr *> results); 425 LogicalResult 426 validateOperationOperands(SMRange loc, Optional<StringRef> name, 427 const ods::Operation *odsOp, 428 MutableArrayRef<ast::Expr *> operands); 429 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 430 const ods::Operation *odsOp, 431 MutableArrayRef<ast::Expr *> results); 432 void checkOperationResultTypeInferrence(SMRange loc, StringRef name, 433 const ods::Operation *odsOp); 434 LogicalResult validateOperationOperandsOrResults( 435 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 436 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 437 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 438 ast::Type rangeTy); 439 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 440 ArrayRef<ast::Expr *> elements, 441 ArrayRef<StringRef> elementNames); 442 443 //===--------------------------------------------------------------------===// 444 // Stmts 445 446 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 447 FailureOr<ast::ReplaceStmt *> 448 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 449 MutableArrayRef<ast::Expr *> replValues); 450 FailureOr<ast::RewriteStmt *> 451 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 452 ast::CompoundStmt *rewriteBody); 453 454 //===--------------------------------------------------------------------===// 455 // Code Completion 456 //===--------------------------------------------------------------------===// 457 458 /// The set of various code completion methods. Every completion method 459 /// returns `failure` to stop the parsing process after providing completion 460 /// results. 461 462 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); 463 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName); 464 LogicalResult codeCompleteConstraintName(ast::Type inferredType, 465 bool allowNonCoreConstraints, 466 bool allowInlineTypeConstraints); 467 LogicalResult codeCompleteDialectName(); 468 LogicalResult codeCompleteOperationName(StringRef dialectName); 469 LogicalResult codeCompletePatternMetadata(); 470 LogicalResult codeCompleteIncludeFilename(StringRef curPath); 471 472 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); 473 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName, 474 unsigned currentNumOperands); 475 void codeCompleteOperationResultsSignature(Optional<StringRef> opName, 476 unsigned currentNumResults); 477 478 //===--------------------------------------------------------------------===// 479 // Lexer Utilities 480 //===--------------------------------------------------------------------===// 481 482 /// If the current token has the specified kind, consume it and return true. 483 /// If not, return false. 484 bool consumeIf(Token::Kind kind) { 485 if (curToken.isNot(kind)) 486 return false; 487 consumeToken(kind); 488 return true; 489 } 490 491 /// Advance the current lexer onto the next token. 492 void consumeToken() { 493 assert(curToken.isNot(Token::eof, Token::error) && 494 "shouldn't advance past EOF or errors"); 495 curToken = lexer.lexToken(); 496 } 497 498 /// Advance the current lexer onto the next token, asserting what the expected 499 /// current token is. This is preferred to the above method because it leads 500 /// to more self-documenting code with better checking. 501 void consumeToken(Token::Kind kind) { 502 assert(curToken.is(kind) && "consumed an unexpected token"); 503 consumeToken(); 504 } 505 506 /// Reset the lexer to the location at the given position. 507 void resetToken(SMRange tokLoc) { 508 lexer.resetPointer(tokLoc.Start.getPointer()); 509 curToken = lexer.lexToken(); 510 } 511 512 /// Consume the specified token if present and return success. On failure, 513 /// output a diagnostic and return failure. 514 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 515 if (curToken.getKind() != kind) 516 return emitError(curToken.getLoc(), msg); 517 consumeToken(); 518 return success(); 519 } 520 LogicalResult emitError(SMRange loc, const Twine &msg) { 521 lexer.emitError(loc, msg); 522 return failure(); 523 } 524 LogicalResult emitError(const Twine &msg) { 525 return emitError(curToken.getLoc(), msg); 526 } 527 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 528 const Twine ¬e) { 529 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 530 return failure(); 531 } 532 533 //===--------------------------------------------------------------------===// 534 // Fields 535 //===--------------------------------------------------------------------===// 536 537 /// The owning AST context. 538 ast::Context &ctx; 539 540 /// The lexer of this parser. 541 Lexer lexer; 542 543 /// The current token within the lexer. 544 Token curToken; 545 546 /// A flag indicating if the parser should add documentation to AST nodes when 547 /// viable. 548 bool enableDocumentation; 549 550 /// The most recently defined decl scope. 551 ast::DeclScope *curDeclScope = nullptr; 552 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 553 554 /// The current context of the parser. 555 ParserContext parserContext = ParserContext::Global; 556 557 /// Cached types to simplify verification and expression creation. 558 ast::Type valueTy, valueRangeTy; 559 ast::Type typeTy, typeRangeTy; 560 ast::Type attrTy; 561 562 /// A counter used when naming anonymous constraints and rewrites. 563 unsigned anonymousDeclNameCounter = 0; 564 565 /// The optional code completion context. 566 CodeCompleteContext *codeCompleteContext; 567 }; 568 } // namespace 569 570 FailureOr<ast::Module *> Parser::parseModule() { 571 SMLoc moduleLoc = curToken.getStartLoc(); 572 pushDeclScope(); 573 574 // Parse the top-level decls of the module. 575 SmallVector<ast::Decl *> decls; 576 if (failed(parseModuleBody(decls))) 577 return popDeclScope(), failure(); 578 579 popDeclScope(); 580 return ast::Module::create(ctx, moduleLoc, decls); 581 } 582 583 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 584 while (curToken.isNot(Token::eof)) { 585 if (curToken.is(Token::directive)) { 586 if (failed(parseDirective(decls))) 587 return failure(); 588 continue; 589 } 590 591 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 592 if (failed(decl)) 593 return failure(); 594 decls.push_back(*decl); 595 } 596 return success(); 597 } 598 599 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 600 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 601 valueRangeTy); 602 } 603 604 LogicalResult Parser::convertExpressionTo( 605 ast::Expr *&expr, ast::Type type, 606 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 607 ast::Type exprType = expr->getType(); 608 if (exprType == type) 609 return success(); 610 611 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 612 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 613 expr->getLoc(), llvm::formatv("unable to convert expression of type " 614 "`{0}` to the expected type of " 615 "`{1}`", 616 exprType, type)); 617 if (noteAttachFn) 618 noteAttachFn(*diag); 619 return diag; 620 }; 621 622 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 623 // Two operation types are compatible if they have the same name, or if the 624 // expected type is more general. 625 if (auto opType = type.dyn_cast<ast::OperationType>()) { 626 if (opType.getName()) 627 return emitConvertError(); 628 return success(); 629 } 630 631 // An operation can always convert to a ValueRange. 632 if (type == valueRangeTy) { 633 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 634 valueRangeTy); 635 return success(); 636 } 637 638 // Allow conversion to a single value by constraining the result range. 639 if (type == valueTy) { 640 // If the operation is registered, we can verify if it can ever have a 641 // single result. 642 if (const ods::Operation *odsOp = exprOpType.getODSOperation()) { 643 if (odsOp->getResults().empty()) { 644 return emitConvertError()->attachNote( 645 llvm::formatv("see the definition of `{0}`, which was defined " 646 "with zero results", 647 odsOp->getName()), 648 odsOp->getLoc()); 649 } 650 651 unsigned numSingleResults = llvm::count_if( 652 odsOp->getResults(), [](const ods::OperandOrResult &result) { 653 return result.getVariableLengthKind() == 654 ods::VariableLengthKind::Single; 655 }); 656 if (numSingleResults > 1) { 657 return emitConvertError()->attachNote( 658 llvm::formatv("see the definition of `{0}`, which was defined " 659 "with at least {1} results", 660 odsOp->getName(), numSingleResults), 661 odsOp->getLoc()); 662 } 663 } 664 665 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 666 valueTy); 667 return success(); 668 } 669 return emitConvertError(); 670 } 671 672 // FIXME: Decide how to allow/support converting a single result to multiple, 673 // and multiple to a single result. For now, we just allow Single->Range, 674 // but this isn't something really supported in the PDL dialect. We should 675 // figure out some way to support both. 676 if ((exprType == valueTy || exprType == valueRangeTy) && 677 (type == valueTy || type == valueRangeTy)) 678 return success(); 679 if ((exprType == typeTy || exprType == typeRangeTy) && 680 (type == typeTy || type == typeRangeTy)) 681 return success(); 682 683 // Handle tuple types. 684 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 685 auto tupleType = type.dyn_cast<ast::TupleType>(); 686 if (!tupleType || tupleType.size() != exprTupleType.size()) 687 return emitConvertError(); 688 689 // Build a new tuple expression using each of the elements of the current 690 // tuple. 691 SmallVector<ast::Expr *> newExprs; 692 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 693 newExprs.push_back(ast::MemberAccessExpr::create( 694 ctx, expr->getLoc(), expr, llvm::to_string(i), 695 exprTupleType.getElementTypes()[i])); 696 697 auto diagFn = [&](ast::Diagnostic &diag) { 698 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 699 i, exprTupleType)); 700 if (noteAttachFn) 701 noteAttachFn(diag); 702 }; 703 if (failed(convertExpressionTo(newExprs.back(), 704 tupleType.getElementTypes()[i], diagFn))) 705 return failure(); 706 } 707 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 708 tupleType.getElementNames()); 709 return success(); 710 } 711 712 return emitConvertError(); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Directives 717 718 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 719 StringRef directive = curToken.getSpelling(); 720 if (directive == "#include") 721 return parseInclude(decls); 722 723 return emitError("unknown directive `" + directive + "`"); 724 } 725 726 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 727 SMRange loc = curToken.getLoc(); 728 consumeToken(Token::directive); 729 730 // Handle code completion of the include file path. 731 if (curToken.is(Token::code_complete_string)) 732 return codeCompleteIncludeFilename(curToken.getStringValue()); 733 734 // Parse the file being included. 735 if (!curToken.isString()) 736 return emitError(loc, 737 "expected string file name after `include` directive"); 738 SMRange fileLoc = curToken.getLoc(); 739 std::string filenameStr = curToken.getStringValue(); 740 StringRef filename = filenameStr; 741 consumeToken(); 742 743 // Check the type of include. If ending with `.pdll`, this is another pdl file 744 // to be parsed along with the current module. 745 if (filename.endswith(".pdll")) { 746 if (failed(lexer.pushInclude(filename, fileLoc))) 747 return emitError(fileLoc, 748 "unable to open include file `" + filename + "`"); 749 750 // If we added the include successfully, parse it into the current module. 751 // Make sure to update to the next token after we finish parsing the nested 752 // file. 753 curToken = lexer.lexToken(); 754 LogicalResult result = parseModuleBody(decls); 755 curToken = lexer.lexToken(); 756 return result; 757 } 758 759 // Otherwise, this must be a `.td` include. 760 if (filename.endswith(".td")) 761 return parseTdInclude(filename, fileLoc, decls); 762 763 return emitError(fileLoc, 764 "expected include filename to end with `.pdll` or `.td`"); 765 } 766 767 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 768 SmallVectorImpl<ast::Decl *> &decls) { 769 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 770 771 // Use the source manager to open the file, but don't yet add it. 772 std::string includedFile; 773 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 774 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 775 if (!includeBuffer) 776 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 777 778 // Setup the source manager for parsing the tablegen file. 779 llvm::SourceMgr tdSrcMgr; 780 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); 781 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); 782 783 // This class provides a context argument for the llvm::SourceMgr diagnostic 784 // handler. 785 struct DiagHandlerContext { 786 Parser &parser; 787 StringRef filename; 788 llvm::SMRange loc; 789 } handlerContext{*this, filename, fileLoc}; 790 791 // Set the diagnostic handler for the tablegen source manager. 792 tdSrcMgr.setDiagHandler( 793 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 794 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 795 (void)ctx->parser.emitError( 796 ctx->loc, 797 llvm::formatv("error while processing include file `{0}`: {1}", 798 ctx->filename, diag.getMessage())); 799 }, 800 &handlerContext); 801 802 // Parse the tablegen file. 803 llvm::RecordKeeper tdRecords; 804 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords)) 805 return failure(); 806 807 // Process the parsed records. 808 processTdIncludeRecords(tdRecords, decls); 809 810 // After we are done processing, move all of the tablegen source buffers to 811 // the main parser source mgr. This allows for directly using source locations 812 // from the .td files without needing to remap them. 813 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End); 814 return success(); 815 } 816 817 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 818 SmallVectorImpl<ast::Decl *> &decls) { 819 // Return the length kind of the given value. 820 auto getLengthKind = [](const auto &value) { 821 if (value.isOptional()) 822 return ods::VariableLengthKind::Optional; 823 return value.isVariadic() ? ods::VariableLengthKind::Variadic 824 : ods::VariableLengthKind::Single; 825 }; 826 827 // Insert a type constraint into the ODS context. 828 ods::Context &odsContext = ctx.getODSContext(); 829 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 830 -> const ods::TypeConstraint & { 831 return odsContext.insertTypeConstraint( 832 cst.constraint.getUniqueDefName(), 833 processDoc(cst.constraint.getSummary()), 834 cst.constraint.getCPPClassName()); 835 }; 836 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 837 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 838 }; 839 840 // Process the parsed tablegen records to build ODS information. 841 /// Operations. 842 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 843 tblgen::Operator op(def); 844 845 // Check to see if this operation is known to support type inferrence. 846 bool supportsResultTypeInferrence = 847 op.getTrait("::mlir::InferTypeOpInterface::Trait"); 848 849 bool inserted = false; 850 ods::Operation *odsOp = nullptr; 851 std::tie(odsOp, inserted) = odsContext.insertOperation( 852 op.getOperationName(), processDoc(op.getSummary()), 853 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(), 854 supportsResultTypeInferrence, op.getLoc().front()); 855 856 // Ignore operations that have already been added. 857 if (!inserted) 858 continue; 859 860 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 861 odsOp->appendAttribute(attr.name, attr.attr.isOptional(), 862 odsContext.insertAttributeConstraint( 863 attr.attr.getUniqueDefName(), 864 processDoc(attr.attr.getSummary()), 865 attr.attr.getStorageType())); 866 } 867 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 868 odsOp->appendOperand(operand.name, getLengthKind(operand), 869 addTypeConstraint(operand)); 870 } 871 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 872 odsOp->appendResult(result.name, getLengthKind(result), 873 addTypeConstraint(result)); 874 } 875 } 876 877 auto shouldBeSkipped = [this](llvm::Record *def) { 878 return def->isAnonymous() || curDeclScope->lookup(def->getName()) || 879 def->isSubClassOf("DeclareInterfaceMethods"); 880 }; 881 882 /// Attr constraints. 883 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 884 if (shouldBeSkipped(def)) 885 continue; 886 887 tblgen::Attribute constraint(def); 888 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 889 constraint, convertLocToRange(def->getLoc().front()), attrTy, 890 constraint.getStorageType())); 891 } 892 /// Type constraints. 893 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 894 if (shouldBeSkipped(def)) 895 continue; 896 897 tblgen::TypeConstraint constraint(def); 898 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 899 constraint, convertLocToRange(def->getLoc().front()), typeTy, 900 constraint.getCPPClassName())); 901 } 902 /// OpInterfaces. 903 ast::Type opTy = ast::OperationType::get(ctx); 904 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) { 905 if (shouldBeSkipped(def)) 906 continue; 907 908 SMRange loc = convertLocToRange(def->getLoc().front()); 909 910 std::string cppClassName = 911 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), 912 def->getValueAsString("cppInterfaceName")) 913 .str(); 914 std::string codeBlock = 915 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", 916 cppClassName) 917 .str(); 918 919 std::string desc = 920 processAndFormatDoc(def->getValueAsString("description")); 921 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 922 def->getName(), codeBlock, loc, opTy, cppClassName, desc)); 923 } 924 } 925 926 template <typename ConstraintT> 927 ast::Decl *Parser::createODSNativePDLLConstraintDecl( 928 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type, 929 StringRef nativeType, StringRef docString) { 930 // Build the single input parameter. 931 ast::DeclScope *argScope = pushDeclScope(); 932 auto *paramVar = ast::VariableDecl::create( 933 ctx, ast::Name::create(ctx, "self", loc), type, 934 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 935 argScope->add(paramVar); 936 popDeclScope(); 937 938 // Build the native constraint. 939 auto *constraintDecl = ast::UserConstraintDecl::createNative( 940 ctx, ast::Name::create(ctx, name, loc), paramVar, 941 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType); 942 constraintDecl->setDocComment(ctx, docString); 943 curDeclScope->add(constraintDecl); 944 return constraintDecl; 945 } 946 947 template <typename ConstraintT> 948 ast::Decl * 949 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 950 SMRange loc, ast::Type type, 951 StringRef nativeType) { 952 // Format the condition template. 953 tblgen::FmtContext fmtContext; 954 fmtContext.withSelf("self"); 955 std::string codeBlock = tblgen::tgfmt( 956 "return ::mlir::success(" + constraint.getConditionTemplate() + ");", 957 &fmtContext); 958 959 // If documentation was enabled, build the doc string for the generated 960 // constraint. It would be nice to do this lazily, but TableGen information is 961 // destroyed after we finish parsing the file. 962 std::string docString; 963 if (enableDocumentation) { 964 StringRef desc = constraint.getDescription(); 965 docString = processAndFormatDoc( 966 constraint.getSummary() + 967 (desc.empty() ? "" : ("\n\n" + constraint.getDescription()))); 968 } 969 970 return createODSNativePDLLConstraintDecl<ConstraintT>( 971 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType, 972 docString); 973 } 974 975 //===----------------------------------------------------------------------===// 976 // Decls 977 978 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 979 FailureOr<ast::Decl *> decl; 980 switch (curToken.getKind()) { 981 case Token::kw_Constraint: 982 decl = parseUserConstraintDecl(); 983 break; 984 case Token::kw_Pattern: 985 decl = parsePatternDecl(); 986 break; 987 case Token::kw_Rewrite: 988 decl = parseUserRewriteDecl(); 989 break; 990 default: 991 return emitError("expected top-level declaration, such as a `Pattern`"); 992 } 993 if (failed(decl)) 994 return failure(); 995 996 // If the decl has a name, add it to the current scope. 997 if (const ast::Name *name = (*decl)->getName()) { 998 if (failed(checkDefineNamedDecl(*name))) 999 return failure(); 1000 curDeclScope->add(*decl); 1001 } 1002 return decl; 1003 } 1004 1005 FailureOr<ast::NamedAttributeDecl *> 1006 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) { 1007 // Check for name code completion. 1008 if (curToken.is(Token::code_complete)) 1009 return codeCompleteAttributeName(parentOpName); 1010 1011 std::string attrNameStr; 1012 if (curToken.isString()) 1013 attrNameStr = curToken.getStringValue(); 1014 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 1015 attrNameStr = curToken.getSpelling().str(); 1016 else 1017 return emitError("expected identifier or string attribute name"); 1018 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 1019 consumeToken(); 1020 1021 // Check for a value of the attribute. 1022 ast::Expr *attrValue = nullptr; 1023 if (consumeIf(Token::equal)) { 1024 FailureOr<ast::Expr *> attrExpr = parseExpr(); 1025 if (failed(attrExpr)) 1026 return failure(); 1027 attrValue = *attrExpr; 1028 } else { 1029 // If there isn't a concrete value, create an expression representing a 1030 // UnitAttr. 1031 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 1032 } 1033 1034 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 1035 } 1036 1037 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 1038 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 1039 bool expectTerminalSemicolon) { 1040 consumeToken(Token::equal_arrow); 1041 1042 // Parse the single statement of the lambda body. 1043 SMLoc bodyStartLoc = curToken.getStartLoc(); 1044 pushDeclScope(); 1045 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 1046 bool failedToParse = 1047 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 1048 popDeclScope(); 1049 if (failedToParse) 1050 return failure(); 1051 1052 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 1053 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 1054 } 1055 1056 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 1057 // Ensure that the argument is named. 1058 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 1059 return emitError("expected identifier argument name"); 1060 1061 // Parse the argument similarly to a normal variable. 1062 StringRef name = curToken.getSpelling(); 1063 SMRange nameLoc = curToken.getLoc(); 1064 consumeToken(); 1065 1066 if (failed( 1067 parseToken(Token::colon, "expected `:` before argument constraint"))) 1068 return failure(); 1069 1070 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1071 if (failed(cst)) 1072 return failure(); 1073 1074 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1075 } 1076 1077 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 1078 // Check to see if this result is named. 1079 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1080 // Check to see if this name actually refers to a Constraint. 1081 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 1082 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 1083 // If yes, and this is a Rewrite, give a nice error message as non-Core 1084 // constraints are not supported on Rewrite results. 1085 if (parserContext == ParserContext::Rewrite) { 1086 return emitError( 1087 "`Rewrite` results are only permitted to use core constraints, " 1088 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 1089 } 1090 1091 // Otherwise, parse this as an unnamed result variable. 1092 } else { 1093 // If it wasn't a constraint, parse the result similarly to a variable. If 1094 // there is already an existing decl, we will emit an error when defining 1095 // this variable later. 1096 StringRef name = curToken.getSpelling(); 1097 SMRange nameLoc = curToken.getLoc(); 1098 consumeToken(); 1099 1100 if (failed(parseToken(Token::colon, 1101 "expected `:` before result constraint"))) 1102 return failure(); 1103 1104 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1105 if (failed(cst)) 1106 return failure(); 1107 1108 return createArgOrResultVariableDecl(name, nameLoc, *cst); 1109 } 1110 } 1111 1112 // If it isn't named, we parse the constraint directly and create an unnamed 1113 // result variable. 1114 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 1115 if (failed(cst)) 1116 return failure(); 1117 1118 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1119 } 1120 1121 FailureOr<ast::UserConstraintDecl *> 1122 Parser::parseUserConstraintDecl(bool isInline) { 1123 // Constraints and rewrites have very similar formats, dispatch to a shared 1124 // interface for parsing. 1125 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1126 [&](auto &&...args) { 1127 return this->parseUserPDLLConstraintDecl(args...); 1128 }, 1129 ParserContext::Constraint, "constraint", isInline); 1130 } 1131 1132 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1133 FailureOr<ast::UserConstraintDecl *> decl = 1134 parseUserConstraintDecl(/*isInline=*/true); 1135 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1136 return failure(); 1137 1138 curDeclScope->add(*decl); 1139 return decl; 1140 } 1141 1142 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1143 const ast::Name &name, bool isInline, 1144 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1145 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1146 // Push the argument scope back onto the list, so that the body can 1147 // reference arguments. 1148 pushDeclScope(argumentScope); 1149 1150 // Parse the body of the constraint. The body is either defined as a compound 1151 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1152 ast::CompoundStmt *body; 1153 if (curToken.is(Token::equal_arrow)) { 1154 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1155 [&](ast::Stmt *&stmt) -> LogicalResult { 1156 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1157 if (!stmtExpr) { 1158 return emitError(stmt->getLoc(), 1159 "expected `Constraint` lambda body to contain a " 1160 "single expression"); 1161 } 1162 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1163 return success(); 1164 }, 1165 /*expectTerminalSemicolon=*/!isInline); 1166 if (failed(bodyResult)) 1167 return failure(); 1168 body = *bodyResult; 1169 } else { 1170 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1171 if (failed(bodyResult)) 1172 return failure(); 1173 body = *bodyResult; 1174 1175 // Verify the structure of the body. 1176 auto bodyIt = body->begin(), bodyE = body->end(); 1177 for (; bodyIt != bodyE; ++bodyIt) 1178 if (isa<ast::ReturnStmt>(*bodyIt)) 1179 break; 1180 if (failed(validateUserConstraintOrRewriteReturn( 1181 "Constraint", body, bodyIt, bodyE, results, resultType))) 1182 return failure(); 1183 } 1184 popDeclScope(); 1185 1186 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1187 name, arguments, results, resultType, body); 1188 } 1189 1190 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1191 // Constraints and rewrites have very similar formats, dispatch to a shared 1192 // interface for parsing. 1193 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1194 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1195 ParserContext::Rewrite, "rewrite", isInline); 1196 } 1197 1198 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1199 FailureOr<ast::UserRewriteDecl *> decl = 1200 parseUserRewriteDecl(/*isInline=*/true); 1201 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1202 return failure(); 1203 1204 curDeclScope->add(*decl); 1205 return decl; 1206 } 1207 1208 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1209 const ast::Name &name, bool isInline, 1210 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1211 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1212 // Push the argument scope back onto the list, so that the body can 1213 // reference arguments. 1214 curDeclScope = argumentScope; 1215 ast::CompoundStmt *body; 1216 if (curToken.is(Token::equal_arrow)) { 1217 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1218 [&](ast::Stmt *&statement) -> LogicalResult { 1219 if (isa<ast::OpRewriteStmt>(statement)) 1220 return success(); 1221 1222 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1223 if (!statementExpr) { 1224 return emitError( 1225 statement->getLoc(), 1226 "expected `Rewrite` lambda body to contain a single expression " 1227 "or an operation rewrite statement; such as `erase`, " 1228 "`replace`, or `rewrite`"); 1229 } 1230 statement = 1231 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1232 return success(); 1233 }, 1234 /*expectTerminalSemicolon=*/!isInline); 1235 if (failed(bodyResult)) 1236 return failure(); 1237 body = *bodyResult; 1238 } else { 1239 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1240 if (failed(bodyResult)) 1241 return failure(); 1242 body = *bodyResult; 1243 } 1244 popDeclScope(); 1245 1246 // Verify the structure of the body. 1247 auto bodyIt = body->begin(), bodyE = body->end(); 1248 for (; bodyIt != bodyE; ++bodyIt) 1249 if (isa<ast::ReturnStmt>(*bodyIt)) 1250 break; 1251 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1252 bodyE, results, resultType))) 1253 return failure(); 1254 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1255 name, arguments, results, resultType, body); 1256 } 1257 1258 template <typename T, typename ParseUserPDLLDeclFnT> 1259 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1260 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1261 StringRef anonymousNamePrefix, bool isInline) { 1262 SMRange loc = curToken.getLoc(); 1263 consumeToken(); 1264 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1265 1266 // Parse the name of the decl. 1267 const ast::Name *name = nullptr; 1268 if (curToken.isNot(Token::identifier)) { 1269 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1270 // in C++, so being unnamed is fine. 1271 if (!isInline) 1272 return emitError("expected identifier name"); 1273 1274 // Create a unique anonymous name to use, as the name for this decl is not 1275 // important. 1276 std::string anonName = 1277 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1278 anonymousDeclNameCounter++) 1279 .str(); 1280 name = &ast::Name::create(ctx, anonName, loc); 1281 } else { 1282 // If a name was provided, we can use it directly. 1283 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1284 consumeToken(Token::identifier); 1285 } 1286 1287 // Parse the functional signature of the decl. 1288 SmallVector<ast::VariableDecl *> arguments, results; 1289 ast::DeclScope *argumentScope; 1290 ast::Type resultType; 1291 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1292 argumentScope, resultType))) 1293 return failure(); 1294 1295 // Check to see which type of constraint this is. If the constraint contains a 1296 // compound body, this is a PDLL decl. 1297 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1298 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1299 resultType); 1300 1301 // Otherwise, this is a native decl. 1302 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1303 results, resultType); 1304 } 1305 1306 template <typename T> 1307 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1308 const ast::Name &name, bool isInline, 1309 ArrayRef<ast::VariableDecl *> arguments, 1310 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1311 // If followed by a string, the native code body has also been specified. 1312 std::string codeStrStorage; 1313 Optional<StringRef> optCodeStr; 1314 if (curToken.isString()) { 1315 codeStrStorage = curToken.getStringValue(); 1316 optCodeStr = codeStrStorage; 1317 consumeToken(); 1318 } else if (isInline) { 1319 return emitError(name.getLoc(), 1320 "external declarations must be declared in global scope"); 1321 } else if (curToken.is(Token::error)) { 1322 return failure(); 1323 } 1324 if (failed(parseToken(Token::semicolon, 1325 "expected `;` after native declaration"))) 1326 return failure(); 1327 // TODO: PDL should be able to support constraint results in certain 1328 // situations, we should revise this. 1329 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1330 return emitError( 1331 "native Constraints currently do not support returning results"); 1332 } 1333 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1334 } 1335 1336 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1337 SmallVectorImpl<ast::VariableDecl *> &arguments, 1338 SmallVectorImpl<ast::VariableDecl *> &results, 1339 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1340 // Parse the argument list of the decl. 1341 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1342 return failure(); 1343 1344 argumentScope = pushDeclScope(); 1345 if (curToken.isNot(Token::r_paren)) { 1346 do { 1347 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1348 if (failed(argument)) 1349 return failure(); 1350 arguments.emplace_back(*argument); 1351 } while (consumeIf(Token::comma)); 1352 } 1353 popDeclScope(); 1354 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1355 return failure(); 1356 1357 // Parse the results of the decl. 1358 pushDeclScope(); 1359 if (consumeIf(Token::arrow)) { 1360 auto parseResultFn = [&]() -> LogicalResult { 1361 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1362 if (failed(result)) 1363 return failure(); 1364 results.emplace_back(*result); 1365 return success(); 1366 }; 1367 1368 // Check for a list of results. 1369 if (consumeIf(Token::l_paren)) { 1370 do { 1371 if (failed(parseResultFn())) 1372 return failure(); 1373 } while (consumeIf(Token::comma)); 1374 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1375 return failure(); 1376 1377 // Otherwise, there is only one result. 1378 } else if (failed(parseResultFn())) { 1379 return failure(); 1380 } 1381 } 1382 popDeclScope(); 1383 1384 // Compute the result type of the decl. 1385 resultType = createUserConstraintRewriteResultType(results); 1386 1387 // Verify that results are only named if there are more than one. 1388 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1389 return emitError( 1390 results.front()->getLoc(), 1391 "cannot create a single-element tuple with an element label"); 1392 } 1393 return success(); 1394 } 1395 1396 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1397 StringRef declType, ast::CompoundStmt *body, 1398 ArrayRef<ast::Stmt *>::iterator bodyIt, 1399 ArrayRef<ast::Stmt *>::iterator bodyE, 1400 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1401 // Handle if a `return` was provided. 1402 if (bodyIt != bodyE) { 1403 // Emit an error if we have trailing statements after the return. 1404 if (std::next(bodyIt) != bodyE) { 1405 return emitError( 1406 (*std::next(bodyIt))->getLoc(), 1407 llvm::formatv("`return` terminated the `{0}` body, but found " 1408 "trailing statements afterwards", 1409 declType)); 1410 } 1411 1412 // Otherwise if a return wasn't provided, check that no results are 1413 // expected. 1414 } else if (!results.empty()) { 1415 return emitError( 1416 {body->getLoc().End, body->getLoc().End}, 1417 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1418 declType, resultType)); 1419 } 1420 return success(); 1421 } 1422 1423 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1424 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1425 if (isa<ast::OpRewriteStmt>(statement)) 1426 return success(); 1427 return emitError( 1428 statement->getLoc(), 1429 "expected Pattern lambda body to contain a single operation " 1430 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1431 }); 1432 } 1433 1434 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1435 SMRange loc = curToken.getLoc(); 1436 consumeToken(Token::kw_Pattern); 1437 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1438 ParserContext::PatternMatch); 1439 1440 // Check for an optional identifier for the pattern name. 1441 const ast::Name *name = nullptr; 1442 if (curToken.is(Token::identifier)) { 1443 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1444 consumeToken(Token::identifier); 1445 } 1446 1447 // Parse any pattern metadata. 1448 ParsedPatternMetadata metadata; 1449 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1450 return failure(); 1451 1452 // Parse the pattern body. 1453 ast::CompoundStmt *body; 1454 1455 // Handle a lambda body. 1456 if (curToken.is(Token::equal_arrow)) { 1457 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1458 if (failed(bodyResult)) 1459 return failure(); 1460 body = *bodyResult; 1461 } else { 1462 if (curToken.isNot(Token::l_brace)) 1463 return emitError("expected `{` or `=>` to start pattern body"); 1464 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1465 if (failed(bodyResult)) 1466 return failure(); 1467 body = *bodyResult; 1468 1469 // Verify the body of the pattern. 1470 auto bodyIt = body->begin(), bodyE = body->end(); 1471 for (; bodyIt != bodyE; ++bodyIt) { 1472 if (isa<ast::ReturnStmt>(*bodyIt)) { 1473 return emitError((*bodyIt)->getLoc(), 1474 "`return` statements are only permitted within a " 1475 "`Constraint` or `Rewrite` body"); 1476 } 1477 // Break when we've found the rewrite statement. 1478 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1479 break; 1480 } 1481 if (bodyIt == bodyE) { 1482 return emitError(loc, 1483 "expected Pattern body to terminate with an operation " 1484 "rewrite statement, such as `erase`"); 1485 } 1486 if (std::next(bodyIt) != bodyE) { 1487 return emitError((*std::next(bodyIt))->getLoc(), 1488 "Pattern body was terminated by an operation " 1489 "rewrite statement, but found trailing statements"); 1490 } 1491 } 1492 1493 return createPatternDecl(loc, name, metadata, body); 1494 } 1495 1496 LogicalResult 1497 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1498 Optional<SMRange> benefitLoc; 1499 Optional<SMRange> hasBoundedRecursionLoc; 1500 1501 do { 1502 // Handle metadata code completion. 1503 if (curToken.is(Token::code_complete)) 1504 return codeCompletePatternMetadata(); 1505 1506 if (curToken.isNot(Token::identifier)) 1507 return emitError("expected pattern metadata identifier"); 1508 StringRef metadataStr = curToken.getSpelling(); 1509 SMRange metadataLoc = curToken.getLoc(); 1510 consumeToken(Token::identifier); 1511 1512 // Parse the benefit metadata: benefit(<integer-value>) 1513 if (metadataStr == "benefit") { 1514 if (benefitLoc) { 1515 return emitErrorAndNote(metadataLoc, 1516 "pattern benefit has already been specified", 1517 *benefitLoc, "see previous definition here"); 1518 } 1519 if (failed(parseToken(Token::l_paren, 1520 "expected `(` before pattern benefit"))) 1521 return failure(); 1522 1523 uint16_t benefitValue = 0; 1524 if (curToken.isNot(Token::integer)) 1525 return emitError("expected integral pattern benefit"); 1526 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1527 return emitError( 1528 "expected pattern benefit to fit within a 16-bit integer"); 1529 consumeToken(Token::integer); 1530 1531 metadata.benefit = benefitValue; 1532 benefitLoc = metadataLoc; 1533 1534 if (failed( 1535 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1536 return failure(); 1537 continue; 1538 } 1539 1540 // Parse the bounded recursion metadata: recursion 1541 if (metadataStr == "recursion") { 1542 if (hasBoundedRecursionLoc) { 1543 return emitErrorAndNote( 1544 metadataLoc, 1545 "pattern recursion metadata has already been specified", 1546 *hasBoundedRecursionLoc, "see previous definition here"); 1547 } 1548 metadata.hasBoundedRecursion = true; 1549 hasBoundedRecursionLoc = metadataLoc; 1550 continue; 1551 } 1552 1553 return emitError(metadataLoc, "unknown pattern metadata"); 1554 } while (consumeIf(Token::comma)); 1555 1556 return success(); 1557 } 1558 1559 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1560 consumeToken(Token::less); 1561 1562 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1563 if (failed(typeExpr) || 1564 failed(parseToken(Token::greater, 1565 "expected `>` after variable type constraint"))) 1566 return failure(); 1567 return typeExpr; 1568 } 1569 1570 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1571 assert(curDeclScope && "defining decl outside of a decl scope"); 1572 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1573 return emitErrorAndNote( 1574 name.getLoc(), "`" + name.getName() + "` has already been defined", 1575 lastDecl->getName()->getLoc(), "see previous definition here"); 1576 } 1577 return success(); 1578 } 1579 1580 FailureOr<ast::VariableDecl *> 1581 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1582 ast::Expr *initExpr, 1583 ArrayRef<ast::ConstraintRef> constraints) { 1584 assert(curDeclScope && "defining variable outside of decl scope"); 1585 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1586 1587 // If the name of the variable indicates a special variable, we don't add it 1588 // to the scope. This variable is local to the definition point. 1589 if (name.empty() || name == "_") { 1590 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1591 constraints); 1592 } 1593 if (failed(checkDefineNamedDecl(nameDecl))) 1594 return failure(); 1595 1596 auto *varDecl = 1597 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1598 curDeclScope->add(varDecl); 1599 return varDecl; 1600 } 1601 1602 FailureOr<ast::VariableDecl *> 1603 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1604 ArrayRef<ast::ConstraintRef> constraints) { 1605 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1606 constraints); 1607 } 1608 1609 LogicalResult Parser::parseVariableDeclConstraintList( 1610 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1611 Optional<SMRange> typeConstraint; 1612 auto parseSingleConstraint = [&] { 1613 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1614 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, 1615 /*allowNonCoreConstraints=*/true); 1616 if (failed(constraint)) 1617 return failure(); 1618 constraints.push_back(*constraint); 1619 return success(); 1620 }; 1621 1622 // Check to see if this is a single constraint, or a list. 1623 if (!consumeIf(Token::l_square)) 1624 return parseSingleConstraint(); 1625 1626 do { 1627 if (failed(parseSingleConstraint())) 1628 return failure(); 1629 } while (consumeIf(Token::comma)); 1630 return parseToken(Token::r_square, "expected `]` after constraint list"); 1631 } 1632 1633 FailureOr<ast::ConstraintRef> 1634 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1635 ArrayRef<ast::ConstraintRef> existingConstraints, 1636 bool allowInlineTypeConstraints, 1637 bool allowNonCoreConstraints) { 1638 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1639 if (!allowInlineTypeConstraints) { 1640 return emitError( 1641 curToken.getLoc(), 1642 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1643 "permitted on arguments or results"); 1644 } 1645 if (typeConstraint) 1646 return emitErrorAndNote( 1647 curToken.getLoc(), 1648 "the type of this variable has already been constrained", 1649 *typeConstraint, "see previous constraint location here"); 1650 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1651 if (failed(constraintExpr)) 1652 return failure(); 1653 typeExpr = *constraintExpr; 1654 typeConstraint = typeExpr->getLoc(); 1655 return success(); 1656 }; 1657 1658 SMRange loc = curToken.getLoc(); 1659 switch (curToken.getKind()) { 1660 case Token::kw_Attr: { 1661 consumeToken(Token::kw_Attr); 1662 1663 // Check for a type constraint. 1664 ast::Expr *typeExpr = nullptr; 1665 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1666 return failure(); 1667 return ast::ConstraintRef( 1668 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1669 } 1670 case Token::kw_Op: { 1671 consumeToken(Token::kw_Op); 1672 1673 // Parse an optional operation name. If the name isn't provided, this refers 1674 // to "any" operation. 1675 FailureOr<ast::OpNameDecl *> opName = 1676 parseWrappedOperationName(/*allowEmptyName=*/true); 1677 if (failed(opName)) 1678 return failure(); 1679 1680 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1681 loc); 1682 } 1683 case Token::kw_Type: 1684 consumeToken(Token::kw_Type); 1685 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1686 case Token::kw_TypeRange: 1687 consumeToken(Token::kw_TypeRange); 1688 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1689 loc); 1690 case Token::kw_Value: { 1691 consumeToken(Token::kw_Value); 1692 1693 // Check for a type constraint. 1694 ast::Expr *typeExpr = nullptr; 1695 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1696 return failure(); 1697 1698 return ast::ConstraintRef( 1699 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1700 } 1701 case Token::kw_ValueRange: { 1702 consumeToken(Token::kw_ValueRange); 1703 1704 // Check for a type constraint. 1705 ast::Expr *typeExpr = nullptr; 1706 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1707 return failure(); 1708 1709 return ast::ConstraintRef( 1710 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1711 } 1712 1713 case Token::kw_Constraint: { 1714 // Handle an inline constraint. 1715 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1716 if (failed(decl)) 1717 return failure(); 1718 return ast::ConstraintRef(*decl, loc); 1719 } 1720 case Token::identifier: { 1721 StringRef constraintName = curToken.getSpelling(); 1722 consumeToken(Token::identifier); 1723 1724 // Lookup the referenced constraint. 1725 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1726 if (!cstDecl) { 1727 return emitError(loc, "unknown reference to constraint `" + 1728 constraintName + "`"); 1729 } 1730 1731 // Handle a reference to a proper constraint. 1732 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1733 return ast::ConstraintRef(cst, loc); 1734 1735 return emitErrorAndNote( 1736 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1737 "see the definition of `" + constraintName + "` here"); 1738 } 1739 // Handle single entity constraint code completion. 1740 case Token::code_complete: { 1741 // Try to infer the current type for use by code completion. 1742 ast::Type inferredType; 1743 if (failed(validateVariableConstraints(existingConstraints, inferredType, 1744 allowNonCoreConstraints))) 1745 return failure(); 1746 1747 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, 1748 allowInlineTypeConstraints); 1749 } 1750 default: 1751 break; 1752 } 1753 return emitError(loc, "expected identifier constraint"); 1754 } 1755 1756 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1757 // Constraint arguments may apply more complex constraints via the arguments. 1758 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 1759 1760 Optional<SMRange> typeConstraint; 1761 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1762 /*allowInlineTypeConstraints=*/false, 1763 allowNonCoreConstraints); 1764 } 1765 1766 //===----------------------------------------------------------------------===// 1767 // Exprs 1768 1769 FailureOr<ast::Expr *> Parser::parseExpr() { 1770 if (curToken.is(Token::underscore)) 1771 return parseUnderscoreExpr(); 1772 1773 // Parse the LHS expression. 1774 FailureOr<ast::Expr *> lhsExpr; 1775 switch (curToken.getKind()) { 1776 case Token::kw_attr: 1777 lhsExpr = parseAttributeExpr(); 1778 break; 1779 case Token::kw_Constraint: 1780 lhsExpr = parseInlineConstraintLambdaExpr(); 1781 break; 1782 case Token::identifier: 1783 lhsExpr = parseIdentifierExpr(); 1784 break; 1785 case Token::kw_op: 1786 lhsExpr = parseOperationExpr(); 1787 break; 1788 case Token::kw_Rewrite: 1789 lhsExpr = parseInlineRewriteLambdaExpr(); 1790 break; 1791 case Token::kw_type: 1792 lhsExpr = parseTypeExpr(); 1793 break; 1794 case Token::l_paren: 1795 lhsExpr = parseTupleExpr(); 1796 break; 1797 default: 1798 return emitError("expected expression"); 1799 } 1800 if (failed(lhsExpr)) 1801 return failure(); 1802 1803 // Check for an operator expression. 1804 while (true) { 1805 switch (curToken.getKind()) { 1806 case Token::dot: 1807 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1808 break; 1809 case Token::l_paren: 1810 lhsExpr = parseCallExpr(*lhsExpr); 1811 break; 1812 default: 1813 return lhsExpr; 1814 } 1815 if (failed(lhsExpr)) 1816 return failure(); 1817 } 1818 } 1819 1820 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1821 SMRange loc = curToken.getLoc(); 1822 consumeToken(Token::kw_attr); 1823 1824 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1825 // identifier. 1826 if (!consumeIf(Token::less)) { 1827 resetToken(loc); 1828 return parseIdentifierExpr(); 1829 } 1830 1831 if (!curToken.isString()) 1832 return emitError("expected string literal containing MLIR attribute"); 1833 std::string attrExpr = curToken.getStringValue(); 1834 consumeToken(); 1835 1836 loc.End = curToken.getEndLoc(); 1837 if (failed( 1838 parseToken(Token::greater, "expected `>` after attribute literal"))) 1839 return failure(); 1840 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1841 } 1842 1843 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1844 consumeToken(Token::l_paren); 1845 1846 // Parse the arguments of the call. 1847 SmallVector<ast::Expr *> arguments; 1848 if (curToken.isNot(Token::r_paren)) { 1849 do { 1850 // Handle code completion for the call arguments. 1851 if (curToken.is(Token::code_complete)) { 1852 codeCompleteCallSignature(parentExpr, arguments.size()); 1853 return failure(); 1854 } 1855 1856 FailureOr<ast::Expr *> argument = parseExpr(); 1857 if (failed(argument)) 1858 return failure(); 1859 arguments.push_back(*argument); 1860 } while (consumeIf(Token::comma)); 1861 } 1862 1863 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1864 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1865 return failure(); 1866 1867 return createCallExpr(loc, parentExpr, arguments); 1868 } 1869 1870 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1871 ast::Decl *decl = curDeclScope->lookup(name); 1872 if (!decl) 1873 return emitError(loc, "undefined reference to `" + name + "`"); 1874 1875 return createDeclRefExpr(loc, decl); 1876 } 1877 1878 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1879 StringRef name = curToken.getSpelling(); 1880 SMRange nameLoc = curToken.getLoc(); 1881 consumeToken(); 1882 1883 // Check to see if this is a decl ref expression that defines a variable 1884 // inline. 1885 if (consumeIf(Token::colon)) { 1886 SmallVector<ast::ConstraintRef> constraints; 1887 if (failed(parseVariableDeclConstraintList(constraints))) 1888 return failure(); 1889 ast::Type type; 1890 if (failed(validateVariableConstraints(constraints, type))) 1891 return failure(); 1892 return createInlineVariableExpr(type, name, nameLoc, constraints); 1893 } 1894 1895 return parseDeclRefExpr(name, nameLoc); 1896 } 1897 1898 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1899 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1900 if (failed(decl)) 1901 return failure(); 1902 1903 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1904 ast::ConstraintType::get(ctx)); 1905 } 1906 1907 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1908 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1909 if (failed(decl)) 1910 return failure(); 1911 1912 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1913 ast::RewriteType::get(ctx)); 1914 } 1915 1916 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1917 SMRange dotLoc = curToken.getLoc(); 1918 consumeToken(Token::dot); 1919 1920 // Check for code completion of the member name. 1921 if (curToken.is(Token::code_complete)) 1922 return codeCompleteMemberAccess(parentExpr); 1923 1924 // Parse the member name. 1925 Token memberNameTok = curToken; 1926 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1927 !memberNameTok.isKeyword()) 1928 return emitError(dotLoc, "expected identifier or numeric member name"); 1929 StringRef memberName = memberNameTok.getSpelling(); 1930 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); 1931 consumeToken(); 1932 1933 return createMemberAccessExpr(parentExpr, memberName, loc); 1934 } 1935 1936 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1937 SMRange loc = curToken.getLoc(); 1938 1939 // Check for code completion for the dialect name. 1940 if (curToken.is(Token::code_complete)) 1941 return codeCompleteDialectName(); 1942 1943 // Handle the case of an no operation name. 1944 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1945 if (allowEmptyName) 1946 return ast::OpNameDecl::create(ctx, SMRange()); 1947 return emitError("expected dialect namespace"); 1948 } 1949 StringRef name = curToken.getSpelling(); 1950 consumeToken(); 1951 1952 // Otherwise, this is a literal operation name. 1953 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1954 return failure(); 1955 1956 // Check for code completion for the operation name. 1957 if (curToken.is(Token::code_complete)) 1958 return codeCompleteOperationName(name); 1959 1960 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1961 return emitError("expected operation name after dialect namespace"); 1962 1963 name = StringRef(name.data(), name.size() + 1); 1964 do { 1965 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1966 loc.End = curToken.getEndLoc(); 1967 consumeToken(); 1968 } while (curToken.isAny(Token::identifier, Token::dot) || 1969 curToken.isKeyword()); 1970 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1971 } 1972 1973 FailureOr<ast::OpNameDecl *> 1974 Parser::parseWrappedOperationName(bool allowEmptyName) { 1975 if (!consumeIf(Token::less)) 1976 return ast::OpNameDecl::create(ctx, SMRange()); 1977 1978 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1979 if (failed(opNameDecl)) 1980 return failure(); 1981 1982 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1983 return failure(); 1984 return opNameDecl; 1985 } 1986 1987 FailureOr<ast::Expr *> 1988 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { 1989 SMRange loc = curToken.getLoc(); 1990 consumeToken(Token::kw_op); 1991 1992 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1993 // identifier. 1994 if (curToken.isNot(Token::less)) { 1995 resetToken(loc); 1996 return parseIdentifierExpr(); 1997 } 1998 1999 // Parse the operation name. The name may be elided, in which case the 2000 // operation refers to "any" operation(i.e. a difference between `MyOp` and 2001 // `Operation*`). Operation names within a rewrite context must be named. 2002 bool allowEmptyName = parserContext != ParserContext::Rewrite; 2003 FailureOr<ast::OpNameDecl *> opNameDecl = 2004 parseWrappedOperationName(allowEmptyName); 2005 if (failed(opNameDecl)) 2006 return failure(); 2007 Optional<StringRef> opName = (*opNameDecl)->getName(); 2008 2009 // Functor used to create an implicit range variable, used for implicit "all" 2010 // operand or results variables. 2011 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 2012 FailureOr<ast::VariableDecl *> rangeVar = 2013 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 2014 assert(succeeded(rangeVar) && "expected range variable to be valid"); 2015 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 2016 }; 2017 2018 // Check for the optional list of operands. 2019 SmallVector<ast::Expr *> operands; 2020 if (!consumeIf(Token::l_paren)) { 2021 // If the operand list isn't specified and we are in a match context, define 2022 // an inplace unconstrained operand range corresponding to all of the 2023 // operands of the operation. This avoids treating zero operands the same 2024 // way as "unconstrained operands". 2025 if (parserContext != ParserContext::Rewrite) { 2026 operands.push_back(createImplicitRangeVar( 2027 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 2028 } 2029 } else if (!consumeIf(Token::r_paren)) { 2030 // If the operand list was specified and non-empty, parse the operands. 2031 do { 2032 // Check for operand signature code completion. 2033 if (curToken.is(Token::code_complete)) { 2034 codeCompleteOperationOperandsSignature(opName, operands.size()); 2035 return failure(); 2036 } 2037 2038 FailureOr<ast::Expr *> operand = parseExpr(); 2039 if (failed(operand)) 2040 return failure(); 2041 operands.push_back(*operand); 2042 } while (consumeIf(Token::comma)); 2043 2044 if (failed(parseToken(Token::r_paren, 2045 "expected `)` after operation operand list"))) 2046 return failure(); 2047 } 2048 2049 // Check for the optional list of attributes. 2050 SmallVector<ast::NamedAttributeDecl *> attributes; 2051 if (consumeIf(Token::l_brace)) { 2052 do { 2053 FailureOr<ast::NamedAttributeDecl *> decl = 2054 parseNamedAttributeDecl(opName); 2055 if (failed(decl)) 2056 return failure(); 2057 attributes.emplace_back(*decl); 2058 } while (consumeIf(Token::comma)); 2059 2060 if (failed(parseToken(Token::r_brace, 2061 "expected `}` after operation attribute list"))) 2062 return failure(); 2063 } 2064 2065 // Handle the result types of the operation. 2066 SmallVector<ast::Expr *> resultTypes; 2067 OpResultTypeContext resultTypeContext = inputResultTypeContext; 2068 2069 // Check for an explicit list of result types. 2070 if (consumeIf(Token::arrow)) { 2071 if (failed(parseToken(Token::l_paren, 2072 "expected `(` before operation result type list"))) 2073 return failure(); 2074 2075 // If result types are provided, initially assume that the operation does 2076 // not rely on type inferrence. We don't assert that it isn't, because we 2077 // may be inferring the value of some type/type range variables, but given 2078 // that these variables may be defined in calls we can't always discern when 2079 // this is the case. 2080 resultTypeContext = OpResultTypeContext::Explicit; 2081 2082 // Handle the case of an empty result list. 2083 if (!consumeIf(Token::r_paren)) { 2084 do { 2085 // Check for result signature code completion. 2086 if (curToken.is(Token::code_complete)) { 2087 codeCompleteOperationResultsSignature(opName, resultTypes.size()); 2088 return failure(); 2089 } 2090 2091 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 2092 if (failed(resultTypeExpr)) 2093 return failure(); 2094 resultTypes.push_back(*resultTypeExpr); 2095 } while (consumeIf(Token::comma)); 2096 2097 if (failed(parseToken(Token::r_paren, 2098 "expected `)` after operation result type list"))) 2099 return failure(); 2100 } 2101 } else if (parserContext != ParserContext::Rewrite) { 2102 // If the result list isn't specified and we are in a match context, define 2103 // an inplace unconstrained result range corresponding to all of the results 2104 // of the operation. This avoids treating zero results the same way as 2105 // "unconstrained results". 2106 resultTypes.push_back(createImplicitRangeVar( 2107 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 2108 } else if (resultTypeContext == OpResultTypeContext::Explicit) { 2109 // If the result list isn't specified and we are in a rewrite, try to infer 2110 // them at runtime instead. 2111 resultTypeContext = OpResultTypeContext::Interface; 2112 } 2113 2114 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands, 2115 attributes, resultTypes); 2116 } 2117 2118 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 2119 SMRange loc = curToken.getLoc(); 2120 consumeToken(Token::l_paren); 2121 2122 DenseMap<StringRef, SMRange> usedNames; 2123 SmallVector<StringRef> elementNames; 2124 SmallVector<ast::Expr *> elements; 2125 if (curToken.isNot(Token::r_paren)) { 2126 do { 2127 // Check for the optional element name assignment before the value. 2128 StringRef elementName; 2129 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 2130 Token elementNameTok = curToken; 2131 consumeToken(); 2132 2133 // The element name is only present if followed by an `=`. 2134 if (consumeIf(Token::equal)) { 2135 elementName = elementNameTok.getSpelling(); 2136 2137 // Check to see if this name is already used. 2138 auto elementNameIt = 2139 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 2140 if (!elementNameIt.second) { 2141 return emitErrorAndNote( 2142 elementNameTok.getLoc(), 2143 llvm::formatv("duplicate tuple element label `{0}`", 2144 elementName), 2145 elementNameIt.first->getSecond(), 2146 "see previous label use here"); 2147 } 2148 } else { 2149 // Otherwise, we treat this as part of an expression so reset the 2150 // lexer. 2151 resetToken(elementNameTok.getLoc()); 2152 } 2153 } 2154 elementNames.push_back(elementName); 2155 2156 // Parse the tuple element value. 2157 FailureOr<ast::Expr *> element = parseExpr(); 2158 if (failed(element)) 2159 return failure(); 2160 elements.push_back(*element); 2161 } while (consumeIf(Token::comma)); 2162 } 2163 loc.End = curToken.getEndLoc(); 2164 if (failed( 2165 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 2166 return failure(); 2167 return createTupleExpr(loc, elements, elementNames); 2168 } 2169 2170 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 2171 SMRange loc = curToken.getLoc(); 2172 consumeToken(Token::kw_type); 2173 2174 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 2175 // identifier. 2176 if (!consumeIf(Token::less)) { 2177 resetToken(loc); 2178 return parseIdentifierExpr(); 2179 } 2180 2181 if (!curToken.isString()) 2182 return emitError("expected string literal containing MLIR type"); 2183 std::string attrExpr = curToken.getStringValue(); 2184 consumeToken(); 2185 2186 loc.End = curToken.getEndLoc(); 2187 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 2188 return failure(); 2189 return ast::TypeExpr::create(ctx, loc, attrExpr); 2190 } 2191 2192 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2193 StringRef name = curToken.getSpelling(); 2194 SMRange nameLoc = curToken.getLoc(); 2195 consumeToken(Token::underscore); 2196 2197 // Underscore expressions require a constraint list. 2198 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2199 return failure(); 2200 2201 // Parse the constraints for the expression. 2202 SmallVector<ast::ConstraintRef> constraints; 2203 if (failed(parseVariableDeclConstraintList(constraints))) 2204 return failure(); 2205 2206 ast::Type type; 2207 if (failed(validateVariableConstraints(constraints, type))) 2208 return failure(); 2209 return createInlineVariableExpr(type, name, nameLoc, constraints); 2210 } 2211 2212 //===----------------------------------------------------------------------===// 2213 // Stmts 2214 2215 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2216 FailureOr<ast::Stmt *> stmt; 2217 switch (curToken.getKind()) { 2218 case Token::kw_erase: 2219 stmt = parseEraseStmt(); 2220 break; 2221 case Token::kw_let: 2222 stmt = parseLetStmt(); 2223 break; 2224 case Token::kw_replace: 2225 stmt = parseReplaceStmt(); 2226 break; 2227 case Token::kw_return: 2228 stmt = parseReturnStmt(); 2229 break; 2230 case Token::kw_rewrite: 2231 stmt = parseRewriteStmt(); 2232 break; 2233 default: 2234 stmt = parseExpr(); 2235 break; 2236 } 2237 if (failed(stmt) || 2238 (expectTerminalSemicolon && 2239 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2240 return failure(); 2241 return stmt; 2242 } 2243 2244 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2245 SMLoc startLoc = curToken.getStartLoc(); 2246 consumeToken(Token::l_brace); 2247 2248 // Push a new block scope and parse any nested statements. 2249 pushDeclScope(); 2250 SmallVector<ast::Stmt *> statements; 2251 while (curToken.isNot(Token::r_brace)) { 2252 FailureOr<ast::Stmt *> statement = parseStmt(); 2253 if (failed(statement)) 2254 return popDeclScope(), failure(); 2255 statements.push_back(*statement); 2256 } 2257 popDeclScope(); 2258 2259 // Consume the end brace. 2260 SMRange location(startLoc, curToken.getEndLoc()); 2261 consumeToken(Token::r_brace); 2262 2263 return ast::CompoundStmt::create(ctx, location, statements); 2264 } 2265 2266 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2267 if (parserContext == ParserContext::Constraint) 2268 return emitError("`erase` cannot be used within a Constraint"); 2269 SMRange loc = curToken.getLoc(); 2270 consumeToken(Token::kw_erase); 2271 2272 // Parse the root operation expression. 2273 FailureOr<ast::Expr *> rootOp = parseExpr(); 2274 if (failed(rootOp)) 2275 return failure(); 2276 2277 return createEraseStmt(loc, *rootOp); 2278 } 2279 2280 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2281 SMRange loc = curToken.getLoc(); 2282 consumeToken(Token::kw_let); 2283 2284 // Parse the name of the new variable. 2285 SMRange varLoc = curToken.getLoc(); 2286 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2287 // `_` is a reserved variable name. 2288 if (curToken.is(Token::underscore)) { 2289 return emitError(varLoc, 2290 "`_` may only be used to define \"inline\" variables"); 2291 } 2292 return emitError(varLoc, 2293 "expected identifier after `let` to name a new variable"); 2294 } 2295 StringRef varName = curToken.getSpelling(); 2296 consumeToken(); 2297 2298 // Parse the optional set of constraints. 2299 SmallVector<ast::ConstraintRef> constraints; 2300 if (consumeIf(Token::colon) && 2301 failed(parseVariableDeclConstraintList(constraints))) 2302 return failure(); 2303 2304 // Parse the optional initializer expression. 2305 ast::Expr *initializer = nullptr; 2306 if (consumeIf(Token::equal)) { 2307 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2308 if (failed(initOrFailure)) 2309 return failure(); 2310 initializer = *initOrFailure; 2311 2312 // Check that the constraints are compatible with having an initializer, 2313 // e.g. type constraints cannot be used with initializers. 2314 for (ast::ConstraintRef constraint : constraints) { 2315 LogicalResult result = 2316 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2317 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2318 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2319 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2320 return this->emitError( 2321 constraint.referenceLoc, 2322 "type constraints are not permitted on variables with " 2323 "initializers"); 2324 } 2325 return success(); 2326 }) 2327 .Default(success()); 2328 if (failed(result)) 2329 return failure(); 2330 } 2331 } 2332 2333 FailureOr<ast::VariableDecl *> varDecl = 2334 createVariableDecl(varName, varLoc, initializer, constraints); 2335 if (failed(varDecl)) 2336 return failure(); 2337 return ast::LetStmt::create(ctx, loc, *varDecl); 2338 } 2339 2340 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2341 if (parserContext == ParserContext::Constraint) 2342 return emitError("`replace` cannot be used within a Constraint"); 2343 SMRange loc = curToken.getLoc(); 2344 consumeToken(Token::kw_replace); 2345 2346 // Parse the root operation expression. 2347 FailureOr<ast::Expr *> rootOp = parseExpr(); 2348 if (failed(rootOp)) 2349 return failure(); 2350 2351 if (failed( 2352 parseToken(Token::kw_with, "expected `with` after root operation"))) 2353 return failure(); 2354 2355 // The replacement portion of this statement is within a rewrite context. 2356 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2357 ParserContext::Rewrite); 2358 2359 // Parse the replacement values. 2360 SmallVector<ast::Expr *> replValues; 2361 if (consumeIf(Token::l_paren)) { 2362 if (consumeIf(Token::r_paren)) { 2363 return emitError( 2364 loc, "expected at least one replacement value, consider using " 2365 "`erase` if no replacement values are desired"); 2366 } 2367 2368 do { 2369 FailureOr<ast::Expr *> replExpr = parseExpr(); 2370 if (failed(replExpr)) 2371 return failure(); 2372 replValues.emplace_back(*replExpr); 2373 } while (consumeIf(Token::comma)); 2374 2375 if (failed(parseToken(Token::r_paren, 2376 "expected `)` after replacement values"))) 2377 return failure(); 2378 } else { 2379 // Handle replacement with an operation uniquely, as the replacement 2380 // operation supports type inferrence from the root operation. 2381 FailureOr<ast::Expr *> replExpr; 2382 if (curToken.is(Token::kw_op)) 2383 replExpr = parseOperationExpr(OpResultTypeContext::Replacement); 2384 else 2385 replExpr = parseExpr(); 2386 if (failed(replExpr)) 2387 return failure(); 2388 replValues.emplace_back(*replExpr); 2389 } 2390 2391 return createReplaceStmt(loc, *rootOp, replValues); 2392 } 2393 2394 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2395 SMRange loc = curToken.getLoc(); 2396 consumeToken(Token::kw_return); 2397 2398 // Parse the result value. 2399 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2400 if (failed(resultExpr)) 2401 return failure(); 2402 2403 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2404 } 2405 2406 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2407 if (parserContext == ParserContext::Constraint) 2408 return emitError("`rewrite` cannot be used within a Constraint"); 2409 SMRange loc = curToken.getLoc(); 2410 consumeToken(Token::kw_rewrite); 2411 2412 // Parse the root operation. 2413 FailureOr<ast::Expr *> rootOp = parseExpr(); 2414 if (failed(rootOp)) 2415 return failure(); 2416 2417 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2418 return failure(); 2419 2420 if (curToken.isNot(Token::l_brace)) 2421 return emitError("expected `{` to start rewrite body"); 2422 2423 // The rewrite body of this statement is within a rewrite context. 2424 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2425 ParserContext::Rewrite); 2426 2427 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2428 if (failed(rewriteBody)) 2429 return failure(); 2430 2431 // Verify the rewrite body. 2432 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2433 if (isa<ast::ReturnStmt>(stmt)) { 2434 return emitError(stmt->getLoc(), 2435 "`return` statements are only permitted within a " 2436 "`Constraint` or `Rewrite` body"); 2437 } 2438 } 2439 2440 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2441 } 2442 2443 //===----------------------------------------------------------------------===// 2444 // Creation+Analysis 2445 //===----------------------------------------------------------------------===// 2446 2447 //===----------------------------------------------------------------------===// 2448 // Decls 2449 2450 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2451 // Unwrap reference expressions. 2452 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2453 node = init->getDecl(); 2454 return dyn_cast<ast::CallableDecl>(node); 2455 } 2456 2457 FailureOr<ast::PatternDecl *> 2458 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2459 const ParsedPatternMetadata &metadata, 2460 ast::CompoundStmt *body) { 2461 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2462 metadata.hasBoundedRecursion, body); 2463 } 2464 2465 ast::Type Parser::createUserConstraintRewriteResultType( 2466 ArrayRef<ast::VariableDecl *> results) { 2467 // Single result decls use the type of the single result. 2468 if (results.size() == 1) 2469 return results[0]->getType(); 2470 2471 // Multiple results use a tuple type, with the types and names grabbed from 2472 // the result variable decls. 2473 auto resultTypes = llvm::map_range( 2474 results, [&](const auto *result) { return result->getType(); }); 2475 auto resultNames = llvm::map_range( 2476 results, [&](const auto *result) { return result->getName().getName(); }); 2477 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2478 llvm::to_vector(resultNames)); 2479 } 2480 2481 template <typename T> 2482 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2483 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2484 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2485 ast::CompoundStmt *body) { 2486 if (!body->getChildren().empty()) { 2487 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2488 ast::Expr *resultExpr = retStmt->getResultExpr(); 2489 2490 // Process the result of the decl. If no explicit signature results 2491 // were provided, check for return type inference. Otherwise, check that 2492 // the return expression can be converted to the expected type. 2493 if (results.empty()) 2494 resultType = resultExpr->getType(); 2495 else if (failed(convertExpressionTo(resultExpr, resultType))) 2496 return failure(); 2497 else 2498 retStmt->setResultExpr(resultExpr); 2499 } 2500 } 2501 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2502 } 2503 2504 FailureOr<ast::VariableDecl *> 2505 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2506 ArrayRef<ast::ConstraintRef> constraints) { 2507 // The type of the variable, which is expected to be inferred by either a 2508 // constraint or an initializer expression. 2509 ast::Type type; 2510 if (failed(validateVariableConstraints(constraints, type))) 2511 return failure(); 2512 2513 if (initializer) { 2514 // Update the variable type based on the initializer, or try to convert the 2515 // initializer to the existing type. 2516 if (!type) 2517 type = initializer->getType(); 2518 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2519 type = mergedType; 2520 else if (failed(convertExpressionTo(initializer, type))) 2521 return failure(); 2522 2523 // Otherwise, if there is no initializer check that the type has already 2524 // been resolved from the constraint list. 2525 } else if (!type) { 2526 return emitErrorAndNote( 2527 loc, "unable to infer type for variable `" + name + "`", loc, 2528 "the type of a variable must be inferable from the constraint " 2529 "list or the initializer"); 2530 } 2531 2532 // Constraint types cannot be used when defining variables. 2533 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2534 return emitError( 2535 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2536 } 2537 2538 // Try to define a variable with the given name. 2539 FailureOr<ast::VariableDecl *> varDecl = 2540 defineVariableDecl(name, loc, type, initializer, constraints); 2541 if (failed(varDecl)) 2542 return failure(); 2543 2544 return *varDecl; 2545 } 2546 2547 FailureOr<ast::VariableDecl *> 2548 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2549 const ast::ConstraintRef &constraint) { 2550 // Constraint arguments may apply more complex constraints via the arguments. 2551 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2552 ast::Type argType; 2553 if (failed(validateVariableConstraint(constraint, argType, 2554 allowNonCoreConstraints))) 2555 return failure(); 2556 return defineVariableDecl(name, loc, argType, constraint); 2557 } 2558 2559 LogicalResult 2560 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2561 ast::Type &inferredType, 2562 bool allowNonCoreConstraints) { 2563 for (const ast::ConstraintRef &ref : constraints) 2564 if (failed(validateVariableConstraint(ref, inferredType, 2565 allowNonCoreConstraints))) 2566 return failure(); 2567 return success(); 2568 } 2569 2570 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2571 ast::Type &inferredType, 2572 bool allowNonCoreConstraints) { 2573 ast::Type constraintType; 2574 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2575 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2576 if (failed(validateTypeConstraintExpr(typeExpr))) 2577 return failure(); 2578 } 2579 constraintType = ast::AttributeType::get(ctx); 2580 } else if (const auto *cst = 2581 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2582 constraintType = ast::OperationType::get( 2583 ctx, cst->getName(), lookupODSOperation(cst->getName())); 2584 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2585 constraintType = typeTy; 2586 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2587 constraintType = typeRangeTy; 2588 } else if (const auto *cst = 2589 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2590 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2591 if (failed(validateTypeConstraintExpr(typeExpr))) 2592 return failure(); 2593 } 2594 constraintType = valueTy; 2595 } else if (const auto *cst = 2596 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2597 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2598 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2599 return failure(); 2600 } 2601 constraintType = valueRangeTy; 2602 } else if (const auto *cst = 2603 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2604 if (!allowNonCoreConstraints) { 2605 return emitError(ref.referenceLoc, 2606 "`Rewrite` arguments and results are only permitted to " 2607 "use core constraints, such as `Attr`, `Op`, `Type`, " 2608 "`TypeRange`, `Value`, `ValueRange`"); 2609 } 2610 2611 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2612 if (inputs.size() != 1) { 2613 return emitErrorAndNote(ref.referenceLoc, 2614 "`Constraint`s applied via a variable constraint " 2615 "list must take a single input, but got " + 2616 Twine(inputs.size()), 2617 cst->getLoc(), 2618 "see definition of constraint here"); 2619 } 2620 constraintType = inputs.front()->getType(); 2621 } else { 2622 llvm_unreachable("unknown constraint type"); 2623 } 2624 2625 // Check that the constraint type is compatible with the current inferred 2626 // type. 2627 if (!inferredType) { 2628 inferredType = constraintType; 2629 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2630 inferredType = mergedTy; 2631 } else { 2632 return emitError(ref.referenceLoc, 2633 llvm::formatv("constraint type `{0}` is incompatible " 2634 "with the previously inferred type `{1}`", 2635 constraintType, inferredType)); 2636 } 2637 return success(); 2638 } 2639 2640 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2641 ast::Type typeExprType = typeExpr->getType(); 2642 if (typeExprType != typeTy) { 2643 return emitError(typeExpr->getLoc(), 2644 "expected expression of `Type` in type constraint"); 2645 } 2646 return success(); 2647 } 2648 2649 LogicalResult 2650 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2651 ast::Type typeExprType = typeExpr->getType(); 2652 if (typeExprType != typeRangeTy) { 2653 return emitError(typeExpr->getLoc(), 2654 "expected expression of `TypeRange` in type constraint"); 2655 } 2656 return success(); 2657 } 2658 2659 //===----------------------------------------------------------------------===// 2660 // Exprs 2661 2662 FailureOr<ast::CallExpr *> 2663 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2664 MutableArrayRef<ast::Expr *> arguments) { 2665 ast::Type parentType = parentExpr->getType(); 2666 2667 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2668 if (!callableDecl) { 2669 return emitError(loc, 2670 llvm::formatv("expected a reference to a callable " 2671 "`Constraint` or `Rewrite`, but got: `{0}`", 2672 parentType)); 2673 } 2674 if (parserContext == ParserContext::Rewrite) { 2675 if (isa<ast::UserConstraintDecl>(callableDecl)) 2676 return emitError( 2677 loc, "unable to invoke `Constraint` within a rewrite section"); 2678 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2679 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2680 } 2681 2682 // Verify the arguments of the call. 2683 /// Handle size mismatch. 2684 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2685 if (callArgs.size() != arguments.size()) { 2686 return emitErrorAndNote( 2687 loc, 2688 llvm::formatv("invalid number of arguments for {0} call; expected " 2689 "{1}, but got {2}", 2690 callableDecl->getCallableType(), callArgs.size(), 2691 arguments.size()), 2692 callableDecl->getLoc(), 2693 llvm::formatv("see the definition of {0} here", 2694 callableDecl->getName()->getName())); 2695 } 2696 2697 /// Handle argument type mismatch. 2698 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2699 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2700 callableDecl->getName()->getName()), 2701 callableDecl->getLoc()); 2702 }; 2703 for (auto it : llvm::zip(callArgs, arguments)) { 2704 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2705 attachDiagFn))) 2706 return failure(); 2707 } 2708 2709 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2710 callableDecl->getResultType()); 2711 } 2712 2713 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2714 ast::Decl *decl) { 2715 // Check the type of decl being referenced. 2716 ast::Type declType; 2717 if (isa<ast::ConstraintDecl>(decl)) 2718 declType = ast::ConstraintType::get(ctx); 2719 else if (isa<ast::UserRewriteDecl>(decl)) 2720 declType = ast::RewriteType::get(ctx); 2721 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2722 declType = varDecl->getType(); 2723 else 2724 return emitError(loc, "invalid reference to `" + 2725 decl->getName()->getName() + "`"); 2726 2727 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2728 } 2729 2730 FailureOr<ast::DeclRefExpr *> 2731 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2732 ArrayRef<ast::ConstraintRef> constraints) { 2733 FailureOr<ast::VariableDecl *> decl = 2734 defineVariableDecl(name, loc, type, constraints); 2735 if (failed(decl)) 2736 return failure(); 2737 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2738 } 2739 2740 FailureOr<ast::MemberAccessExpr *> 2741 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2742 SMRange loc) { 2743 // Validate the member name for the given parent expression. 2744 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2745 if (failed(memberType)) 2746 return failure(); 2747 2748 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2749 } 2750 2751 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2752 StringRef name, SMRange loc) { 2753 ast::Type parentType = parentExpr->getType(); 2754 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2755 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2756 return valueRangeTy; 2757 2758 // Verify member access based on the operation type. 2759 if (const ods::Operation *odsOp = opType.getODSOperation()) { 2760 auto results = odsOp->getResults(); 2761 2762 // Handle indexed results. 2763 unsigned index = 0; 2764 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2765 index < results.size()) { 2766 return results[index].isVariadic() ? valueRangeTy : valueTy; 2767 } 2768 2769 // Handle named results. 2770 const auto *it = llvm::find_if(results, [&](const auto &result) { 2771 return result.getName() == name; 2772 }); 2773 if (it != results.end()) 2774 return it->isVariadic() ? valueRangeTy : valueTy; 2775 } else if (llvm::isDigit(name[0])) { 2776 // Allow unchecked numeric indexing of the results of unregistered 2777 // operations. It returns a single value. 2778 return valueTy; 2779 } 2780 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2781 // Handle indexed results. 2782 unsigned index = 0; 2783 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2784 index < tupleType.size()) { 2785 return tupleType.getElementTypes()[index]; 2786 } 2787 2788 // Handle named results. 2789 auto elementNames = tupleType.getElementNames(); 2790 const auto *it = llvm::find(elementNames, name); 2791 if (it != elementNames.end()) 2792 return tupleType.getElementTypes()[it - elementNames.begin()]; 2793 } 2794 return emitError( 2795 loc, 2796 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2797 name, parentType)); 2798 } 2799 2800 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2801 SMRange loc, const ast::OpNameDecl *name, 2802 OpResultTypeContext resultTypeContext, 2803 MutableArrayRef<ast::Expr *> operands, 2804 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2805 MutableArrayRef<ast::Expr *> results) { 2806 Optional<StringRef> opNameRef = name->getName(); 2807 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2808 2809 // Verify the inputs operands. 2810 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2811 return failure(); 2812 2813 // Verify the attribute list. 2814 for (ast::NamedAttributeDecl *attr : attributes) { 2815 // Check for an attribute type, or a type awaiting resolution. 2816 ast::Type attrType = attr->getValue()->getType(); 2817 if (!attrType.isa<ast::AttributeType>()) { 2818 return emitError( 2819 attr->getValue()->getLoc(), 2820 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2821 } 2822 } 2823 2824 assert( 2825 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && 2826 "unexpected inferrence when results were explicitly specified"); 2827 2828 // If we aren't relying on type inferrence, or explicit results were provided, 2829 // validate them. 2830 if (resultTypeContext == OpResultTypeContext::Explicit) { 2831 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2832 return failure(); 2833 2834 // Validate the use of interface based type inferrence for this operation. 2835 } else if (resultTypeContext == OpResultTypeContext::Interface) { 2836 assert(opNameRef && 2837 "expected valid operation name when inferring operation results"); 2838 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); 2839 } 2840 2841 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results, 2842 attributes); 2843 } 2844 2845 LogicalResult 2846 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2847 const ods::Operation *odsOp, 2848 MutableArrayRef<ast::Expr *> operands) { 2849 return validateOperationOperandsOrResults( 2850 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2851 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2852 valueRangeTy); 2853 } 2854 2855 LogicalResult 2856 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2857 const ods::Operation *odsOp, 2858 MutableArrayRef<ast::Expr *> results) { 2859 return validateOperationOperandsOrResults( 2860 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2861 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2862 } 2863 2864 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, 2865 const ods::Operation *odsOp) { 2866 // If the operation might not have inferrence support, emit a warning to the 2867 // user. We don't emit an error because the interface might be added to the 2868 // operation at runtime. It's rare, but it could still happen. We emit a 2869 // warning here instead. 2870 2871 // Handle inferrence warnings for unknown operations. 2872 if (!odsOp) { 2873 ctx.getDiagEngine().emitWarning( 2874 loc, llvm::formatv( 2875 "operation result types are marked to be inferred, but " 2876 "`{0}` is unknown. Ensure that `{0}` supports zero " 2877 "results or implements `InferTypeOpInterface`. Include " 2878 "the ODS definition of this operation to remove this warning.", 2879 opName)); 2880 return; 2881 } 2882 2883 // Handle inferrence warnings for known operations that expected at least one 2884 // result, but don't have inference support. An elided results list can mean 2885 // "zero-results", and we don't want to warn when that is the expected 2886 // behavior. 2887 bool requiresInferrence = 2888 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) { 2889 return !result.isVariableLength(); 2890 }); 2891 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { 2892 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( 2893 loc, 2894 llvm::formatv("operation result types are marked to be inferred, but " 2895 "`{0}` does not provide an implementation of " 2896 "`InferTypeOpInterface`. Ensure that `{0}` attaches " 2897 "`InferTypeOpInterface` at runtime, or add support to " 2898 "the ODS definition to remove this warning.", 2899 opName)); 2900 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName), 2901 odsOp->getLoc()); 2902 return; 2903 } 2904 } 2905 2906 LogicalResult Parser::validateOperationOperandsOrResults( 2907 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2908 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2909 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2910 ast::Type rangeTy) { 2911 // All operation types accept a single range parameter. 2912 if (values.size() == 1) { 2913 if (failed(convertExpressionTo(values[0], rangeTy))) 2914 return failure(); 2915 return success(); 2916 } 2917 2918 /// If the operation has ODS information, we can more accurately verify the 2919 /// values. 2920 if (odsOpLoc) { 2921 if (odsValues.size() != values.size()) { 2922 return emitErrorAndNote( 2923 loc, 2924 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2925 "{2}, but got {3}", 2926 groupName, *name, odsValues.size(), values.size()), 2927 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2928 } 2929 auto diagFn = [&](ast::Diagnostic &diag) { 2930 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2931 *odsOpLoc); 2932 }; 2933 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2934 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2935 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2936 return failure(); 2937 } 2938 return success(); 2939 } 2940 2941 // Otherwise, accept the value groups as they have been defined and just 2942 // ensure they are one of the expected types. 2943 for (ast::Expr *&valueExpr : values) { 2944 ast::Type valueExprType = valueExpr->getType(); 2945 2946 // Check if this is one of the expected types. 2947 if (valueExprType == rangeTy || valueExprType == singleTy) 2948 continue; 2949 2950 // If the operand is an Operation, allow converting to a Value or 2951 // ValueRange. This situations arises quite often with nested operation 2952 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2953 if (singleTy == valueTy) { 2954 if (valueExprType.isa<ast::OperationType>()) { 2955 valueExpr = convertOpToValue(valueExpr); 2956 continue; 2957 } 2958 } 2959 2960 return emitError( 2961 valueExpr->getLoc(), 2962 llvm::formatv( 2963 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2964 singleTy, rangeTy, valueExprType)); 2965 } 2966 return success(); 2967 } 2968 2969 FailureOr<ast::TupleExpr *> 2970 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2971 ArrayRef<StringRef> elementNames) { 2972 for (const ast::Expr *element : elements) { 2973 ast::Type eleTy = element->getType(); 2974 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2975 return emitError( 2976 element->getLoc(), 2977 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2978 } 2979 } 2980 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2981 } 2982 2983 //===----------------------------------------------------------------------===// 2984 // Stmts 2985 2986 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2987 ast::Expr *rootOp) { 2988 // Check that root is an Operation. 2989 ast::Type rootType = rootOp->getType(); 2990 if (!rootType.isa<ast::OperationType>()) 2991 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2992 2993 return ast::EraseStmt::create(ctx, loc, rootOp); 2994 } 2995 2996 FailureOr<ast::ReplaceStmt *> 2997 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2998 MutableArrayRef<ast::Expr *> replValues) { 2999 // Check that root is an Operation. 3000 ast::Type rootType = rootOp->getType(); 3001 if (!rootType.isa<ast::OperationType>()) { 3002 return emitError( 3003 rootOp->getLoc(), 3004 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3005 } 3006 3007 // If there are multiple replacement values, we implicitly convert any Op 3008 // expressions to the value form. 3009 bool shouldConvertOpToValues = replValues.size() > 1; 3010 for (ast::Expr *&replExpr : replValues) { 3011 ast::Type replType = replExpr->getType(); 3012 3013 // Check that replExpr is an Operation, Value, or ValueRange. 3014 if (replType.isa<ast::OperationType>()) { 3015 if (shouldConvertOpToValues) 3016 replExpr = convertOpToValue(replExpr); 3017 continue; 3018 } 3019 3020 if (replType != valueTy && replType != valueRangeTy) { 3021 return emitError(replExpr->getLoc(), 3022 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 3023 "expression, but got `{0}`", 3024 replType)); 3025 } 3026 } 3027 3028 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 3029 } 3030 3031 FailureOr<ast::RewriteStmt *> 3032 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 3033 ast::CompoundStmt *rewriteBody) { 3034 // Check that root is an Operation. 3035 ast::Type rootType = rootOp->getType(); 3036 if (!rootType.isa<ast::OperationType>()) { 3037 return emitError( 3038 rootOp->getLoc(), 3039 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 3040 } 3041 3042 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 3043 } 3044 3045 //===----------------------------------------------------------------------===// 3046 // Code Completion 3047 //===----------------------------------------------------------------------===// 3048 3049 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { 3050 ast::Type parentType = parentExpr->getType(); 3051 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) 3052 codeCompleteContext->codeCompleteOperationMemberAccess(opType); 3053 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) 3054 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); 3055 return failure(); 3056 } 3057 3058 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) { 3059 if (opName) 3060 codeCompleteContext->codeCompleteOperationAttributeName(*opName); 3061 return failure(); 3062 } 3063 3064 LogicalResult 3065 Parser::codeCompleteConstraintName(ast::Type inferredType, 3066 bool allowNonCoreConstraints, 3067 bool allowInlineTypeConstraints) { 3068 codeCompleteContext->codeCompleteConstraintName( 3069 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, 3070 curDeclScope); 3071 return failure(); 3072 } 3073 3074 LogicalResult Parser::codeCompleteDialectName() { 3075 codeCompleteContext->codeCompleteDialectName(); 3076 return failure(); 3077 } 3078 3079 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { 3080 codeCompleteContext->codeCompleteOperationName(dialectName); 3081 return failure(); 3082 } 3083 3084 LogicalResult Parser::codeCompletePatternMetadata() { 3085 codeCompleteContext->codeCompletePatternMetadata(); 3086 return failure(); 3087 } 3088 3089 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { 3090 codeCompleteContext->codeCompleteIncludeFilename(curPath); 3091 return failure(); 3092 } 3093 3094 void Parser::codeCompleteCallSignature(ast::Node *parent, 3095 unsigned currentNumArgs) { 3096 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); 3097 if (!callableDecl) 3098 return; 3099 3100 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs); 3101 } 3102 3103 void Parser::codeCompleteOperationOperandsSignature( 3104 Optional<StringRef> opName, unsigned currentNumOperands) { 3105 codeCompleteContext->codeCompleteOperationOperandsSignature( 3106 opName, currentNumOperands); 3107 } 3108 3109 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName, 3110 unsigned currentNumResults) { 3111 codeCompleteContext->codeCompleteOperationResultsSignature(opName, 3112 currentNumResults); 3113 } 3114 3115 //===----------------------------------------------------------------------===// 3116 // Parser 3117 //===----------------------------------------------------------------------===// 3118 3119 FailureOr<ast::Module *> 3120 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, 3121 bool enableDocumentation, 3122 CodeCompleteContext *codeCompleteContext) { 3123 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext); 3124 return parser.parseModule(); 3125 } 3126