1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/Tools/PDLL/AST/Context.h"
13 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
14 #include "mlir/Tools/PDLL/AST/Nodes.h"
15 #include "mlir/Tools/PDLL/AST/Types.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/SaveAndRestore.h"
20 #include "llvm/Support/ScopedPrinter.h"
21 #include <string>
22 
23 using namespace mlir;
24 using namespace mlir::pdll;
25 
26 //===----------------------------------------------------------------------===//
27 // Parser
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 class Parser {
32 public:
33   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
34       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
35         curToken(lexer.lexToken()), curDeclScope(nullptr),
36         valueTy(ast::ValueType::get(ctx)),
37         valueRangeTy(ast::ValueRangeType::get(ctx)),
38         typeTy(ast::TypeType::get(ctx)),
39         typeRangeTy(ast::TypeRangeType::get(ctx)) {}
40 
41   /// Try to parse a new module. Returns nullptr in the case of failure.
42   FailureOr<ast::Module *> parseModule();
43 
44 private:
45   /// The current context of the parser. It allows for the parser to know a bit
46   /// about the construct it is nested within during parsing. This is used
47   /// specifically to provide additional verification during parsing, e.g. to
48   /// prevent using rewrites within a match context, matcher constraints within
49   /// a rewrite section, etc.
50   enum class ParserContext {
51     /// The parser is in the global context.
52     Global,
53     /// The parser is currently within the matcher portion of a Pattern, which
54     /// is allows a terminal operation rewrite statement but no other rewrite
55     /// transformations.
56     PatternMatch,
57     /// The parser is currently within a Rewrite, which disallows calls to
58     /// constraints, requires operation expressions to have names, etc.
59     Rewrite,
60   };
61 
62   //===--------------------------------------------------------------------===//
63   // Parsing
64   //===--------------------------------------------------------------------===//
65 
66   /// Push a new decl scope onto the lexer.
67   ast::DeclScope *pushDeclScope() {
68     ast::DeclScope *newScope =
69         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
70     return (curDeclScope = newScope);
71   }
72   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
73 
74   /// Pop the last decl scope from the lexer.
75   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
76 
77   /// Parse the body of an AST module.
78   LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
79 
80   /// Try to convert the given expression to `type`. Returns failure and emits
81   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
82   /// invoked to attach notes to the emitted error diagnostic. On success,
83   /// `expr` is updated to the expression used to convert to `type`.
84   LogicalResult convertExpressionTo(
85       ast::Expr *&expr, ast::Type type,
86       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
87 
88   /// Given an operation expression, convert it to a Value or ValueRange
89   /// typed expression.
90   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
91 
92   //===--------------------------------------------------------------------===//
93   // Directives
94 
95   LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
96   LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
97 
98   //===--------------------------------------------------------------------===//
99   // Decls
100 
101   /// This structure contains the set of pattern metadata that may be parsed.
102   struct ParsedPatternMetadata {
103     Optional<uint16_t> benefit;
104     bool hasBoundedRecursion = false;
105   };
106 
107   FailureOr<ast::Decl *> parseTopLevelDecl();
108   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
109   FailureOr<ast::Decl *> parsePatternDecl();
110   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
111 
112   /// Check to see if a decl has already been defined with the given name, if
113   /// one has emit and error and return failure. Returns success otherwise.
114   LogicalResult checkDefineNamedDecl(const ast::Name &name);
115 
116   /// Try to define a variable decl with the given components, returns the
117   /// variable on success.
118   FailureOr<ast::VariableDecl *>
119   defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
120                      ast::Expr *initExpr,
121                      ArrayRef<ast::ConstraintRef> constraints);
122   FailureOr<ast::VariableDecl *>
123   defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
124                      ArrayRef<ast::ConstraintRef> constraints);
125 
126   /// Parse the constraint reference list for a variable decl.
127   LogicalResult parseVariableDeclConstraintList(
128       SmallVectorImpl<ast::ConstraintRef> &constraints);
129 
130   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
131   FailureOr<ast::Expr *> parseTypeConstraintExpr();
132 
133   /// Try to parse a single reference to a constraint. `typeConstraint` is the
134   /// location of a previously parsed type constraint for the entity that will
135   /// be constrained by the parsed constraint. `existingConstraints` are any
136   /// existing constraints that have already been parsed for the same entity
137   /// that will be constrained by this constraint.
138   FailureOr<ast::ConstraintRef>
139   parseConstraint(Optional<llvm::SMRange> &typeConstraint,
140                   ArrayRef<ast::ConstraintRef> existingConstraints);
141 
142   //===--------------------------------------------------------------------===//
143   // Exprs
144 
145   FailureOr<ast::Expr *> parseExpr();
146 
147   /// Identifier expressions.
148   FailureOr<ast::Expr *> parseAttributeExpr();
149   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, llvm::SMRange loc);
150   FailureOr<ast::Expr *> parseIdentifierExpr();
151   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
152   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
153   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
154   FailureOr<ast::Expr *> parseOperationExpr();
155   FailureOr<ast::Expr *> parseTupleExpr();
156   FailureOr<ast::Expr *> parseTypeExpr();
157   FailureOr<ast::Expr *> parseUnderscoreExpr();
158 
159   //===--------------------------------------------------------------------===//
160   // Stmts
161 
162   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
163   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
164   FailureOr<ast::EraseStmt *> parseEraseStmt();
165   FailureOr<ast::LetStmt *> parseLetStmt();
166   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
167   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
168 
169   //===--------------------------------------------------------------------===//
170   // Creation+Analysis
171   //===--------------------------------------------------------------------===//
172 
173   //===--------------------------------------------------------------------===//
174   // Decls
175 
176   /// Try to create a pattern decl with the given components, returning the
177   /// Pattern on success.
178   FailureOr<ast::PatternDecl *>
179   createPatternDecl(llvm::SMRange loc, const ast::Name *name,
180                     const ParsedPatternMetadata &metadata,
181                     ast::CompoundStmt *body);
182 
183   /// Try to create a variable decl with the given components, returning the
184   /// Variable on success.
185   FailureOr<ast::VariableDecl *>
186   createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer,
187                      ArrayRef<ast::ConstraintRef> constraints);
188 
189   /// Validate the constraints used to constraint a variable decl.
190   /// `inferredType` is the type of the variable inferred by the constraints
191   /// within the list, and is updated to the most refined type as determined by
192   /// the constraints. Returns success if the constraint list is valid, failure
193   /// otherwise.
194   LogicalResult
195   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
196                               ast::Type &inferredType);
197   /// Validate a single reference to a constraint. `inferredType` contains the
198   /// currently inferred variabled type and is refined within the type defined
199   /// by the constraint. Returns success if the constraint is valid, failure
200   /// otherwise.
201   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
202                                            ast::Type &inferredType);
203   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
204   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
205 
206   //===--------------------------------------------------------------------===//
207   // Exprs
208 
209   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(llvm::SMRange loc,
210                                                   ast::Decl *decl);
211   FailureOr<ast::DeclRefExpr *>
212   createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc,
213                            ArrayRef<ast::ConstraintRef> constraints);
214   FailureOr<ast::MemberAccessExpr *>
215   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
216                          llvm::SMRange loc);
217 
218   /// Validate the member access `name` into the given parent expression. On
219   /// success, this also returns the type of the member accessed.
220   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
221                                             StringRef name, llvm::SMRange loc);
222   FailureOr<ast::OperationExpr *>
223   createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name,
224                       MutableArrayRef<ast::Expr *> operands,
225                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
226                       MutableArrayRef<ast::Expr *> results);
227   LogicalResult
228   validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
229                             MutableArrayRef<ast::Expr *> operands);
230   LogicalResult validateOperationResults(llvm::SMRange loc,
231                                          Optional<StringRef> name,
232                                          MutableArrayRef<ast::Expr *> results);
233   LogicalResult
234   validateOperationOperandsOrResults(llvm::SMRange loc,
235                                      Optional<StringRef> name,
236                                      MutableArrayRef<ast::Expr *> values,
237                                      ast::Type singleTy, ast::Type rangeTy);
238   FailureOr<ast::TupleExpr *> createTupleExpr(llvm::SMRange loc,
239                                               ArrayRef<ast::Expr *> elements,
240                                               ArrayRef<StringRef> elementNames);
241 
242   //===--------------------------------------------------------------------===//
243   // Stmts
244 
245   FailureOr<ast::EraseStmt *> createEraseStmt(llvm::SMRange loc,
246                                               ast::Expr *rootOp);
247   FailureOr<ast::ReplaceStmt *>
248   createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
249                     MutableArrayRef<ast::Expr *> replValues);
250   FailureOr<ast::RewriteStmt *>
251   createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
252                     ast::CompoundStmt *rewriteBody);
253 
254   //===--------------------------------------------------------------------===//
255   // Lexer Utilities
256   //===--------------------------------------------------------------------===//
257 
258   /// If the current token has the specified kind, consume it and return true.
259   /// If not, return false.
260   bool consumeIf(Token::Kind kind) {
261     if (curToken.isNot(kind))
262       return false;
263     consumeToken(kind);
264     return true;
265   }
266 
267   /// Advance the current lexer onto the next token.
268   void consumeToken() {
269     assert(curToken.isNot(Token::eof, Token::error) &&
270            "shouldn't advance past EOF or errors");
271     curToken = lexer.lexToken();
272   }
273 
274   /// Advance the current lexer onto the next token, asserting what the expected
275   /// current token is. This is preferred to the above method because it leads
276   /// to more self-documenting code with better checking.
277   void consumeToken(Token::Kind kind) {
278     assert(curToken.is(kind) && "consumed an unexpected token");
279     consumeToken();
280   }
281 
282   /// Reset the lexer to the location at the given position.
283   void resetToken(llvm::SMRange tokLoc) {
284     lexer.resetPointer(tokLoc.Start.getPointer());
285     curToken = lexer.lexToken();
286   }
287 
288   /// Consume the specified token if present and return success. On failure,
289   /// output a diagnostic and return failure.
290   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
291     if (curToken.getKind() != kind)
292       return emitError(curToken.getLoc(), msg);
293     consumeToken();
294     return success();
295   }
296   LogicalResult emitError(llvm::SMRange loc, const Twine &msg) {
297     lexer.emitError(loc, msg);
298     return failure();
299   }
300   LogicalResult emitError(const Twine &msg) {
301     return emitError(curToken.getLoc(), msg);
302   }
303   LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
304                                  llvm::SMRange noteLoc, const Twine &note) {
305     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
306     return failure();
307   }
308 
309   //===--------------------------------------------------------------------===//
310   // Fields
311   //===--------------------------------------------------------------------===//
312 
313   /// The owning AST context.
314   ast::Context &ctx;
315 
316   /// The lexer of this parser.
317   Lexer lexer;
318 
319   /// The current token within the lexer.
320   Token curToken;
321 
322   /// The most recently defined decl scope.
323   ast::DeclScope *curDeclScope;
324   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
325 
326   /// The current context of the parser.
327   ParserContext parserContext = ParserContext::Global;
328 
329   /// Cached types to simplify verification and expression creation.
330   ast::Type valueTy, valueRangeTy;
331   ast::Type typeTy, typeRangeTy;
332 };
333 } // namespace
334 
335 FailureOr<ast::Module *> Parser::parseModule() {
336   llvm::SMLoc moduleLoc = curToken.getStartLoc();
337   pushDeclScope();
338 
339   // Parse the top-level decls of the module.
340   SmallVector<ast::Decl *> decls;
341   if (failed(parseModuleBody(decls)))
342     return popDeclScope(), failure();
343 
344   popDeclScope();
345   return ast::Module::create(ctx, moduleLoc, decls);
346 }
347 
348 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
349   while (curToken.isNot(Token::eof)) {
350     if (curToken.is(Token::directive)) {
351       if (failed(parseDirective(decls)))
352         return failure();
353       continue;
354     }
355 
356     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
357     if (failed(decl))
358       return failure();
359     decls.push_back(*decl);
360   }
361   return success();
362 }
363 
364 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
365   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
366                                                  valueRangeTy);
367 }
368 
369 LogicalResult Parser::convertExpressionTo(
370     ast::Expr *&expr, ast::Type type,
371     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
372   ast::Type exprType = expr->getType();
373   if (exprType == type)
374     return success();
375 
376   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
377     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
378         expr->getLoc(), llvm::formatv("unable to convert expression of type "
379                                       "`{0}` to the expected type of "
380                                       "`{1}`",
381                                       exprType, type));
382     if (noteAttachFn)
383       noteAttachFn(*diag);
384     return diag;
385   };
386 
387   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
388     // Two operation types are compatible if they have the same name, or if the
389     // expected type is more general.
390     if (auto opType = type.dyn_cast<ast::OperationType>()) {
391       if (opType.getName())
392         return emitConvertError();
393       return success();
394     }
395 
396     // An operation can always convert to a ValueRange.
397     if (type == valueRangeTy) {
398       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
399                                                      valueRangeTy);
400       return success();
401     }
402 
403     // Allow conversion to a single value by constraining the result range.
404     if (type == valueTy) {
405       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
406                                                      valueTy);
407       return success();
408     }
409     return emitConvertError();
410   }
411 
412   // FIXME: Decide how to allow/support converting a single result to multiple,
413   // and multiple to a single result. For now, we just allow Single->Range,
414   // but this isn't something really supported in the PDL dialect. We should
415   // figure out some way to support both.
416   if ((exprType == valueTy || exprType == valueRangeTy) &&
417       (type == valueTy || type == valueRangeTy))
418     return success();
419   if ((exprType == typeTy || exprType == typeRangeTy) &&
420       (type == typeTy || type == typeRangeTy))
421     return success();
422 
423   // Handle tuple types.
424   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
425     auto tupleType = type.dyn_cast<ast::TupleType>();
426     if (!tupleType || tupleType.size() != exprTupleType.size())
427       return emitConvertError();
428 
429     // Build a new tuple expression using each of the elements of the current
430     // tuple.
431     SmallVector<ast::Expr *> newExprs;
432     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
433       newExprs.push_back(ast::MemberAccessExpr::create(
434           ctx, expr->getLoc(), expr, llvm::to_string(i),
435           exprTupleType.getElementTypes()[i]));
436 
437       auto diagFn = [&](ast::Diagnostic &diag) {
438         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
439                                       i, exprTupleType));
440         if (noteAttachFn)
441           noteAttachFn(diag);
442       };
443       if (failed(convertExpressionTo(newExprs.back(),
444                                      tupleType.getElementTypes()[i], diagFn)))
445         return failure();
446     }
447     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
448                                   tupleType.getElementNames());
449     return success();
450   }
451 
452   return emitConvertError();
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // Directives
457 
458 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
459   StringRef directive = curToken.getSpelling();
460   if (directive == "#include")
461     return parseInclude(decls);
462 
463   return emitError("unknown directive `" + directive + "`");
464 }
465 
466 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
467   llvm::SMRange loc = curToken.getLoc();
468   consumeToken(Token::directive);
469 
470   // Parse the file being included.
471   if (!curToken.isString())
472     return emitError(loc,
473                      "expected string file name after `include` directive");
474   llvm::SMRange fileLoc = curToken.getLoc();
475   std::string filenameStr = curToken.getStringValue();
476   StringRef filename = filenameStr;
477   consumeToken();
478 
479   // Check the type of include. If ending with `.pdll`, this is another pdl file
480   // to be parsed along with the current module.
481   if (filename.endswith(".pdll")) {
482     if (failed(lexer.pushInclude(filename)))
483       return emitError(fileLoc,
484                        "unable to open include file `" + filename + "`");
485 
486     // If we added the include successfully, parse it into the current module.
487     // Make sure to save the current token so that we can restore it when we
488     // finish parsing the nested file.
489     Token oldToken = curToken;
490     curToken = lexer.lexToken();
491     LogicalResult result = parseModuleBody(decls);
492     curToken = oldToken;
493     return result;
494   }
495 
496   return emitError(fileLoc, "expected include filename to end with `.pdll`");
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // Decls
501 
502 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
503   FailureOr<ast::Decl *> decl;
504   switch (curToken.getKind()) {
505   case Token::kw_Pattern:
506     decl = parsePatternDecl();
507     break;
508   default:
509     return emitError("expected top-level declaration, such as a `Pattern`");
510   }
511   if (failed(decl))
512     return failure();
513 
514   // If the decl has a name, add it to the current scope.
515   if (const ast::Name *name = (*decl)->getName()) {
516     if (failed(checkDefineNamedDecl(*name)))
517       return failure();
518     curDeclScope->add(*decl);
519   }
520   return decl;
521 }
522 
523 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
524   std::string attrNameStr;
525   if (curToken.isString())
526     attrNameStr = curToken.getStringValue();
527   else if (curToken.is(Token::identifier) || curToken.isKeyword())
528     attrNameStr = curToken.getSpelling().str();
529   else
530     return emitError("expected identifier or string attribute name");
531   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
532   consumeToken();
533 
534   // Check for a value of the attribute.
535   ast::Expr *attrValue = nullptr;
536   if (consumeIf(Token::equal)) {
537     FailureOr<ast::Expr *> attrExpr = parseExpr();
538     if (failed(attrExpr))
539       return failure();
540     attrValue = *attrExpr;
541   } else {
542     // If there isn't a concrete value, create an expression representing a
543     // UnitAttr.
544     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
545   }
546 
547   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
548 }
549 
550 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
551   llvm::SMRange loc = curToken.getLoc();
552   consumeToken(Token::kw_Pattern);
553   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
554                                               ParserContext::PatternMatch);
555 
556   // Check for an optional identifier for the pattern name.
557   const ast::Name *name = nullptr;
558   if (curToken.is(Token::identifier)) {
559     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
560     consumeToken(Token::identifier);
561   }
562 
563   // Parse any pattern metadata.
564   ParsedPatternMetadata metadata;
565   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
566     return failure();
567 
568   // Parse the pattern body.
569   ast::CompoundStmt *body;
570 
571   if (curToken.isNot(Token::l_brace))
572     return emitError("expected `{` to start pattern body");
573   FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
574   if (failed(bodyResult))
575     return failure();
576   body = *bodyResult;
577 
578   // Verify the body of the pattern.
579   auto bodyIt = body->begin(), bodyE = body->end();
580   for (; bodyIt != bodyE; ++bodyIt) {
581     // Break when we've found the rewrite statement.
582     if (isa<ast::OpRewriteStmt>(*bodyIt))
583       break;
584   }
585   if (bodyIt == bodyE) {
586     return emitError(loc,
587                      "expected Pattern body to terminate with an operation "
588                      "rewrite statement, such as `erase`");
589   }
590   if (std::next(bodyIt) != bodyE) {
591     return emitError((*std::next(bodyIt))->getLoc(),
592                      "Pattern body was terminated by an operation "
593                      "rewrite statement, but found trailing statements");
594   }
595 
596   return createPatternDecl(loc, name, metadata, body);
597 }
598 
599 LogicalResult
600 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
601   Optional<llvm::SMRange> benefitLoc;
602   Optional<llvm::SMRange> hasBoundedRecursionLoc;
603 
604   do {
605     if (curToken.isNot(Token::identifier))
606       return emitError("expected pattern metadata identifier");
607     StringRef metadataStr = curToken.getSpelling();
608     llvm::SMRange metadataLoc = curToken.getLoc();
609     consumeToken(Token::identifier);
610 
611     // Parse the benefit metadata: benefit(<integer-value>)
612     if (metadataStr == "benefit") {
613       if (benefitLoc) {
614         return emitErrorAndNote(metadataLoc,
615                                 "pattern benefit has already been specified",
616                                 *benefitLoc, "see previous definition here");
617       }
618       if (failed(parseToken(Token::l_paren,
619                             "expected `(` before pattern benefit")))
620         return failure();
621 
622       uint16_t benefitValue = 0;
623       if (curToken.isNot(Token::integer))
624         return emitError("expected integral pattern benefit");
625       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
626         return emitError(
627             "expected pattern benefit to fit within a 16-bit integer");
628       consumeToken(Token::integer);
629 
630       metadata.benefit = benefitValue;
631       benefitLoc = metadataLoc;
632 
633       if (failed(
634               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
635         return failure();
636       continue;
637     }
638 
639     // Parse the bounded recursion metadata: recursion
640     if (metadataStr == "recursion") {
641       if (hasBoundedRecursionLoc) {
642         return emitErrorAndNote(
643             metadataLoc,
644             "pattern recursion metadata has already been specified",
645             *hasBoundedRecursionLoc, "see previous definition here");
646       }
647       metadata.hasBoundedRecursion = true;
648       hasBoundedRecursionLoc = metadataLoc;
649       continue;
650     }
651 
652     return emitError(metadataLoc, "unknown pattern metadata");
653   } while (consumeIf(Token::comma));
654 
655   return success();
656 }
657 
658 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
659   consumeToken(Token::less);
660 
661   FailureOr<ast::Expr *> typeExpr = parseExpr();
662   if (failed(typeExpr) ||
663       failed(parseToken(Token::greater,
664                         "expected `>` after variable type constraint")))
665     return failure();
666   return typeExpr;
667 }
668 
669 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
670   assert(curDeclScope && "defining decl outside of a decl scope");
671   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
672     return emitErrorAndNote(
673         name.getLoc(), "`" + name.getName() + "` has already been defined",
674         lastDecl->getName()->getLoc(), "see previous definition here");
675   }
676   return success();
677 }
678 
679 FailureOr<ast::VariableDecl *>
680 Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
681                            ast::Type type, ast::Expr *initExpr,
682                            ArrayRef<ast::ConstraintRef> constraints) {
683   assert(curDeclScope && "defining variable outside of decl scope");
684   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
685 
686   // If the name of the variable indicates a special variable, we don't add it
687   // to the scope. This variable is local to the definition point.
688   if (name.empty() || name == "_") {
689     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
690                                      constraints);
691   }
692   if (failed(checkDefineNamedDecl(nameDecl)))
693     return failure();
694 
695   auto *varDecl =
696       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
697   curDeclScope->add(varDecl);
698   return varDecl;
699 }
700 
701 FailureOr<ast::VariableDecl *>
702 Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
703                            ast::Type type,
704                            ArrayRef<ast::ConstraintRef> constraints) {
705   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
706                             constraints);
707 }
708 
709 LogicalResult Parser::parseVariableDeclConstraintList(
710     SmallVectorImpl<ast::ConstraintRef> &constraints) {
711   Optional<llvm::SMRange> typeConstraint;
712   auto parseSingleConstraint = [&] {
713     FailureOr<ast::ConstraintRef> constraint =
714         parseConstraint(typeConstraint, constraints);
715     if (failed(constraint))
716       return failure();
717     constraints.push_back(*constraint);
718     return success();
719   };
720 
721   // Check to see if this is a single constraint, or a list.
722   if (!consumeIf(Token::l_square))
723     return parseSingleConstraint();
724 
725   do {
726     if (failed(parseSingleConstraint()))
727       return failure();
728   } while (consumeIf(Token::comma));
729   return parseToken(Token::r_square, "expected `]` after constraint list");
730 }
731 
732 FailureOr<ast::ConstraintRef>
733 Parser::parseConstraint(Optional<llvm::SMRange> &typeConstraint,
734                         ArrayRef<ast::ConstraintRef> existingConstraints) {
735   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
736     if (typeConstraint)
737       return emitErrorAndNote(
738           curToken.getLoc(),
739           "the type of this variable has already been constrained",
740           *typeConstraint, "see previous constraint location here");
741     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
742     if (failed(constraintExpr))
743       return failure();
744     typeExpr = *constraintExpr;
745     typeConstraint = typeExpr->getLoc();
746     return success();
747   };
748 
749   llvm::SMRange loc = curToken.getLoc();
750   switch (curToken.getKind()) {
751   case Token::kw_Attr: {
752     consumeToken(Token::kw_Attr);
753 
754     // Check for a type constraint.
755     ast::Expr *typeExpr = nullptr;
756     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
757       return failure();
758     return ast::ConstraintRef(
759         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
760   }
761   case Token::kw_Op: {
762     consumeToken(Token::kw_Op);
763 
764     // Parse an optional operation name. If the name isn't provided, this refers
765     // to "any" operation.
766     FailureOr<ast::OpNameDecl *> opName =
767         parseWrappedOperationName(/*allowEmptyName=*/true);
768     if (failed(opName))
769       return failure();
770 
771     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
772                               loc);
773   }
774   case Token::kw_Type:
775     consumeToken(Token::kw_Type);
776     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
777   case Token::kw_TypeRange:
778     consumeToken(Token::kw_TypeRange);
779     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
780                               loc);
781   case Token::kw_Value: {
782     consumeToken(Token::kw_Value);
783 
784     // Check for a type constraint.
785     ast::Expr *typeExpr = nullptr;
786     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
787       return failure();
788 
789     return ast::ConstraintRef(
790         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
791   }
792   case Token::kw_ValueRange: {
793     consumeToken(Token::kw_ValueRange);
794 
795     // Check for a type constraint.
796     ast::Expr *typeExpr = nullptr;
797     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
798       return failure();
799 
800     return ast::ConstraintRef(
801         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
802   }
803   case Token::identifier: {
804     StringRef constraintName = curToken.getSpelling();
805     consumeToken(Token::identifier);
806 
807     // Lookup the referenced constraint.
808     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
809     if (!cstDecl) {
810       return emitError(loc, "unknown reference to constraint `" +
811                                 constraintName + "`");
812     }
813 
814     // Handle a reference to a proper constraint.
815     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
816       return ast::ConstraintRef(cst, loc);
817 
818     return emitErrorAndNote(
819         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
820         "see the definition of `" + constraintName + "` here");
821   }
822   default:
823     break;
824   }
825   return emitError(loc, "expected identifier constraint");
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // Exprs
830 
831 FailureOr<ast::Expr *> Parser::parseExpr() {
832   if (curToken.is(Token::underscore))
833     return parseUnderscoreExpr();
834 
835   // Parse the LHS expression.
836   FailureOr<ast::Expr *> lhsExpr;
837   switch (curToken.getKind()) {
838   case Token::kw_attr:
839     lhsExpr = parseAttributeExpr();
840     break;
841   case Token::identifier:
842     lhsExpr = parseIdentifierExpr();
843     break;
844   case Token::kw_op:
845     lhsExpr = parseOperationExpr();
846     break;
847   case Token::kw_type:
848     lhsExpr = parseTypeExpr();
849     break;
850   case Token::l_paren:
851     lhsExpr = parseTupleExpr();
852     break;
853   default:
854     return emitError("expected expression");
855   }
856   if (failed(lhsExpr))
857     return failure();
858 
859   // Check for an operator expression.
860   while (true) {
861     switch (curToken.getKind()) {
862     case Token::dot:
863       lhsExpr = parseMemberAccessExpr(*lhsExpr);
864       break;
865     default:
866       return lhsExpr;
867     }
868     if (failed(lhsExpr))
869       return failure();
870   }
871 }
872 
873 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
874   llvm::SMRange loc = curToken.getLoc();
875   consumeToken(Token::kw_attr);
876 
877   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
878   // identifier.
879   if (!consumeIf(Token::less)) {
880     resetToken(loc);
881     return parseIdentifierExpr();
882   }
883 
884   if (!curToken.isString())
885     return emitError("expected string literal containing MLIR attribute");
886   std::string attrExpr = curToken.getStringValue();
887   consumeToken();
888 
889   if (failed(
890           parseToken(Token::greater, "expected `>` after attribute literal")))
891     return failure();
892   return ast::AttributeExpr::create(ctx, loc, attrExpr);
893 }
894 
895 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
896                                                 llvm::SMRange loc) {
897   ast::Decl *decl = curDeclScope->lookup(name);
898   if (!decl)
899     return emitError(loc, "undefined reference to `" + name + "`");
900 
901   return createDeclRefExpr(loc, decl);
902 }
903 
904 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
905   StringRef name = curToken.getSpelling();
906   llvm::SMRange nameLoc = curToken.getLoc();
907   consumeToken();
908 
909   // Check to see if this is a decl ref expression that defines a variable
910   // inline.
911   if (consumeIf(Token::colon)) {
912     SmallVector<ast::ConstraintRef> constraints;
913     if (failed(parseVariableDeclConstraintList(constraints)))
914       return failure();
915     ast::Type type;
916     if (failed(validateVariableConstraints(constraints, type)))
917       return failure();
918     return createInlineVariableExpr(type, name, nameLoc, constraints);
919   }
920 
921   return parseDeclRefExpr(name, nameLoc);
922 }
923 
924 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
925   llvm::SMRange loc = curToken.getLoc();
926   consumeToken(Token::dot);
927 
928   // Parse the member name.
929   Token memberNameTok = curToken;
930   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
931       !memberNameTok.isKeyword())
932     return emitError(loc, "expected identifier or numeric member name");
933   StringRef memberName = memberNameTok.getSpelling();
934   consumeToken();
935 
936   return createMemberAccessExpr(parentExpr, memberName, loc);
937 }
938 
939 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
940   llvm::SMRange loc = curToken.getLoc();
941 
942   // Handle the case of an no operation name.
943   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
944     if (allowEmptyName)
945       return ast::OpNameDecl::create(ctx, llvm::SMRange());
946     return emitError("expected dialect namespace");
947   }
948   StringRef name = curToken.getSpelling();
949   consumeToken();
950 
951   // Otherwise, this is a literal operation name.
952   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
953     return failure();
954 
955   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
956     return emitError("expected operation name after dialect namespace");
957 
958   name = StringRef(name.data(), name.size() + 1);
959   do {
960     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
961     loc.End = curToken.getEndLoc();
962     consumeToken();
963   } while (curToken.isAny(Token::identifier, Token::dot) ||
964            curToken.isKeyword());
965   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
966 }
967 
968 FailureOr<ast::OpNameDecl *>
969 Parser::parseWrappedOperationName(bool allowEmptyName) {
970   if (!consumeIf(Token::less))
971     return ast::OpNameDecl::create(ctx, llvm::SMRange());
972 
973   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
974   if (failed(opNameDecl))
975     return failure();
976 
977   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
978     return failure();
979   return opNameDecl;
980 }
981 
982 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
983   llvm::SMRange loc = curToken.getLoc();
984   consumeToken(Token::kw_op);
985 
986   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
987   // identifier.
988   if (curToken.isNot(Token::less)) {
989     resetToken(loc);
990     return parseIdentifierExpr();
991   }
992 
993   // Parse the operation name. The name may be elided, in which case the
994   // operation refers to "any" operation(i.e. a difference between `MyOp` and
995   // `Operation*`). Operation names within a rewrite context must be named.
996   bool allowEmptyName = parserContext != ParserContext::Rewrite;
997   FailureOr<ast::OpNameDecl *> opNameDecl =
998       parseWrappedOperationName(allowEmptyName);
999   if (failed(opNameDecl))
1000     return failure();
1001 
1002   // Check for the optional list of operands.
1003   SmallVector<ast::Expr *> operands;
1004   if (consumeIf(Token::l_paren)) {
1005     do {
1006       FailureOr<ast::Expr *> operand = parseExpr();
1007       if (failed(operand))
1008         return failure();
1009       operands.push_back(*operand);
1010     } while (consumeIf(Token::comma));
1011 
1012     if (failed(parseToken(Token::r_paren,
1013                           "expected `)` after operation operand list")))
1014       return failure();
1015   }
1016 
1017   // Check for the optional list of attributes.
1018   SmallVector<ast::NamedAttributeDecl *> attributes;
1019   if (consumeIf(Token::l_brace)) {
1020     do {
1021       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1022       if (failed(decl))
1023         return failure();
1024       attributes.emplace_back(*decl);
1025     } while (consumeIf(Token::comma));
1026 
1027     if (failed(parseToken(Token::r_brace,
1028                           "expected `}` after operation attribute list")))
1029       return failure();
1030   }
1031 
1032   // Check for the optional list of result types.
1033   SmallVector<ast::Expr *> resultTypes;
1034   if (consumeIf(Token::arrow)) {
1035     if (failed(parseToken(Token::l_paren,
1036                           "expected `(` before operation result type list")))
1037       return failure();
1038 
1039     do {
1040       FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1041       if (failed(resultTypeExpr))
1042         return failure();
1043       resultTypes.push_back(*resultTypeExpr);
1044     } while (consumeIf(Token::comma));
1045 
1046     if (failed(parseToken(Token::r_paren,
1047                           "expected `)` after operation result type list")))
1048       return failure();
1049   }
1050 
1051   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1052                              resultTypes);
1053 }
1054 
1055 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1056   llvm::SMRange loc = curToken.getLoc();
1057   consumeToken(Token::l_paren);
1058 
1059   DenseMap<StringRef, llvm::SMRange> usedNames;
1060   SmallVector<StringRef> elementNames;
1061   SmallVector<ast::Expr *> elements;
1062   if (curToken.isNot(Token::r_paren)) {
1063     do {
1064       // Check for the optional element name assignment before the value.
1065       StringRef elementName;
1066       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1067         Token elementNameTok = curToken;
1068         consumeToken();
1069 
1070         // The element name is only present if followed by an `=`.
1071         if (consumeIf(Token::equal)) {
1072           elementName = elementNameTok.getSpelling();
1073 
1074           // Check to see if this name is already used.
1075           auto elementNameIt =
1076               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1077           if (!elementNameIt.second) {
1078             return emitErrorAndNote(
1079                 elementNameTok.getLoc(),
1080                 llvm::formatv("duplicate tuple element label `{0}`",
1081                               elementName),
1082                 elementNameIt.first->getSecond(),
1083                 "see previous label use here");
1084           }
1085         } else {
1086           // Otherwise, we treat this as part of an expression so reset the
1087           // lexer.
1088           resetToken(elementNameTok.getLoc());
1089         }
1090       }
1091       elementNames.push_back(elementName);
1092 
1093       // Parse the tuple element value.
1094       FailureOr<ast::Expr *> element = parseExpr();
1095       if (failed(element))
1096         return failure();
1097       elements.push_back(*element);
1098     } while (consumeIf(Token::comma));
1099   }
1100   loc.End = curToken.getEndLoc();
1101   if (failed(
1102           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1103     return failure();
1104   return createTupleExpr(loc, elements, elementNames);
1105 }
1106 
1107 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1108   llvm::SMRange loc = curToken.getLoc();
1109   consumeToken(Token::kw_type);
1110 
1111   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1112   // identifier.
1113   if (!consumeIf(Token::less)) {
1114     resetToken(loc);
1115     return parseIdentifierExpr();
1116   }
1117 
1118   if (!curToken.isString())
1119     return emitError("expected string literal containing MLIR type");
1120   std::string attrExpr = curToken.getStringValue();
1121   consumeToken();
1122 
1123   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1124     return failure();
1125   return ast::TypeExpr::create(ctx, loc, attrExpr);
1126 }
1127 
1128 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
1129   StringRef name = curToken.getSpelling();
1130   llvm::SMRange nameLoc = curToken.getLoc();
1131   consumeToken(Token::underscore);
1132 
1133   // Underscore expressions require a constraint list.
1134   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
1135     return failure();
1136 
1137   // Parse the constraints for the expression.
1138   SmallVector<ast::ConstraintRef> constraints;
1139   if (failed(parseVariableDeclConstraintList(constraints)))
1140     return failure();
1141 
1142   ast::Type type;
1143   if (failed(validateVariableConstraints(constraints, type)))
1144     return failure();
1145   return createInlineVariableExpr(type, name, nameLoc, constraints);
1146 }
1147 
1148 //===----------------------------------------------------------------------===//
1149 // Stmts
1150 
1151 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
1152   FailureOr<ast::Stmt *> stmt;
1153   switch (curToken.getKind()) {
1154   case Token::kw_erase:
1155     stmt = parseEraseStmt();
1156     break;
1157   case Token::kw_let:
1158     stmt = parseLetStmt();
1159     break;
1160   case Token::kw_replace:
1161     stmt = parseReplaceStmt();
1162     break;
1163   case Token::kw_rewrite:
1164     stmt = parseRewriteStmt();
1165     break;
1166   default:
1167     stmt = parseExpr();
1168     break;
1169   }
1170   if (failed(stmt) ||
1171       (expectTerminalSemicolon &&
1172        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
1173     return failure();
1174   return stmt;
1175 }
1176 
1177 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
1178   llvm::SMLoc startLoc = curToken.getStartLoc();
1179   consumeToken(Token::l_brace);
1180 
1181   // Push a new block scope and parse any nested statements.
1182   pushDeclScope();
1183   SmallVector<ast::Stmt *> statements;
1184   while (curToken.isNot(Token::r_brace)) {
1185     FailureOr<ast::Stmt *> statement = parseStmt();
1186     if (failed(statement))
1187       return popDeclScope(), failure();
1188     statements.push_back(*statement);
1189   }
1190   popDeclScope();
1191 
1192   // Consume the end brace.
1193   llvm::SMRange location(startLoc, curToken.getEndLoc());
1194   consumeToken(Token::r_brace);
1195 
1196   return ast::CompoundStmt::create(ctx, location, statements);
1197 }
1198 
1199 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
1200   llvm::SMRange loc = curToken.getLoc();
1201   consumeToken(Token::kw_erase);
1202 
1203   // Parse the root operation expression.
1204   FailureOr<ast::Expr *> rootOp = parseExpr();
1205   if (failed(rootOp))
1206     return failure();
1207 
1208   return createEraseStmt(loc, *rootOp);
1209 }
1210 
1211 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
1212   llvm::SMRange loc = curToken.getLoc();
1213   consumeToken(Token::kw_let);
1214 
1215   // Parse the name of the new variable.
1216   llvm::SMRange varLoc = curToken.getLoc();
1217   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
1218     // `_` is a reserved variable name.
1219     if (curToken.is(Token::underscore)) {
1220       return emitError(varLoc,
1221                        "`_` may only be used to define \"inline\" variables");
1222     }
1223     return emitError(varLoc,
1224                      "expected identifier after `let` to name a new variable");
1225   }
1226   StringRef varName = curToken.getSpelling();
1227   consumeToken();
1228 
1229   // Parse the optional set of constraints.
1230   SmallVector<ast::ConstraintRef> constraints;
1231   if (consumeIf(Token::colon) &&
1232       failed(parseVariableDeclConstraintList(constraints)))
1233     return failure();
1234 
1235   // Parse the optional initializer expression.
1236   ast::Expr *initializer = nullptr;
1237   if (consumeIf(Token::equal)) {
1238     FailureOr<ast::Expr *> initOrFailure = parseExpr();
1239     if (failed(initOrFailure))
1240       return failure();
1241     initializer = *initOrFailure;
1242 
1243     // Check that the constraints are compatible with having an initializer,
1244     // e.g. type constraints cannot be used with initializers.
1245     for (ast::ConstraintRef constraint : constraints) {
1246       LogicalResult result =
1247           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
1248               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
1249                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
1250                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
1251                   return this->emitError(
1252                       constraint.referenceLoc,
1253                       "type constraints are not permitted on variables with "
1254                       "initializers");
1255                 }
1256                 return success();
1257               })
1258               .Default(success());
1259       if (failed(result))
1260         return failure();
1261     }
1262   }
1263 
1264   FailureOr<ast::VariableDecl *> varDecl =
1265       createVariableDecl(varName, varLoc, initializer, constraints);
1266   if (failed(varDecl))
1267     return failure();
1268   return ast::LetStmt::create(ctx, loc, *varDecl);
1269 }
1270 
1271 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
1272   llvm::SMRange loc = curToken.getLoc();
1273   consumeToken(Token::kw_replace);
1274 
1275   // Parse the root operation expression.
1276   FailureOr<ast::Expr *> rootOp = parseExpr();
1277   if (failed(rootOp))
1278     return failure();
1279 
1280   if (failed(
1281           parseToken(Token::kw_with, "expected `with` after root operation")))
1282     return failure();
1283 
1284   // The replacement portion of this statement is within a rewrite context.
1285   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1286                                               ParserContext::Rewrite);
1287 
1288   // Parse the replacement values.
1289   SmallVector<ast::Expr *> replValues;
1290   if (consumeIf(Token::l_paren)) {
1291     if (consumeIf(Token::r_paren)) {
1292       return emitError(
1293           loc, "expected at least one replacement value, consider using "
1294                "`erase` if no replacement values are desired");
1295     }
1296 
1297     do {
1298       FailureOr<ast::Expr *> replExpr = parseExpr();
1299       if (failed(replExpr))
1300         return failure();
1301       replValues.emplace_back(*replExpr);
1302     } while (consumeIf(Token::comma));
1303 
1304     if (failed(parseToken(Token::r_paren,
1305                           "expected `)` after replacement values")))
1306       return failure();
1307   } else {
1308     FailureOr<ast::Expr *> replExpr = parseExpr();
1309     if (failed(replExpr))
1310       return failure();
1311     replValues.emplace_back(*replExpr);
1312   }
1313 
1314   return createReplaceStmt(loc, *rootOp, replValues);
1315 }
1316 
1317 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
1318   llvm::SMRange loc = curToken.getLoc();
1319   consumeToken(Token::kw_rewrite);
1320 
1321   // Parse the root operation.
1322   FailureOr<ast::Expr *> rootOp = parseExpr();
1323   if (failed(rootOp))
1324     return failure();
1325 
1326   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
1327     return failure();
1328 
1329   if (curToken.isNot(Token::l_brace))
1330     return emitError("expected `{` to start rewrite body");
1331 
1332   // The rewrite body of this statement is within a rewrite context.
1333   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1334                                               ParserContext::Rewrite);
1335 
1336   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
1337   if (failed(rewriteBody))
1338     return failure();
1339 
1340   return createRewriteStmt(loc, *rootOp, *rewriteBody);
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // Creation+Analysis
1345 //===----------------------------------------------------------------------===//
1346 
1347 //===----------------------------------------------------------------------===//
1348 // Decls
1349 
1350 FailureOr<ast::PatternDecl *>
1351 Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name,
1352                           const ParsedPatternMetadata &metadata,
1353                           ast::CompoundStmt *body) {
1354   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
1355                                   metadata.hasBoundedRecursion, body);
1356 }
1357 
1358 FailureOr<ast::VariableDecl *>
1359 Parser::createVariableDecl(StringRef name, llvm::SMRange loc,
1360                            ast::Expr *initializer,
1361                            ArrayRef<ast::ConstraintRef> constraints) {
1362   // The type of the variable, which is expected to be inferred by either a
1363   // constraint or an initializer expression.
1364   ast::Type type;
1365   if (failed(validateVariableConstraints(constraints, type)))
1366     return failure();
1367 
1368   if (initializer) {
1369     // Update the variable type based on the initializer, or try to convert the
1370     // initializer to the existing type.
1371     if (!type)
1372       type = initializer->getType();
1373     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
1374       type = mergedType;
1375     else if (failed(convertExpressionTo(initializer, type)))
1376       return failure();
1377 
1378     // Otherwise, if there is no initializer check that the type has already
1379     // been resolved from the constraint list.
1380   } else if (!type) {
1381     return emitErrorAndNote(
1382         loc, "unable to infer type for variable `" + name + "`", loc,
1383         "the type of a variable must be inferable from the constraint "
1384         "list or the initializer");
1385   }
1386 
1387   // Try to define a variable with the given name.
1388   FailureOr<ast::VariableDecl *> varDecl =
1389       defineVariableDecl(name, loc, type, initializer, constraints);
1390   if (failed(varDecl))
1391     return failure();
1392 
1393   return *varDecl;
1394 }
1395 
1396 LogicalResult
1397 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
1398                                     ast::Type &inferredType) {
1399   for (const ast::ConstraintRef &ref : constraints)
1400     if (failed(validateVariableConstraint(ref, inferredType)))
1401       return failure();
1402   return success();
1403 }
1404 
1405 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
1406                                                  ast::Type &inferredType) {
1407   ast::Type constraintType;
1408   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
1409     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
1410       if (failed(validateTypeConstraintExpr(typeExpr)))
1411         return failure();
1412     }
1413     constraintType = ast::AttributeType::get(ctx);
1414   } else if (const auto *cst =
1415                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
1416     constraintType = ast::OperationType::get(ctx, cst->getName());
1417   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
1418     constraintType = typeTy;
1419   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
1420     constraintType = typeRangeTy;
1421   } else if (const auto *cst =
1422                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
1423     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
1424       if (failed(validateTypeConstraintExpr(typeExpr)))
1425         return failure();
1426     }
1427     constraintType = valueTy;
1428   } else if (const auto *cst =
1429                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
1430     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
1431       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
1432         return failure();
1433     }
1434     constraintType = valueRangeTy;
1435   } else {
1436     llvm_unreachable("unknown constraint type");
1437   }
1438 
1439   // Check that the constraint type is compatible with the current inferred
1440   // type.
1441   if (!inferredType) {
1442     inferredType = constraintType;
1443   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
1444     inferredType = mergedTy;
1445   } else {
1446     return emitError(ref.referenceLoc,
1447                      llvm::formatv("constraint type `{0}` is incompatible "
1448                                    "with the previously inferred type `{1}`",
1449                                    constraintType, inferredType));
1450   }
1451   return success();
1452 }
1453 
1454 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
1455   ast::Type typeExprType = typeExpr->getType();
1456   if (typeExprType != typeTy) {
1457     return emitError(typeExpr->getLoc(),
1458                      "expected expression of `Type` in type constraint");
1459   }
1460   return success();
1461 }
1462 
1463 LogicalResult
1464 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
1465   ast::Type typeExprType = typeExpr->getType();
1466   if (typeExprType != typeRangeTy) {
1467     return emitError(typeExpr->getLoc(),
1468                      "expected expression of `TypeRange` in type constraint");
1469   }
1470   return success();
1471 }
1472 
1473 //===----------------------------------------------------------------------===//
1474 // Exprs
1475 
1476 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(llvm::SMRange loc,
1477                                                         ast::Decl *decl) {
1478   // Check the type of decl being referenced.
1479   ast::Type declType;
1480   if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
1481     declType = varDecl->getType();
1482   else
1483     return emitError(loc, "invalid reference to `" +
1484                               decl->getName()->getName() + "`");
1485 
1486   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
1487 }
1488 
1489 FailureOr<ast::DeclRefExpr *>
1490 Parser::createInlineVariableExpr(ast::Type type, StringRef name,
1491                                  llvm::SMRange loc,
1492                                  ArrayRef<ast::ConstraintRef> constraints) {
1493   FailureOr<ast::VariableDecl *> decl =
1494       defineVariableDecl(name, loc, type, constraints);
1495   if (failed(decl))
1496     return failure();
1497   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
1498 }
1499 
1500 FailureOr<ast::MemberAccessExpr *>
1501 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
1502                                llvm::SMRange loc) {
1503   // Validate the member name for the given parent expression.
1504   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
1505   if (failed(memberType))
1506     return failure();
1507 
1508   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
1509 }
1510 
1511 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
1512                                                   StringRef name,
1513                                                   llvm::SMRange loc) {
1514   ast::Type parentType = parentExpr->getType();
1515   if (parentType.isa<ast::OperationType>()) {
1516     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
1517       return valueRangeTy;
1518   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
1519     // Handle indexed results.
1520     unsigned index = 0;
1521     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
1522         index < tupleType.size()) {
1523       return tupleType.getElementTypes()[index];
1524     }
1525 
1526     // Handle named results.
1527     auto elementNames = tupleType.getElementNames();
1528     const auto *it = llvm::find(elementNames, name);
1529     if (it != elementNames.end())
1530       return tupleType.getElementTypes()[it - elementNames.begin()];
1531   }
1532   return emitError(
1533       loc,
1534       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
1535                     name, parentType));
1536 }
1537 
1538 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
1539     llvm::SMRange loc, const ast::OpNameDecl *name,
1540     MutableArrayRef<ast::Expr *> operands,
1541     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
1542     MutableArrayRef<ast::Expr *> results) {
1543   Optional<StringRef> opNameRef = name->getName();
1544 
1545   // Verify the inputs operands.
1546   if (failed(validateOperationOperands(loc, opNameRef, operands)))
1547     return failure();
1548 
1549   // Verify the attribute list.
1550   for (ast::NamedAttributeDecl *attr : attributes) {
1551     // Check for an attribute type, or a type awaiting resolution.
1552     ast::Type attrType = attr->getValue()->getType();
1553     if (!attrType.isa<ast::AttributeType>()) {
1554       return emitError(
1555           attr->getValue()->getLoc(),
1556           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
1557     }
1558   }
1559 
1560   // Verify the result types.
1561   if (failed(validateOperationResults(loc, opNameRef, results)))
1562     return failure();
1563 
1564   return ast::OperationExpr::create(ctx, loc, name, operands, results,
1565                                     attributes);
1566 }
1567 
1568 LogicalResult
1569 Parser::validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
1570                                   MutableArrayRef<ast::Expr *> operands) {
1571   return validateOperationOperandsOrResults(loc, name, operands, valueTy,
1572                                             valueRangeTy);
1573 }
1574 
1575 LogicalResult
1576 Parser::validateOperationResults(llvm::SMRange loc, Optional<StringRef> name,
1577                                  MutableArrayRef<ast::Expr *> results) {
1578   return validateOperationOperandsOrResults(loc, name, results, typeTy,
1579                                             typeRangeTy);
1580 }
1581 
1582 LogicalResult Parser::validateOperationOperandsOrResults(
1583     llvm::SMRange loc, Optional<StringRef> name,
1584     MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
1585     ast::Type rangeTy) {
1586   // All operation types accept a single range parameter.
1587   if (values.size() == 1) {
1588     if (failed(convertExpressionTo(values[0], rangeTy)))
1589       return failure();
1590     return success();
1591   }
1592 
1593   // Otherwise, accept the value groups as they have been defined and just
1594   // ensure they are one of the expected types.
1595   for (ast::Expr *&valueExpr : values) {
1596     ast::Type valueExprType = valueExpr->getType();
1597 
1598     // Check if this is one of the expected types.
1599     if (valueExprType == rangeTy || valueExprType == singleTy)
1600       continue;
1601 
1602     // If the operand is an Operation, allow converting to a Value or
1603     // ValueRange. This situations arises quite often with nested operation
1604     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
1605     if (singleTy == valueTy) {
1606       if (valueExprType.isa<ast::OperationType>()) {
1607         valueExpr = convertOpToValue(valueExpr);
1608         continue;
1609       }
1610     }
1611 
1612     return emitError(
1613         valueExpr->getLoc(),
1614         llvm::formatv(
1615             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
1616             singleTy, rangeTy, valueExprType));
1617   }
1618   return success();
1619 }
1620 
1621 FailureOr<ast::TupleExpr *>
1622 Parser::createTupleExpr(llvm::SMRange loc, ArrayRef<ast::Expr *> elements,
1623                         ArrayRef<StringRef> elementNames) {
1624   for (const ast::Expr *element : elements) {
1625     ast::Type eleTy = element->getType();
1626     if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
1627       return emitError(
1628           element->getLoc(),
1629           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
1630     }
1631   }
1632   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
1633 }
1634 
1635 //===----------------------------------------------------------------------===//
1636 // Stmts
1637 
1638 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(llvm::SMRange loc,
1639                                                     ast::Expr *rootOp) {
1640   // Check that root is an Operation.
1641   ast::Type rootType = rootOp->getType();
1642   if (!rootType.isa<ast::OperationType>())
1643     return emitError(rootOp->getLoc(), "expected `Op` expression");
1644 
1645   return ast::EraseStmt::create(ctx, loc, rootOp);
1646 }
1647 
1648 FailureOr<ast::ReplaceStmt *>
1649 Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
1650                           MutableArrayRef<ast::Expr *> replValues) {
1651   // Check that root is an Operation.
1652   ast::Type rootType = rootOp->getType();
1653   if (!rootType.isa<ast::OperationType>()) {
1654     return emitError(
1655         rootOp->getLoc(),
1656         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
1657   }
1658 
1659   // If there are multiple replacement values, we implicitly convert any Op
1660   // expressions to the value form.
1661   bool shouldConvertOpToValues = replValues.size() > 1;
1662   for (ast::Expr *&replExpr : replValues) {
1663     ast::Type replType = replExpr->getType();
1664 
1665     // Check that replExpr is an Operation, Value, or ValueRange.
1666     if (replType.isa<ast::OperationType>()) {
1667       if (shouldConvertOpToValues)
1668         replExpr = convertOpToValue(replExpr);
1669       continue;
1670     }
1671 
1672     if (replType != valueTy && replType != valueRangeTy) {
1673       return emitError(replExpr->getLoc(),
1674                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
1675                                      "expression, but got `{0}`",
1676                                      replType));
1677     }
1678   }
1679 
1680   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
1681 }
1682 
1683 FailureOr<ast::RewriteStmt *>
1684 Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
1685                           ast::CompoundStmt *rewriteBody) {
1686   // Check that root is an Operation.
1687   ast::Type rootType = rootOp->getType();
1688   if (!rootType.isa<ast::OperationType>()) {
1689     return emitError(
1690         rootOp->getLoc(),
1691         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
1692   }
1693 
1694   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
1695 }
1696 
1697 //===----------------------------------------------------------------------===//
1698 // Parser
1699 //===----------------------------------------------------------------------===//
1700 
1701 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
1702                                                  llvm::SourceMgr &sourceMgr) {
1703   Parser parser(ctx, sourceMgr);
1704   return parser.parseModule();
1705 }
1706