1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/Tools/PDLL/AST/Context.h"
13 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
14 #include "mlir/Tools/PDLL/AST/Nodes.h"
15 #include "mlir/Tools/PDLL/AST/Types.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/SaveAndRestore.h"
20 #include "llvm/Support/ScopedPrinter.h"
21 #include <string>
22 
23 using namespace mlir;
24 using namespace mlir::pdll;
25 
26 //===----------------------------------------------------------------------===//
27 // Parser
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 class Parser {
32 public:
33   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
34       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
35         curToken(lexer.lexToken()), curDeclScope(nullptr),
36         valueTy(ast::ValueType::get(ctx)),
37         valueRangeTy(ast::ValueRangeType::get(ctx)),
38         typeTy(ast::TypeType::get(ctx)),
39         typeRangeTy(ast::TypeRangeType::get(ctx)) {}
40 
41   /// Try to parse a new module. Returns nullptr in the case of failure.
42   FailureOr<ast::Module *> parseModule();
43 
44 private:
45   /// The current context of the parser. It allows for the parser to know a bit
46   /// about the construct it is nested within during parsing. This is used
47   /// specifically to provide additional verification during parsing, e.g. to
48   /// prevent using rewrites within a match context, matcher constraints within
49   /// a rewrite section, etc.
50   enum class ParserContext {
51     /// The parser is in the global context.
52     Global,
53     /// The parser is currently within a Constraint, which disallows all types
54     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
55     Constraint,
56     /// The parser is currently within the matcher portion of a Pattern, which
57     /// is allows a terminal operation rewrite statement but no other rewrite
58     /// transformations.
59     PatternMatch,
60     /// The parser is currently within a Rewrite, which disallows calls to
61     /// constraints, requires operation expressions to have names, etc.
62     Rewrite,
63   };
64 
65   //===--------------------------------------------------------------------===//
66   // Parsing
67   //===--------------------------------------------------------------------===//
68 
69   /// Push a new decl scope onto the lexer.
70   ast::DeclScope *pushDeclScope() {
71     ast::DeclScope *newScope =
72         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
73     return (curDeclScope = newScope);
74   }
75   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
76 
77   /// Pop the last decl scope from the lexer.
78   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
79 
80   /// Parse the body of an AST module.
81   LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
82 
83   /// Try to convert the given expression to `type`. Returns failure and emits
84   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
85   /// invoked to attach notes to the emitted error diagnostic. On success,
86   /// `expr` is updated to the expression used to convert to `type`.
87   LogicalResult convertExpressionTo(
88       ast::Expr *&expr, ast::Type type,
89       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
90 
91   /// Given an operation expression, convert it to a Value or ValueRange
92   /// typed expression.
93   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
94 
95   //===--------------------------------------------------------------------===//
96   // Directives
97 
98   LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
99   LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
100 
101   //===--------------------------------------------------------------------===//
102   // Decls
103 
104   /// This structure contains the set of pattern metadata that may be parsed.
105   struct ParsedPatternMetadata {
106     Optional<uint16_t> benefit;
107     bool hasBoundedRecursion = false;
108   };
109 
110   FailureOr<ast::Decl *> parseTopLevelDecl();
111   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
112 
113   /// Parse an argument variable as part of the signature of a
114   /// UserConstraintDecl or UserRewriteDecl.
115   FailureOr<ast::VariableDecl *> parseArgumentDecl();
116 
117   /// Parse a result variable as part of the signature of a UserConstraintDecl
118   /// or UserRewriteDecl.
119   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
120 
121   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
122   /// defined in a non-global context.
123   FailureOr<ast::UserConstraintDecl *>
124   parseUserConstraintDecl(bool isInline = false);
125 
126   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
127   /// non-global context, such as within a Pattern/Constraint/etc.
128   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
129 
130   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
131   /// PDLL constructs.
132   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
133       const ast::Name &name, bool isInline,
134       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
135       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
136 
137   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
138   /// defined in a non-global context.
139   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
140 
141   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
142   /// non-global context, such as within a Pattern/Rewrite/etc.
143   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
144 
145   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
146   /// PDLL constructs.
147   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
148       const ast::Name &name, bool isInline,
149       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
150       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
151 
152   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
153   /// effectively the same syntax, and only differ on slight semantics (given
154   /// the different parsing contexts).
155   template <typename T, typename ParseUserPDLLDeclFnT>
156   FailureOr<T *> parseUserConstraintOrRewriteDecl(
157       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
158       StringRef anonymousNamePrefix, bool isInline);
159 
160   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
161   /// These decls have effectively the same syntax.
162   template <typename T>
163   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
164       const ast::Name &name, bool isInline,
165       ArrayRef<ast::VariableDecl *> arguments,
166       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
167 
168   /// Parse the functional signature (i.e. the arguments and results) of a
169   /// UserConstraintDecl or UserRewriteDecl.
170   LogicalResult parseUserConstraintOrRewriteSignature(
171       SmallVectorImpl<ast::VariableDecl *> &arguments,
172       SmallVectorImpl<ast::VariableDecl *> &results,
173       ast::DeclScope *&argumentScope, ast::Type &resultType);
174 
175   /// Validate the return (which if present is specified by bodyIt) of a
176   /// UserConstraintDecl or UserRewriteDecl.
177   LogicalResult validateUserConstraintOrRewriteReturn(
178       StringRef declType, ast::CompoundStmt *body,
179       ArrayRef<ast::Stmt *>::iterator bodyIt,
180       ArrayRef<ast::Stmt *>::iterator bodyE,
181       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
182 
183   FailureOr<ast::CompoundStmt *>
184   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
185                   bool expectTerminalSemicolon = true);
186   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
187   FailureOr<ast::Decl *> parsePatternDecl();
188   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
189 
190   /// Check to see if a decl has already been defined with the given name, if
191   /// one has emit and error and return failure. Returns success otherwise.
192   LogicalResult checkDefineNamedDecl(const ast::Name &name);
193 
194   /// Try to define a variable decl with the given components, returns the
195   /// variable on success.
196   FailureOr<ast::VariableDecl *>
197   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
198                      ast::Expr *initExpr,
199                      ArrayRef<ast::ConstraintRef> constraints);
200   FailureOr<ast::VariableDecl *>
201   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
202                      ArrayRef<ast::ConstraintRef> constraints);
203 
204   /// Parse the constraint reference list for a variable decl.
205   LogicalResult parseVariableDeclConstraintList(
206       SmallVectorImpl<ast::ConstraintRef> &constraints);
207 
208   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
209   FailureOr<ast::Expr *> parseTypeConstraintExpr();
210 
211   /// Try to parse a single reference to a constraint. `typeConstraint` is the
212   /// location of a previously parsed type constraint for the entity that will
213   /// be constrained by the parsed constraint. `existingConstraints` are any
214   /// existing constraints that have already been parsed for the same entity
215   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
216   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
217   FailureOr<ast::ConstraintRef>
218   parseConstraint(Optional<SMRange> &typeConstraint,
219                   ArrayRef<ast::ConstraintRef> existingConstraints,
220                   bool allowInlineTypeConstraints);
221 
222   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
223   /// argument or result variable. The constraints for these variables do not
224   /// allow inline type constraints, and only permit a single constraint.
225   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
226 
227   //===--------------------------------------------------------------------===//
228   // Exprs
229 
230   FailureOr<ast::Expr *> parseExpr();
231 
232   /// Identifier expressions.
233   FailureOr<ast::Expr *> parseAttributeExpr();
234   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
235   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
236   FailureOr<ast::Expr *> parseIdentifierExpr();
237   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
238   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
239   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
240   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
241   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
242   FailureOr<ast::Expr *> parseOperationExpr();
243   FailureOr<ast::Expr *> parseTupleExpr();
244   FailureOr<ast::Expr *> parseTypeExpr();
245   FailureOr<ast::Expr *> parseUnderscoreExpr();
246 
247   //===--------------------------------------------------------------------===//
248   // Stmts
249 
250   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
251   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
252   FailureOr<ast::EraseStmt *> parseEraseStmt();
253   FailureOr<ast::LetStmt *> parseLetStmt();
254   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
255   FailureOr<ast::ReturnStmt *> parseReturnStmt();
256   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
257 
258   //===--------------------------------------------------------------------===//
259   // Creation+Analysis
260   //===--------------------------------------------------------------------===//
261 
262   //===--------------------------------------------------------------------===//
263   // Decls
264 
265   /// Try to extract a callable from the given AST node. Returns nullptr on
266   /// failure.
267   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
268 
269   /// Try to create a pattern decl with the given components, returning the
270   /// Pattern on success.
271   FailureOr<ast::PatternDecl *>
272   createPatternDecl(SMRange loc, const ast::Name *name,
273                     const ParsedPatternMetadata &metadata,
274                     ast::CompoundStmt *body);
275 
276   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
277   /// of results, defined as part of the signature.
278   ast::Type
279   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
280 
281   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
282   template <typename T>
283   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
284       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
285       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
286       ast::CompoundStmt *body);
287 
288   /// Try to create a variable decl with the given components, returning the
289   /// Variable on success.
290   FailureOr<ast::VariableDecl *>
291   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
292                      ArrayRef<ast::ConstraintRef> constraints);
293 
294   /// Create a variable for an argument or result defined as part of the
295   /// signature of a UserConstraintDecl/UserRewriteDecl.
296   FailureOr<ast::VariableDecl *>
297   createArgOrResultVariableDecl(StringRef name, SMRange loc,
298                                 const ast::ConstraintRef &constraint);
299 
300   /// Validate the constraints used to constraint a variable decl.
301   /// `inferredType` is the type of the variable inferred by the constraints
302   /// within the list, and is updated to the most refined type as determined by
303   /// the constraints. Returns success if the constraint list is valid, failure
304   /// otherwise.
305   LogicalResult
306   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
307                               ast::Type &inferredType);
308   /// Validate a single reference to a constraint. `inferredType` contains the
309   /// currently inferred variabled type and is refined within the type defined
310   /// by the constraint. Returns success if the constraint is valid, failure
311   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
312   /// defined constraints) may be used with the variable.
313   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
314                                            ast::Type &inferredType,
315                                            bool allowNonCoreConstraints = true);
316   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
317   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
318 
319   //===--------------------------------------------------------------------===//
320   // Exprs
321 
322   FailureOr<ast::CallExpr *>
323   createCallExpr(SMRange loc, ast::Expr *parentExpr,
324                  MutableArrayRef<ast::Expr *> arguments);
325   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
326   FailureOr<ast::DeclRefExpr *>
327   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
328                            ArrayRef<ast::ConstraintRef> constraints);
329   FailureOr<ast::MemberAccessExpr *>
330   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
331 
332   /// Validate the member access `name` into the given parent expression. On
333   /// success, this also returns the type of the member accessed.
334   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
335                                             StringRef name, SMRange loc);
336   FailureOr<ast::OperationExpr *>
337   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
338                       MutableArrayRef<ast::Expr *> operands,
339                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
340                       MutableArrayRef<ast::Expr *> results);
341   LogicalResult
342   validateOperationOperands(SMRange loc, Optional<StringRef> name,
343                             MutableArrayRef<ast::Expr *> operands);
344   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
345                                          MutableArrayRef<ast::Expr *> results);
346   LogicalResult
347   validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
348                                      MutableArrayRef<ast::Expr *> values,
349                                      ast::Type singleTy, ast::Type rangeTy);
350   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
351                                               ArrayRef<ast::Expr *> elements,
352                                               ArrayRef<StringRef> elementNames);
353 
354   //===--------------------------------------------------------------------===//
355   // Stmts
356 
357   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
358   FailureOr<ast::ReplaceStmt *>
359   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
360                     MutableArrayRef<ast::Expr *> replValues);
361   FailureOr<ast::RewriteStmt *>
362   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
363                     ast::CompoundStmt *rewriteBody);
364 
365   //===--------------------------------------------------------------------===//
366   // Lexer Utilities
367   //===--------------------------------------------------------------------===//
368 
369   /// If the current token has the specified kind, consume it and return true.
370   /// If not, return false.
371   bool consumeIf(Token::Kind kind) {
372     if (curToken.isNot(kind))
373       return false;
374     consumeToken(kind);
375     return true;
376   }
377 
378   /// Advance the current lexer onto the next token.
379   void consumeToken() {
380     assert(curToken.isNot(Token::eof, Token::error) &&
381            "shouldn't advance past EOF or errors");
382     curToken = lexer.lexToken();
383   }
384 
385   /// Advance the current lexer onto the next token, asserting what the expected
386   /// current token is. This is preferred to the above method because it leads
387   /// to more self-documenting code with better checking.
388   void consumeToken(Token::Kind kind) {
389     assert(curToken.is(kind) && "consumed an unexpected token");
390     consumeToken();
391   }
392 
393   /// Reset the lexer to the location at the given position.
394   void resetToken(SMRange tokLoc) {
395     lexer.resetPointer(tokLoc.Start.getPointer());
396     curToken = lexer.lexToken();
397   }
398 
399   /// Consume the specified token if present and return success. On failure,
400   /// output a diagnostic and return failure.
401   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
402     if (curToken.getKind() != kind)
403       return emitError(curToken.getLoc(), msg);
404     consumeToken();
405     return success();
406   }
407   LogicalResult emitError(SMRange loc, const Twine &msg) {
408     lexer.emitError(loc, msg);
409     return failure();
410   }
411   LogicalResult emitError(const Twine &msg) {
412     return emitError(curToken.getLoc(), msg);
413   }
414   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
415                                  const Twine &note) {
416     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
417     return failure();
418   }
419 
420   //===--------------------------------------------------------------------===//
421   // Fields
422   //===--------------------------------------------------------------------===//
423 
424   /// The owning AST context.
425   ast::Context &ctx;
426 
427   /// The lexer of this parser.
428   Lexer lexer;
429 
430   /// The current token within the lexer.
431   Token curToken;
432 
433   /// The most recently defined decl scope.
434   ast::DeclScope *curDeclScope;
435   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
436 
437   /// The current context of the parser.
438   ParserContext parserContext = ParserContext::Global;
439 
440   /// Cached types to simplify verification and expression creation.
441   ast::Type valueTy, valueRangeTy;
442   ast::Type typeTy, typeRangeTy;
443 
444   /// A counter used when naming anonymous constraints and rewrites.
445   unsigned anonymousDeclNameCounter = 0;
446 };
447 } // namespace
448 
449 FailureOr<ast::Module *> Parser::parseModule() {
450   SMLoc moduleLoc = curToken.getStartLoc();
451   pushDeclScope();
452 
453   // Parse the top-level decls of the module.
454   SmallVector<ast::Decl *> decls;
455   if (failed(parseModuleBody(decls)))
456     return popDeclScope(), failure();
457 
458   popDeclScope();
459   return ast::Module::create(ctx, moduleLoc, decls);
460 }
461 
462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
463   while (curToken.isNot(Token::eof)) {
464     if (curToken.is(Token::directive)) {
465       if (failed(parseDirective(decls)))
466         return failure();
467       continue;
468     }
469 
470     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
471     if (failed(decl))
472       return failure();
473     decls.push_back(*decl);
474   }
475   return success();
476 }
477 
478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
479   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
480                                                  valueRangeTy);
481 }
482 
483 LogicalResult Parser::convertExpressionTo(
484     ast::Expr *&expr, ast::Type type,
485     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
486   ast::Type exprType = expr->getType();
487   if (exprType == type)
488     return success();
489 
490   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
491     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
492         expr->getLoc(), llvm::formatv("unable to convert expression of type "
493                                       "`{0}` to the expected type of "
494                                       "`{1}`",
495                                       exprType, type));
496     if (noteAttachFn)
497       noteAttachFn(*diag);
498     return diag;
499   };
500 
501   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
502     // Two operation types are compatible if they have the same name, or if the
503     // expected type is more general.
504     if (auto opType = type.dyn_cast<ast::OperationType>()) {
505       if (opType.getName())
506         return emitConvertError();
507       return success();
508     }
509 
510     // An operation can always convert to a ValueRange.
511     if (type == valueRangeTy) {
512       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
513                                                      valueRangeTy);
514       return success();
515     }
516 
517     // Allow conversion to a single value by constraining the result range.
518     if (type == valueTy) {
519       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
520                                                      valueTy);
521       return success();
522     }
523     return emitConvertError();
524   }
525 
526   // FIXME: Decide how to allow/support converting a single result to multiple,
527   // and multiple to a single result. For now, we just allow Single->Range,
528   // but this isn't something really supported in the PDL dialect. We should
529   // figure out some way to support both.
530   if ((exprType == valueTy || exprType == valueRangeTy) &&
531       (type == valueTy || type == valueRangeTy))
532     return success();
533   if ((exprType == typeTy || exprType == typeRangeTy) &&
534       (type == typeTy || type == typeRangeTy))
535     return success();
536 
537   // Handle tuple types.
538   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
539     auto tupleType = type.dyn_cast<ast::TupleType>();
540     if (!tupleType || tupleType.size() != exprTupleType.size())
541       return emitConvertError();
542 
543     // Build a new tuple expression using each of the elements of the current
544     // tuple.
545     SmallVector<ast::Expr *> newExprs;
546     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
547       newExprs.push_back(ast::MemberAccessExpr::create(
548           ctx, expr->getLoc(), expr, llvm::to_string(i),
549           exprTupleType.getElementTypes()[i]));
550 
551       auto diagFn = [&](ast::Diagnostic &diag) {
552         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
553                                       i, exprTupleType));
554         if (noteAttachFn)
555           noteAttachFn(diag);
556       };
557       if (failed(convertExpressionTo(newExprs.back(),
558                                      tupleType.getElementTypes()[i], diagFn)))
559         return failure();
560     }
561     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
562                                   tupleType.getElementNames());
563     return success();
564   }
565 
566   return emitConvertError();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // Directives
571 
572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
573   StringRef directive = curToken.getSpelling();
574   if (directive == "#include")
575     return parseInclude(decls);
576 
577   return emitError("unknown directive `" + directive + "`");
578 }
579 
580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
581   SMRange loc = curToken.getLoc();
582   consumeToken(Token::directive);
583 
584   // Parse the file being included.
585   if (!curToken.isString())
586     return emitError(loc,
587                      "expected string file name after `include` directive");
588   SMRange fileLoc = curToken.getLoc();
589   std::string filenameStr = curToken.getStringValue();
590   StringRef filename = filenameStr;
591   consumeToken();
592 
593   // Check the type of include. If ending with `.pdll`, this is another pdl file
594   // to be parsed along with the current module.
595   if (filename.endswith(".pdll")) {
596     if (failed(lexer.pushInclude(filename)))
597       return emitError(fileLoc,
598                        "unable to open include file `" + filename + "`");
599 
600     // If we added the include successfully, parse it into the current module.
601     // Make sure to save the current token so that we can restore it when we
602     // finish parsing the nested file.
603     Token oldToken = curToken;
604     curToken = lexer.lexToken();
605     LogicalResult result = parseModuleBody(decls);
606     curToken = oldToken;
607     return result;
608   }
609 
610   return emitError(fileLoc, "expected include filename to end with `.pdll`");
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // Decls
615 
616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
617   FailureOr<ast::Decl *> decl;
618   switch (curToken.getKind()) {
619   case Token::kw_Constraint:
620     decl = parseUserConstraintDecl();
621     break;
622   case Token::kw_Pattern:
623     decl = parsePatternDecl();
624     break;
625   case Token::kw_Rewrite:
626     decl = parseUserRewriteDecl();
627     break;
628   default:
629     return emitError("expected top-level declaration, such as a `Pattern`");
630   }
631   if (failed(decl))
632     return failure();
633 
634   // If the decl has a name, add it to the current scope.
635   if (const ast::Name *name = (*decl)->getName()) {
636     if (failed(checkDefineNamedDecl(*name)))
637       return failure();
638     curDeclScope->add(*decl);
639   }
640   return decl;
641 }
642 
643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
644   std::string attrNameStr;
645   if (curToken.isString())
646     attrNameStr = curToken.getStringValue();
647   else if (curToken.is(Token::identifier) || curToken.isKeyword())
648     attrNameStr = curToken.getSpelling().str();
649   else
650     return emitError("expected identifier or string attribute name");
651   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
652   consumeToken();
653 
654   // Check for a value of the attribute.
655   ast::Expr *attrValue = nullptr;
656   if (consumeIf(Token::equal)) {
657     FailureOr<ast::Expr *> attrExpr = parseExpr();
658     if (failed(attrExpr))
659       return failure();
660     attrValue = *attrExpr;
661   } else {
662     // If there isn't a concrete value, create an expression representing a
663     // UnitAttr.
664     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
665   }
666 
667   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
668 }
669 
670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
671     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
672     bool expectTerminalSemicolon) {
673   consumeToken(Token::equal_arrow);
674 
675   // Parse the single statement of the lambda body.
676   SMLoc bodyStartLoc = curToken.getStartLoc();
677   pushDeclScope();
678   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
679   bool failedToParse =
680       failed(singleStatement) || failed(processStatementFn(*singleStatement));
681   popDeclScope();
682   if (failedToParse)
683     return failure();
684 
685   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
686   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
687 }
688 
689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
690   // Ensure that the argument is named.
691   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
692     return emitError("expected identifier argument name");
693 
694   // Parse the argument similarly to a normal variable.
695   StringRef name = curToken.getSpelling();
696   SMRange nameLoc = curToken.getLoc();
697   consumeToken();
698 
699   if (failed(
700           parseToken(Token::colon, "expected `:` before argument constraint")))
701     return failure();
702 
703   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
704   if (failed(cst))
705     return failure();
706 
707   return createArgOrResultVariableDecl(name, nameLoc, *cst);
708 }
709 
710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
711   // Check to see if this result is named.
712   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
713     // Check to see if this name actually refers to a Constraint.
714     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
715     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
716       // If yes, and this is a Rewrite, give a nice error message as non-Core
717       // constraints are not supported on Rewrite results.
718       if (parserContext == ParserContext::Rewrite) {
719         return emitError(
720             "`Rewrite` results are only permitted to use core constraints, "
721             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
722       }
723 
724       // Otherwise, parse this as an unnamed result variable.
725     } else {
726       // If it wasn't a constraint, parse the result similarly to a variable. If
727       // there is already an existing decl, we will emit an error when defining
728       // this variable later.
729       StringRef name = curToken.getSpelling();
730       SMRange nameLoc = curToken.getLoc();
731       consumeToken();
732 
733       if (failed(parseToken(Token::colon,
734                             "expected `:` before result constraint")))
735         return failure();
736 
737       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
738       if (failed(cst))
739         return failure();
740 
741       return createArgOrResultVariableDecl(name, nameLoc, *cst);
742     }
743   }
744 
745   // If it isn't named, we parse the constraint directly and create an unnamed
746   // result variable.
747   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
748   if (failed(cst))
749     return failure();
750 
751   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
752 }
753 
754 FailureOr<ast::UserConstraintDecl *>
755 Parser::parseUserConstraintDecl(bool isInline) {
756   // Constraints and rewrites have very similar formats, dispatch to a shared
757   // interface for parsing.
758   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
759       [&](auto &&...args) {
760         return this->parseUserPDLLConstraintDecl(args...);
761       },
762       ParserContext::Constraint, "constraint", isInline);
763 }
764 
765 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
766   FailureOr<ast::UserConstraintDecl *> decl =
767       parseUserConstraintDecl(/*isInline=*/true);
768   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
769     return failure();
770 
771   curDeclScope->add(*decl);
772   return decl;
773 }
774 
775 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
776     const ast::Name &name, bool isInline,
777     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
778     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
779   // Push the argument scope back onto the list, so that the body can
780   // reference arguments.
781   pushDeclScope(argumentScope);
782 
783   // Parse the body of the constraint. The body is either defined as a compound
784   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
785   ast::CompoundStmt *body;
786   if (curToken.is(Token::equal_arrow)) {
787     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
788         [&](ast::Stmt *&stmt) -> LogicalResult {
789           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
790           if (!stmtExpr) {
791             return emitError(stmt->getLoc(),
792                              "expected `Constraint` lambda body to contain a "
793                              "single expression");
794           }
795           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
796           return success();
797         },
798         /*expectTerminalSemicolon=*/!isInline);
799     if (failed(bodyResult))
800       return failure();
801     body = *bodyResult;
802   } else {
803     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
804     if (failed(bodyResult))
805       return failure();
806     body = *bodyResult;
807 
808     // Verify the structure of the body.
809     auto bodyIt = body->begin(), bodyE = body->end();
810     for (; bodyIt != bodyE; ++bodyIt)
811       if (isa<ast::ReturnStmt>(*bodyIt))
812         break;
813     if (failed(validateUserConstraintOrRewriteReturn(
814             "Constraint", body, bodyIt, bodyE, results, resultType)))
815       return failure();
816   }
817   popDeclScope();
818 
819   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
820       name, arguments, results, resultType, body);
821 }
822 
823 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
824   // Constraints and rewrites have very similar formats, dispatch to a shared
825   // interface for parsing.
826   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
827       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
828       ParserContext::Rewrite, "rewrite", isInline);
829 }
830 
831 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
832   FailureOr<ast::UserRewriteDecl *> decl =
833       parseUserRewriteDecl(/*isInline=*/true);
834   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
835     return failure();
836 
837   curDeclScope->add(*decl);
838   return decl;
839 }
840 
841 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
842     const ast::Name &name, bool isInline,
843     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
844     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
845   // Push the argument scope back onto the list, so that the body can
846   // reference arguments.
847   curDeclScope = argumentScope;
848   ast::CompoundStmt *body;
849   if (curToken.is(Token::equal_arrow)) {
850     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
851         [&](ast::Stmt *&statement) -> LogicalResult {
852           if (isa<ast::OpRewriteStmt>(statement))
853             return success();
854 
855           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
856           if (!statementExpr) {
857             return emitError(
858                 statement->getLoc(),
859                 "expected `Rewrite` lambda body to contain a single expression "
860                 "or an operation rewrite statement; such as `erase`, "
861                 "`replace`, or `rewrite`");
862           }
863           statement =
864               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
865           return success();
866         },
867         /*expectTerminalSemicolon=*/!isInline);
868     if (failed(bodyResult))
869       return failure();
870     body = *bodyResult;
871   } else {
872     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
873     if (failed(bodyResult))
874       return failure();
875     body = *bodyResult;
876   }
877   popDeclScope();
878 
879   // Verify the structure of the body.
880   auto bodyIt = body->begin(), bodyE = body->end();
881   for (; bodyIt != bodyE; ++bodyIt)
882     if (isa<ast::ReturnStmt>(*bodyIt))
883       break;
884   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
885                                                    bodyE, results, resultType)))
886     return failure();
887   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
888       name, arguments, results, resultType, body);
889 }
890 
891 template <typename T, typename ParseUserPDLLDeclFnT>
892 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
893     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
894     StringRef anonymousNamePrefix, bool isInline) {
895   SMRange loc = curToken.getLoc();
896   consumeToken();
897   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
898 
899   // Parse the name of the decl.
900   const ast::Name *name = nullptr;
901   if (curToken.isNot(Token::identifier)) {
902     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
903     // in C++, so being unnamed is fine.
904     if (!isInline)
905       return emitError("expected identifier name");
906 
907     // Create a unique anonymous name to use, as the name for this decl is not
908     // important.
909     std::string anonName =
910         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
911                       anonymousDeclNameCounter++)
912             .str();
913     name = &ast::Name::create(ctx, anonName, loc);
914   } else {
915     // If a name was provided, we can use it directly.
916     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
917     consumeToken(Token::identifier);
918   }
919 
920   // Parse the functional signature of the decl.
921   SmallVector<ast::VariableDecl *> arguments, results;
922   ast::DeclScope *argumentScope;
923   ast::Type resultType;
924   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
925                                                    argumentScope, resultType)))
926     return failure();
927 
928   // Check to see which type of constraint this is. If the constraint contains a
929   // compound body, this is a PDLL decl.
930   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
931     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
932                            resultType);
933 
934   // Otherwise, this is a native decl.
935   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
936                                                    results, resultType);
937 }
938 
939 template <typename T>
940 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
941     const ast::Name &name, bool isInline,
942     ArrayRef<ast::VariableDecl *> arguments,
943     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
944   // If followed by a string, the native code body has also been specified.
945   std::string codeStrStorage;
946   Optional<StringRef> optCodeStr;
947   if (curToken.isString()) {
948     codeStrStorage = curToken.getStringValue();
949     optCodeStr = codeStrStorage;
950     consumeToken();
951   } else if (isInline) {
952     return emitError(name.getLoc(),
953                      "external declarations must be declared in global scope");
954   }
955   if (failed(parseToken(Token::semicolon,
956                         "expected `;` after native declaration")))
957     return failure();
958   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
959 }
960 
961 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
962     SmallVectorImpl<ast::VariableDecl *> &arguments,
963     SmallVectorImpl<ast::VariableDecl *> &results,
964     ast::DeclScope *&argumentScope, ast::Type &resultType) {
965   // Parse the argument list of the decl.
966   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
967     return failure();
968 
969   argumentScope = pushDeclScope();
970   if (curToken.isNot(Token::r_paren)) {
971     do {
972       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
973       if (failed(argument))
974         return failure();
975       arguments.emplace_back(*argument);
976     } while (consumeIf(Token::comma));
977   }
978   popDeclScope();
979   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
980     return failure();
981 
982   // Parse the results of the decl.
983   pushDeclScope();
984   if (consumeIf(Token::arrow)) {
985     auto parseResultFn = [&]() -> LogicalResult {
986       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
987       if (failed(result))
988         return failure();
989       results.emplace_back(*result);
990       return success();
991     };
992 
993     // Check for a list of results.
994     if (consumeIf(Token::l_paren)) {
995       do {
996         if (failed(parseResultFn()))
997           return failure();
998       } while (consumeIf(Token::comma));
999       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1000         return failure();
1001 
1002       // Otherwise, there is only one result.
1003     } else if (failed(parseResultFn())) {
1004       return failure();
1005     }
1006   }
1007   popDeclScope();
1008 
1009   // Compute the result type of the decl.
1010   resultType = createUserConstraintRewriteResultType(results);
1011 
1012   // Verify that results are only named if there are more than one.
1013   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1014     return emitError(
1015         results.front()->getLoc(),
1016         "cannot create a single-element tuple with an element label");
1017   }
1018   return success();
1019 }
1020 
1021 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1022     StringRef declType, ast::CompoundStmt *body,
1023     ArrayRef<ast::Stmt *>::iterator bodyIt,
1024     ArrayRef<ast::Stmt *>::iterator bodyE,
1025     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1026   // Handle if a `return` was provided.
1027   if (bodyIt != bodyE) {
1028     // Emit an error if we have trailing statements after the return.
1029     if (std::next(bodyIt) != bodyE) {
1030       return emitError(
1031           (*std::next(bodyIt))->getLoc(),
1032           llvm::formatv("`return` terminated the `{0}` body, but found "
1033                         "trailing statements afterwards",
1034                         declType));
1035     }
1036 
1037     // Otherwise if a return wasn't provided, check that no results are
1038     // expected.
1039   } else if (!results.empty()) {
1040     return emitError(
1041         {body->getLoc().End, body->getLoc().End},
1042         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1043                       declType, resultType));
1044   }
1045   return success();
1046 }
1047 
1048 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1049   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1050     if (isa<ast::OpRewriteStmt>(statement))
1051       return success();
1052     return emitError(
1053         statement->getLoc(),
1054         "expected Pattern lambda body to contain a single operation "
1055         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1056   });
1057 }
1058 
1059 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1060   SMRange loc = curToken.getLoc();
1061   consumeToken(Token::kw_Pattern);
1062   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1063                                               ParserContext::PatternMatch);
1064 
1065   // Check for an optional identifier for the pattern name.
1066   const ast::Name *name = nullptr;
1067   if (curToken.is(Token::identifier)) {
1068     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1069     consumeToken(Token::identifier);
1070   }
1071 
1072   // Parse any pattern metadata.
1073   ParsedPatternMetadata metadata;
1074   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1075     return failure();
1076 
1077   // Parse the pattern body.
1078   ast::CompoundStmt *body;
1079 
1080   // Handle a lambda body.
1081   if (curToken.is(Token::equal_arrow)) {
1082     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1083     if (failed(bodyResult))
1084       return failure();
1085     body = *bodyResult;
1086   } else {
1087     if (curToken.isNot(Token::l_brace))
1088       return emitError("expected `{` or `=>` to start pattern body");
1089     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1090     if (failed(bodyResult))
1091       return failure();
1092     body = *bodyResult;
1093 
1094     // Verify the body of the pattern.
1095     auto bodyIt = body->begin(), bodyE = body->end();
1096     for (; bodyIt != bodyE; ++bodyIt) {
1097       if (isa<ast::ReturnStmt>(*bodyIt)) {
1098         return emitError((*bodyIt)->getLoc(),
1099                          "`return` statements are only permitted within a "
1100                          "`Constraint` or `Rewrite` body");
1101       }
1102       // Break when we've found the rewrite statement.
1103       if (isa<ast::OpRewriteStmt>(*bodyIt))
1104         break;
1105     }
1106     if (bodyIt == bodyE) {
1107       return emitError(loc,
1108                        "expected Pattern body to terminate with an operation "
1109                        "rewrite statement, such as `erase`");
1110     }
1111     if (std::next(bodyIt) != bodyE) {
1112       return emitError((*std::next(bodyIt))->getLoc(),
1113                        "Pattern body was terminated by an operation "
1114                        "rewrite statement, but found trailing statements");
1115     }
1116   }
1117 
1118   return createPatternDecl(loc, name, metadata, body);
1119 }
1120 
1121 LogicalResult
1122 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1123   Optional<SMRange> benefitLoc;
1124   Optional<SMRange> hasBoundedRecursionLoc;
1125 
1126   do {
1127     if (curToken.isNot(Token::identifier))
1128       return emitError("expected pattern metadata identifier");
1129     StringRef metadataStr = curToken.getSpelling();
1130     SMRange metadataLoc = curToken.getLoc();
1131     consumeToken(Token::identifier);
1132 
1133     // Parse the benefit metadata: benefit(<integer-value>)
1134     if (metadataStr == "benefit") {
1135       if (benefitLoc) {
1136         return emitErrorAndNote(metadataLoc,
1137                                 "pattern benefit has already been specified",
1138                                 *benefitLoc, "see previous definition here");
1139       }
1140       if (failed(parseToken(Token::l_paren,
1141                             "expected `(` before pattern benefit")))
1142         return failure();
1143 
1144       uint16_t benefitValue = 0;
1145       if (curToken.isNot(Token::integer))
1146         return emitError("expected integral pattern benefit");
1147       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1148         return emitError(
1149             "expected pattern benefit to fit within a 16-bit integer");
1150       consumeToken(Token::integer);
1151 
1152       metadata.benefit = benefitValue;
1153       benefitLoc = metadataLoc;
1154 
1155       if (failed(
1156               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1157         return failure();
1158       continue;
1159     }
1160 
1161     // Parse the bounded recursion metadata: recursion
1162     if (metadataStr == "recursion") {
1163       if (hasBoundedRecursionLoc) {
1164         return emitErrorAndNote(
1165             metadataLoc,
1166             "pattern recursion metadata has already been specified",
1167             *hasBoundedRecursionLoc, "see previous definition here");
1168       }
1169       metadata.hasBoundedRecursion = true;
1170       hasBoundedRecursionLoc = metadataLoc;
1171       continue;
1172     }
1173 
1174     return emitError(metadataLoc, "unknown pattern metadata");
1175   } while (consumeIf(Token::comma));
1176 
1177   return success();
1178 }
1179 
1180 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1181   consumeToken(Token::less);
1182 
1183   FailureOr<ast::Expr *> typeExpr = parseExpr();
1184   if (failed(typeExpr) ||
1185       failed(parseToken(Token::greater,
1186                         "expected `>` after variable type constraint")))
1187     return failure();
1188   return typeExpr;
1189 }
1190 
1191 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1192   assert(curDeclScope && "defining decl outside of a decl scope");
1193   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1194     return emitErrorAndNote(
1195         name.getLoc(), "`" + name.getName() + "` has already been defined",
1196         lastDecl->getName()->getLoc(), "see previous definition here");
1197   }
1198   return success();
1199 }
1200 
1201 FailureOr<ast::VariableDecl *>
1202 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1203                            ast::Expr *initExpr,
1204                            ArrayRef<ast::ConstraintRef> constraints) {
1205   assert(curDeclScope && "defining variable outside of decl scope");
1206   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1207 
1208   // If the name of the variable indicates a special variable, we don't add it
1209   // to the scope. This variable is local to the definition point.
1210   if (name.empty() || name == "_") {
1211     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1212                                      constraints);
1213   }
1214   if (failed(checkDefineNamedDecl(nameDecl)))
1215     return failure();
1216 
1217   auto *varDecl =
1218       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1219   curDeclScope->add(varDecl);
1220   return varDecl;
1221 }
1222 
1223 FailureOr<ast::VariableDecl *>
1224 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1225                            ArrayRef<ast::ConstraintRef> constraints) {
1226   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1227                             constraints);
1228 }
1229 
1230 LogicalResult Parser::parseVariableDeclConstraintList(
1231     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1232   Optional<SMRange> typeConstraint;
1233   auto parseSingleConstraint = [&] {
1234     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1235         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1236     if (failed(constraint))
1237       return failure();
1238     constraints.push_back(*constraint);
1239     return success();
1240   };
1241 
1242   // Check to see if this is a single constraint, or a list.
1243   if (!consumeIf(Token::l_square))
1244     return parseSingleConstraint();
1245 
1246   do {
1247     if (failed(parseSingleConstraint()))
1248       return failure();
1249   } while (consumeIf(Token::comma));
1250   return parseToken(Token::r_square, "expected `]` after constraint list");
1251 }
1252 
1253 FailureOr<ast::ConstraintRef>
1254 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1255                         ArrayRef<ast::ConstraintRef> existingConstraints,
1256                         bool allowInlineTypeConstraints) {
1257   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1258     if (!allowInlineTypeConstraints) {
1259       return emitError(
1260           curToken.getLoc(),
1261           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1262           "permitted on arguments or results");
1263     }
1264     if (typeConstraint)
1265       return emitErrorAndNote(
1266           curToken.getLoc(),
1267           "the type of this variable has already been constrained",
1268           *typeConstraint, "see previous constraint location here");
1269     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1270     if (failed(constraintExpr))
1271       return failure();
1272     typeExpr = *constraintExpr;
1273     typeConstraint = typeExpr->getLoc();
1274     return success();
1275   };
1276 
1277   SMRange loc = curToken.getLoc();
1278   switch (curToken.getKind()) {
1279   case Token::kw_Attr: {
1280     consumeToken(Token::kw_Attr);
1281 
1282     // Check for a type constraint.
1283     ast::Expr *typeExpr = nullptr;
1284     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1285       return failure();
1286     return ast::ConstraintRef(
1287         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1288   }
1289   case Token::kw_Op: {
1290     consumeToken(Token::kw_Op);
1291 
1292     // Parse an optional operation name. If the name isn't provided, this refers
1293     // to "any" operation.
1294     FailureOr<ast::OpNameDecl *> opName =
1295         parseWrappedOperationName(/*allowEmptyName=*/true);
1296     if (failed(opName))
1297       return failure();
1298 
1299     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1300                               loc);
1301   }
1302   case Token::kw_Type:
1303     consumeToken(Token::kw_Type);
1304     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1305   case Token::kw_TypeRange:
1306     consumeToken(Token::kw_TypeRange);
1307     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1308                               loc);
1309   case Token::kw_Value: {
1310     consumeToken(Token::kw_Value);
1311 
1312     // Check for a type constraint.
1313     ast::Expr *typeExpr = nullptr;
1314     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1315       return failure();
1316 
1317     return ast::ConstraintRef(
1318         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1319   }
1320   case Token::kw_ValueRange: {
1321     consumeToken(Token::kw_ValueRange);
1322 
1323     // Check for a type constraint.
1324     ast::Expr *typeExpr = nullptr;
1325     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1326       return failure();
1327 
1328     return ast::ConstraintRef(
1329         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1330   }
1331 
1332   case Token::kw_Constraint: {
1333     // Handle an inline constraint.
1334     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1335     if (failed(decl))
1336       return failure();
1337     return ast::ConstraintRef(*decl, loc);
1338   }
1339   case Token::identifier: {
1340     StringRef constraintName = curToken.getSpelling();
1341     consumeToken(Token::identifier);
1342 
1343     // Lookup the referenced constraint.
1344     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1345     if (!cstDecl) {
1346       return emitError(loc, "unknown reference to constraint `" +
1347                                 constraintName + "`");
1348     }
1349 
1350     // Handle a reference to a proper constraint.
1351     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1352       return ast::ConstraintRef(cst, loc);
1353 
1354     return emitErrorAndNote(
1355         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1356         "see the definition of `" + constraintName + "` here");
1357   }
1358   default:
1359     break;
1360   }
1361   return emitError(loc, "expected identifier constraint");
1362 }
1363 
1364 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1365   Optional<SMRange> typeConstraint;
1366   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1367                          /*allowInlineTypeConstraints=*/false);
1368 }
1369 
1370 //===----------------------------------------------------------------------===//
1371 // Exprs
1372 
1373 FailureOr<ast::Expr *> Parser::parseExpr() {
1374   if (curToken.is(Token::underscore))
1375     return parseUnderscoreExpr();
1376 
1377   // Parse the LHS expression.
1378   FailureOr<ast::Expr *> lhsExpr;
1379   switch (curToken.getKind()) {
1380   case Token::kw_attr:
1381     lhsExpr = parseAttributeExpr();
1382     break;
1383   case Token::kw_Constraint:
1384     lhsExpr = parseInlineConstraintLambdaExpr();
1385     break;
1386   case Token::identifier:
1387     lhsExpr = parseIdentifierExpr();
1388     break;
1389   case Token::kw_op:
1390     lhsExpr = parseOperationExpr();
1391     break;
1392   case Token::kw_Rewrite:
1393     lhsExpr = parseInlineRewriteLambdaExpr();
1394     break;
1395   case Token::kw_type:
1396     lhsExpr = parseTypeExpr();
1397     break;
1398   case Token::l_paren:
1399     lhsExpr = parseTupleExpr();
1400     break;
1401   default:
1402     return emitError("expected expression");
1403   }
1404   if (failed(lhsExpr))
1405     return failure();
1406 
1407   // Check for an operator expression.
1408   while (true) {
1409     switch (curToken.getKind()) {
1410     case Token::dot:
1411       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1412       break;
1413     case Token::l_paren:
1414       lhsExpr = parseCallExpr(*lhsExpr);
1415       break;
1416     default:
1417       return lhsExpr;
1418     }
1419     if (failed(lhsExpr))
1420       return failure();
1421   }
1422 }
1423 
1424 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1425   SMRange loc = curToken.getLoc();
1426   consumeToken(Token::kw_attr);
1427 
1428   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1429   // identifier.
1430   if (!consumeIf(Token::less)) {
1431     resetToken(loc);
1432     return parseIdentifierExpr();
1433   }
1434 
1435   if (!curToken.isString())
1436     return emitError("expected string literal containing MLIR attribute");
1437   std::string attrExpr = curToken.getStringValue();
1438   consumeToken();
1439 
1440   if (failed(
1441           parseToken(Token::greater, "expected `>` after attribute literal")))
1442     return failure();
1443   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1444 }
1445 
1446 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1447   SMRange loc = curToken.getLoc();
1448   consumeToken(Token::l_paren);
1449 
1450   // Parse the arguments of the call.
1451   SmallVector<ast::Expr *> arguments;
1452   if (curToken.isNot(Token::r_paren)) {
1453     do {
1454       FailureOr<ast::Expr *> argument = parseExpr();
1455       if (failed(argument))
1456         return failure();
1457       arguments.push_back(*argument);
1458     } while (consumeIf(Token::comma));
1459   }
1460   loc.End = curToken.getEndLoc();
1461   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1462     return failure();
1463 
1464   return createCallExpr(loc, parentExpr, arguments);
1465 }
1466 
1467 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1468   ast::Decl *decl = curDeclScope->lookup(name);
1469   if (!decl)
1470     return emitError(loc, "undefined reference to `" + name + "`");
1471 
1472   return createDeclRefExpr(loc, decl);
1473 }
1474 
1475 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1476   StringRef name = curToken.getSpelling();
1477   SMRange nameLoc = curToken.getLoc();
1478   consumeToken();
1479 
1480   // Check to see if this is a decl ref expression that defines a variable
1481   // inline.
1482   if (consumeIf(Token::colon)) {
1483     SmallVector<ast::ConstraintRef> constraints;
1484     if (failed(parseVariableDeclConstraintList(constraints)))
1485       return failure();
1486     ast::Type type;
1487     if (failed(validateVariableConstraints(constraints, type)))
1488       return failure();
1489     return createInlineVariableExpr(type, name, nameLoc, constraints);
1490   }
1491 
1492   return parseDeclRefExpr(name, nameLoc);
1493 }
1494 
1495 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1496   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1497   if (failed(decl))
1498     return failure();
1499 
1500   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1501                                   ast::ConstraintType::get(ctx));
1502 }
1503 
1504 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1505   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1506   if (failed(decl))
1507     return failure();
1508 
1509   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1510                                   ast::RewriteType::get(ctx));
1511 }
1512 
1513 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1514   SMRange loc = curToken.getLoc();
1515   consumeToken(Token::dot);
1516 
1517   // Parse the member name.
1518   Token memberNameTok = curToken;
1519   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1520       !memberNameTok.isKeyword())
1521     return emitError(loc, "expected identifier or numeric member name");
1522   StringRef memberName = memberNameTok.getSpelling();
1523   consumeToken();
1524 
1525   return createMemberAccessExpr(parentExpr, memberName, loc);
1526 }
1527 
1528 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1529   SMRange loc = curToken.getLoc();
1530 
1531   // Handle the case of an no operation name.
1532   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1533     if (allowEmptyName)
1534       return ast::OpNameDecl::create(ctx, SMRange());
1535     return emitError("expected dialect namespace");
1536   }
1537   StringRef name = curToken.getSpelling();
1538   consumeToken();
1539 
1540   // Otherwise, this is a literal operation name.
1541   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1542     return failure();
1543 
1544   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1545     return emitError("expected operation name after dialect namespace");
1546 
1547   name = StringRef(name.data(), name.size() + 1);
1548   do {
1549     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1550     loc.End = curToken.getEndLoc();
1551     consumeToken();
1552   } while (curToken.isAny(Token::identifier, Token::dot) ||
1553            curToken.isKeyword());
1554   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1555 }
1556 
1557 FailureOr<ast::OpNameDecl *>
1558 Parser::parseWrappedOperationName(bool allowEmptyName) {
1559   if (!consumeIf(Token::less))
1560     return ast::OpNameDecl::create(ctx, SMRange());
1561 
1562   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1563   if (failed(opNameDecl))
1564     return failure();
1565 
1566   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1567     return failure();
1568   return opNameDecl;
1569 }
1570 
1571 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1572   SMRange loc = curToken.getLoc();
1573   consumeToken(Token::kw_op);
1574 
1575   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1576   // identifier.
1577   if (curToken.isNot(Token::less)) {
1578     resetToken(loc);
1579     return parseIdentifierExpr();
1580   }
1581 
1582   // Parse the operation name. The name may be elided, in which case the
1583   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1584   // `Operation*`). Operation names within a rewrite context must be named.
1585   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1586   FailureOr<ast::OpNameDecl *> opNameDecl =
1587       parseWrappedOperationName(allowEmptyName);
1588   if (failed(opNameDecl))
1589     return failure();
1590 
1591   // Check for the optional list of operands.
1592   SmallVector<ast::Expr *> operands;
1593   if (consumeIf(Token::l_paren)) {
1594     do {
1595       FailureOr<ast::Expr *> operand = parseExpr();
1596       if (failed(operand))
1597         return failure();
1598       operands.push_back(*operand);
1599     } while (consumeIf(Token::comma));
1600 
1601     if (failed(parseToken(Token::r_paren,
1602                           "expected `)` after operation operand list")))
1603       return failure();
1604   }
1605 
1606   // Check for the optional list of attributes.
1607   SmallVector<ast::NamedAttributeDecl *> attributes;
1608   if (consumeIf(Token::l_brace)) {
1609     do {
1610       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1611       if (failed(decl))
1612         return failure();
1613       attributes.emplace_back(*decl);
1614     } while (consumeIf(Token::comma));
1615 
1616     if (failed(parseToken(Token::r_brace,
1617                           "expected `}` after operation attribute list")))
1618       return failure();
1619   }
1620 
1621   // Check for the optional list of result types.
1622   SmallVector<ast::Expr *> resultTypes;
1623   if (consumeIf(Token::arrow)) {
1624     if (failed(parseToken(Token::l_paren,
1625                           "expected `(` before operation result type list")))
1626       return failure();
1627 
1628     do {
1629       FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1630       if (failed(resultTypeExpr))
1631         return failure();
1632       resultTypes.push_back(*resultTypeExpr);
1633     } while (consumeIf(Token::comma));
1634 
1635     if (failed(parseToken(Token::r_paren,
1636                           "expected `)` after operation result type list")))
1637       return failure();
1638   }
1639 
1640   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1641                              resultTypes);
1642 }
1643 
1644 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1645   SMRange loc = curToken.getLoc();
1646   consumeToken(Token::l_paren);
1647 
1648   DenseMap<StringRef, SMRange> usedNames;
1649   SmallVector<StringRef> elementNames;
1650   SmallVector<ast::Expr *> elements;
1651   if (curToken.isNot(Token::r_paren)) {
1652     do {
1653       // Check for the optional element name assignment before the value.
1654       StringRef elementName;
1655       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1656         Token elementNameTok = curToken;
1657         consumeToken();
1658 
1659         // The element name is only present if followed by an `=`.
1660         if (consumeIf(Token::equal)) {
1661           elementName = elementNameTok.getSpelling();
1662 
1663           // Check to see if this name is already used.
1664           auto elementNameIt =
1665               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1666           if (!elementNameIt.second) {
1667             return emitErrorAndNote(
1668                 elementNameTok.getLoc(),
1669                 llvm::formatv("duplicate tuple element label `{0}`",
1670                               elementName),
1671                 elementNameIt.first->getSecond(),
1672                 "see previous label use here");
1673           }
1674         } else {
1675           // Otherwise, we treat this as part of an expression so reset the
1676           // lexer.
1677           resetToken(elementNameTok.getLoc());
1678         }
1679       }
1680       elementNames.push_back(elementName);
1681 
1682       // Parse the tuple element value.
1683       FailureOr<ast::Expr *> element = parseExpr();
1684       if (failed(element))
1685         return failure();
1686       elements.push_back(*element);
1687     } while (consumeIf(Token::comma));
1688   }
1689   loc.End = curToken.getEndLoc();
1690   if (failed(
1691           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1692     return failure();
1693   return createTupleExpr(loc, elements, elementNames);
1694 }
1695 
1696 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1697   SMRange loc = curToken.getLoc();
1698   consumeToken(Token::kw_type);
1699 
1700   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1701   // identifier.
1702   if (!consumeIf(Token::less)) {
1703     resetToken(loc);
1704     return parseIdentifierExpr();
1705   }
1706 
1707   if (!curToken.isString())
1708     return emitError("expected string literal containing MLIR type");
1709   std::string attrExpr = curToken.getStringValue();
1710   consumeToken();
1711 
1712   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1713     return failure();
1714   return ast::TypeExpr::create(ctx, loc, attrExpr);
1715 }
1716 
1717 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
1718   StringRef name = curToken.getSpelling();
1719   SMRange nameLoc = curToken.getLoc();
1720   consumeToken(Token::underscore);
1721 
1722   // Underscore expressions require a constraint list.
1723   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
1724     return failure();
1725 
1726   // Parse the constraints for the expression.
1727   SmallVector<ast::ConstraintRef> constraints;
1728   if (failed(parseVariableDeclConstraintList(constraints)))
1729     return failure();
1730 
1731   ast::Type type;
1732   if (failed(validateVariableConstraints(constraints, type)))
1733     return failure();
1734   return createInlineVariableExpr(type, name, nameLoc, constraints);
1735 }
1736 
1737 //===----------------------------------------------------------------------===//
1738 // Stmts
1739 
1740 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
1741   FailureOr<ast::Stmt *> stmt;
1742   switch (curToken.getKind()) {
1743   case Token::kw_erase:
1744     stmt = parseEraseStmt();
1745     break;
1746   case Token::kw_let:
1747     stmt = parseLetStmt();
1748     break;
1749   case Token::kw_replace:
1750     stmt = parseReplaceStmt();
1751     break;
1752   case Token::kw_return:
1753     stmt = parseReturnStmt();
1754     break;
1755   case Token::kw_rewrite:
1756     stmt = parseRewriteStmt();
1757     break;
1758   default:
1759     stmt = parseExpr();
1760     break;
1761   }
1762   if (failed(stmt) ||
1763       (expectTerminalSemicolon &&
1764        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
1765     return failure();
1766   return stmt;
1767 }
1768 
1769 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
1770   SMLoc startLoc = curToken.getStartLoc();
1771   consumeToken(Token::l_brace);
1772 
1773   // Push a new block scope and parse any nested statements.
1774   pushDeclScope();
1775   SmallVector<ast::Stmt *> statements;
1776   while (curToken.isNot(Token::r_brace)) {
1777     FailureOr<ast::Stmt *> statement = parseStmt();
1778     if (failed(statement))
1779       return popDeclScope(), failure();
1780     statements.push_back(*statement);
1781   }
1782   popDeclScope();
1783 
1784   // Consume the end brace.
1785   SMRange location(startLoc, curToken.getEndLoc());
1786   consumeToken(Token::r_brace);
1787 
1788   return ast::CompoundStmt::create(ctx, location, statements);
1789 }
1790 
1791 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
1792   if (parserContext == ParserContext::Constraint)
1793     return emitError("`erase` cannot be used within a Constraint");
1794   SMRange loc = curToken.getLoc();
1795   consumeToken(Token::kw_erase);
1796 
1797   // Parse the root operation expression.
1798   FailureOr<ast::Expr *> rootOp = parseExpr();
1799   if (failed(rootOp))
1800     return failure();
1801 
1802   return createEraseStmt(loc, *rootOp);
1803 }
1804 
1805 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
1806   SMRange loc = curToken.getLoc();
1807   consumeToken(Token::kw_let);
1808 
1809   // Parse the name of the new variable.
1810   SMRange varLoc = curToken.getLoc();
1811   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
1812     // `_` is a reserved variable name.
1813     if (curToken.is(Token::underscore)) {
1814       return emitError(varLoc,
1815                        "`_` may only be used to define \"inline\" variables");
1816     }
1817     return emitError(varLoc,
1818                      "expected identifier after `let` to name a new variable");
1819   }
1820   StringRef varName = curToken.getSpelling();
1821   consumeToken();
1822 
1823   // Parse the optional set of constraints.
1824   SmallVector<ast::ConstraintRef> constraints;
1825   if (consumeIf(Token::colon) &&
1826       failed(parseVariableDeclConstraintList(constraints)))
1827     return failure();
1828 
1829   // Parse the optional initializer expression.
1830   ast::Expr *initializer = nullptr;
1831   if (consumeIf(Token::equal)) {
1832     FailureOr<ast::Expr *> initOrFailure = parseExpr();
1833     if (failed(initOrFailure))
1834       return failure();
1835     initializer = *initOrFailure;
1836 
1837     // Check that the constraints are compatible with having an initializer,
1838     // e.g. type constraints cannot be used with initializers.
1839     for (ast::ConstraintRef constraint : constraints) {
1840       LogicalResult result =
1841           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
1842               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
1843                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
1844                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
1845                   return this->emitError(
1846                       constraint.referenceLoc,
1847                       "type constraints are not permitted on variables with "
1848                       "initializers");
1849                 }
1850                 return success();
1851               })
1852               .Default(success());
1853       if (failed(result))
1854         return failure();
1855     }
1856   }
1857 
1858   FailureOr<ast::VariableDecl *> varDecl =
1859       createVariableDecl(varName, varLoc, initializer, constraints);
1860   if (failed(varDecl))
1861     return failure();
1862   return ast::LetStmt::create(ctx, loc, *varDecl);
1863 }
1864 
1865 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
1866   if (parserContext == ParserContext::Constraint)
1867     return emitError("`replace` cannot be used within a Constraint");
1868   SMRange loc = curToken.getLoc();
1869   consumeToken(Token::kw_replace);
1870 
1871   // Parse the root operation expression.
1872   FailureOr<ast::Expr *> rootOp = parseExpr();
1873   if (failed(rootOp))
1874     return failure();
1875 
1876   if (failed(
1877           parseToken(Token::kw_with, "expected `with` after root operation")))
1878     return failure();
1879 
1880   // The replacement portion of this statement is within a rewrite context.
1881   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1882                                               ParserContext::Rewrite);
1883 
1884   // Parse the replacement values.
1885   SmallVector<ast::Expr *> replValues;
1886   if (consumeIf(Token::l_paren)) {
1887     if (consumeIf(Token::r_paren)) {
1888       return emitError(
1889           loc, "expected at least one replacement value, consider using "
1890                "`erase` if no replacement values are desired");
1891     }
1892 
1893     do {
1894       FailureOr<ast::Expr *> replExpr = parseExpr();
1895       if (failed(replExpr))
1896         return failure();
1897       replValues.emplace_back(*replExpr);
1898     } while (consumeIf(Token::comma));
1899 
1900     if (failed(parseToken(Token::r_paren,
1901                           "expected `)` after replacement values")))
1902       return failure();
1903   } else {
1904     FailureOr<ast::Expr *> replExpr = parseExpr();
1905     if (failed(replExpr))
1906       return failure();
1907     replValues.emplace_back(*replExpr);
1908   }
1909 
1910   return createReplaceStmt(loc, *rootOp, replValues);
1911 }
1912 
1913 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
1914   SMRange loc = curToken.getLoc();
1915   consumeToken(Token::kw_return);
1916 
1917   // Parse the result value.
1918   FailureOr<ast::Expr *> resultExpr = parseExpr();
1919   if (failed(resultExpr))
1920     return failure();
1921 
1922   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
1923 }
1924 
1925 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
1926   if (parserContext == ParserContext::Constraint)
1927     return emitError("`rewrite` cannot be used within a Constraint");
1928   SMRange loc = curToken.getLoc();
1929   consumeToken(Token::kw_rewrite);
1930 
1931   // Parse the root operation.
1932   FailureOr<ast::Expr *> rootOp = parseExpr();
1933   if (failed(rootOp))
1934     return failure();
1935 
1936   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
1937     return failure();
1938 
1939   if (curToken.isNot(Token::l_brace))
1940     return emitError("expected `{` to start rewrite body");
1941 
1942   // The rewrite body of this statement is within a rewrite context.
1943   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1944                                               ParserContext::Rewrite);
1945 
1946   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
1947   if (failed(rewriteBody))
1948     return failure();
1949 
1950   // Verify the rewrite body.
1951   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
1952     if (isa<ast::ReturnStmt>(stmt)) {
1953       return emitError(stmt->getLoc(),
1954                        "`return` statements are only permitted within a "
1955                        "`Constraint` or `Rewrite` body");
1956     }
1957   }
1958 
1959   return createRewriteStmt(loc, *rootOp, *rewriteBody);
1960 }
1961 
1962 //===----------------------------------------------------------------------===//
1963 // Creation+Analysis
1964 //===----------------------------------------------------------------------===//
1965 
1966 //===----------------------------------------------------------------------===//
1967 // Decls
1968 
1969 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
1970   // Unwrap reference expressions.
1971   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
1972     node = init->getDecl();
1973   return dyn_cast<ast::CallableDecl>(node);
1974 }
1975 
1976 FailureOr<ast::PatternDecl *>
1977 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
1978                           const ParsedPatternMetadata &metadata,
1979                           ast::CompoundStmt *body) {
1980   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
1981                                   metadata.hasBoundedRecursion, body);
1982 }
1983 
1984 ast::Type Parser::createUserConstraintRewriteResultType(
1985     ArrayRef<ast::VariableDecl *> results) {
1986   // Single result decls use the type of the single result.
1987   if (results.size() == 1)
1988     return results[0]->getType();
1989 
1990   // Multiple results use a tuple type, with the types and names grabbed from
1991   // the result variable decls.
1992   auto resultTypes = llvm::map_range(
1993       results, [&](const auto *result) { return result->getType(); });
1994   auto resultNames = llvm::map_range(
1995       results, [&](const auto *result) { return result->getName().getName(); });
1996   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
1997                              llvm::to_vector(resultNames));
1998 }
1999 
2000 template <typename T>
2001 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2002     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2003     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2004     ast::CompoundStmt *body) {
2005   if (!body->getChildren().empty()) {
2006     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2007       ast::Expr *resultExpr = retStmt->getResultExpr();
2008 
2009       // Process the result of the decl. If no explicit signature results
2010       // were provided, check for return type inference. Otherwise, check that
2011       // the return expression can be converted to the expected type.
2012       if (results.empty())
2013         resultType = resultExpr->getType();
2014       else if (failed(convertExpressionTo(resultExpr, resultType)))
2015         return failure();
2016       else
2017         retStmt->setResultExpr(resultExpr);
2018     }
2019   }
2020   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2021 }
2022 
2023 FailureOr<ast::VariableDecl *>
2024 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2025                            ArrayRef<ast::ConstraintRef> constraints) {
2026   // The type of the variable, which is expected to be inferred by either a
2027   // constraint or an initializer expression.
2028   ast::Type type;
2029   if (failed(validateVariableConstraints(constraints, type)))
2030     return failure();
2031 
2032   if (initializer) {
2033     // Update the variable type based on the initializer, or try to convert the
2034     // initializer to the existing type.
2035     if (!type)
2036       type = initializer->getType();
2037     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2038       type = mergedType;
2039     else if (failed(convertExpressionTo(initializer, type)))
2040       return failure();
2041 
2042     // Otherwise, if there is no initializer check that the type has already
2043     // been resolved from the constraint list.
2044   } else if (!type) {
2045     return emitErrorAndNote(
2046         loc, "unable to infer type for variable `" + name + "`", loc,
2047         "the type of a variable must be inferable from the constraint "
2048         "list or the initializer");
2049   }
2050 
2051   // Constraint types cannot be used when defining variables.
2052   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2053     return emitError(
2054         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2055   }
2056 
2057   // Try to define a variable with the given name.
2058   FailureOr<ast::VariableDecl *> varDecl =
2059       defineVariableDecl(name, loc, type, initializer, constraints);
2060   if (failed(varDecl))
2061     return failure();
2062 
2063   return *varDecl;
2064 }
2065 
2066 FailureOr<ast::VariableDecl *>
2067 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2068                                       const ast::ConstraintRef &constraint) {
2069   // Constraint arguments may apply more complex constraints via the arguments.
2070   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2071   ast::Type argType;
2072   if (failed(validateVariableConstraint(constraint, argType,
2073                                         allowNonCoreConstraints)))
2074     return failure();
2075   return defineVariableDecl(name, loc, argType, constraint);
2076 }
2077 
2078 LogicalResult
2079 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2080                                     ast::Type &inferredType) {
2081   for (const ast::ConstraintRef &ref : constraints)
2082     if (failed(validateVariableConstraint(ref, inferredType)))
2083       return failure();
2084   return success();
2085 }
2086 
2087 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2088                                                  ast::Type &inferredType,
2089                                                  bool allowNonCoreConstraints) {
2090   ast::Type constraintType;
2091   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2092     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2093       if (failed(validateTypeConstraintExpr(typeExpr)))
2094         return failure();
2095     }
2096     constraintType = ast::AttributeType::get(ctx);
2097   } else if (const auto *cst =
2098                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2099     constraintType = ast::OperationType::get(ctx, cst->getName());
2100   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2101     constraintType = typeTy;
2102   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2103     constraintType = typeRangeTy;
2104   } else if (const auto *cst =
2105                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2106     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2107       if (failed(validateTypeConstraintExpr(typeExpr)))
2108         return failure();
2109     }
2110     constraintType = valueTy;
2111   } else if (const auto *cst =
2112                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2113     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2114       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2115         return failure();
2116     }
2117     constraintType = valueRangeTy;
2118   } else if (const auto *cst =
2119                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2120     if (!allowNonCoreConstraints) {
2121       return emitError(ref.referenceLoc,
2122                        "`Rewrite` arguments and results are only permitted to "
2123                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2124                        "`TypeRange`, `Value`, `ValueRange`");
2125     }
2126 
2127     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2128     if (inputs.size() != 1) {
2129       return emitErrorAndNote(ref.referenceLoc,
2130                               "`Constraint`s applied via a variable constraint "
2131                               "list must take a single input, but got " +
2132                                   Twine(inputs.size()),
2133                               cst->getLoc(),
2134                               "see definition of constraint here");
2135     }
2136     constraintType = inputs.front()->getType();
2137   } else {
2138     llvm_unreachable("unknown constraint type");
2139   }
2140 
2141   // Check that the constraint type is compatible with the current inferred
2142   // type.
2143   if (!inferredType) {
2144     inferredType = constraintType;
2145   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2146     inferredType = mergedTy;
2147   } else {
2148     return emitError(ref.referenceLoc,
2149                      llvm::formatv("constraint type `{0}` is incompatible "
2150                                    "with the previously inferred type `{1}`",
2151                                    constraintType, inferredType));
2152   }
2153   return success();
2154 }
2155 
2156 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2157   ast::Type typeExprType = typeExpr->getType();
2158   if (typeExprType != typeTy) {
2159     return emitError(typeExpr->getLoc(),
2160                      "expected expression of `Type` in type constraint");
2161   }
2162   return success();
2163 }
2164 
2165 LogicalResult
2166 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2167   ast::Type typeExprType = typeExpr->getType();
2168   if (typeExprType != typeRangeTy) {
2169     return emitError(typeExpr->getLoc(),
2170                      "expected expression of `TypeRange` in type constraint");
2171   }
2172   return success();
2173 }
2174 
2175 //===----------------------------------------------------------------------===//
2176 // Exprs
2177 
2178 FailureOr<ast::CallExpr *>
2179 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2180                        MutableArrayRef<ast::Expr *> arguments) {
2181   ast::Type parentType = parentExpr->getType();
2182 
2183   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2184   if (!callableDecl) {
2185     return emitError(loc,
2186                      llvm::formatv("expected a reference to a callable "
2187                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2188                                    parentType));
2189   }
2190   if (parserContext == ParserContext::Rewrite) {
2191     if (isa<ast::UserConstraintDecl>(callableDecl))
2192       return emitError(
2193           loc, "unable to invoke `Constraint` within a rewrite section");
2194   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2195     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2196   }
2197 
2198   // Verify the arguments of the call.
2199   /// Handle size mismatch.
2200   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2201   if (callArgs.size() != arguments.size()) {
2202     return emitErrorAndNote(
2203         loc,
2204         llvm::formatv("invalid number of arguments for {0} call; expected "
2205                       "{1}, but got {2}",
2206                       callableDecl->getCallableType(), callArgs.size(),
2207                       arguments.size()),
2208         callableDecl->getLoc(),
2209         llvm::formatv("see the definition of {0} here",
2210                       callableDecl->getName()->getName()));
2211   }
2212 
2213   /// Handle argument type mismatch.
2214   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2215     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2216                                   callableDecl->getName()->getName()),
2217                     callableDecl->getLoc());
2218   };
2219   for (auto it : llvm::zip(callArgs, arguments)) {
2220     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2221                                    attachDiagFn)))
2222       return failure();
2223   }
2224 
2225   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2226                                callableDecl->getResultType());
2227 }
2228 
2229 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2230                                                         ast::Decl *decl) {
2231   // Check the type of decl being referenced.
2232   ast::Type declType;
2233   if (isa<ast::ConstraintDecl>(decl))
2234     declType = ast::ConstraintType::get(ctx);
2235   else if (isa<ast::UserRewriteDecl>(decl))
2236     declType = ast::RewriteType::get(ctx);
2237   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2238     declType = varDecl->getType();
2239   else
2240     return emitError(loc, "invalid reference to `" +
2241                               decl->getName()->getName() + "`");
2242 
2243   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2244 }
2245 
2246 FailureOr<ast::DeclRefExpr *>
2247 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2248                                  ArrayRef<ast::ConstraintRef> constraints) {
2249   FailureOr<ast::VariableDecl *> decl =
2250       defineVariableDecl(name, loc, type, constraints);
2251   if (failed(decl))
2252     return failure();
2253   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2254 }
2255 
2256 FailureOr<ast::MemberAccessExpr *>
2257 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2258                                SMRange loc) {
2259   // Validate the member name for the given parent expression.
2260   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2261   if (failed(memberType))
2262     return failure();
2263 
2264   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2265 }
2266 
2267 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2268                                                   StringRef name, SMRange loc) {
2269   ast::Type parentType = parentExpr->getType();
2270   if (parentType.isa<ast::OperationType>()) {
2271     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2272       return valueRangeTy;
2273   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2274     // Handle indexed results.
2275     unsigned index = 0;
2276     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2277         index < tupleType.size()) {
2278       return tupleType.getElementTypes()[index];
2279     }
2280 
2281     // Handle named results.
2282     auto elementNames = tupleType.getElementNames();
2283     const auto *it = llvm::find(elementNames, name);
2284     if (it != elementNames.end())
2285       return tupleType.getElementTypes()[it - elementNames.begin()];
2286   }
2287   return emitError(
2288       loc,
2289       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2290                     name, parentType));
2291 }
2292 
2293 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2294     SMRange loc, const ast::OpNameDecl *name,
2295     MutableArrayRef<ast::Expr *> operands,
2296     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2297     MutableArrayRef<ast::Expr *> results) {
2298   Optional<StringRef> opNameRef = name->getName();
2299 
2300   // Verify the inputs operands.
2301   if (failed(validateOperationOperands(loc, opNameRef, operands)))
2302     return failure();
2303 
2304   // Verify the attribute list.
2305   for (ast::NamedAttributeDecl *attr : attributes) {
2306     // Check for an attribute type, or a type awaiting resolution.
2307     ast::Type attrType = attr->getValue()->getType();
2308     if (!attrType.isa<ast::AttributeType>()) {
2309       return emitError(
2310           attr->getValue()->getLoc(),
2311           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2312     }
2313   }
2314 
2315   // Verify the result types.
2316   if (failed(validateOperationResults(loc, opNameRef, results)))
2317     return failure();
2318 
2319   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2320                                     attributes);
2321 }
2322 
2323 LogicalResult
2324 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2325                                   MutableArrayRef<ast::Expr *> operands) {
2326   return validateOperationOperandsOrResults(loc, name, operands, valueTy,
2327                                             valueRangeTy);
2328 }
2329 
2330 LogicalResult
2331 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2332                                  MutableArrayRef<ast::Expr *> results) {
2333   return validateOperationOperandsOrResults(loc, name, results, typeTy,
2334                                             typeRangeTy);
2335 }
2336 
2337 LogicalResult Parser::validateOperationOperandsOrResults(
2338     SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2339     ast::Type singleTy, ast::Type rangeTy) {
2340   // All operation types accept a single range parameter.
2341   if (values.size() == 1) {
2342     if (failed(convertExpressionTo(values[0], rangeTy)))
2343       return failure();
2344     return success();
2345   }
2346 
2347   // Otherwise, accept the value groups as they have been defined and just
2348   // ensure they are one of the expected types.
2349   for (ast::Expr *&valueExpr : values) {
2350     ast::Type valueExprType = valueExpr->getType();
2351 
2352     // Check if this is one of the expected types.
2353     if (valueExprType == rangeTy || valueExprType == singleTy)
2354       continue;
2355 
2356     // If the operand is an Operation, allow converting to a Value or
2357     // ValueRange. This situations arises quite often with nested operation
2358     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2359     if (singleTy == valueTy) {
2360       if (valueExprType.isa<ast::OperationType>()) {
2361         valueExpr = convertOpToValue(valueExpr);
2362         continue;
2363       }
2364     }
2365 
2366     return emitError(
2367         valueExpr->getLoc(),
2368         llvm::formatv(
2369             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2370             singleTy, rangeTy, valueExprType));
2371   }
2372   return success();
2373 }
2374 
2375 FailureOr<ast::TupleExpr *>
2376 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2377                         ArrayRef<StringRef> elementNames) {
2378   for (const ast::Expr *element : elements) {
2379     ast::Type eleTy = element->getType();
2380     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2381       return emitError(
2382           element->getLoc(),
2383           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2384     }
2385   }
2386   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2387 }
2388 
2389 //===----------------------------------------------------------------------===//
2390 // Stmts
2391 
2392 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2393                                                     ast::Expr *rootOp) {
2394   // Check that root is an Operation.
2395   ast::Type rootType = rootOp->getType();
2396   if (!rootType.isa<ast::OperationType>())
2397     return emitError(rootOp->getLoc(), "expected `Op` expression");
2398 
2399   return ast::EraseStmt::create(ctx, loc, rootOp);
2400 }
2401 
2402 FailureOr<ast::ReplaceStmt *>
2403 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2404                           MutableArrayRef<ast::Expr *> replValues) {
2405   // Check that root is an Operation.
2406   ast::Type rootType = rootOp->getType();
2407   if (!rootType.isa<ast::OperationType>()) {
2408     return emitError(
2409         rootOp->getLoc(),
2410         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2411   }
2412 
2413   // If there are multiple replacement values, we implicitly convert any Op
2414   // expressions to the value form.
2415   bool shouldConvertOpToValues = replValues.size() > 1;
2416   for (ast::Expr *&replExpr : replValues) {
2417     ast::Type replType = replExpr->getType();
2418 
2419     // Check that replExpr is an Operation, Value, or ValueRange.
2420     if (replType.isa<ast::OperationType>()) {
2421       if (shouldConvertOpToValues)
2422         replExpr = convertOpToValue(replExpr);
2423       continue;
2424     }
2425 
2426     if (replType != valueTy && replType != valueRangeTy) {
2427       return emitError(replExpr->getLoc(),
2428                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2429                                      "expression, but got `{0}`",
2430                                      replType));
2431     }
2432   }
2433 
2434   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2435 }
2436 
2437 FailureOr<ast::RewriteStmt *>
2438 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2439                           ast::CompoundStmt *rewriteBody) {
2440   // Check that root is an Operation.
2441   ast::Type rootType = rootOp->getType();
2442   if (!rootType.isa<ast::OperationType>()) {
2443     return emitError(
2444         rootOp->getLoc(),
2445         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2446   }
2447 
2448   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2449 }
2450 
2451 //===----------------------------------------------------------------------===//
2452 // Parser
2453 //===----------------------------------------------------------------------===//
2454 
2455 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
2456                                                  llvm::SourceMgr &sourceMgr) {
2457   Parser parser(ctx, sourceMgr);
2458   return parser.parseModule();
2459 }
2460