1 //===- Parser.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/Parser/Parser.h" 10 #include "Lexer.h" 11 #include "mlir/Support/LogicalResult.h" 12 #include "mlir/TableGen/Argument.h" 13 #include "mlir/TableGen/Attribute.h" 14 #include "mlir/TableGen/Constraint.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/AST/Types.h" 21 #include "mlir/Tools/PDLL/ODS/Constraint.h" 22 #include "mlir/Tools/PDLL/ODS/Context.h" 23 #include "mlir/Tools/PDLL/ODS/Operation.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/ManagedStatic.h" 28 #include "llvm/Support/SaveAndRestore.h" 29 #include "llvm/Support/ScopedPrinter.h" 30 #include "llvm/TableGen/Error.h" 31 #include "llvm/TableGen/Parser.h" 32 #include <string> 33 34 using namespace mlir; 35 using namespace mlir::pdll; 36 37 //===----------------------------------------------------------------------===// 38 // Parser 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 class Parser { 43 public: 44 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) 45 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), 46 curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)), 47 valueRangeTy(ast::ValueRangeType::get(ctx)), 48 typeTy(ast::TypeType::get(ctx)), 49 typeRangeTy(ast::TypeRangeType::get(ctx)), 50 attrTy(ast::AttributeType::get(ctx)) {} 51 52 /// Try to parse a new module. Returns nullptr in the case of failure. 53 FailureOr<ast::Module *> parseModule(); 54 55 private: 56 /// The current context of the parser. It allows for the parser to know a bit 57 /// about the construct it is nested within during parsing. This is used 58 /// specifically to provide additional verification during parsing, e.g. to 59 /// prevent using rewrites within a match context, matcher constraints within 60 /// a rewrite section, etc. 61 enum class ParserContext { 62 /// The parser is in the global context. 63 Global, 64 /// The parser is currently within a Constraint, which disallows all types 65 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). 66 Constraint, 67 /// The parser is currently within the matcher portion of a Pattern, which 68 /// is allows a terminal operation rewrite statement but no other rewrite 69 /// transformations. 70 PatternMatch, 71 /// The parser is currently within a Rewrite, which disallows calls to 72 /// constraints, requires operation expressions to have names, etc. 73 Rewrite, 74 }; 75 76 //===--------------------------------------------------------------------===// 77 // Parsing 78 //===--------------------------------------------------------------------===// 79 80 /// Push a new decl scope onto the lexer. 81 ast::DeclScope *pushDeclScope() { 82 ast::DeclScope *newScope = 83 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); 84 return (curDeclScope = newScope); 85 } 86 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } 87 88 /// Pop the last decl scope from the lexer. 89 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } 90 91 /// Parse the body of an AST module. 92 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); 93 94 /// Try to convert the given expression to `type`. Returns failure and emits 95 /// an error if a conversion is not viable. On failure, `noteAttachFn` is 96 /// invoked to attach notes to the emitted error diagnostic. On success, 97 /// `expr` is updated to the expression used to convert to `type`. 98 LogicalResult convertExpressionTo( 99 ast::Expr *&expr, ast::Type type, 100 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); 101 102 /// Given an operation expression, convert it to a Value or ValueRange 103 /// typed expression. 104 ast::Expr *convertOpToValue(const ast::Expr *opExpr); 105 106 /// Lookup ODS information for the given operation, returns nullptr if no 107 /// information is found. 108 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) { 109 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; 110 } 111 112 //===--------------------------------------------------------------------===// 113 // Directives 114 115 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); 116 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); 117 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, 118 SmallVectorImpl<ast::Decl *> &decls); 119 120 /// Process the records of a parsed tablegen include file. 121 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 122 SmallVectorImpl<ast::Decl *> &decls); 123 124 /// Create a user defined native constraint for a constraint imported from 125 /// ODS. 126 template <typename ConstraintT> 127 ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, 128 StringRef codeBlock, SMRange loc, 129 ast::Type type); 130 template <typename ConstraintT> 131 ast::Decl * 132 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 133 SMRange loc, ast::Type type); 134 135 //===--------------------------------------------------------------------===// 136 // Decls 137 138 /// This structure contains the set of pattern metadata that may be parsed. 139 struct ParsedPatternMetadata { 140 Optional<uint16_t> benefit; 141 bool hasBoundedRecursion = false; 142 }; 143 144 FailureOr<ast::Decl *> parseTopLevelDecl(); 145 FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); 146 147 /// Parse an argument variable as part of the signature of a 148 /// UserConstraintDecl or UserRewriteDecl. 149 FailureOr<ast::VariableDecl *> parseArgumentDecl(); 150 151 /// Parse a result variable as part of the signature of a UserConstraintDecl 152 /// or UserRewriteDecl. 153 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); 154 155 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being 156 /// defined in a non-global context. 157 FailureOr<ast::UserConstraintDecl *> 158 parseUserConstraintDecl(bool isInline = false); 159 160 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a 161 /// non-global context, such as within a Pattern/Constraint/etc. 162 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); 163 164 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 165 /// PDLL constructs. 166 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( 167 const ast::Name &name, bool isInline, 168 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 169 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 170 171 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being 172 /// defined in a non-global context. 173 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); 174 175 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a 176 /// non-global context, such as within a Pattern/Rewrite/etc. 177 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); 178 179 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using 180 /// PDLL constructs. 181 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( 182 const ast::Name &name, bool isInline, 183 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 184 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 185 186 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have 187 /// effectively the same syntax, and only differ on slight semantics (given 188 /// the different parsing contexts). 189 template <typename T, typename ParseUserPDLLDeclFnT> 190 FailureOr<T *> parseUserConstraintOrRewriteDecl( 191 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 192 StringRef anonymousNamePrefix, bool isInline); 193 194 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. 195 /// These decls have effectively the same syntax. 196 template <typename T> 197 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( 198 const ast::Name &name, bool isInline, 199 ArrayRef<ast::VariableDecl *> arguments, 200 ArrayRef<ast::VariableDecl *> results, ast::Type resultType); 201 202 /// Parse the functional signature (i.e. the arguments and results) of a 203 /// UserConstraintDecl or UserRewriteDecl. 204 LogicalResult parseUserConstraintOrRewriteSignature( 205 SmallVectorImpl<ast::VariableDecl *> &arguments, 206 SmallVectorImpl<ast::VariableDecl *> &results, 207 ast::DeclScope *&argumentScope, ast::Type &resultType); 208 209 /// Validate the return (which if present is specified by bodyIt) of a 210 /// UserConstraintDecl or UserRewriteDecl. 211 LogicalResult validateUserConstraintOrRewriteReturn( 212 StringRef declType, ast::CompoundStmt *body, 213 ArrayRef<ast::Stmt *>::iterator bodyIt, 214 ArrayRef<ast::Stmt *>::iterator bodyE, 215 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); 216 217 FailureOr<ast::CompoundStmt *> 218 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 219 bool expectTerminalSemicolon = true); 220 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); 221 FailureOr<ast::Decl *> parsePatternDecl(); 222 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); 223 224 /// Check to see if a decl has already been defined with the given name, if 225 /// one has emit and error and return failure. Returns success otherwise. 226 LogicalResult checkDefineNamedDecl(const ast::Name &name); 227 228 /// Try to define a variable decl with the given components, returns the 229 /// variable on success. 230 FailureOr<ast::VariableDecl *> 231 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 232 ast::Expr *initExpr, 233 ArrayRef<ast::ConstraintRef> constraints); 234 FailureOr<ast::VariableDecl *> 235 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 236 ArrayRef<ast::ConstraintRef> constraints); 237 238 /// Parse the constraint reference list for a variable decl. 239 LogicalResult parseVariableDeclConstraintList( 240 SmallVectorImpl<ast::ConstraintRef> &constraints); 241 242 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. 243 FailureOr<ast::Expr *> parseTypeConstraintExpr(); 244 245 /// Try to parse a single reference to a constraint. `typeConstraint` is the 246 /// location of a previously parsed type constraint for the entity that will 247 /// be constrained by the parsed constraint. `existingConstraints` are any 248 /// existing constraints that have already been parsed for the same entity 249 /// that will be constrained by this constraint. `allowInlineTypeConstraints` 250 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. 251 FailureOr<ast::ConstraintRef> 252 parseConstraint(Optional<SMRange> &typeConstraint, 253 ArrayRef<ast::ConstraintRef> existingConstraints, 254 bool allowInlineTypeConstraints); 255 256 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl 257 /// argument or result variable. The constraints for these variables do not 258 /// allow inline type constraints, and only permit a single constraint. 259 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); 260 261 //===--------------------------------------------------------------------===// 262 // Exprs 263 264 FailureOr<ast::Expr *> parseExpr(); 265 266 /// Identifier expressions. 267 FailureOr<ast::Expr *> parseAttributeExpr(); 268 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr); 269 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); 270 FailureOr<ast::Expr *> parseIdentifierExpr(); 271 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); 272 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); 273 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); 274 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); 275 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); 276 FailureOr<ast::Expr *> parseOperationExpr(); 277 FailureOr<ast::Expr *> parseTupleExpr(); 278 FailureOr<ast::Expr *> parseTypeExpr(); 279 FailureOr<ast::Expr *> parseUnderscoreExpr(); 280 281 //===--------------------------------------------------------------------===// 282 // Stmts 283 284 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); 285 FailureOr<ast::CompoundStmt *> parseCompoundStmt(); 286 FailureOr<ast::EraseStmt *> parseEraseStmt(); 287 FailureOr<ast::LetStmt *> parseLetStmt(); 288 FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); 289 FailureOr<ast::ReturnStmt *> parseReturnStmt(); 290 FailureOr<ast::RewriteStmt *> parseRewriteStmt(); 291 292 //===--------------------------------------------------------------------===// 293 // Creation+Analysis 294 //===--------------------------------------------------------------------===// 295 296 //===--------------------------------------------------------------------===// 297 // Decls 298 299 /// Try to extract a callable from the given AST node. Returns nullptr on 300 /// failure. 301 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); 302 303 /// Try to create a pattern decl with the given components, returning the 304 /// Pattern on success. 305 FailureOr<ast::PatternDecl *> 306 createPatternDecl(SMRange loc, const ast::Name *name, 307 const ParsedPatternMetadata &metadata, 308 ast::CompoundStmt *body); 309 310 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set 311 /// of results, defined as part of the signature. 312 ast::Type 313 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); 314 315 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. 316 template <typename T> 317 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( 318 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 319 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 320 ast::CompoundStmt *body); 321 322 /// Try to create a variable decl with the given components, returning the 323 /// Variable on success. 324 FailureOr<ast::VariableDecl *> 325 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 326 ArrayRef<ast::ConstraintRef> constraints); 327 328 /// Create a variable for an argument or result defined as part of the 329 /// signature of a UserConstraintDecl/UserRewriteDecl. 330 FailureOr<ast::VariableDecl *> 331 createArgOrResultVariableDecl(StringRef name, SMRange loc, 332 const ast::ConstraintRef &constraint); 333 334 /// Validate the constraints used to constraint a variable decl. 335 /// `inferredType` is the type of the variable inferred by the constraints 336 /// within the list, and is updated to the most refined type as determined by 337 /// the constraints. Returns success if the constraint list is valid, failure 338 /// otherwise. 339 LogicalResult 340 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 341 ast::Type &inferredType); 342 /// Validate a single reference to a constraint. `inferredType` contains the 343 /// currently inferred variabled type and is refined within the type defined 344 /// by the constraint. Returns success if the constraint is valid, failure 345 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user 346 /// defined constraints) may be used with the variable. 347 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, 348 ast::Type &inferredType, 349 bool allowNonCoreConstraints = true); 350 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); 351 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); 352 353 //===--------------------------------------------------------------------===// 354 // Exprs 355 356 FailureOr<ast::CallExpr *> 357 createCallExpr(SMRange loc, ast::Expr *parentExpr, 358 MutableArrayRef<ast::Expr *> arguments); 359 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); 360 FailureOr<ast::DeclRefExpr *> 361 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 362 ArrayRef<ast::ConstraintRef> constraints); 363 FailureOr<ast::MemberAccessExpr *> 364 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); 365 366 /// Validate the member access `name` into the given parent expression. On 367 /// success, this also returns the type of the member accessed. 368 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, 369 StringRef name, SMRange loc); 370 FailureOr<ast::OperationExpr *> 371 createOperationExpr(SMRange loc, const ast::OpNameDecl *name, 372 MutableArrayRef<ast::Expr *> operands, 373 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 374 MutableArrayRef<ast::Expr *> results); 375 LogicalResult 376 validateOperationOperands(SMRange loc, Optional<StringRef> name, 377 const ods::Operation *odsOp, 378 MutableArrayRef<ast::Expr *> operands); 379 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, 380 const ods::Operation *odsOp, 381 MutableArrayRef<ast::Expr *> results); 382 LogicalResult validateOperationOperandsOrResults( 383 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 384 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 385 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 386 ast::Type rangeTy); 387 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, 388 ArrayRef<ast::Expr *> elements, 389 ArrayRef<StringRef> elementNames); 390 391 //===--------------------------------------------------------------------===// 392 // Stmts 393 394 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); 395 FailureOr<ast::ReplaceStmt *> 396 createReplaceStmt(SMRange loc, ast::Expr *rootOp, 397 MutableArrayRef<ast::Expr *> replValues); 398 FailureOr<ast::RewriteStmt *> 399 createRewriteStmt(SMRange loc, ast::Expr *rootOp, 400 ast::CompoundStmt *rewriteBody); 401 402 //===--------------------------------------------------------------------===// 403 // Lexer Utilities 404 //===--------------------------------------------------------------------===// 405 406 /// If the current token has the specified kind, consume it and return true. 407 /// If not, return false. 408 bool consumeIf(Token::Kind kind) { 409 if (curToken.isNot(kind)) 410 return false; 411 consumeToken(kind); 412 return true; 413 } 414 415 /// Advance the current lexer onto the next token. 416 void consumeToken() { 417 assert(curToken.isNot(Token::eof, Token::error) && 418 "shouldn't advance past EOF or errors"); 419 curToken = lexer.lexToken(); 420 } 421 422 /// Advance the current lexer onto the next token, asserting what the expected 423 /// current token is. This is preferred to the above method because it leads 424 /// to more self-documenting code with better checking. 425 void consumeToken(Token::Kind kind) { 426 assert(curToken.is(kind) && "consumed an unexpected token"); 427 consumeToken(); 428 } 429 430 /// Reset the lexer to the location at the given position. 431 void resetToken(SMRange tokLoc) { 432 lexer.resetPointer(tokLoc.Start.getPointer()); 433 curToken = lexer.lexToken(); 434 } 435 436 /// Consume the specified token if present and return success. On failure, 437 /// output a diagnostic and return failure. 438 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 439 if (curToken.getKind() != kind) 440 return emitError(curToken.getLoc(), msg); 441 consumeToken(); 442 return success(); 443 } 444 LogicalResult emitError(SMRange loc, const Twine &msg) { 445 lexer.emitError(loc, msg); 446 return failure(); 447 } 448 LogicalResult emitError(const Twine &msg) { 449 return emitError(curToken.getLoc(), msg); 450 } 451 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 452 const Twine ¬e) { 453 lexer.emitErrorAndNote(loc, msg, noteLoc, note); 454 return failure(); 455 } 456 457 //===--------------------------------------------------------------------===// 458 // Fields 459 //===--------------------------------------------------------------------===// 460 461 /// The owning AST context. 462 ast::Context &ctx; 463 464 /// The lexer of this parser. 465 Lexer lexer; 466 467 /// The current token within the lexer. 468 Token curToken; 469 470 /// The most recently defined decl scope. 471 ast::DeclScope *curDeclScope = nullptr; 472 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; 473 474 /// The current context of the parser. 475 ParserContext parserContext = ParserContext::Global; 476 477 /// Cached types to simplify verification and expression creation. 478 ast::Type valueTy, valueRangeTy; 479 ast::Type typeTy, typeRangeTy; 480 ast::Type attrTy; 481 482 /// A counter used when naming anonymous constraints and rewrites. 483 unsigned anonymousDeclNameCounter = 0; 484 }; 485 } // namespace 486 487 FailureOr<ast::Module *> Parser::parseModule() { 488 SMLoc moduleLoc = curToken.getStartLoc(); 489 pushDeclScope(); 490 491 // Parse the top-level decls of the module. 492 SmallVector<ast::Decl *> decls; 493 if (failed(parseModuleBody(decls))) 494 return popDeclScope(), failure(); 495 496 popDeclScope(); 497 return ast::Module::create(ctx, moduleLoc, decls); 498 } 499 500 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { 501 while (curToken.isNot(Token::eof)) { 502 if (curToken.is(Token::directive)) { 503 if (failed(parseDirective(decls))) 504 return failure(); 505 continue; 506 } 507 508 FailureOr<ast::Decl *> decl = parseTopLevelDecl(); 509 if (failed(decl)) 510 return failure(); 511 decls.push_back(*decl); 512 } 513 return success(); 514 } 515 516 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { 517 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, 518 valueRangeTy); 519 } 520 521 LogicalResult Parser::convertExpressionTo( 522 ast::Expr *&expr, ast::Type type, 523 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { 524 ast::Type exprType = expr->getType(); 525 if (exprType == type) 526 return success(); 527 528 auto emitConvertError = [&]() -> ast::InFlightDiagnostic { 529 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( 530 expr->getLoc(), llvm::formatv("unable to convert expression of type " 531 "`{0}` to the expected type of " 532 "`{1}`", 533 exprType, type)); 534 if (noteAttachFn) 535 noteAttachFn(*diag); 536 return diag; 537 }; 538 539 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) { 540 // Two operation types are compatible if they have the same name, or if the 541 // expected type is more general. 542 if (auto opType = type.dyn_cast<ast::OperationType>()) { 543 if (opType.getName()) 544 return emitConvertError(); 545 return success(); 546 } 547 548 // An operation can always convert to a ValueRange. 549 if (type == valueRangeTy) { 550 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 551 valueRangeTy); 552 return success(); 553 } 554 555 // Allow conversion to a single value by constraining the result range. 556 if (type == valueTy) { 557 // If the operation is registered, we can verify if it can ever have a 558 // single result. 559 Optional<StringRef> opName = exprOpType.getName(); 560 if (const ods::Operation *odsOp = lookupODSOperation(opName)) { 561 if (odsOp->getResults().empty()) { 562 return emitConvertError()->attachNote( 563 llvm::formatv("see the definition of `{0}`, which was defined " 564 "with zero results", 565 odsOp->getName()), 566 odsOp->getLoc()); 567 } 568 569 unsigned numSingleResults = llvm::count_if( 570 odsOp->getResults(), [](const ods::OperandOrResult &result) { 571 return result.getVariableLengthKind() == 572 ods::VariableLengthKind::Single; 573 }); 574 if (numSingleResults > 1) { 575 return emitConvertError()->attachNote( 576 llvm::formatv("see the definition of `{0}`, which was defined " 577 "with at least {1} results", 578 odsOp->getName(), numSingleResults), 579 odsOp->getLoc()); 580 } 581 } 582 583 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, 584 valueTy); 585 return success(); 586 } 587 return emitConvertError(); 588 } 589 590 // FIXME: Decide how to allow/support converting a single result to multiple, 591 // and multiple to a single result. For now, we just allow Single->Range, 592 // but this isn't something really supported in the PDL dialect. We should 593 // figure out some way to support both. 594 if ((exprType == valueTy || exprType == valueRangeTy) && 595 (type == valueTy || type == valueRangeTy)) 596 return success(); 597 if ((exprType == typeTy || exprType == typeRangeTy) && 598 (type == typeTy || type == typeRangeTy)) 599 return success(); 600 601 // Handle tuple types. 602 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) { 603 auto tupleType = type.dyn_cast<ast::TupleType>(); 604 if (!tupleType || tupleType.size() != exprTupleType.size()) 605 return emitConvertError(); 606 607 // Build a new tuple expression using each of the elements of the current 608 // tuple. 609 SmallVector<ast::Expr *> newExprs; 610 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { 611 newExprs.push_back(ast::MemberAccessExpr::create( 612 ctx, expr->getLoc(), expr, llvm::to_string(i), 613 exprTupleType.getElementTypes()[i])); 614 615 auto diagFn = [&](ast::Diagnostic &diag) { 616 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", 617 i, exprTupleType)); 618 if (noteAttachFn) 619 noteAttachFn(diag); 620 }; 621 if (failed(convertExpressionTo(newExprs.back(), 622 tupleType.getElementTypes()[i], diagFn))) 623 return failure(); 624 } 625 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, 626 tupleType.getElementNames()); 627 return success(); 628 } 629 630 return emitConvertError(); 631 } 632 633 //===----------------------------------------------------------------------===// 634 // Directives 635 636 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { 637 StringRef directive = curToken.getSpelling(); 638 if (directive == "#include") 639 return parseInclude(decls); 640 641 return emitError("unknown directive `" + directive + "`"); 642 } 643 644 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { 645 SMRange loc = curToken.getLoc(); 646 consumeToken(Token::directive); 647 648 // Parse the file being included. 649 if (!curToken.isString()) 650 return emitError(loc, 651 "expected string file name after `include` directive"); 652 SMRange fileLoc = curToken.getLoc(); 653 std::string filenameStr = curToken.getStringValue(); 654 StringRef filename = filenameStr; 655 consumeToken(); 656 657 // Check the type of include. If ending with `.pdll`, this is another pdl file 658 // to be parsed along with the current module. 659 if (filename.endswith(".pdll")) { 660 if (failed(lexer.pushInclude(filename))) 661 return emitError(fileLoc, 662 "unable to open include file `" + filename + "`"); 663 664 // If we added the include successfully, parse it into the current module. 665 // Make sure to save the current token so that we can restore it when we 666 // finish parsing the nested file. 667 Token oldToken = curToken; 668 curToken = lexer.lexToken(); 669 LogicalResult result = parseModuleBody(decls); 670 curToken = oldToken; 671 return result; 672 } 673 674 // Otherwise, this must be a `.td` include. 675 if (filename.endswith(".td")) 676 return parseTdInclude(filename, fileLoc, decls); 677 678 return emitError(fileLoc, 679 "expected include filename to end with `.pdll` or `.td`"); 680 } 681 682 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, 683 SmallVectorImpl<ast::Decl *> &decls) { 684 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); 685 686 // This class provides a context argument for the llvm::SourceMgr diagnostic 687 // handler. 688 struct DiagHandlerContext { 689 Parser &parser; 690 StringRef filename; 691 llvm::SMRange loc; 692 } handlerContext{*this, filename, fileLoc}; 693 694 // Set the diagnostic handler for the tablegen source manager. 695 llvm::SrcMgr.setDiagHandler( 696 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { 697 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); 698 (void)ctx->parser.emitError( 699 ctx->loc, 700 llvm::formatv("error while processing include file `{0}`: {1}", 701 ctx->filename, diag.getMessage())); 702 }, 703 &handlerContext); 704 705 // Use the source manager to open the file, but don't yet add it. 706 std::string includedFile; 707 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = 708 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); 709 if (!includeBuffer) 710 return emitError(fileLoc, "unable to open include file `" + filename + "`"); 711 712 auto processFn = [&](llvm::RecordKeeper &records) { 713 processTdIncludeRecords(records, decls); 714 715 // After we are done processing, move all of the tablegen source buffers to 716 // the main parser source mgr. This allows for directly using source 717 // locations from the .td files without needing to remap them. 718 parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr); 719 return false; 720 }; 721 if (llvm::TableGenParseFile(std::move(*includeBuffer), 722 parserSrcMgr.getIncludeDirs(), processFn)) 723 return failure(); 724 725 return success(); 726 } 727 728 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, 729 SmallVectorImpl<ast::Decl *> &decls) { 730 // Return the length kind of the given value. 731 auto getLengthKind = [](const auto &value) { 732 if (value.isOptional()) 733 return ods::VariableLengthKind::Optional; 734 return value.isVariadic() ? ods::VariableLengthKind::Variadic 735 : ods::VariableLengthKind::Single; 736 }; 737 738 // Insert a type constraint into the ODS context. 739 ods::Context &odsContext = ctx.getODSContext(); 740 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) 741 -> const ods::TypeConstraint & { 742 return odsContext.insertTypeConstraint(cst.constraint.getDefName(), 743 cst.constraint.getSummary(), 744 cst.constraint.getCPPClassName()); 745 }; 746 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { 747 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; 748 }; 749 750 // Process the parsed tablegen records to build ODS information. 751 /// Operations. 752 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { 753 tblgen::Operator op(def); 754 755 bool inserted = false; 756 ods::Operation *odsOp = nullptr; 757 std::tie(odsOp, inserted) = 758 odsContext.insertOperation(op.getOperationName(), op.getSummary(), 759 op.getDescription(), op.getLoc().front()); 760 761 // Ignore operations that have already been added. 762 if (!inserted) 763 continue; 764 765 for (const tblgen::NamedAttribute &attr : op.getAttributes()) { 766 odsOp->appendAttribute( 767 attr.name, attr.attr.isOptional(), 768 odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(), 769 attr.attr.getSummary(), 770 attr.attr.getStorageType())); 771 } 772 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { 773 odsOp->appendOperand(operand.name, getLengthKind(operand), 774 addTypeConstraint(operand)); 775 } 776 for (const tblgen::NamedTypeConstraint &result : op.getResults()) { 777 odsOp->appendResult(result.name, getLengthKind(result), 778 addTypeConstraint(result)); 779 } 780 } 781 /// Attr constraints. 782 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { 783 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 784 decls.push_back( 785 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 786 tblgen::AttrConstraint(def), 787 convertLocToRange(def->getLoc().front()), attrTy)); 788 } 789 } 790 /// Type constraints. 791 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { 792 if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { 793 decls.push_back( 794 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 795 tblgen::TypeConstraint(def), 796 convertLocToRange(def->getLoc().front()), typeTy)); 797 } 798 } 799 /// Interfaces. 800 ast::Type opTy = ast::OperationType::get(ctx); 801 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { 802 StringRef name = def->getName(); 803 if (def->isAnonymous() || curDeclScope->lookup(name) || 804 def->isSubClassOf("DeclareInterfaceMethods")) 805 continue; 806 SMRange loc = convertLocToRange(def->getLoc().front()); 807 808 StringRef className = def->getValueAsString("cppClassName"); 809 StringRef cppNamespace = def->getValueAsString("cppNamespace"); 810 std::string codeBlock = 811 llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) 812 .str(); 813 814 if (def->isSubClassOf("OpInterface")) { 815 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( 816 name, codeBlock, loc, opTy)); 817 } else if (def->isSubClassOf("AttrInterface")) { 818 decls.push_back( 819 createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( 820 name, codeBlock, loc, attrTy)); 821 } else if (def->isSubClassOf("TypeInterface")) { 822 decls.push_back( 823 createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( 824 name, codeBlock, loc, typeTy)); 825 } 826 } 827 } 828 829 template <typename ConstraintT> 830 ast::Decl * 831 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, 832 SMRange loc, ast::Type type) { 833 // Build the single input parameter. 834 ast::DeclScope *argScope = pushDeclScope(); 835 auto *paramVar = ast::VariableDecl::create( 836 ctx, ast::Name::create(ctx, "self", loc), type, 837 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); 838 argScope->add(paramVar); 839 popDeclScope(); 840 841 // Build the native constraint. 842 auto *constraintDecl = ast::UserConstraintDecl::createNative( 843 ctx, ast::Name::create(ctx, name, loc), paramVar, 844 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); 845 curDeclScope->add(constraintDecl); 846 return constraintDecl; 847 } 848 849 template <typename ConstraintT> 850 ast::Decl * 851 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, 852 SMRange loc, ast::Type type) { 853 // Format the condition template. 854 tblgen::FmtContext fmtContext; 855 fmtContext.withSelf("self"); 856 std::string codeBlock = 857 tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); 858 859 return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(), 860 codeBlock, loc, type); 861 } 862 863 //===----------------------------------------------------------------------===// 864 // Decls 865 866 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { 867 FailureOr<ast::Decl *> decl; 868 switch (curToken.getKind()) { 869 case Token::kw_Constraint: 870 decl = parseUserConstraintDecl(); 871 break; 872 case Token::kw_Pattern: 873 decl = parsePatternDecl(); 874 break; 875 case Token::kw_Rewrite: 876 decl = parseUserRewriteDecl(); 877 break; 878 default: 879 return emitError("expected top-level declaration, such as a `Pattern`"); 880 } 881 if (failed(decl)) 882 return failure(); 883 884 // If the decl has a name, add it to the current scope. 885 if (const ast::Name *name = (*decl)->getName()) { 886 if (failed(checkDefineNamedDecl(*name))) 887 return failure(); 888 curDeclScope->add(*decl); 889 } 890 return decl; 891 } 892 893 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() { 894 std::string attrNameStr; 895 if (curToken.isString()) 896 attrNameStr = curToken.getStringValue(); 897 else if (curToken.is(Token::identifier) || curToken.isKeyword()) 898 attrNameStr = curToken.getSpelling().str(); 899 else 900 return emitError("expected identifier or string attribute name"); 901 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); 902 consumeToken(); 903 904 // Check for a value of the attribute. 905 ast::Expr *attrValue = nullptr; 906 if (consumeIf(Token::equal)) { 907 FailureOr<ast::Expr *> attrExpr = parseExpr(); 908 if (failed(attrExpr)) 909 return failure(); 910 attrValue = *attrExpr; 911 } else { 912 // If there isn't a concrete value, create an expression representing a 913 // UnitAttr. 914 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); 915 } 916 917 return ast::NamedAttributeDecl::create(ctx, name, attrValue); 918 } 919 920 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( 921 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, 922 bool expectTerminalSemicolon) { 923 consumeToken(Token::equal_arrow); 924 925 // Parse the single statement of the lambda body. 926 SMLoc bodyStartLoc = curToken.getStartLoc(); 927 pushDeclScope(); 928 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); 929 bool failedToParse = 930 failed(singleStatement) || failed(processStatementFn(*singleStatement)); 931 popDeclScope(); 932 if (failedToParse) 933 return failure(); 934 935 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); 936 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); 937 } 938 939 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { 940 // Ensure that the argument is named. 941 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) 942 return emitError("expected identifier argument name"); 943 944 // Parse the argument similarly to a normal variable. 945 StringRef name = curToken.getSpelling(); 946 SMRange nameLoc = curToken.getLoc(); 947 consumeToken(); 948 949 if (failed( 950 parseToken(Token::colon, "expected `:` before argument constraint"))) 951 return failure(); 952 953 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 954 if (failed(cst)) 955 return failure(); 956 957 return createArgOrResultVariableDecl(name, nameLoc, *cst); 958 } 959 960 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { 961 // Check to see if this result is named. 962 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 963 // Check to see if this name actually refers to a Constraint. 964 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); 965 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) { 966 // If yes, and this is a Rewrite, give a nice error message as non-Core 967 // constraints are not supported on Rewrite results. 968 if (parserContext == ParserContext::Rewrite) { 969 return emitError( 970 "`Rewrite` results are only permitted to use core constraints, " 971 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); 972 } 973 974 // Otherwise, parse this as an unnamed result variable. 975 } else { 976 // If it wasn't a constraint, parse the result similarly to a variable. If 977 // there is already an existing decl, we will emit an error when defining 978 // this variable later. 979 StringRef name = curToken.getSpelling(); 980 SMRange nameLoc = curToken.getLoc(); 981 consumeToken(); 982 983 if (failed(parseToken(Token::colon, 984 "expected `:` before result constraint"))) 985 return failure(); 986 987 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 988 if (failed(cst)) 989 return failure(); 990 991 return createArgOrResultVariableDecl(name, nameLoc, *cst); 992 } 993 } 994 995 // If it isn't named, we parse the constraint directly and create an unnamed 996 // result variable. 997 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); 998 if (failed(cst)) 999 return failure(); 1000 1001 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); 1002 } 1003 1004 FailureOr<ast::UserConstraintDecl *> 1005 Parser::parseUserConstraintDecl(bool isInline) { 1006 // Constraints and rewrites have very similar formats, dispatch to a shared 1007 // interface for parsing. 1008 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1009 [&](auto &&...args) { 1010 return this->parseUserPDLLConstraintDecl(args...); 1011 }, 1012 ParserContext::Constraint, "constraint", isInline); 1013 } 1014 1015 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { 1016 FailureOr<ast::UserConstraintDecl *> decl = 1017 parseUserConstraintDecl(/*isInline=*/true); 1018 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1019 return failure(); 1020 1021 curDeclScope->add(*decl); 1022 return decl; 1023 } 1024 1025 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( 1026 const ast::Name &name, bool isInline, 1027 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1028 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1029 // Push the argument scope back onto the list, so that the body can 1030 // reference arguments. 1031 pushDeclScope(argumentScope); 1032 1033 // Parse the body of the constraint. The body is either defined as a compound 1034 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. 1035 ast::CompoundStmt *body; 1036 if (curToken.is(Token::equal_arrow)) { 1037 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1038 [&](ast::Stmt *&stmt) -> LogicalResult { 1039 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt); 1040 if (!stmtExpr) { 1041 return emitError(stmt->getLoc(), 1042 "expected `Constraint` lambda body to contain a " 1043 "single expression"); 1044 } 1045 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); 1046 return success(); 1047 }, 1048 /*expectTerminalSemicolon=*/!isInline); 1049 if (failed(bodyResult)) 1050 return failure(); 1051 body = *bodyResult; 1052 } else { 1053 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1054 if (failed(bodyResult)) 1055 return failure(); 1056 body = *bodyResult; 1057 1058 // Verify the structure of the body. 1059 auto bodyIt = body->begin(), bodyE = body->end(); 1060 for (; bodyIt != bodyE; ++bodyIt) 1061 if (isa<ast::ReturnStmt>(*bodyIt)) 1062 break; 1063 if (failed(validateUserConstraintOrRewriteReturn( 1064 "Constraint", body, bodyIt, bodyE, results, resultType))) 1065 return failure(); 1066 } 1067 popDeclScope(); 1068 1069 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( 1070 name, arguments, results, resultType, body); 1071 } 1072 1073 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { 1074 // Constraints and rewrites have very similar formats, dispatch to a shared 1075 // interface for parsing. 1076 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1077 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); }, 1078 ParserContext::Rewrite, "rewrite", isInline); 1079 } 1080 1081 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { 1082 FailureOr<ast::UserRewriteDecl *> decl = 1083 parseUserRewriteDecl(/*isInline=*/true); 1084 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) 1085 return failure(); 1086 1087 curDeclScope->add(*decl); 1088 return decl; 1089 } 1090 1091 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( 1092 const ast::Name &name, bool isInline, 1093 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, 1094 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1095 // Push the argument scope back onto the list, so that the body can 1096 // reference arguments. 1097 curDeclScope = argumentScope; 1098 ast::CompoundStmt *body; 1099 if (curToken.is(Token::equal_arrow)) { 1100 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( 1101 [&](ast::Stmt *&statement) -> LogicalResult { 1102 if (isa<ast::OpRewriteStmt>(statement)) 1103 return success(); 1104 1105 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement); 1106 if (!statementExpr) { 1107 return emitError( 1108 statement->getLoc(), 1109 "expected `Rewrite` lambda body to contain a single expression " 1110 "or an operation rewrite statement; such as `erase`, " 1111 "`replace`, or `rewrite`"); 1112 } 1113 statement = 1114 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); 1115 return success(); 1116 }, 1117 /*expectTerminalSemicolon=*/!isInline); 1118 if (failed(bodyResult)) 1119 return failure(); 1120 body = *bodyResult; 1121 } else { 1122 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1123 if (failed(bodyResult)) 1124 return failure(); 1125 body = *bodyResult; 1126 } 1127 popDeclScope(); 1128 1129 // Verify the structure of the body. 1130 auto bodyIt = body->begin(), bodyE = body->end(); 1131 for (; bodyIt != bodyE; ++bodyIt) 1132 if (isa<ast::ReturnStmt>(*bodyIt)) 1133 break; 1134 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, 1135 bodyE, results, resultType))) 1136 return failure(); 1137 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( 1138 name, arguments, results, resultType, body); 1139 } 1140 1141 template <typename T, typename ParseUserPDLLDeclFnT> 1142 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( 1143 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, 1144 StringRef anonymousNamePrefix, bool isInline) { 1145 SMRange loc = curToken.getLoc(); 1146 consumeToken(); 1147 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext); 1148 1149 // Parse the name of the decl. 1150 const ast::Name *name = nullptr; 1151 if (curToken.isNot(Token::identifier)) { 1152 // Only inline decls can be un-named. Inline decls are similar to "lambdas" 1153 // in C++, so being unnamed is fine. 1154 if (!isInline) 1155 return emitError("expected identifier name"); 1156 1157 // Create a unique anonymous name to use, as the name for this decl is not 1158 // important. 1159 std::string anonName = 1160 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix, 1161 anonymousDeclNameCounter++) 1162 .str(); 1163 name = &ast::Name::create(ctx, anonName, loc); 1164 } else { 1165 // If a name was provided, we can use it directly. 1166 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1167 consumeToken(Token::identifier); 1168 } 1169 1170 // Parse the functional signature of the decl. 1171 SmallVector<ast::VariableDecl *> arguments, results; 1172 ast::DeclScope *argumentScope; 1173 ast::Type resultType; 1174 if (failed(parseUserConstraintOrRewriteSignature(arguments, results, 1175 argumentScope, resultType))) 1176 return failure(); 1177 1178 // Check to see which type of constraint this is. If the constraint contains a 1179 // compound body, this is a PDLL decl. 1180 if (curToken.isAny(Token::l_brace, Token::equal_arrow)) 1181 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, 1182 resultType); 1183 1184 // Otherwise, this is a native decl. 1185 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, 1186 results, resultType); 1187 } 1188 1189 template <typename T> 1190 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( 1191 const ast::Name &name, bool isInline, 1192 ArrayRef<ast::VariableDecl *> arguments, 1193 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { 1194 // If followed by a string, the native code body has also been specified. 1195 std::string codeStrStorage; 1196 Optional<StringRef> optCodeStr; 1197 if (curToken.isString()) { 1198 codeStrStorage = curToken.getStringValue(); 1199 optCodeStr = codeStrStorage; 1200 consumeToken(); 1201 } else if (isInline) { 1202 return emitError(name.getLoc(), 1203 "external declarations must be declared in global scope"); 1204 } 1205 if (failed(parseToken(Token::semicolon, 1206 "expected `;` after native declaration"))) 1207 return failure(); 1208 // TODO: PDL should be able to support constraint results in certain 1209 // situations, we should revise this. 1210 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { 1211 return emitError( 1212 "native Constraints currently do not support returning results"); 1213 } 1214 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); 1215 } 1216 1217 LogicalResult Parser::parseUserConstraintOrRewriteSignature( 1218 SmallVectorImpl<ast::VariableDecl *> &arguments, 1219 SmallVectorImpl<ast::VariableDecl *> &results, 1220 ast::DeclScope *&argumentScope, ast::Type &resultType) { 1221 // Parse the argument list of the decl. 1222 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) 1223 return failure(); 1224 1225 argumentScope = pushDeclScope(); 1226 if (curToken.isNot(Token::r_paren)) { 1227 do { 1228 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); 1229 if (failed(argument)) 1230 return failure(); 1231 arguments.emplace_back(*argument); 1232 } while (consumeIf(Token::comma)); 1233 } 1234 popDeclScope(); 1235 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) 1236 return failure(); 1237 1238 // Parse the results of the decl. 1239 pushDeclScope(); 1240 if (consumeIf(Token::arrow)) { 1241 auto parseResultFn = [&]() -> LogicalResult { 1242 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size()); 1243 if (failed(result)) 1244 return failure(); 1245 results.emplace_back(*result); 1246 return success(); 1247 }; 1248 1249 // Check for a list of results. 1250 if (consumeIf(Token::l_paren)) { 1251 do { 1252 if (failed(parseResultFn())) 1253 return failure(); 1254 } while (consumeIf(Token::comma)); 1255 if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) 1256 return failure(); 1257 1258 // Otherwise, there is only one result. 1259 } else if (failed(parseResultFn())) { 1260 return failure(); 1261 } 1262 } 1263 popDeclScope(); 1264 1265 // Compute the result type of the decl. 1266 resultType = createUserConstraintRewriteResultType(results); 1267 1268 // Verify that results are only named if there are more than one. 1269 if (results.size() == 1 && !results.front()->getName().getName().empty()) { 1270 return emitError( 1271 results.front()->getLoc(), 1272 "cannot create a single-element tuple with an element label"); 1273 } 1274 return success(); 1275 } 1276 1277 LogicalResult Parser::validateUserConstraintOrRewriteReturn( 1278 StringRef declType, ast::CompoundStmt *body, 1279 ArrayRef<ast::Stmt *>::iterator bodyIt, 1280 ArrayRef<ast::Stmt *>::iterator bodyE, 1281 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { 1282 // Handle if a `return` was provided. 1283 if (bodyIt != bodyE) { 1284 // Emit an error if we have trailing statements after the return. 1285 if (std::next(bodyIt) != bodyE) { 1286 return emitError( 1287 (*std::next(bodyIt))->getLoc(), 1288 llvm::formatv("`return` terminated the `{0}` body, but found " 1289 "trailing statements afterwards", 1290 declType)); 1291 } 1292 1293 // Otherwise if a return wasn't provided, check that no results are 1294 // expected. 1295 } else if (!results.empty()) { 1296 return emitError( 1297 {body->getLoc().End, body->getLoc().End}, 1298 llvm::formatv("missing return in a `{0}` expected to return `{1}`", 1299 declType, resultType)); 1300 } 1301 return success(); 1302 } 1303 1304 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { 1305 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { 1306 if (isa<ast::OpRewriteStmt>(statement)) 1307 return success(); 1308 return emitError( 1309 statement->getLoc(), 1310 "expected Pattern lambda body to contain a single operation " 1311 "rewrite statement, such as `erase`, `replace`, or `rewrite`"); 1312 }); 1313 } 1314 1315 FailureOr<ast::Decl *> Parser::parsePatternDecl() { 1316 SMRange loc = curToken.getLoc(); 1317 consumeToken(Token::kw_Pattern); 1318 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 1319 ParserContext::PatternMatch); 1320 1321 // Check for an optional identifier for the pattern name. 1322 const ast::Name *name = nullptr; 1323 if (curToken.is(Token::identifier)) { 1324 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); 1325 consumeToken(Token::identifier); 1326 } 1327 1328 // Parse any pattern metadata. 1329 ParsedPatternMetadata metadata; 1330 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) 1331 return failure(); 1332 1333 // Parse the pattern body. 1334 ast::CompoundStmt *body; 1335 1336 // Handle a lambda body. 1337 if (curToken.is(Token::equal_arrow)) { 1338 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); 1339 if (failed(bodyResult)) 1340 return failure(); 1341 body = *bodyResult; 1342 } else { 1343 if (curToken.isNot(Token::l_brace)) 1344 return emitError("expected `{` or `=>` to start pattern body"); 1345 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); 1346 if (failed(bodyResult)) 1347 return failure(); 1348 body = *bodyResult; 1349 1350 // Verify the body of the pattern. 1351 auto bodyIt = body->begin(), bodyE = body->end(); 1352 for (; bodyIt != bodyE; ++bodyIt) { 1353 if (isa<ast::ReturnStmt>(*bodyIt)) { 1354 return emitError((*bodyIt)->getLoc(), 1355 "`return` statements are only permitted within a " 1356 "`Constraint` or `Rewrite` body"); 1357 } 1358 // Break when we've found the rewrite statement. 1359 if (isa<ast::OpRewriteStmt>(*bodyIt)) 1360 break; 1361 } 1362 if (bodyIt == bodyE) { 1363 return emitError(loc, 1364 "expected Pattern body to terminate with an operation " 1365 "rewrite statement, such as `erase`"); 1366 } 1367 if (std::next(bodyIt) != bodyE) { 1368 return emitError((*std::next(bodyIt))->getLoc(), 1369 "Pattern body was terminated by an operation " 1370 "rewrite statement, but found trailing statements"); 1371 } 1372 } 1373 1374 return createPatternDecl(loc, name, metadata, body); 1375 } 1376 1377 LogicalResult 1378 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { 1379 Optional<SMRange> benefitLoc; 1380 Optional<SMRange> hasBoundedRecursionLoc; 1381 1382 do { 1383 if (curToken.isNot(Token::identifier)) 1384 return emitError("expected pattern metadata identifier"); 1385 StringRef metadataStr = curToken.getSpelling(); 1386 SMRange metadataLoc = curToken.getLoc(); 1387 consumeToken(Token::identifier); 1388 1389 // Parse the benefit metadata: benefit(<integer-value>) 1390 if (metadataStr == "benefit") { 1391 if (benefitLoc) { 1392 return emitErrorAndNote(metadataLoc, 1393 "pattern benefit has already been specified", 1394 *benefitLoc, "see previous definition here"); 1395 } 1396 if (failed(parseToken(Token::l_paren, 1397 "expected `(` before pattern benefit"))) 1398 return failure(); 1399 1400 uint16_t benefitValue = 0; 1401 if (curToken.isNot(Token::integer)) 1402 return emitError("expected integral pattern benefit"); 1403 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) 1404 return emitError( 1405 "expected pattern benefit to fit within a 16-bit integer"); 1406 consumeToken(Token::integer); 1407 1408 metadata.benefit = benefitValue; 1409 benefitLoc = metadataLoc; 1410 1411 if (failed( 1412 parseToken(Token::r_paren, "expected `)` after pattern benefit"))) 1413 return failure(); 1414 continue; 1415 } 1416 1417 // Parse the bounded recursion metadata: recursion 1418 if (metadataStr == "recursion") { 1419 if (hasBoundedRecursionLoc) { 1420 return emitErrorAndNote( 1421 metadataLoc, 1422 "pattern recursion metadata has already been specified", 1423 *hasBoundedRecursionLoc, "see previous definition here"); 1424 } 1425 metadata.hasBoundedRecursion = true; 1426 hasBoundedRecursionLoc = metadataLoc; 1427 continue; 1428 } 1429 1430 return emitError(metadataLoc, "unknown pattern metadata"); 1431 } while (consumeIf(Token::comma)); 1432 1433 return success(); 1434 } 1435 1436 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { 1437 consumeToken(Token::less); 1438 1439 FailureOr<ast::Expr *> typeExpr = parseExpr(); 1440 if (failed(typeExpr) || 1441 failed(parseToken(Token::greater, 1442 "expected `>` after variable type constraint"))) 1443 return failure(); 1444 return typeExpr; 1445 } 1446 1447 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { 1448 assert(curDeclScope && "defining decl outside of a decl scope"); 1449 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { 1450 return emitErrorAndNote( 1451 name.getLoc(), "`" + name.getName() + "` has already been defined", 1452 lastDecl->getName()->getLoc(), "see previous definition here"); 1453 } 1454 return success(); 1455 } 1456 1457 FailureOr<ast::VariableDecl *> 1458 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1459 ast::Expr *initExpr, 1460 ArrayRef<ast::ConstraintRef> constraints) { 1461 assert(curDeclScope && "defining variable outside of decl scope"); 1462 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); 1463 1464 // If the name of the variable indicates a special variable, we don't add it 1465 // to the scope. This variable is local to the definition point. 1466 if (name.empty() || name == "_") { 1467 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, 1468 constraints); 1469 } 1470 if (failed(checkDefineNamedDecl(nameDecl))) 1471 return failure(); 1472 1473 auto *varDecl = 1474 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); 1475 curDeclScope->add(varDecl); 1476 return varDecl; 1477 } 1478 1479 FailureOr<ast::VariableDecl *> 1480 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, 1481 ArrayRef<ast::ConstraintRef> constraints) { 1482 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, 1483 constraints); 1484 } 1485 1486 LogicalResult Parser::parseVariableDeclConstraintList( 1487 SmallVectorImpl<ast::ConstraintRef> &constraints) { 1488 Optional<SMRange> typeConstraint; 1489 auto parseSingleConstraint = [&] { 1490 FailureOr<ast::ConstraintRef> constraint = parseConstraint( 1491 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); 1492 if (failed(constraint)) 1493 return failure(); 1494 constraints.push_back(*constraint); 1495 return success(); 1496 }; 1497 1498 // Check to see if this is a single constraint, or a list. 1499 if (!consumeIf(Token::l_square)) 1500 return parseSingleConstraint(); 1501 1502 do { 1503 if (failed(parseSingleConstraint())) 1504 return failure(); 1505 } while (consumeIf(Token::comma)); 1506 return parseToken(Token::r_square, "expected `]` after constraint list"); 1507 } 1508 1509 FailureOr<ast::ConstraintRef> 1510 Parser::parseConstraint(Optional<SMRange> &typeConstraint, 1511 ArrayRef<ast::ConstraintRef> existingConstraints, 1512 bool allowInlineTypeConstraints) { 1513 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { 1514 if (!allowInlineTypeConstraints) { 1515 return emitError( 1516 curToken.getLoc(), 1517 "inline `Attr`, `Value`, and `ValueRange` type constraints are not " 1518 "permitted on arguments or results"); 1519 } 1520 if (typeConstraint) 1521 return emitErrorAndNote( 1522 curToken.getLoc(), 1523 "the type of this variable has already been constrained", 1524 *typeConstraint, "see previous constraint location here"); 1525 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); 1526 if (failed(constraintExpr)) 1527 return failure(); 1528 typeExpr = *constraintExpr; 1529 typeConstraint = typeExpr->getLoc(); 1530 return success(); 1531 }; 1532 1533 SMRange loc = curToken.getLoc(); 1534 switch (curToken.getKind()) { 1535 case Token::kw_Attr: { 1536 consumeToken(Token::kw_Attr); 1537 1538 // Check for a type constraint. 1539 ast::Expr *typeExpr = nullptr; 1540 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1541 return failure(); 1542 return ast::ConstraintRef( 1543 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); 1544 } 1545 case Token::kw_Op: { 1546 consumeToken(Token::kw_Op); 1547 1548 // Parse an optional operation name. If the name isn't provided, this refers 1549 // to "any" operation. 1550 FailureOr<ast::OpNameDecl *> opName = 1551 parseWrappedOperationName(/*allowEmptyName=*/true); 1552 if (failed(opName)) 1553 return failure(); 1554 1555 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), 1556 loc); 1557 } 1558 case Token::kw_Type: 1559 consumeToken(Token::kw_Type); 1560 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); 1561 case Token::kw_TypeRange: 1562 consumeToken(Token::kw_TypeRange); 1563 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), 1564 loc); 1565 case Token::kw_Value: { 1566 consumeToken(Token::kw_Value); 1567 1568 // Check for a type constraint. 1569 ast::Expr *typeExpr = nullptr; 1570 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1571 return failure(); 1572 1573 return ast::ConstraintRef( 1574 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); 1575 } 1576 case Token::kw_ValueRange: { 1577 consumeToken(Token::kw_ValueRange); 1578 1579 // Check for a type constraint. 1580 ast::Expr *typeExpr = nullptr; 1581 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) 1582 return failure(); 1583 1584 return ast::ConstraintRef( 1585 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); 1586 } 1587 1588 case Token::kw_Constraint: { 1589 // Handle an inline constraint. 1590 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1591 if (failed(decl)) 1592 return failure(); 1593 return ast::ConstraintRef(*decl, loc); 1594 } 1595 case Token::identifier: { 1596 StringRef constraintName = curToken.getSpelling(); 1597 consumeToken(Token::identifier); 1598 1599 // Lookup the referenced constraint. 1600 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName); 1601 if (!cstDecl) { 1602 return emitError(loc, "unknown reference to constraint `" + 1603 constraintName + "`"); 1604 } 1605 1606 // Handle a reference to a proper constraint. 1607 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl)) 1608 return ast::ConstraintRef(cst, loc); 1609 1610 return emitErrorAndNote( 1611 loc, "invalid reference to non-constraint", cstDecl->getLoc(), 1612 "see the definition of `" + constraintName + "` here"); 1613 } 1614 default: 1615 break; 1616 } 1617 return emitError(loc, "expected identifier constraint"); 1618 } 1619 1620 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { 1621 Optional<SMRange> typeConstraint; 1622 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, 1623 /*allowInlineTypeConstraints=*/false); 1624 } 1625 1626 //===----------------------------------------------------------------------===// 1627 // Exprs 1628 1629 FailureOr<ast::Expr *> Parser::parseExpr() { 1630 if (curToken.is(Token::underscore)) 1631 return parseUnderscoreExpr(); 1632 1633 // Parse the LHS expression. 1634 FailureOr<ast::Expr *> lhsExpr; 1635 switch (curToken.getKind()) { 1636 case Token::kw_attr: 1637 lhsExpr = parseAttributeExpr(); 1638 break; 1639 case Token::kw_Constraint: 1640 lhsExpr = parseInlineConstraintLambdaExpr(); 1641 break; 1642 case Token::identifier: 1643 lhsExpr = parseIdentifierExpr(); 1644 break; 1645 case Token::kw_op: 1646 lhsExpr = parseOperationExpr(); 1647 break; 1648 case Token::kw_Rewrite: 1649 lhsExpr = parseInlineRewriteLambdaExpr(); 1650 break; 1651 case Token::kw_type: 1652 lhsExpr = parseTypeExpr(); 1653 break; 1654 case Token::l_paren: 1655 lhsExpr = parseTupleExpr(); 1656 break; 1657 default: 1658 return emitError("expected expression"); 1659 } 1660 if (failed(lhsExpr)) 1661 return failure(); 1662 1663 // Check for an operator expression. 1664 while (true) { 1665 switch (curToken.getKind()) { 1666 case Token::dot: 1667 lhsExpr = parseMemberAccessExpr(*lhsExpr); 1668 break; 1669 case Token::l_paren: 1670 lhsExpr = parseCallExpr(*lhsExpr); 1671 break; 1672 default: 1673 return lhsExpr; 1674 } 1675 if (failed(lhsExpr)) 1676 return failure(); 1677 } 1678 } 1679 1680 FailureOr<ast::Expr *> Parser::parseAttributeExpr() { 1681 SMRange loc = curToken.getLoc(); 1682 consumeToken(Token::kw_attr); 1683 1684 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal 1685 // identifier. 1686 if (!consumeIf(Token::less)) { 1687 resetToken(loc); 1688 return parseIdentifierExpr(); 1689 } 1690 1691 if (!curToken.isString()) 1692 return emitError("expected string literal containing MLIR attribute"); 1693 std::string attrExpr = curToken.getStringValue(); 1694 consumeToken(); 1695 1696 if (failed( 1697 parseToken(Token::greater, "expected `>` after attribute literal"))) 1698 return failure(); 1699 return ast::AttributeExpr::create(ctx, loc, attrExpr); 1700 } 1701 1702 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) { 1703 SMRange loc = curToken.getLoc(); 1704 consumeToken(Token::l_paren); 1705 1706 // Parse the arguments of the call. 1707 SmallVector<ast::Expr *> arguments; 1708 if (curToken.isNot(Token::r_paren)) { 1709 do { 1710 FailureOr<ast::Expr *> argument = parseExpr(); 1711 if (failed(argument)) 1712 return failure(); 1713 arguments.push_back(*argument); 1714 } while (consumeIf(Token::comma)); 1715 } 1716 loc.End = curToken.getEndLoc(); 1717 if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) 1718 return failure(); 1719 1720 return createCallExpr(loc, parentExpr, arguments); 1721 } 1722 1723 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { 1724 ast::Decl *decl = curDeclScope->lookup(name); 1725 if (!decl) 1726 return emitError(loc, "undefined reference to `" + name + "`"); 1727 1728 return createDeclRefExpr(loc, decl); 1729 } 1730 1731 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { 1732 StringRef name = curToken.getSpelling(); 1733 SMRange nameLoc = curToken.getLoc(); 1734 consumeToken(); 1735 1736 // Check to see if this is a decl ref expression that defines a variable 1737 // inline. 1738 if (consumeIf(Token::colon)) { 1739 SmallVector<ast::ConstraintRef> constraints; 1740 if (failed(parseVariableDeclConstraintList(constraints))) 1741 return failure(); 1742 ast::Type type; 1743 if (failed(validateVariableConstraints(constraints, type))) 1744 return failure(); 1745 return createInlineVariableExpr(type, name, nameLoc, constraints); 1746 } 1747 1748 return parseDeclRefExpr(name, nameLoc); 1749 } 1750 1751 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { 1752 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); 1753 if (failed(decl)) 1754 return failure(); 1755 1756 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1757 ast::ConstraintType::get(ctx)); 1758 } 1759 1760 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { 1761 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); 1762 if (failed(decl)) 1763 return failure(); 1764 1765 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, 1766 ast::RewriteType::get(ctx)); 1767 } 1768 1769 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { 1770 SMRange loc = curToken.getLoc(); 1771 consumeToken(Token::dot); 1772 1773 // Parse the member name. 1774 Token memberNameTok = curToken; 1775 if (memberNameTok.isNot(Token::identifier, Token::integer) && 1776 !memberNameTok.isKeyword()) 1777 return emitError(loc, "expected identifier or numeric member name"); 1778 StringRef memberName = memberNameTok.getSpelling(); 1779 consumeToken(); 1780 1781 return createMemberAccessExpr(parentExpr, memberName, loc); 1782 } 1783 1784 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { 1785 SMRange loc = curToken.getLoc(); 1786 1787 // Handle the case of an no operation name. 1788 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { 1789 if (allowEmptyName) 1790 return ast::OpNameDecl::create(ctx, SMRange()); 1791 return emitError("expected dialect namespace"); 1792 } 1793 StringRef name = curToken.getSpelling(); 1794 consumeToken(); 1795 1796 // Otherwise, this is a literal operation name. 1797 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) 1798 return failure(); 1799 1800 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) 1801 return emitError("expected operation name after dialect namespace"); 1802 1803 name = StringRef(name.data(), name.size() + 1); 1804 do { 1805 name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); 1806 loc.End = curToken.getEndLoc(); 1807 consumeToken(); 1808 } while (curToken.isAny(Token::identifier, Token::dot) || 1809 curToken.isKeyword()); 1810 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); 1811 } 1812 1813 FailureOr<ast::OpNameDecl *> 1814 Parser::parseWrappedOperationName(bool allowEmptyName) { 1815 if (!consumeIf(Token::less)) 1816 return ast::OpNameDecl::create(ctx, SMRange()); 1817 1818 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); 1819 if (failed(opNameDecl)) 1820 return failure(); 1821 1822 if (failed(parseToken(Token::greater, "expected `>` after operation name"))) 1823 return failure(); 1824 return opNameDecl; 1825 } 1826 1827 FailureOr<ast::Expr *> Parser::parseOperationExpr() { 1828 SMRange loc = curToken.getLoc(); 1829 consumeToken(Token::kw_op); 1830 1831 // If it isn't followed by a `<`, the `op` keyword is treated as a normal 1832 // identifier. 1833 if (curToken.isNot(Token::less)) { 1834 resetToken(loc); 1835 return parseIdentifierExpr(); 1836 } 1837 1838 // Parse the operation name. The name may be elided, in which case the 1839 // operation refers to "any" operation(i.e. a difference between `MyOp` and 1840 // `Operation*`). Operation names within a rewrite context must be named. 1841 bool allowEmptyName = parserContext != ParserContext::Rewrite; 1842 FailureOr<ast::OpNameDecl *> opNameDecl = 1843 parseWrappedOperationName(allowEmptyName); 1844 if (failed(opNameDecl)) 1845 return failure(); 1846 1847 // Functor used to create an implicit range variable, used for implicit "all" 1848 // operand or results variables. 1849 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { 1850 FailureOr<ast::VariableDecl *> rangeVar = 1851 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc)); 1852 assert(succeeded(rangeVar) && "expected range variable to be valid"); 1853 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type); 1854 }; 1855 1856 // Check for the optional list of operands. 1857 SmallVector<ast::Expr *> operands; 1858 if (!consumeIf(Token::l_paren)) { 1859 // If the operand list isn't specified and we are in a match context, define 1860 // an inplace unconstrained operand range corresponding to all of the 1861 // operands of the operation. This avoids treating zero operands the same 1862 // way as "unconstrained operands". 1863 if (parserContext != ParserContext::Rewrite) { 1864 operands.push_back(createImplicitRangeVar( 1865 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); 1866 } 1867 } else if (!consumeIf(Token::r_paren)) { 1868 // If the operand list was specified and non-empty, parse the operands. 1869 do { 1870 FailureOr<ast::Expr *> operand = parseExpr(); 1871 if (failed(operand)) 1872 return failure(); 1873 operands.push_back(*operand); 1874 } while (consumeIf(Token::comma)); 1875 1876 if (failed(parseToken(Token::r_paren, 1877 "expected `)` after operation operand list"))) 1878 return failure(); 1879 } 1880 1881 // Check for the optional list of attributes. 1882 SmallVector<ast::NamedAttributeDecl *> attributes; 1883 if (consumeIf(Token::l_brace)) { 1884 do { 1885 FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl(); 1886 if (failed(decl)) 1887 return failure(); 1888 attributes.emplace_back(*decl); 1889 } while (consumeIf(Token::comma)); 1890 1891 if (failed(parseToken(Token::r_brace, 1892 "expected `}` after operation attribute list"))) 1893 return failure(); 1894 } 1895 1896 // Check for the optional list of result types. 1897 SmallVector<ast::Expr *> resultTypes; 1898 if (consumeIf(Token::arrow)) { 1899 if (failed(parseToken(Token::l_paren, 1900 "expected `(` before operation result type list"))) 1901 return failure(); 1902 1903 // Handle the case of an empty result list. 1904 if (!consumeIf(Token::r_paren)) { 1905 do { 1906 FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); 1907 if (failed(resultTypeExpr)) 1908 return failure(); 1909 resultTypes.push_back(*resultTypeExpr); 1910 } while (consumeIf(Token::comma)); 1911 1912 if (failed(parseToken(Token::r_paren, 1913 "expected `)` after operation result type list"))) 1914 return failure(); 1915 } 1916 } else if (parserContext != ParserContext::Rewrite) { 1917 // If the result list isn't specified and we are in a match context, define 1918 // an inplace unconstrained result range corresponding to all of the results 1919 // of the operation. This avoids treating zero results the same way as 1920 // "unconstrained results". 1921 resultTypes.push_back(createImplicitRangeVar( 1922 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); 1923 } 1924 1925 return createOperationExpr(loc, *opNameDecl, operands, attributes, 1926 resultTypes); 1927 } 1928 1929 FailureOr<ast::Expr *> Parser::parseTupleExpr() { 1930 SMRange loc = curToken.getLoc(); 1931 consumeToken(Token::l_paren); 1932 1933 DenseMap<StringRef, SMRange> usedNames; 1934 SmallVector<StringRef> elementNames; 1935 SmallVector<ast::Expr *> elements; 1936 if (curToken.isNot(Token::r_paren)) { 1937 do { 1938 // Check for the optional element name assignment before the value. 1939 StringRef elementName; 1940 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { 1941 Token elementNameTok = curToken; 1942 consumeToken(); 1943 1944 // The element name is only present if followed by an `=`. 1945 if (consumeIf(Token::equal)) { 1946 elementName = elementNameTok.getSpelling(); 1947 1948 // Check to see if this name is already used. 1949 auto elementNameIt = 1950 usedNames.try_emplace(elementName, elementNameTok.getLoc()); 1951 if (!elementNameIt.second) { 1952 return emitErrorAndNote( 1953 elementNameTok.getLoc(), 1954 llvm::formatv("duplicate tuple element label `{0}`", 1955 elementName), 1956 elementNameIt.first->getSecond(), 1957 "see previous label use here"); 1958 } 1959 } else { 1960 // Otherwise, we treat this as part of an expression so reset the 1961 // lexer. 1962 resetToken(elementNameTok.getLoc()); 1963 } 1964 } 1965 elementNames.push_back(elementName); 1966 1967 // Parse the tuple element value. 1968 FailureOr<ast::Expr *> element = parseExpr(); 1969 if (failed(element)) 1970 return failure(); 1971 elements.push_back(*element); 1972 } while (consumeIf(Token::comma)); 1973 } 1974 loc.End = curToken.getEndLoc(); 1975 if (failed( 1976 parseToken(Token::r_paren, "expected `)` after tuple element list"))) 1977 return failure(); 1978 return createTupleExpr(loc, elements, elementNames); 1979 } 1980 1981 FailureOr<ast::Expr *> Parser::parseTypeExpr() { 1982 SMRange loc = curToken.getLoc(); 1983 consumeToken(Token::kw_type); 1984 1985 // If we aren't followed by a `<`, the `type` keyword is treated as a normal 1986 // identifier. 1987 if (!consumeIf(Token::less)) { 1988 resetToken(loc); 1989 return parseIdentifierExpr(); 1990 } 1991 1992 if (!curToken.isString()) 1993 return emitError("expected string literal containing MLIR type"); 1994 std::string attrExpr = curToken.getStringValue(); 1995 consumeToken(); 1996 1997 if (failed(parseToken(Token::greater, "expected `>` after type literal"))) 1998 return failure(); 1999 return ast::TypeExpr::create(ctx, loc, attrExpr); 2000 } 2001 2002 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { 2003 StringRef name = curToken.getSpelling(); 2004 SMRange nameLoc = curToken.getLoc(); 2005 consumeToken(Token::underscore); 2006 2007 // Underscore expressions require a constraint list. 2008 if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) 2009 return failure(); 2010 2011 // Parse the constraints for the expression. 2012 SmallVector<ast::ConstraintRef> constraints; 2013 if (failed(parseVariableDeclConstraintList(constraints))) 2014 return failure(); 2015 2016 ast::Type type; 2017 if (failed(validateVariableConstraints(constraints, type))) 2018 return failure(); 2019 return createInlineVariableExpr(type, name, nameLoc, constraints); 2020 } 2021 2022 //===----------------------------------------------------------------------===// 2023 // Stmts 2024 2025 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { 2026 FailureOr<ast::Stmt *> stmt; 2027 switch (curToken.getKind()) { 2028 case Token::kw_erase: 2029 stmt = parseEraseStmt(); 2030 break; 2031 case Token::kw_let: 2032 stmt = parseLetStmt(); 2033 break; 2034 case Token::kw_replace: 2035 stmt = parseReplaceStmt(); 2036 break; 2037 case Token::kw_return: 2038 stmt = parseReturnStmt(); 2039 break; 2040 case Token::kw_rewrite: 2041 stmt = parseRewriteStmt(); 2042 break; 2043 default: 2044 stmt = parseExpr(); 2045 break; 2046 } 2047 if (failed(stmt) || 2048 (expectTerminalSemicolon && 2049 failed(parseToken(Token::semicolon, "expected `;` after statement")))) 2050 return failure(); 2051 return stmt; 2052 } 2053 2054 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { 2055 SMLoc startLoc = curToken.getStartLoc(); 2056 consumeToken(Token::l_brace); 2057 2058 // Push a new block scope and parse any nested statements. 2059 pushDeclScope(); 2060 SmallVector<ast::Stmt *> statements; 2061 while (curToken.isNot(Token::r_brace)) { 2062 FailureOr<ast::Stmt *> statement = parseStmt(); 2063 if (failed(statement)) 2064 return popDeclScope(), failure(); 2065 statements.push_back(*statement); 2066 } 2067 popDeclScope(); 2068 2069 // Consume the end brace. 2070 SMRange location(startLoc, curToken.getEndLoc()); 2071 consumeToken(Token::r_brace); 2072 2073 return ast::CompoundStmt::create(ctx, location, statements); 2074 } 2075 2076 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { 2077 if (parserContext == ParserContext::Constraint) 2078 return emitError("`erase` cannot be used within a Constraint"); 2079 SMRange loc = curToken.getLoc(); 2080 consumeToken(Token::kw_erase); 2081 2082 // Parse the root operation expression. 2083 FailureOr<ast::Expr *> rootOp = parseExpr(); 2084 if (failed(rootOp)) 2085 return failure(); 2086 2087 return createEraseStmt(loc, *rootOp); 2088 } 2089 2090 FailureOr<ast::LetStmt *> Parser::parseLetStmt() { 2091 SMRange loc = curToken.getLoc(); 2092 consumeToken(Token::kw_let); 2093 2094 // Parse the name of the new variable. 2095 SMRange varLoc = curToken.getLoc(); 2096 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { 2097 // `_` is a reserved variable name. 2098 if (curToken.is(Token::underscore)) { 2099 return emitError(varLoc, 2100 "`_` may only be used to define \"inline\" variables"); 2101 } 2102 return emitError(varLoc, 2103 "expected identifier after `let` to name a new variable"); 2104 } 2105 StringRef varName = curToken.getSpelling(); 2106 consumeToken(); 2107 2108 // Parse the optional set of constraints. 2109 SmallVector<ast::ConstraintRef> constraints; 2110 if (consumeIf(Token::colon) && 2111 failed(parseVariableDeclConstraintList(constraints))) 2112 return failure(); 2113 2114 // Parse the optional initializer expression. 2115 ast::Expr *initializer = nullptr; 2116 if (consumeIf(Token::equal)) { 2117 FailureOr<ast::Expr *> initOrFailure = parseExpr(); 2118 if (failed(initOrFailure)) 2119 return failure(); 2120 initializer = *initOrFailure; 2121 2122 // Check that the constraints are compatible with having an initializer, 2123 // e.g. type constraints cannot be used with initializers. 2124 for (ast::ConstraintRef constraint : constraints) { 2125 LogicalResult result = 2126 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) 2127 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 2128 ast::ValueRangeConstraintDecl>([&](const auto *cst) { 2129 if (auto *typeConstraintExpr = cst->getTypeExpr()) { 2130 return this->emitError( 2131 constraint.referenceLoc, 2132 "type constraints are not permitted on variables with " 2133 "initializers"); 2134 } 2135 return success(); 2136 }) 2137 .Default(success()); 2138 if (failed(result)) 2139 return failure(); 2140 } 2141 } 2142 2143 FailureOr<ast::VariableDecl *> varDecl = 2144 createVariableDecl(varName, varLoc, initializer, constraints); 2145 if (failed(varDecl)) 2146 return failure(); 2147 return ast::LetStmt::create(ctx, loc, *varDecl); 2148 } 2149 2150 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { 2151 if (parserContext == ParserContext::Constraint) 2152 return emitError("`replace` cannot be used within a Constraint"); 2153 SMRange loc = curToken.getLoc(); 2154 consumeToken(Token::kw_replace); 2155 2156 // Parse the root operation expression. 2157 FailureOr<ast::Expr *> rootOp = parseExpr(); 2158 if (failed(rootOp)) 2159 return failure(); 2160 2161 if (failed( 2162 parseToken(Token::kw_with, "expected `with` after root operation"))) 2163 return failure(); 2164 2165 // The replacement portion of this statement is within a rewrite context. 2166 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2167 ParserContext::Rewrite); 2168 2169 // Parse the replacement values. 2170 SmallVector<ast::Expr *> replValues; 2171 if (consumeIf(Token::l_paren)) { 2172 if (consumeIf(Token::r_paren)) { 2173 return emitError( 2174 loc, "expected at least one replacement value, consider using " 2175 "`erase` if no replacement values are desired"); 2176 } 2177 2178 do { 2179 FailureOr<ast::Expr *> replExpr = parseExpr(); 2180 if (failed(replExpr)) 2181 return failure(); 2182 replValues.emplace_back(*replExpr); 2183 } while (consumeIf(Token::comma)); 2184 2185 if (failed(parseToken(Token::r_paren, 2186 "expected `)` after replacement values"))) 2187 return failure(); 2188 } else { 2189 FailureOr<ast::Expr *> replExpr = parseExpr(); 2190 if (failed(replExpr)) 2191 return failure(); 2192 replValues.emplace_back(*replExpr); 2193 } 2194 2195 return createReplaceStmt(loc, *rootOp, replValues); 2196 } 2197 2198 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { 2199 SMRange loc = curToken.getLoc(); 2200 consumeToken(Token::kw_return); 2201 2202 // Parse the result value. 2203 FailureOr<ast::Expr *> resultExpr = parseExpr(); 2204 if (failed(resultExpr)) 2205 return failure(); 2206 2207 return ast::ReturnStmt::create(ctx, loc, *resultExpr); 2208 } 2209 2210 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { 2211 if (parserContext == ParserContext::Constraint) 2212 return emitError("`rewrite` cannot be used within a Constraint"); 2213 SMRange loc = curToken.getLoc(); 2214 consumeToken(Token::kw_rewrite); 2215 2216 // Parse the root operation. 2217 FailureOr<ast::Expr *> rootOp = parseExpr(); 2218 if (failed(rootOp)) 2219 return failure(); 2220 2221 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) 2222 return failure(); 2223 2224 if (curToken.isNot(Token::l_brace)) 2225 return emitError("expected `{` to start rewrite body"); 2226 2227 // The rewrite body of this statement is within a rewrite context. 2228 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, 2229 ParserContext::Rewrite); 2230 2231 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); 2232 if (failed(rewriteBody)) 2233 return failure(); 2234 2235 // Verify the rewrite body. 2236 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { 2237 if (isa<ast::ReturnStmt>(stmt)) { 2238 return emitError(stmt->getLoc(), 2239 "`return` statements are only permitted within a " 2240 "`Constraint` or `Rewrite` body"); 2241 } 2242 } 2243 2244 return createRewriteStmt(loc, *rootOp, *rewriteBody); 2245 } 2246 2247 //===----------------------------------------------------------------------===// 2248 // Creation+Analysis 2249 //===----------------------------------------------------------------------===// 2250 2251 //===----------------------------------------------------------------------===// 2252 // Decls 2253 2254 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { 2255 // Unwrap reference expressions. 2256 if (auto *init = dyn_cast<ast::DeclRefExpr>(node)) 2257 node = init->getDecl(); 2258 return dyn_cast<ast::CallableDecl>(node); 2259 } 2260 2261 FailureOr<ast::PatternDecl *> 2262 Parser::createPatternDecl(SMRange loc, const ast::Name *name, 2263 const ParsedPatternMetadata &metadata, 2264 ast::CompoundStmt *body) { 2265 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, 2266 metadata.hasBoundedRecursion, body); 2267 } 2268 2269 ast::Type Parser::createUserConstraintRewriteResultType( 2270 ArrayRef<ast::VariableDecl *> results) { 2271 // Single result decls use the type of the single result. 2272 if (results.size() == 1) 2273 return results[0]->getType(); 2274 2275 // Multiple results use a tuple type, with the types and names grabbed from 2276 // the result variable decls. 2277 auto resultTypes = llvm::map_range( 2278 results, [&](const auto *result) { return result->getType(); }); 2279 auto resultNames = llvm::map_range( 2280 results, [&](const auto *result) { return result->getName().getName(); }); 2281 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), 2282 llvm::to_vector(resultNames)); 2283 } 2284 2285 template <typename T> 2286 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( 2287 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, 2288 ArrayRef<ast::VariableDecl *> results, ast::Type resultType, 2289 ast::CompoundStmt *body) { 2290 if (!body->getChildren().empty()) { 2291 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) { 2292 ast::Expr *resultExpr = retStmt->getResultExpr(); 2293 2294 // Process the result of the decl. If no explicit signature results 2295 // were provided, check for return type inference. Otherwise, check that 2296 // the return expression can be converted to the expected type. 2297 if (results.empty()) 2298 resultType = resultExpr->getType(); 2299 else if (failed(convertExpressionTo(resultExpr, resultType))) 2300 return failure(); 2301 else 2302 retStmt->setResultExpr(resultExpr); 2303 } 2304 } 2305 return T::createPDLL(ctx, name, arguments, results, body, resultType); 2306 } 2307 2308 FailureOr<ast::VariableDecl *> 2309 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, 2310 ArrayRef<ast::ConstraintRef> constraints) { 2311 // The type of the variable, which is expected to be inferred by either a 2312 // constraint or an initializer expression. 2313 ast::Type type; 2314 if (failed(validateVariableConstraints(constraints, type))) 2315 return failure(); 2316 2317 if (initializer) { 2318 // Update the variable type based on the initializer, or try to convert the 2319 // initializer to the existing type. 2320 if (!type) 2321 type = initializer->getType(); 2322 else if (ast::Type mergedType = type.refineWith(initializer->getType())) 2323 type = mergedType; 2324 else if (failed(convertExpressionTo(initializer, type))) 2325 return failure(); 2326 2327 // Otherwise, if there is no initializer check that the type has already 2328 // been resolved from the constraint list. 2329 } else if (!type) { 2330 return emitErrorAndNote( 2331 loc, "unable to infer type for variable `" + name + "`", loc, 2332 "the type of a variable must be inferable from the constraint " 2333 "list or the initializer"); 2334 } 2335 2336 // Constraint types cannot be used when defining variables. 2337 if (type.isa<ast::ConstraintType, ast::RewriteType>()) { 2338 return emitError( 2339 loc, llvm::formatv("unable to define variable of `{0}` type", type)); 2340 } 2341 2342 // Try to define a variable with the given name. 2343 FailureOr<ast::VariableDecl *> varDecl = 2344 defineVariableDecl(name, loc, type, initializer, constraints); 2345 if (failed(varDecl)) 2346 return failure(); 2347 2348 return *varDecl; 2349 } 2350 2351 FailureOr<ast::VariableDecl *> 2352 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, 2353 const ast::ConstraintRef &constraint) { 2354 // Constraint arguments may apply more complex constraints via the arguments. 2355 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; 2356 ast::Type argType; 2357 if (failed(validateVariableConstraint(constraint, argType, 2358 allowNonCoreConstraints))) 2359 return failure(); 2360 return defineVariableDecl(name, loc, argType, constraint); 2361 } 2362 2363 LogicalResult 2364 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, 2365 ast::Type &inferredType) { 2366 for (const ast::ConstraintRef &ref : constraints) 2367 if (failed(validateVariableConstraint(ref, inferredType))) 2368 return failure(); 2369 return success(); 2370 } 2371 2372 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, 2373 ast::Type &inferredType, 2374 bool allowNonCoreConstraints) { 2375 ast::Type constraintType; 2376 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) { 2377 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2378 if (failed(validateTypeConstraintExpr(typeExpr))) 2379 return failure(); 2380 } 2381 constraintType = ast::AttributeType::get(ctx); 2382 } else if (const auto *cst = 2383 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) { 2384 constraintType = ast::OperationType::get(ctx, cst->getName()); 2385 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) { 2386 constraintType = typeTy; 2387 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) { 2388 constraintType = typeRangeTy; 2389 } else if (const auto *cst = 2390 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) { 2391 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2392 if (failed(validateTypeConstraintExpr(typeExpr))) 2393 return failure(); 2394 } 2395 constraintType = valueTy; 2396 } else if (const auto *cst = 2397 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) { 2398 if (const ast::Expr *typeExpr = cst->getTypeExpr()) { 2399 if (failed(validateTypeRangeConstraintExpr(typeExpr))) 2400 return failure(); 2401 } 2402 constraintType = valueRangeTy; 2403 } else if (const auto *cst = 2404 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) { 2405 if (!allowNonCoreConstraints) { 2406 return emitError(ref.referenceLoc, 2407 "`Rewrite` arguments and results are only permitted to " 2408 "use core constraints, such as `Attr`, `Op`, `Type`, " 2409 "`TypeRange`, `Value`, `ValueRange`"); 2410 } 2411 2412 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); 2413 if (inputs.size() != 1) { 2414 return emitErrorAndNote(ref.referenceLoc, 2415 "`Constraint`s applied via a variable constraint " 2416 "list must take a single input, but got " + 2417 Twine(inputs.size()), 2418 cst->getLoc(), 2419 "see definition of constraint here"); 2420 } 2421 constraintType = inputs.front()->getType(); 2422 } else { 2423 llvm_unreachable("unknown constraint type"); 2424 } 2425 2426 // Check that the constraint type is compatible with the current inferred 2427 // type. 2428 if (!inferredType) { 2429 inferredType = constraintType; 2430 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { 2431 inferredType = mergedTy; 2432 } else { 2433 return emitError(ref.referenceLoc, 2434 llvm::formatv("constraint type `{0}` is incompatible " 2435 "with the previously inferred type `{1}`", 2436 constraintType, inferredType)); 2437 } 2438 return success(); 2439 } 2440 2441 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { 2442 ast::Type typeExprType = typeExpr->getType(); 2443 if (typeExprType != typeTy) { 2444 return emitError(typeExpr->getLoc(), 2445 "expected expression of `Type` in type constraint"); 2446 } 2447 return success(); 2448 } 2449 2450 LogicalResult 2451 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { 2452 ast::Type typeExprType = typeExpr->getType(); 2453 if (typeExprType != typeRangeTy) { 2454 return emitError(typeExpr->getLoc(), 2455 "expected expression of `TypeRange` in type constraint"); 2456 } 2457 return success(); 2458 } 2459 2460 //===----------------------------------------------------------------------===// 2461 // Exprs 2462 2463 FailureOr<ast::CallExpr *> 2464 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, 2465 MutableArrayRef<ast::Expr *> arguments) { 2466 ast::Type parentType = parentExpr->getType(); 2467 2468 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); 2469 if (!callableDecl) { 2470 return emitError(loc, 2471 llvm::formatv("expected a reference to a callable " 2472 "`Constraint` or `Rewrite`, but got: `{0}`", 2473 parentType)); 2474 } 2475 if (parserContext == ParserContext::Rewrite) { 2476 if (isa<ast::UserConstraintDecl>(callableDecl)) 2477 return emitError( 2478 loc, "unable to invoke `Constraint` within a rewrite section"); 2479 } else if (isa<ast::UserRewriteDecl>(callableDecl)) { 2480 return emitError(loc, "unable to invoke `Rewrite` within a match section"); 2481 } 2482 2483 // Verify the arguments of the call. 2484 /// Handle size mismatch. 2485 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); 2486 if (callArgs.size() != arguments.size()) { 2487 return emitErrorAndNote( 2488 loc, 2489 llvm::formatv("invalid number of arguments for {0} call; expected " 2490 "{1}, but got {2}", 2491 callableDecl->getCallableType(), callArgs.size(), 2492 arguments.size()), 2493 callableDecl->getLoc(), 2494 llvm::formatv("see the definition of {0} here", 2495 callableDecl->getName()->getName())); 2496 } 2497 2498 /// Handle argument type mismatch. 2499 auto attachDiagFn = [&](ast::Diagnostic &diag) { 2500 diag.attachNote(llvm::formatv("see the definition of `{0}` here", 2501 callableDecl->getName()->getName()), 2502 callableDecl->getLoc()); 2503 }; 2504 for (auto it : llvm::zip(callArgs, arguments)) { 2505 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), 2506 attachDiagFn))) 2507 return failure(); 2508 } 2509 2510 return ast::CallExpr::create(ctx, loc, parentExpr, arguments, 2511 callableDecl->getResultType()); 2512 } 2513 2514 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, 2515 ast::Decl *decl) { 2516 // Check the type of decl being referenced. 2517 ast::Type declType; 2518 if (isa<ast::ConstraintDecl>(decl)) 2519 declType = ast::ConstraintType::get(ctx); 2520 else if (isa<ast::UserRewriteDecl>(decl)) 2521 declType = ast::RewriteType::get(ctx); 2522 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) 2523 declType = varDecl->getType(); 2524 else 2525 return emitError(loc, "invalid reference to `" + 2526 decl->getName()->getName() + "`"); 2527 2528 return ast::DeclRefExpr::create(ctx, loc, decl, declType); 2529 } 2530 2531 FailureOr<ast::DeclRefExpr *> 2532 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, 2533 ArrayRef<ast::ConstraintRef> constraints) { 2534 FailureOr<ast::VariableDecl *> decl = 2535 defineVariableDecl(name, loc, type, constraints); 2536 if (failed(decl)) 2537 return failure(); 2538 return ast::DeclRefExpr::create(ctx, loc, *decl, type); 2539 } 2540 2541 FailureOr<ast::MemberAccessExpr *> 2542 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, 2543 SMRange loc) { 2544 // Validate the member name for the given parent expression. 2545 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); 2546 if (failed(memberType)) 2547 return failure(); 2548 2549 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); 2550 } 2551 2552 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, 2553 StringRef name, SMRange loc) { 2554 ast::Type parentType = parentExpr->getType(); 2555 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 2556 if (name == ast::AllResultsMemberAccessExpr::getMemberName()) 2557 return valueRangeTy; 2558 2559 // Verify member access based on the operation type. 2560 if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { 2561 auto results = odsOp->getResults(); 2562 2563 // Handle indexed results. 2564 unsigned index = 0; 2565 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2566 index < results.size()) { 2567 return results[index].isVariadic() ? valueRangeTy : valueTy; 2568 } 2569 2570 // Handle named results. 2571 const auto *it = llvm::find_if(results, [&](const auto &result) { 2572 return result.getName() == name; 2573 }); 2574 if (it != results.end()) 2575 return it->isVariadic() ? valueRangeTy : valueTy; 2576 } 2577 2578 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 2579 // Handle indexed results. 2580 unsigned index = 0; 2581 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && 2582 index < tupleType.size()) { 2583 return tupleType.getElementTypes()[index]; 2584 } 2585 2586 // Handle named results. 2587 auto elementNames = tupleType.getElementNames(); 2588 const auto *it = llvm::find(elementNames, name); 2589 if (it != elementNames.end()) 2590 return tupleType.getElementTypes()[it - elementNames.begin()]; 2591 } 2592 return emitError( 2593 loc, 2594 llvm::formatv("invalid member access `{0}` on expression of type `{1}`", 2595 name, parentType)); 2596 } 2597 2598 FailureOr<ast::OperationExpr *> Parser::createOperationExpr( 2599 SMRange loc, const ast::OpNameDecl *name, 2600 MutableArrayRef<ast::Expr *> operands, 2601 MutableArrayRef<ast::NamedAttributeDecl *> attributes, 2602 MutableArrayRef<ast::Expr *> results) { 2603 Optional<StringRef> opNameRef = name->getName(); 2604 const ods::Operation *odsOp = lookupODSOperation(opNameRef); 2605 2606 // Verify the inputs operands. 2607 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) 2608 return failure(); 2609 2610 // Verify the attribute list. 2611 for (ast::NamedAttributeDecl *attr : attributes) { 2612 // Check for an attribute type, or a type awaiting resolution. 2613 ast::Type attrType = attr->getValue()->getType(); 2614 if (!attrType.isa<ast::AttributeType>()) { 2615 return emitError( 2616 attr->getValue()->getLoc(), 2617 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); 2618 } 2619 } 2620 2621 // Verify the result types. 2622 if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) 2623 return failure(); 2624 2625 return ast::OperationExpr::create(ctx, loc, name, operands, results, 2626 attributes); 2627 } 2628 2629 LogicalResult 2630 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, 2631 const ods::Operation *odsOp, 2632 MutableArrayRef<ast::Expr *> operands) { 2633 return validateOperationOperandsOrResults( 2634 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2635 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, 2636 valueRangeTy); 2637 } 2638 2639 LogicalResult 2640 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, 2641 const ods::Operation *odsOp, 2642 MutableArrayRef<ast::Expr *> results) { 2643 return validateOperationOperandsOrResults( 2644 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name, 2645 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); 2646 } 2647 2648 LogicalResult Parser::validateOperationOperandsOrResults( 2649 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc, 2650 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, 2651 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, 2652 ast::Type rangeTy) { 2653 // All operation types accept a single range parameter. 2654 if (values.size() == 1) { 2655 if (failed(convertExpressionTo(values[0], rangeTy))) 2656 return failure(); 2657 return success(); 2658 } 2659 2660 /// If the operation has ODS information, we can more accurately verify the 2661 /// values. 2662 if (odsOpLoc) { 2663 if (odsValues.size() != values.size()) { 2664 return emitErrorAndNote( 2665 loc, 2666 llvm::formatv("invalid number of {0} groups for `{1}`; expected " 2667 "{2}, but got {3}", 2668 groupName, *name, odsValues.size(), values.size()), 2669 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); 2670 } 2671 auto diagFn = [&](ast::Diagnostic &diag) { 2672 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), 2673 *odsOpLoc); 2674 }; 2675 for (unsigned i = 0, e = values.size(); i < e; ++i) { 2676 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; 2677 if (failed(convertExpressionTo(values[i], expectedType, diagFn))) 2678 return failure(); 2679 } 2680 return success(); 2681 } 2682 2683 // Otherwise, accept the value groups as they have been defined and just 2684 // ensure they are one of the expected types. 2685 for (ast::Expr *&valueExpr : values) { 2686 ast::Type valueExprType = valueExpr->getType(); 2687 2688 // Check if this is one of the expected types. 2689 if (valueExprType == rangeTy || valueExprType == singleTy) 2690 continue; 2691 2692 // If the operand is an Operation, allow converting to a Value or 2693 // ValueRange. This situations arises quite often with nested operation 2694 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` 2695 if (singleTy == valueTy) { 2696 if (valueExprType.isa<ast::OperationType>()) { 2697 valueExpr = convertOpToValue(valueExpr); 2698 continue; 2699 } 2700 } 2701 2702 return emitError( 2703 valueExpr->getLoc(), 2704 llvm::formatv( 2705 "expected `{0}` or `{1}` convertible expression, but got `{2}`", 2706 singleTy, rangeTy, valueExprType)); 2707 } 2708 return success(); 2709 } 2710 2711 FailureOr<ast::TupleExpr *> 2712 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, 2713 ArrayRef<StringRef> elementNames) { 2714 for (const ast::Expr *element : elements) { 2715 ast::Type eleTy = element->getType(); 2716 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { 2717 return emitError( 2718 element->getLoc(), 2719 llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); 2720 } 2721 } 2722 return ast::TupleExpr::create(ctx, loc, elements, elementNames); 2723 } 2724 2725 //===----------------------------------------------------------------------===// 2726 // Stmts 2727 2728 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, 2729 ast::Expr *rootOp) { 2730 // Check that root is an Operation. 2731 ast::Type rootType = rootOp->getType(); 2732 if (!rootType.isa<ast::OperationType>()) 2733 return emitError(rootOp->getLoc(), "expected `Op` expression"); 2734 2735 return ast::EraseStmt::create(ctx, loc, rootOp); 2736 } 2737 2738 FailureOr<ast::ReplaceStmt *> 2739 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, 2740 MutableArrayRef<ast::Expr *> replValues) { 2741 // Check that root is an Operation. 2742 ast::Type rootType = rootOp->getType(); 2743 if (!rootType.isa<ast::OperationType>()) { 2744 return emitError( 2745 rootOp->getLoc(), 2746 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2747 } 2748 2749 // If there are multiple replacement values, we implicitly convert any Op 2750 // expressions to the value form. 2751 bool shouldConvertOpToValues = replValues.size() > 1; 2752 for (ast::Expr *&replExpr : replValues) { 2753 ast::Type replType = replExpr->getType(); 2754 2755 // Check that replExpr is an Operation, Value, or ValueRange. 2756 if (replType.isa<ast::OperationType>()) { 2757 if (shouldConvertOpToValues) 2758 replExpr = convertOpToValue(replExpr); 2759 continue; 2760 } 2761 2762 if (replType != valueTy && replType != valueRangeTy) { 2763 return emitError(replExpr->getLoc(), 2764 llvm::formatv("expected `Op`, `Value` or `ValueRange` " 2765 "expression, but got `{0}`", 2766 replType)); 2767 } 2768 } 2769 2770 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); 2771 } 2772 2773 FailureOr<ast::RewriteStmt *> 2774 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, 2775 ast::CompoundStmt *rewriteBody) { 2776 // Check that root is an Operation. 2777 ast::Type rootType = rootOp->getType(); 2778 if (!rootType.isa<ast::OperationType>()) { 2779 return emitError( 2780 rootOp->getLoc(), 2781 llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); 2782 } 2783 2784 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); 2785 } 2786 2787 //===----------------------------------------------------------------------===// 2788 // Parser 2789 //===----------------------------------------------------------------------===// 2790 2791 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx, 2792 llvm::SourceMgr &sourceMgr) { 2793 Parser parser(ctx, sourceMgr); 2794 return parser.parseModule(); 2795 } 2796