1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/Tools/PDLL/AST/Context.h"
13 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
14 #include "mlir/Tools/PDLL/AST/Nodes.h"
15 #include "mlir/Tools/PDLL/AST/Types.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/SaveAndRestore.h"
20 #include "llvm/Support/ScopedPrinter.h"
21 #include <string>
22 
23 using namespace mlir;
24 using namespace mlir::pdll;
25 
26 //===----------------------------------------------------------------------===//
27 // Parser
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 class Parser {
32 public:
33   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
34       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
35         curToken(lexer.lexToken()), curDeclScope(nullptr),
36         valueTy(ast::ValueType::get(ctx)),
37         valueRangeTy(ast::ValueRangeType::get(ctx)),
38         typeTy(ast::TypeType::get(ctx)),
39         typeRangeTy(ast::TypeRangeType::get(ctx)) {}
40 
41   /// Try to parse a new module. Returns nullptr in the case of failure.
42   FailureOr<ast::Module *> parseModule();
43 
44 private:
45   /// The current context of the parser. It allows for the parser to know a bit
46   /// about the construct it is nested within during parsing. This is used
47   /// specifically to provide additional verification during parsing, e.g. to
48   /// prevent using rewrites within a match context, matcher constraints within
49   /// a rewrite section, etc.
50   enum class ParserContext {
51     /// The parser is in the global context.
52     Global,
53     /// The parser is currently within a Constraint, which disallows all types
54     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
55     Constraint,
56     /// The parser is currently within the matcher portion of a Pattern, which
57     /// is allows a terminal operation rewrite statement but no other rewrite
58     /// transformations.
59     PatternMatch,
60     /// The parser is currently within a Rewrite, which disallows calls to
61     /// constraints, requires operation expressions to have names, etc.
62     Rewrite,
63   };
64 
65   //===--------------------------------------------------------------------===//
66   // Parsing
67   //===--------------------------------------------------------------------===//
68 
69   /// Push a new decl scope onto the lexer.
70   ast::DeclScope *pushDeclScope() {
71     ast::DeclScope *newScope =
72         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
73     return (curDeclScope = newScope);
74   }
75   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
76 
77   /// Pop the last decl scope from the lexer.
78   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
79 
80   /// Parse the body of an AST module.
81   LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
82 
83   /// Try to convert the given expression to `type`. Returns failure and emits
84   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
85   /// invoked to attach notes to the emitted error diagnostic. On success,
86   /// `expr` is updated to the expression used to convert to `type`.
87   LogicalResult convertExpressionTo(
88       ast::Expr *&expr, ast::Type type,
89       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
90 
91   /// Given an operation expression, convert it to a Value or ValueRange
92   /// typed expression.
93   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
94 
95   //===--------------------------------------------------------------------===//
96   // Directives
97 
98   LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
99   LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
100 
101   //===--------------------------------------------------------------------===//
102   // Decls
103 
104   /// This structure contains the set of pattern metadata that may be parsed.
105   struct ParsedPatternMetadata {
106     Optional<uint16_t> benefit;
107     bool hasBoundedRecursion = false;
108   };
109 
110   FailureOr<ast::Decl *> parseTopLevelDecl();
111   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
112 
113   /// Parse an argument variable as part of the signature of a
114   /// UserConstraintDecl or UserRewriteDecl.
115   FailureOr<ast::VariableDecl *> parseArgumentDecl();
116 
117   /// Parse a result variable as part of the signature of a UserConstraintDecl
118   /// or UserRewriteDecl.
119   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
120 
121   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
122   /// defined in a non-global context.
123   FailureOr<ast::UserConstraintDecl *>
124   parseUserConstraintDecl(bool isInline = false);
125 
126   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
127   /// non-global context, such as within a Pattern/Constraint/etc.
128   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
129 
130   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
131   /// PDLL constructs.
132   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
133       const ast::Name &name, bool isInline,
134       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
135       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
136 
137   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
138   /// defined in a non-global context.
139   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
140 
141   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
142   /// non-global context, such as within a Pattern/Rewrite/etc.
143   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
144 
145   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
146   /// PDLL constructs.
147   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
148       const ast::Name &name, bool isInline,
149       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
150       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
151 
152   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
153   /// effectively the same syntax, and only differ on slight semantics (given
154   /// the different parsing contexts).
155   template <typename T, typename ParseUserPDLLDeclFnT>
156   FailureOr<T *> parseUserConstraintOrRewriteDecl(
157       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
158       StringRef anonymousNamePrefix, bool isInline);
159 
160   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
161   /// These decls have effectively the same syntax.
162   template <typename T>
163   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
164       const ast::Name &name, bool isInline,
165       ArrayRef<ast::VariableDecl *> arguments,
166       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
167 
168   /// Parse the functional signature (i.e. the arguments and results) of a
169   /// UserConstraintDecl or UserRewriteDecl.
170   LogicalResult parseUserConstraintOrRewriteSignature(
171       SmallVectorImpl<ast::VariableDecl *> &arguments,
172       SmallVectorImpl<ast::VariableDecl *> &results,
173       ast::DeclScope *&argumentScope, ast::Type &resultType);
174 
175   /// Validate the return (which if present is specified by bodyIt) of a
176   /// UserConstraintDecl or UserRewriteDecl.
177   LogicalResult validateUserConstraintOrRewriteReturn(
178       StringRef declType, ast::CompoundStmt *body,
179       ArrayRef<ast::Stmt *>::iterator bodyIt,
180       ArrayRef<ast::Stmt *>::iterator bodyE,
181       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
182 
183   FailureOr<ast::CompoundStmt *>
184   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
185                   bool expectTerminalSemicolon = true);
186   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
187   FailureOr<ast::Decl *> parsePatternDecl();
188   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
189 
190   /// Check to see if a decl has already been defined with the given name, if
191   /// one has emit and error and return failure. Returns success otherwise.
192   LogicalResult checkDefineNamedDecl(const ast::Name &name);
193 
194   /// Try to define a variable decl with the given components, returns the
195   /// variable on success.
196   FailureOr<ast::VariableDecl *>
197   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
198                      ast::Expr *initExpr,
199                      ArrayRef<ast::ConstraintRef> constraints);
200   FailureOr<ast::VariableDecl *>
201   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
202                      ArrayRef<ast::ConstraintRef> constraints);
203 
204   /// Parse the constraint reference list for a variable decl.
205   LogicalResult parseVariableDeclConstraintList(
206       SmallVectorImpl<ast::ConstraintRef> &constraints);
207 
208   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
209   FailureOr<ast::Expr *> parseTypeConstraintExpr();
210 
211   /// Try to parse a single reference to a constraint. `typeConstraint` is the
212   /// location of a previously parsed type constraint for the entity that will
213   /// be constrained by the parsed constraint. `existingConstraints` are any
214   /// existing constraints that have already been parsed for the same entity
215   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
216   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
217   FailureOr<ast::ConstraintRef>
218   parseConstraint(Optional<SMRange> &typeConstraint,
219                   ArrayRef<ast::ConstraintRef> existingConstraints,
220                   bool allowInlineTypeConstraints);
221 
222   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
223   /// argument or result variable. The constraints for these variables do not
224   /// allow inline type constraints, and only permit a single constraint.
225   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
226 
227   //===--------------------------------------------------------------------===//
228   // Exprs
229 
230   FailureOr<ast::Expr *> parseExpr();
231 
232   /// Identifier expressions.
233   FailureOr<ast::Expr *> parseAttributeExpr();
234   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
235   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
236   FailureOr<ast::Expr *> parseIdentifierExpr();
237   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
238   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
239   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
240   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
241   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
242   FailureOr<ast::Expr *> parseOperationExpr();
243   FailureOr<ast::Expr *> parseTupleExpr();
244   FailureOr<ast::Expr *> parseTypeExpr();
245   FailureOr<ast::Expr *> parseUnderscoreExpr();
246 
247   //===--------------------------------------------------------------------===//
248   // Stmts
249 
250   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
251   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
252   FailureOr<ast::EraseStmt *> parseEraseStmt();
253   FailureOr<ast::LetStmt *> parseLetStmt();
254   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
255   FailureOr<ast::ReturnStmt *> parseReturnStmt();
256   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
257 
258   //===--------------------------------------------------------------------===//
259   // Creation+Analysis
260   //===--------------------------------------------------------------------===//
261 
262   //===--------------------------------------------------------------------===//
263   // Decls
264 
265   /// Try to extract a callable from the given AST node. Returns nullptr on
266   /// failure.
267   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
268 
269   /// Try to create a pattern decl with the given components, returning the
270   /// Pattern on success.
271   FailureOr<ast::PatternDecl *>
272   createPatternDecl(SMRange loc, const ast::Name *name,
273                     const ParsedPatternMetadata &metadata,
274                     ast::CompoundStmt *body);
275 
276   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
277   /// of results, defined as part of the signature.
278   ast::Type
279   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
280 
281   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
282   template <typename T>
283   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
284       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
285       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
286       ast::CompoundStmt *body);
287 
288   /// Try to create a variable decl with the given components, returning the
289   /// Variable on success.
290   FailureOr<ast::VariableDecl *>
291   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
292                      ArrayRef<ast::ConstraintRef> constraints);
293 
294   /// Create a variable for an argument or result defined as part of the
295   /// signature of a UserConstraintDecl/UserRewriteDecl.
296   FailureOr<ast::VariableDecl *>
297   createArgOrResultVariableDecl(StringRef name, SMRange loc,
298                                 const ast::ConstraintRef &constraint);
299 
300   /// Validate the constraints used to constraint a variable decl.
301   /// `inferredType` is the type of the variable inferred by the constraints
302   /// within the list, and is updated to the most refined type as determined by
303   /// the constraints. Returns success if the constraint list is valid, failure
304   /// otherwise.
305   LogicalResult
306   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
307                               ast::Type &inferredType);
308   /// Validate a single reference to a constraint. `inferredType` contains the
309   /// currently inferred variabled type and is refined within the type defined
310   /// by the constraint. Returns success if the constraint is valid, failure
311   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
312   /// defined constraints) may be used with the variable.
313   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
314                                            ast::Type &inferredType,
315                                            bool allowNonCoreConstraints = true);
316   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
317   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
318 
319   //===--------------------------------------------------------------------===//
320   // Exprs
321 
322   FailureOr<ast::CallExpr *>
323   createCallExpr(SMRange loc, ast::Expr *parentExpr,
324                  MutableArrayRef<ast::Expr *> arguments);
325   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
326   FailureOr<ast::DeclRefExpr *>
327   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
328                            ArrayRef<ast::ConstraintRef> constraints);
329   FailureOr<ast::MemberAccessExpr *>
330   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
331 
332   /// Validate the member access `name` into the given parent expression. On
333   /// success, this also returns the type of the member accessed.
334   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
335                                             StringRef name, SMRange loc);
336   FailureOr<ast::OperationExpr *>
337   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
338                       MutableArrayRef<ast::Expr *> operands,
339                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
340                       MutableArrayRef<ast::Expr *> results);
341   LogicalResult
342   validateOperationOperands(SMRange loc, Optional<StringRef> name,
343                             MutableArrayRef<ast::Expr *> operands);
344   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
345                                          MutableArrayRef<ast::Expr *> results);
346   LogicalResult
347   validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
348                                      MutableArrayRef<ast::Expr *> values,
349                                      ast::Type singleTy, ast::Type rangeTy);
350   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
351                                               ArrayRef<ast::Expr *> elements,
352                                               ArrayRef<StringRef> elementNames);
353 
354   //===--------------------------------------------------------------------===//
355   // Stmts
356 
357   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
358   FailureOr<ast::ReplaceStmt *>
359   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
360                     MutableArrayRef<ast::Expr *> replValues);
361   FailureOr<ast::RewriteStmt *>
362   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
363                     ast::CompoundStmt *rewriteBody);
364 
365   //===--------------------------------------------------------------------===//
366   // Lexer Utilities
367   //===--------------------------------------------------------------------===//
368 
369   /// If the current token has the specified kind, consume it and return true.
370   /// If not, return false.
371   bool consumeIf(Token::Kind kind) {
372     if (curToken.isNot(kind))
373       return false;
374     consumeToken(kind);
375     return true;
376   }
377 
378   /// Advance the current lexer onto the next token.
379   void consumeToken() {
380     assert(curToken.isNot(Token::eof, Token::error) &&
381            "shouldn't advance past EOF or errors");
382     curToken = lexer.lexToken();
383   }
384 
385   /// Advance the current lexer onto the next token, asserting what the expected
386   /// current token is. This is preferred to the above method because it leads
387   /// to more self-documenting code with better checking.
388   void consumeToken(Token::Kind kind) {
389     assert(curToken.is(kind) && "consumed an unexpected token");
390     consumeToken();
391   }
392 
393   /// Reset the lexer to the location at the given position.
394   void resetToken(SMRange tokLoc) {
395     lexer.resetPointer(tokLoc.Start.getPointer());
396     curToken = lexer.lexToken();
397   }
398 
399   /// Consume the specified token if present and return success. On failure,
400   /// output a diagnostic and return failure.
401   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
402     if (curToken.getKind() != kind)
403       return emitError(curToken.getLoc(), msg);
404     consumeToken();
405     return success();
406   }
407   LogicalResult emitError(SMRange loc, const Twine &msg) {
408     lexer.emitError(loc, msg);
409     return failure();
410   }
411   LogicalResult emitError(const Twine &msg) {
412     return emitError(curToken.getLoc(), msg);
413   }
414   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
415                                  const Twine &note) {
416     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
417     return failure();
418   }
419 
420   //===--------------------------------------------------------------------===//
421   // Fields
422   //===--------------------------------------------------------------------===//
423 
424   /// The owning AST context.
425   ast::Context &ctx;
426 
427   /// The lexer of this parser.
428   Lexer lexer;
429 
430   /// The current token within the lexer.
431   Token curToken;
432 
433   /// The most recently defined decl scope.
434   ast::DeclScope *curDeclScope;
435   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
436 
437   /// The current context of the parser.
438   ParserContext parserContext = ParserContext::Global;
439 
440   /// Cached types to simplify verification and expression creation.
441   ast::Type valueTy, valueRangeTy;
442   ast::Type typeTy, typeRangeTy;
443 
444   /// A counter used when naming anonymous constraints and rewrites.
445   unsigned anonymousDeclNameCounter = 0;
446 };
447 } // namespace
448 
449 FailureOr<ast::Module *> Parser::parseModule() {
450   SMLoc moduleLoc = curToken.getStartLoc();
451   pushDeclScope();
452 
453   // Parse the top-level decls of the module.
454   SmallVector<ast::Decl *> decls;
455   if (failed(parseModuleBody(decls)))
456     return popDeclScope(), failure();
457 
458   popDeclScope();
459   return ast::Module::create(ctx, moduleLoc, decls);
460 }
461 
462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
463   while (curToken.isNot(Token::eof)) {
464     if (curToken.is(Token::directive)) {
465       if (failed(parseDirective(decls)))
466         return failure();
467       continue;
468     }
469 
470     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
471     if (failed(decl))
472       return failure();
473     decls.push_back(*decl);
474   }
475   return success();
476 }
477 
478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
479   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
480                                                  valueRangeTy);
481 }
482 
483 LogicalResult Parser::convertExpressionTo(
484     ast::Expr *&expr, ast::Type type,
485     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
486   ast::Type exprType = expr->getType();
487   if (exprType == type)
488     return success();
489 
490   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
491     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
492         expr->getLoc(), llvm::formatv("unable to convert expression of type "
493                                       "`{0}` to the expected type of "
494                                       "`{1}`",
495                                       exprType, type));
496     if (noteAttachFn)
497       noteAttachFn(*diag);
498     return diag;
499   };
500 
501   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
502     // Two operation types are compatible if they have the same name, or if the
503     // expected type is more general.
504     if (auto opType = type.dyn_cast<ast::OperationType>()) {
505       if (opType.getName())
506         return emitConvertError();
507       return success();
508     }
509 
510     // An operation can always convert to a ValueRange.
511     if (type == valueRangeTy) {
512       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
513                                                      valueRangeTy);
514       return success();
515     }
516 
517     // Allow conversion to a single value by constraining the result range.
518     if (type == valueTy) {
519       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
520                                                      valueTy);
521       return success();
522     }
523     return emitConvertError();
524   }
525 
526   // FIXME: Decide how to allow/support converting a single result to multiple,
527   // and multiple to a single result. For now, we just allow Single->Range,
528   // but this isn't something really supported in the PDL dialect. We should
529   // figure out some way to support both.
530   if ((exprType == valueTy || exprType == valueRangeTy) &&
531       (type == valueTy || type == valueRangeTy))
532     return success();
533   if ((exprType == typeTy || exprType == typeRangeTy) &&
534       (type == typeTy || type == typeRangeTy))
535     return success();
536 
537   // Handle tuple types.
538   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
539     auto tupleType = type.dyn_cast<ast::TupleType>();
540     if (!tupleType || tupleType.size() != exprTupleType.size())
541       return emitConvertError();
542 
543     // Build a new tuple expression using each of the elements of the current
544     // tuple.
545     SmallVector<ast::Expr *> newExprs;
546     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
547       newExprs.push_back(ast::MemberAccessExpr::create(
548           ctx, expr->getLoc(), expr, llvm::to_string(i),
549           exprTupleType.getElementTypes()[i]));
550 
551       auto diagFn = [&](ast::Diagnostic &diag) {
552         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
553                                       i, exprTupleType));
554         if (noteAttachFn)
555           noteAttachFn(diag);
556       };
557       if (failed(convertExpressionTo(newExprs.back(),
558                                      tupleType.getElementTypes()[i], diagFn)))
559         return failure();
560     }
561     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
562                                   tupleType.getElementNames());
563     return success();
564   }
565 
566   return emitConvertError();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // Directives
571 
572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
573   StringRef directive = curToken.getSpelling();
574   if (directive == "#include")
575     return parseInclude(decls);
576 
577   return emitError("unknown directive `" + directive + "`");
578 }
579 
580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
581   SMRange loc = curToken.getLoc();
582   consumeToken(Token::directive);
583 
584   // Parse the file being included.
585   if (!curToken.isString())
586     return emitError(loc,
587                      "expected string file name after `include` directive");
588   SMRange fileLoc = curToken.getLoc();
589   std::string filenameStr = curToken.getStringValue();
590   StringRef filename = filenameStr;
591   consumeToken();
592 
593   // Check the type of include. If ending with `.pdll`, this is another pdl file
594   // to be parsed along with the current module.
595   if (filename.endswith(".pdll")) {
596     if (failed(lexer.pushInclude(filename)))
597       return emitError(fileLoc,
598                        "unable to open include file `" + filename + "`");
599 
600     // If we added the include successfully, parse it into the current module.
601     // Make sure to save the current token so that we can restore it when we
602     // finish parsing the nested file.
603     Token oldToken = curToken;
604     curToken = lexer.lexToken();
605     LogicalResult result = parseModuleBody(decls);
606     curToken = oldToken;
607     return result;
608   }
609 
610   return emitError(fileLoc, "expected include filename to end with `.pdll`");
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // Decls
615 
616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
617   FailureOr<ast::Decl *> decl;
618   switch (curToken.getKind()) {
619   case Token::kw_Constraint:
620     decl = parseUserConstraintDecl();
621     break;
622   case Token::kw_Pattern:
623     decl = parsePatternDecl();
624     break;
625   case Token::kw_Rewrite:
626     decl = parseUserRewriteDecl();
627     break;
628   default:
629     return emitError("expected top-level declaration, such as a `Pattern`");
630   }
631   if (failed(decl))
632     return failure();
633 
634   // If the decl has a name, add it to the current scope.
635   if (const ast::Name *name = (*decl)->getName()) {
636     if (failed(checkDefineNamedDecl(*name)))
637       return failure();
638     curDeclScope->add(*decl);
639   }
640   return decl;
641 }
642 
643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
644   std::string attrNameStr;
645   if (curToken.isString())
646     attrNameStr = curToken.getStringValue();
647   else if (curToken.is(Token::identifier) || curToken.isKeyword())
648     attrNameStr = curToken.getSpelling().str();
649   else
650     return emitError("expected identifier or string attribute name");
651   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
652   consumeToken();
653 
654   // Check for a value of the attribute.
655   ast::Expr *attrValue = nullptr;
656   if (consumeIf(Token::equal)) {
657     FailureOr<ast::Expr *> attrExpr = parseExpr();
658     if (failed(attrExpr))
659       return failure();
660     attrValue = *attrExpr;
661   } else {
662     // If there isn't a concrete value, create an expression representing a
663     // UnitAttr.
664     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
665   }
666 
667   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
668 }
669 
670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
671     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
672     bool expectTerminalSemicolon) {
673   consumeToken(Token::equal_arrow);
674 
675   // Parse the single statement of the lambda body.
676   SMLoc bodyStartLoc = curToken.getStartLoc();
677   pushDeclScope();
678   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
679   bool failedToParse =
680       failed(singleStatement) || failed(processStatementFn(*singleStatement));
681   popDeclScope();
682   if (failedToParse)
683     return failure();
684 
685   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
686   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
687 }
688 
689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
690   // Ensure that the argument is named.
691   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
692     return emitError("expected identifier argument name");
693 
694   // Parse the argument similarly to a normal variable.
695   StringRef name = curToken.getSpelling();
696   SMRange nameLoc = curToken.getLoc();
697   consumeToken();
698 
699   if (failed(
700           parseToken(Token::colon, "expected `:` before argument constraint")))
701     return failure();
702 
703   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
704   if (failed(cst))
705     return failure();
706 
707   return createArgOrResultVariableDecl(name, nameLoc, *cst);
708 }
709 
710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
711   // Check to see if this result is named.
712   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
713     // Check to see if this name actually refers to a Constraint.
714     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
715     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
716       // If yes, and this is a Rewrite, give a nice error message as non-Core
717       // constraints are not supported on Rewrite results.
718       if (parserContext == ParserContext::Rewrite) {
719         return emitError(
720             "`Rewrite` results are only permitted to use core constraints, "
721             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
722       }
723 
724       // Otherwise, parse this as an unnamed result variable.
725     } else {
726       // If it wasn't a constraint, parse the result similarly to a variable. If
727       // there is already an existing decl, we will emit an error when defining
728       // this variable later.
729       StringRef name = curToken.getSpelling();
730       SMRange nameLoc = curToken.getLoc();
731       consumeToken();
732 
733       if (failed(parseToken(Token::colon,
734                             "expected `:` before result constraint")))
735         return failure();
736 
737       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
738       if (failed(cst))
739         return failure();
740 
741       return createArgOrResultVariableDecl(name, nameLoc, *cst);
742     }
743   }
744 
745   // If it isn't named, we parse the constraint directly and create an unnamed
746   // result variable.
747   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
748   if (failed(cst))
749     return failure();
750 
751   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
752 }
753 
754 FailureOr<ast::UserConstraintDecl *>
755 Parser::parseUserConstraintDecl(bool isInline) {
756   // Constraints and rewrites have very similar formats, dispatch to a shared
757   // interface for parsing.
758   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
759       [&](auto &&...args) { return parseUserPDLLConstraintDecl(args...); },
760       ParserContext::Constraint, "constraint", isInline);
761 }
762 
763 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
764   FailureOr<ast::UserConstraintDecl *> decl =
765       parseUserConstraintDecl(/*isInline=*/true);
766   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
767     return failure();
768 
769   curDeclScope->add(*decl);
770   return decl;
771 }
772 
773 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
774     const ast::Name &name, bool isInline,
775     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
776     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
777   // Push the argument scope back onto the list, so that the body can
778   // reference arguments.
779   pushDeclScope(argumentScope);
780 
781   // Parse the body of the constraint. The body is either defined as a compound
782   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
783   ast::CompoundStmt *body;
784   if (curToken.is(Token::equal_arrow)) {
785     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
786         [&](ast::Stmt *&stmt) -> LogicalResult {
787           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
788           if (!stmtExpr) {
789             return emitError(stmt->getLoc(),
790                              "expected `Constraint` lambda body to contain a "
791                              "single expression");
792           }
793           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
794           return success();
795         },
796         /*expectTerminalSemicolon=*/!isInline);
797     if (failed(bodyResult))
798       return failure();
799     body = *bodyResult;
800   } else {
801     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
802     if (failed(bodyResult))
803       return failure();
804     body = *bodyResult;
805 
806     // Verify the structure of the body.
807     auto bodyIt = body->begin(), bodyE = body->end();
808     for (; bodyIt != bodyE; ++bodyIt)
809       if (isa<ast::ReturnStmt>(*bodyIt))
810         break;
811     if (failed(validateUserConstraintOrRewriteReturn(
812             "Constraint", body, bodyIt, bodyE, results, resultType)))
813       return failure();
814   }
815   popDeclScope();
816 
817   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
818       name, arguments, results, resultType, body);
819 }
820 
821 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
822   // Constraints and rewrites have very similar formats, dispatch to a shared
823   // interface for parsing.
824   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
825       [&](auto &&...args) { return parseUserPDLLRewriteDecl(args...); },
826       ParserContext::Rewrite, "rewrite", isInline);
827 }
828 
829 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
830   FailureOr<ast::UserRewriteDecl *> decl =
831       parseUserRewriteDecl(/*isInline=*/true);
832   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
833     return failure();
834 
835   curDeclScope->add(*decl);
836   return decl;
837 }
838 
839 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
840     const ast::Name &name, bool isInline,
841     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
842     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
843   // Push the argument scope back onto the list, so that the body can
844   // reference arguments.
845   curDeclScope = argumentScope;
846   ast::CompoundStmt *body;
847   if (curToken.is(Token::equal_arrow)) {
848     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
849         [&](ast::Stmt *&statement) -> LogicalResult {
850           if (isa<ast::OpRewriteStmt>(statement))
851             return success();
852 
853           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
854           if (!statementExpr) {
855             return emitError(
856                 statement->getLoc(),
857                 "expected `Rewrite` lambda body to contain a single expression "
858                 "or an operation rewrite statement; such as `erase`, "
859                 "`replace`, or `rewrite`");
860           }
861           statement =
862               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
863           return success();
864         },
865         /*expectTerminalSemicolon=*/!isInline);
866     if (failed(bodyResult))
867       return failure();
868     body = *bodyResult;
869   } else {
870     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
871     if (failed(bodyResult))
872       return failure();
873     body = *bodyResult;
874   }
875   popDeclScope();
876 
877   // Verify the structure of the body.
878   auto bodyIt = body->begin(), bodyE = body->end();
879   for (; bodyIt != bodyE; ++bodyIt)
880     if (isa<ast::ReturnStmt>(*bodyIt))
881       break;
882   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
883                                                    bodyE, results, resultType)))
884     return failure();
885   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
886       name, arguments, results, resultType, body);
887 }
888 
889 template <typename T, typename ParseUserPDLLDeclFnT>
890 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
891     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
892     StringRef anonymousNamePrefix, bool isInline) {
893   SMRange loc = curToken.getLoc();
894   consumeToken();
895   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
896 
897   // Parse the name of the decl.
898   const ast::Name *name = nullptr;
899   if (curToken.isNot(Token::identifier)) {
900     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
901     // in C++, so being unnamed is fine.
902     if (!isInline)
903       return emitError("expected identifier name");
904 
905     // Create a unique anonymous name to use, as the name for this decl is not
906     // important.
907     std::string anonName =
908         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
909                       anonymousDeclNameCounter++)
910             .str();
911     name = &ast::Name::create(ctx, anonName, loc);
912   } else {
913     // If a name was provided, we can use it directly.
914     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
915     consumeToken(Token::identifier);
916   }
917 
918   // Parse the functional signature of the decl.
919   SmallVector<ast::VariableDecl *> arguments, results;
920   ast::DeclScope *argumentScope;
921   ast::Type resultType;
922   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
923                                                    argumentScope, resultType)))
924     return failure();
925 
926   // Check to see which type of constraint this is. If the constraint contains a
927   // compound body, this is a PDLL decl.
928   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
929     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
930                            resultType);
931 
932   // Otherwise, this is a native decl.
933   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
934                                                    results, resultType);
935 }
936 
937 template <typename T>
938 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
939     const ast::Name &name, bool isInline,
940     ArrayRef<ast::VariableDecl *> arguments,
941     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
942   // If followed by a string, the native code body has also been specified.
943   std::string codeStrStorage;
944   Optional<StringRef> optCodeStr;
945   if (curToken.isString()) {
946     codeStrStorage = curToken.getStringValue();
947     optCodeStr = codeStrStorage;
948     consumeToken();
949   } else if (isInline) {
950     return emitError(name.getLoc(),
951                      "external declarations must be declared in global scope");
952   }
953   if (failed(parseToken(Token::semicolon,
954                         "expected `;` after native declaration")))
955     return failure();
956   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
957 }
958 
959 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
960     SmallVectorImpl<ast::VariableDecl *> &arguments,
961     SmallVectorImpl<ast::VariableDecl *> &results,
962     ast::DeclScope *&argumentScope, ast::Type &resultType) {
963   // Parse the argument list of the decl.
964   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
965     return failure();
966 
967   argumentScope = pushDeclScope();
968   if (curToken.isNot(Token::r_paren)) {
969     do {
970       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
971       if (failed(argument))
972         return failure();
973       arguments.emplace_back(*argument);
974     } while (consumeIf(Token::comma));
975   }
976   popDeclScope();
977   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
978     return failure();
979 
980   // Parse the results of the decl.
981   pushDeclScope();
982   if (consumeIf(Token::arrow)) {
983     auto parseResultFn = [&]() -> LogicalResult {
984       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
985       if (failed(result))
986         return failure();
987       results.emplace_back(*result);
988       return success();
989     };
990 
991     // Check for a list of results.
992     if (consumeIf(Token::l_paren)) {
993       do {
994         if (failed(parseResultFn()))
995           return failure();
996       } while (consumeIf(Token::comma));
997       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
998         return failure();
999 
1000       // Otherwise, there is only one result.
1001     } else if (failed(parseResultFn())) {
1002       return failure();
1003     }
1004   }
1005   popDeclScope();
1006 
1007   // Compute the result type of the decl.
1008   resultType = createUserConstraintRewriteResultType(results);
1009 
1010   // Verify that results are only named if there are more than one.
1011   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1012     return emitError(
1013         results.front()->getLoc(),
1014         "cannot create a single-element tuple with an element label");
1015   }
1016   return success();
1017 }
1018 
1019 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1020     StringRef declType, ast::CompoundStmt *body,
1021     ArrayRef<ast::Stmt *>::iterator bodyIt,
1022     ArrayRef<ast::Stmt *>::iterator bodyE,
1023     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1024   // Handle if a `return` was provided.
1025   if (bodyIt != bodyE) {
1026     // Emit an error if we have trailing statements after the return.
1027     if (std::next(bodyIt) != bodyE) {
1028       return emitError(
1029           (*std::next(bodyIt))->getLoc(),
1030           llvm::formatv("`return` terminated the `{0}` body, but found "
1031                         "trailing statements afterwards",
1032                         declType));
1033     }
1034 
1035     // Otherwise if a return wasn't provided, check that no results are
1036     // expected.
1037   } else if (!results.empty()) {
1038     return emitError(
1039         {body->getLoc().End, body->getLoc().End},
1040         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1041                       declType, resultType));
1042   }
1043   return success();
1044 }
1045 
1046 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1047   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1048     if (isa<ast::OpRewriteStmt>(statement))
1049       return success();
1050     return emitError(
1051         statement->getLoc(),
1052         "expected Pattern lambda body to contain a single operation "
1053         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1054   });
1055 }
1056 
1057 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1058   SMRange loc = curToken.getLoc();
1059   consumeToken(Token::kw_Pattern);
1060   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1061                                               ParserContext::PatternMatch);
1062 
1063   // Check for an optional identifier for the pattern name.
1064   const ast::Name *name = nullptr;
1065   if (curToken.is(Token::identifier)) {
1066     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1067     consumeToken(Token::identifier);
1068   }
1069 
1070   // Parse any pattern metadata.
1071   ParsedPatternMetadata metadata;
1072   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1073     return failure();
1074 
1075   // Parse the pattern body.
1076   ast::CompoundStmt *body;
1077 
1078   // Handle a lambda body.
1079   if (curToken.is(Token::equal_arrow)) {
1080     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1081     if (failed(bodyResult))
1082       return failure();
1083     body = *bodyResult;
1084   } else {
1085     if (curToken.isNot(Token::l_brace))
1086       return emitError("expected `{` or `=>` to start pattern body");
1087     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1088     if (failed(bodyResult))
1089       return failure();
1090     body = *bodyResult;
1091 
1092     // Verify the body of the pattern.
1093     auto bodyIt = body->begin(), bodyE = body->end();
1094     for (; bodyIt != bodyE; ++bodyIt) {
1095       if (isa<ast::ReturnStmt>(*bodyIt)) {
1096         return emitError((*bodyIt)->getLoc(),
1097                          "`return` statements are only permitted within a "
1098                          "`Constraint` or `Rewrite` body");
1099       }
1100       // Break when we've found the rewrite statement.
1101       if (isa<ast::OpRewriteStmt>(*bodyIt))
1102         break;
1103     }
1104     if (bodyIt == bodyE) {
1105       return emitError(loc,
1106                        "expected Pattern body to terminate with an operation "
1107                        "rewrite statement, such as `erase`");
1108     }
1109     if (std::next(bodyIt) != bodyE) {
1110       return emitError((*std::next(bodyIt))->getLoc(),
1111                        "Pattern body was terminated by an operation "
1112                        "rewrite statement, but found trailing statements");
1113     }
1114   }
1115 
1116   return createPatternDecl(loc, name, metadata, body);
1117 }
1118 
1119 LogicalResult
1120 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1121   Optional<SMRange> benefitLoc;
1122   Optional<SMRange> hasBoundedRecursionLoc;
1123 
1124   do {
1125     if (curToken.isNot(Token::identifier))
1126       return emitError("expected pattern metadata identifier");
1127     StringRef metadataStr = curToken.getSpelling();
1128     SMRange metadataLoc = curToken.getLoc();
1129     consumeToken(Token::identifier);
1130 
1131     // Parse the benefit metadata: benefit(<integer-value>)
1132     if (metadataStr == "benefit") {
1133       if (benefitLoc) {
1134         return emitErrorAndNote(metadataLoc,
1135                                 "pattern benefit has already been specified",
1136                                 *benefitLoc, "see previous definition here");
1137       }
1138       if (failed(parseToken(Token::l_paren,
1139                             "expected `(` before pattern benefit")))
1140         return failure();
1141 
1142       uint16_t benefitValue = 0;
1143       if (curToken.isNot(Token::integer))
1144         return emitError("expected integral pattern benefit");
1145       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1146         return emitError(
1147             "expected pattern benefit to fit within a 16-bit integer");
1148       consumeToken(Token::integer);
1149 
1150       metadata.benefit = benefitValue;
1151       benefitLoc = metadataLoc;
1152 
1153       if (failed(
1154               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1155         return failure();
1156       continue;
1157     }
1158 
1159     // Parse the bounded recursion metadata: recursion
1160     if (metadataStr == "recursion") {
1161       if (hasBoundedRecursionLoc) {
1162         return emitErrorAndNote(
1163             metadataLoc,
1164             "pattern recursion metadata has already been specified",
1165             *hasBoundedRecursionLoc, "see previous definition here");
1166       }
1167       metadata.hasBoundedRecursion = true;
1168       hasBoundedRecursionLoc = metadataLoc;
1169       continue;
1170     }
1171 
1172     return emitError(metadataLoc, "unknown pattern metadata");
1173   } while (consumeIf(Token::comma));
1174 
1175   return success();
1176 }
1177 
1178 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1179   consumeToken(Token::less);
1180 
1181   FailureOr<ast::Expr *> typeExpr = parseExpr();
1182   if (failed(typeExpr) ||
1183       failed(parseToken(Token::greater,
1184                         "expected `>` after variable type constraint")))
1185     return failure();
1186   return typeExpr;
1187 }
1188 
1189 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1190   assert(curDeclScope && "defining decl outside of a decl scope");
1191   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1192     return emitErrorAndNote(
1193         name.getLoc(), "`" + name.getName() + "` has already been defined",
1194         lastDecl->getName()->getLoc(), "see previous definition here");
1195   }
1196   return success();
1197 }
1198 
1199 FailureOr<ast::VariableDecl *>
1200 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1201                            ast::Expr *initExpr,
1202                            ArrayRef<ast::ConstraintRef> constraints) {
1203   assert(curDeclScope && "defining variable outside of decl scope");
1204   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1205 
1206   // If the name of the variable indicates a special variable, we don't add it
1207   // to the scope. This variable is local to the definition point.
1208   if (name.empty() || name == "_") {
1209     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1210                                      constraints);
1211   }
1212   if (failed(checkDefineNamedDecl(nameDecl)))
1213     return failure();
1214 
1215   auto *varDecl =
1216       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1217   curDeclScope->add(varDecl);
1218   return varDecl;
1219 }
1220 
1221 FailureOr<ast::VariableDecl *>
1222 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1223                            ArrayRef<ast::ConstraintRef> constraints) {
1224   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1225                             constraints);
1226 }
1227 
1228 LogicalResult Parser::parseVariableDeclConstraintList(
1229     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1230   Optional<SMRange> typeConstraint;
1231   auto parseSingleConstraint = [&] {
1232     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1233         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1234     if (failed(constraint))
1235       return failure();
1236     constraints.push_back(*constraint);
1237     return success();
1238   };
1239 
1240   // Check to see if this is a single constraint, or a list.
1241   if (!consumeIf(Token::l_square))
1242     return parseSingleConstraint();
1243 
1244   do {
1245     if (failed(parseSingleConstraint()))
1246       return failure();
1247   } while (consumeIf(Token::comma));
1248   return parseToken(Token::r_square, "expected `]` after constraint list");
1249 }
1250 
1251 FailureOr<ast::ConstraintRef>
1252 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1253                         ArrayRef<ast::ConstraintRef> existingConstraints,
1254                         bool allowInlineTypeConstraints) {
1255   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1256     if (!allowInlineTypeConstraints) {
1257       return emitError(
1258           curToken.getLoc(),
1259           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1260           "permitted on arguments or results");
1261     }
1262     if (typeConstraint)
1263       return emitErrorAndNote(
1264           curToken.getLoc(),
1265           "the type of this variable has already been constrained",
1266           *typeConstraint, "see previous constraint location here");
1267     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1268     if (failed(constraintExpr))
1269       return failure();
1270     typeExpr = *constraintExpr;
1271     typeConstraint = typeExpr->getLoc();
1272     return success();
1273   };
1274 
1275   SMRange loc = curToken.getLoc();
1276   switch (curToken.getKind()) {
1277   case Token::kw_Attr: {
1278     consumeToken(Token::kw_Attr);
1279 
1280     // Check for a type constraint.
1281     ast::Expr *typeExpr = nullptr;
1282     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1283       return failure();
1284     return ast::ConstraintRef(
1285         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1286   }
1287   case Token::kw_Op: {
1288     consumeToken(Token::kw_Op);
1289 
1290     // Parse an optional operation name. If the name isn't provided, this refers
1291     // to "any" operation.
1292     FailureOr<ast::OpNameDecl *> opName =
1293         parseWrappedOperationName(/*allowEmptyName=*/true);
1294     if (failed(opName))
1295       return failure();
1296 
1297     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1298                               loc);
1299   }
1300   case Token::kw_Type:
1301     consumeToken(Token::kw_Type);
1302     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1303   case Token::kw_TypeRange:
1304     consumeToken(Token::kw_TypeRange);
1305     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1306                               loc);
1307   case Token::kw_Value: {
1308     consumeToken(Token::kw_Value);
1309 
1310     // Check for a type constraint.
1311     ast::Expr *typeExpr = nullptr;
1312     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1313       return failure();
1314 
1315     return ast::ConstraintRef(
1316         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1317   }
1318   case Token::kw_ValueRange: {
1319     consumeToken(Token::kw_ValueRange);
1320 
1321     // Check for a type constraint.
1322     ast::Expr *typeExpr = nullptr;
1323     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1324       return failure();
1325 
1326     return ast::ConstraintRef(
1327         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1328   }
1329 
1330   case Token::kw_Constraint: {
1331     // Handle an inline constraint.
1332     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1333     if (failed(decl))
1334       return failure();
1335     return ast::ConstraintRef(*decl, loc);
1336   }
1337   case Token::identifier: {
1338     StringRef constraintName = curToken.getSpelling();
1339     consumeToken(Token::identifier);
1340 
1341     // Lookup the referenced constraint.
1342     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1343     if (!cstDecl) {
1344       return emitError(loc, "unknown reference to constraint `" +
1345                                 constraintName + "`");
1346     }
1347 
1348     // Handle a reference to a proper constraint.
1349     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1350       return ast::ConstraintRef(cst, loc);
1351 
1352     return emitErrorAndNote(
1353         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1354         "see the definition of `" + constraintName + "` here");
1355   }
1356   default:
1357     break;
1358   }
1359   return emitError(loc, "expected identifier constraint");
1360 }
1361 
1362 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1363   Optional<SMRange> typeConstraint;
1364   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1365                          /*allowInlineTypeConstraints=*/false);
1366 }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // Exprs
1370 
1371 FailureOr<ast::Expr *> Parser::parseExpr() {
1372   if (curToken.is(Token::underscore))
1373     return parseUnderscoreExpr();
1374 
1375   // Parse the LHS expression.
1376   FailureOr<ast::Expr *> lhsExpr;
1377   switch (curToken.getKind()) {
1378   case Token::kw_attr:
1379     lhsExpr = parseAttributeExpr();
1380     break;
1381   case Token::kw_Constraint:
1382     lhsExpr = parseInlineConstraintLambdaExpr();
1383     break;
1384   case Token::identifier:
1385     lhsExpr = parseIdentifierExpr();
1386     break;
1387   case Token::kw_op:
1388     lhsExpr = parseOperationExpr();
1389     break;
1390   case Token::kw_Rewrite:
1391     lhsExpr = parseInlineRewriteLambdaExpr();
1392     break;
1393   case Token::kw_type:
1394     lhsExpr = parseTypeExpr();
1395     break;
1396   case Token::l_paren:
1397     lhsExpr = parseTupleExpr();
1398     break;
1399   default:
1400     return emitError("expected expression");
1401   }
1402   if (failed(lhsExpr))
1403     return failure();
1404 
1405   // Check for an operator expression.
1406   while (true) {
1407     switch (curToken.getKind()) {
1408     case Token::dot:
1409       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1410       break;
1411     case Token::l_paren:
1412       lhsExpr = parseCallExpr(*lhsExpr);
1413       break;
1414     default:
1415       return lhsExpr;
1416     }
1417     if (failed(lhsExpr))
1418       return failure();
1419   }
1420 }
1421 
1422 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1423   SMRange loc = curToken.getLoc();
1424   consumeToken(Token::kw_attr);
1425 
1426   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1427   // identifier.
1428   if (!consumeIf(Token::less)) {
1429     resetToken(loc);
1430     return parseIdentifierExpr();
1431   }
1432 
1433   if (!curToken.isString())
1434     return emitError("expected string literal containing MLIR attribute");
1435   std::string attrExpr = curToken.getStringValue();
1436   consumeToken();
1437 
1438   if (failed(
1439           parseToken(Token::greater, "expected `>` after attribute literal")))
1440     return failure();
1441   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1442 }
1443 
1444 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1445   SMRange loc = curToken.getLoc();
1446   consumeToken(Token::l_paren);
1447 
1448   // Parse the arguments of the call.
1449   SmallVector<ast::Expr *> arguments;
1450   if (curToken.isNot(Token::r_paren)) {
1451     do {
1452       FailureOr<ast::Expr *> argument = parseExpr();
1453       if (failed(argument))
1454         return failure();
1455       arguments.push_back(*argument);
1456     } while (consumeIf(Token::comma));
1457   }
1458   loc.End = curToken.getEndLoc();
1459   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1460     return failure();
1461 
1462   return createCallExpr(loc, parentExpr, arguments);
1463 }
1464 
1465 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1466   ast::Decl *decl = curDeclScope->lookup(name);
1467   if (!decl)
1468     return emitError(loc, "undefined reference to `" + name + "`");
1469 
1470   return createDeclRefExpr(loc, decl);
1471 }
1472 
1473 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1474   StringRef name = curToken.getSpelling();
1475   SMRange nameLoc = curToken.getLoc();
1476   consumeToken();
1477 
1478   // Check to see if this is a decl ref expression that defines a variable
1479   // inline.
1480   if (consumeIf(Token::colon)) {
1481     SmallVector<ast::ConstraintRef> constraints;
1482     if (failed(parseVariableDeclConstraintList(constraints)))
1483       return failure();
1484     ast::Type type;
1485     if (failed(validateVariableConstraints(constraints, type)))
1486       return failure();
1487     return createInlineVariableExpr(type, name, nameLoc, constraints);
1488   }
1489 
1490   return parseDeclRefExpr(name, nameLoc);
1491 }
1492 
1493 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1494   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1495   if (failed(decl))
1496     return failure();
1497 
1498   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1499                                   ast::ConstraintType::get(ctx));
1500 }
1501 
1502 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1503   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1504   if (failed(decl))
1505     return failure();
1506 
1507   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1508                                   ast::RewriteType::get(ctx));
1509 }
1510 
1511 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1512   SMRange loc = curToken.getLoc();
1513   consumeToken(Token::dot);
1514 
1515   // Parse the member name.
1516   Token memberNameTok = curToken;
1517   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1518       !memberNameTok.isKeyword())
1519     return emitError(loc, "expected identifier or numeric member name");
1520   StringRef memberName = memberNameTok.getSpelling();
1521   consumeToken();
1522 
1523   return createMemberAccessExpr(parentExpr, memberName, loc);
1524 }
1525 
1526 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1527   SMRange loc = curToken.getLoc();
1528 
1529   // Handle the case of an no operation name.
1530   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1531     if (allowEmptyName)
1532       return ast::OpNameDecl::create(ctx, SMRange());
1533     return emitError("expected dialect namespace");
1534   }
1535   StringRef name = curToken.getSpelling();
1536   consumeToken();
1537 
1538   // Otherwise, this is a literal operation name.
1539   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1540     return failure();
1541 
1542   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1543     return emitError("expected operation name after dialect namespace");
1544 
1545   name = StringRef(name.data(), name.size() + 1);
1546   do {
1547     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1548     loc.End = curToken.getEndLoc();
1549     consumeToken();
1550   } while (curToken.isAny(Token::identifier, Token::dot) ||
1551            curToken.isKeyword());
1552   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1553 }
1554 
1555 FailureOr<ast::OpNameDecl *>
1556 Parser::parseWrappedOperationName(bool allowEmptyName) {
1557   if (!consumeIf(Token::less))
1558     return ast::OpNameDecl::create(ctx, SMRange());
1559 
1560   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1561   if (failed(opNameDecl))
1562     return failure();
1563 
1564   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1565     return failure();
1566   return opNameDecl;
1567 }
1568 
1569 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1570   SMRange loc = curToken.getLoc();
1571   consumeToken(Token::kw_op);
1572 
1573   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1574   // identifier.
1575   if (curToken.isNot(Token::less)) {
1576     resetToken(loc);
1577     return parseIdentifierExpr();
1578   }
1579 
1580   // Parse the operation name. The name may be elided, in which case the
1581   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1582   // `Operation*`). Operation names within a rewrite context must be named.
1583   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1584   FailureOr<ast::OpNameDecl *> opNameDecl =
1585       parseWrappedOperationName(allowEmptyName);
1586   if (failed(opNameDecl))
1587     return failure();
1588 
1589   // Check for the optional list of operands.
1590   SmallVector<ast::Expr *> operands;
1591   if (consumeIf(Token::l_paren)) {
1592     do {
1593       FailureOr<ast::Expr *> operand = parseExpr();
1594       if (failed(operand))
1595         return failure();
1596       operands.push_back(*operand);
1597     } while (consumeIf(Token::comma));
1598 
1599     if (failed(parseToken(Token::r_paren,
1600                           "expected `)` after operation operand list")))
1601       return failure();
1602   }
1603 
1604   // Check for the optional list of attributes.
1605   SmallVector<ast::NamedAttributeDecl *> attributes;
1606   if (consumeIf(Token::l_brace)) {
1607     do {
1608       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1609       if (failed(decl))
1610         return failure();
1611       attributes.emplace_back(*decl);
1612     } while (consumeIf(Token::comma));
1613 
1614     if (failed(parseToken(Token::r_brace,
1615                           "expected `}` after operation attribute list")))
1616       return failure();
1617   }
1618 
1619   // Check for the optional list of result types.
1620   SmallVector<ast::Expr *> resultTypes;
1621   if (consumeIf(Token::arrow)) {
1622     if (failed(parseToken(Token::l_paren,
1623                           "expected `(` before operation result type list")))
1624       return failure();
1625 
1626     do {
1627       FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1628       if (failed(resultTypeExpr))
1629         return failure();
1630       resultTypes.push_back(*resultTypeExpr);
1631     } while (consumeIf(Token::comma));
1632 
1633     if (failed(parseToken(Token::r_paren,
1634                           "expected `)` after operation result type list")))
1635       return failure();
1636   }
1637 
1638   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1639                              resultTypes);
1640 }
1641 
1642 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1643   SMRange loc = curToken.getLoc();
1644   consumeToken(Token::l_paren);
1645 
1646   DenseMap<StringRef, SMRange> usedNames;
1647   SmallVector<StringRef> elementNames;
1648   SmallVector<ast::Expr *> elements;
1649   if (curToken.isNot(Token::r_paren)) {
1650     do {
1651       // Check for the optional element name assignment before the value.
1652       StringRef elementName;
1653       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1654         Token elementNameTok = curToken;
1655         consumeToken();
1656 
1657         // The element name is only present if followed by an `=`.
1658         if (consumeIf(Token::equal)) {
1659           elementName = elementNameTok.getSpelling();
1660 
1661           // Check to see if this name is already used.
1662           auto elementNameIt =
1663               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1664           if (!elementNameIt.second) {
1665             return emitErrorAndNote(
1666                 elementNameTok.getLoc(),
1667                 llvm::formatv("duplicate tuple element label `{0}`",
1668                               elementName),
1669                 elementNameIt.first->getSecond(),
1670                 "see previous label use here");
1671           }
1672         } else {
1673           // Otherwise, we treat this as part of an expression so reset the
1674           // lexer.
1675           resetToken(elementNameTok.getLoc());
1676         }
1677       }
1678       elementNames.push_back(elementName);
1679 
1680       // Parse the tuple element value.
1681       FailureOr<ast::Expr *> element = parseExpr();
1682       if (failed(element))
1683         return failure();
1684       elements.push_back(*element);
1685     } while (consumeIf(Token::comma));
1686   }
1687   loc.End = curToken.getEndLoc();
1688   if (failed(
1689           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1690     return failure();
1691   return createTupleExpr(loc, elements, elementNames);
1692 }
1693 
1694 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1695   SMRange loc = curToken.getLoc();
1696   consumeToken(Token::kw_type);
1697 
1698   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1699   // identifier.
1700   if (!consumeIf(Token::less)) {
1701     resetToken(loc);
1702     return parseIdentifierExpr();
1703   }
1704 
1705   if (!curToken.isString())
1706     return emitError("expected string literal containing MLIR type");
1707   std::string attrExpr = curToken.getStringValue();
1708   consumeToken();
1709 
1710   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1711     return failure();
1712   return ast::TypeExpr::create(ctx, loc, attrExpr);
1713 }
1714 
1715 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
1716   StringRef name = curToken.getSpelling();
1717   SMRange nameLoc = curToken.getLoc();
1718   consumeToken(Token::underscore);
1719 
1720   // Underscore expressions require a constraint list.
1721   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
1722     return failure();
1723 
1724   // Parse the constraints for the expression.
1725   SmallVector<ast::ConstraintRef> constraints;
1726   if (failed(parseVariableDeclConstraintList(constraints)))
1727     return failure();
1728 
1729   ast::Type type;
1730   if (failed(validateVariableConstraints(constraints, type)))
1731     return failure();
1732   return createInlineVariableExpr(type, name, nameLoc, constraints);
1733 }
1734 
1735 //===----------------------------------------------------------------------===//
1736 // Stmts
1737 
1738 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
1739   FailureOr<ast::Stmt *> stmt;
1740   switch (curToken.getKind()) {
1741   case Token::kw_erase:
1742     stmt = parseEraseStmt();
1743     break;
1744   case Token::kw_let:
1745     stmt = parseLetStmt();
1746     break;
1747   case Token::kw_replace:
1748     stmt = parseReplaceStmt();
1749     break;
1750   case Token::kw_return:
1751     stmt = parseReturnStmt();
1752     break;
1753   case Token::kw_rewrite:
1754     stmt = parseRewriteStmt();
1755     break;
1756   default:
1757     stmt = parseExpr();
1758     break;
1759   }
1760   if (failed(stmt) ||
1761       (expectTerminalSemicolon &&
1762        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
1763     return failure();
1764   return stmt;
1765 }
1766 
1767 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
1768   SMLoc startLoc = curToken.getStartLoc();
1769   consumeToken(Token::l_brace);
1770 
1771   // Push a new block scope and parse any nested statements.
1772   pushDeclScope();
1773   SmallVector<ast::Stmt *> statements;
1774   while (curToken.isNot(Token::r_brace)) {
1775     FailureOr<ast::Stmt *> statement = parseStmt();
1776     if (failed(statement))
1777       return popDeclScope(), failure();
1778     statements.push_back(*statement);
1779   }
1780   popDeclScope();
1781 
1782   // Consume the end brace.
1783   SMRange location(startLoc, curToken.getEndLoc());
1784   consumeToken(Token::r_brace);
1785 
1786   return ast::CompoundStmt::create(ctx, location, statements);
1787 }
1788 
1789 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
1790   if (parserContext == ParserContext::Constraint)
1791     return emitError("`erase` cannot be used within a Constraint");
1792   SMRange loc = curToken.getLoc();
1793   consumeToken(Token::kw_erase);
1794 
1795   // Parse the root operation expression.
1796   FailureOr<ast::Expr *> rootOp = parseExpr();
1797   if (failed(rootOp))
1798     return failure();
1799 
1800   return createEraseStmt(loc, *rootOp);
1801 }
1802 
1803 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
1804   SMRange loc = curToken.getLoc();
1805   consumeToken(Token::kw_let);
1806 
1807   // Parse the name of the new variable.
1808   SMRange varLoc = curToken.getLoc();
1809   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
1810     // `_` is a reserved variable name.
1811     if (curToken.is(Token::underscore)) {
1812       return emitError(varLoc,
1813                        "`_` may only be used to define \"inline\" variables");
1814     }
1815     return emitError(varLoc,
1816                      "expected identifier after `let` to name a new variable");
1817   }
1818   StringRef varName = curToken.getSpelling();
1819   consumeToken();
1820 
1821   // Parse the optional set of constraints.
1822   SmallVector<ast::ConstraintRef> constraints;
1823   if (consumeIf(Token::colon) &&
1824       failed(parseVariableDeclConstraintList(constraints)))
1825     return failure();
1826 
1827   // Parse the optional initializer expression.
1828   ast::Expr *initializer = nullptr;
1829   if (consumeIf(Token::equal)) {
1830     FailureOr<ast::Expr *> initOrFailure = parseExpr();
1831     if (failed(initOrFailure))
1832       return failure();
1833     initializer = *initOrFailure;
1834 
1835     // Check that the constraints are compatible with having an initializer,
1836     // e.g. type constraints cannot be used with initializers.
1837     for (ast::ConstraintRef constraint : constraints) {
1838       LogicalResult result =
1839           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
1840               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
1841                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
1842                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
1843                   return this->emitError(
1844                       constraint.referenceLoc,
1845                       "type constraints are not permitted on variables with "
1846                       "initializers");
1847                 }
1848                 return success();
1849               })
1850               .Default(success());
1851       if (failed(result))
1852         return failure();
1853     }
1854   }
1855 
1856   FailureOr<ast::VariableDecl *> varDecl =
1857       createVariableDecl(varName, varLoc, initializer, constraints);
1858   if (failed(varDecl))
1859     return failure();
1860   return ast::LetStmt::create(ctx, loc, *varDecl);
1861 }
1862 
1863 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
1864   if (parserContext == ParserContext::Constraint)
1865     return emitError("`replace` cannot be used within a Constraint");
1866   SMRange loc = curToken.getLoc();
1867   consumeToken(Token::kw_replace);
1868 
1869   // Parse the root operation expression.
1870   FailureOr<ast::Expr *> rootOp = parseExpr();
1871   if (failed(rootOp))
1872     return failure();
1873 
1874   if (failed(
1875           parseToken(Token::kw_with, "expected `with` after root operation")))
1876     return failure();
1877 
1878   // The replacement portion of this statement is within a rewrite context.
1879   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1880                                               ParserContext::Rewrite);
1881 
1882   // Parse the replacement values.
1883   SmallVector<ast::Expr *> replValues;
1884   if (consumeIf(Token::l_paren)) {
1885     if (consumeIf(Token::r_paren)) {
1886       return emitError(
1887           loc, "expected at least one replacement value, consider using "
1888                "`erase` if no replacement values are desired");
1889     }
1890 
1891     do {
1892       FailureOr<ast::Expr *> replExpr = parseExpr();
1893       if (failed(replExpr))
1894         return failure();
1895       replValues.emplace_back(*replExpr);
1896     } while (consumeIf(Token::comma));
1897 
1898     if (failed(parseToken(Token::r_paren,
1899                           "expected `)` after replacement values")))
1900       return failure();
1901   } else {
1902     FailureOr<ast::Expr *> replExpr = parseExpr();
1903     if (failed(replExpr))
1904       return failure();
1905     replValues.emplace_back(*replExpr);
1906   }
1907 
1908   return createReplaceStmt(loc, *rootOp, replValues);
1909 }
1910 
1911 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
1912   SMRange loc = curToken.getLoc();
1913   consumeToken(Token::kw_return);
1914 
1915   // Parse the result value.
1916   FailureOr<ast::Expr *> resultExpr = parseExpr();
1917   if (failed(resultExpr))
1918     return failure();
1919 
1920   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
1921 }
1922 
1923 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
1924   if (parserContext == ParserContext::Constraint)
1925     return emitError("`rewrite` cannot be used within a Constraint");
1926   SMRange loc = curToken.getLoc();
1927   consumeToken(Token::kw_rewrite);
1928 
1929   // Parse the root operation.
1930   FailureOr<ast::Expr *> rootOp = parseExpr();
1931   if (failed(rootOp))
1932     return failure();
1933 
1934   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
1935     return failure();
1936 
1937   if (curToken.isNot(Token::l_brace))
1938     return emitError("expected `{` to start rewrite body");
1939 
1940   // The rewrite body of this statement is within a rewrite context.
1941   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1942                                               ParserContext::Rewrite);
1943 
1944   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
1945   if (failed(rewriteBody))
1946     return failure();
1947 
1948   // Verify the rewrite body.
1949   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
1950     if (isa<ast::ReturnStmt>(stmt)) {
1951       return emitError(stmt->getLoc(),
1952                        "`return` statements are only permitted within a "
1953                        "`Constraint` or `Rewrite` body");
1954     }
1955   }
1956 
1957   return createRewriteStmt(loc, *rootOp, *rewriteBody);
1958 }
1959 
1960 //===----------------------------------------------------------------------===//
1961 // Creation+Analysis
1962 //===----------------------------------------------------------------------===//
1963 
1964 //===----------------------------------------------------------------------===//
1965 // Decls
1966 
1967 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
1968   // Unwrap reference expressions.
1969   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
1970     node = init->getDecl();
1971   return dyn_cast<ast::CallableDecl>(node);
1972 }
1973 
1974 FailureOr<ast::PatternDecl *>
1975 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
1976                           const ParsedPatternMetadata &metadata,
1977                           ast::CompoundStmt *body) {
1978   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
1979                                   metadata.hasBoundedRecursion, body);
1980 }
1981 
1982 ast::Type Parser::createUserConstraintRewriteResultType(
1983     ArrayRef<ast::VariableDecl *> results) {
1984   // Single result decls use the type of the single result.
1985   if (results.size() == 1)
1986     return results[0]->getType();
1987 
1988   // Multiple results use a tuple type, with the types and names grabbed from
1989   // the result variable decls.
1990   auto resultTypes = llvm::map_range(
1991       results, [&](const auto *result) { return result->getType(); });
1992   auto resultNames = llvm::map_range(
1993       results, [&](const auto *result) { return result->getName().getName(); });
1994   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
1995                              llvm::to_vector(resultNames));
1996 }
1997 
1998 template <typename T>
1999 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2000     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2001     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2002     ast::CompoundStmt *body) {
2003   if (!body->getChildren().empty()) {
2004     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2005       ast::Expr *resultExpr = retStmt->getResultExpr();
2006 
2007       // Process the result of the decl. If no explicit signature results
2008       // were provided, check for return type inference. Otherwise, check that
2009       // the return expression can be converted to the expected type.
2010       if (results.empty())
2011         resultType = resultExpr->getType();
2012       else if (failed(convertExpressionTo(resultExpr, resultType)))
2013         return failure();
2014       else
2015         retStmt->setResultExpr(resultExpr);
2016     }
2017   }
2018   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2019 }
2020 
2021 FailureOr<ast::VariableDecl *>
2022 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2023                            ArrayRef<ast::ConstraintRef> constraints) {
2024   // The type of the variable, which is expected to be inferred by either a
2025   // constraint or an initializer expression.
2026   ast::Type type;
2027   if (failed(validateVariableConstraints(constraints, type)))
2028     return failure();
2029 
2030   if (initializer) {
2031     // Update the variable type based on the initializer, or try to convert the
2032     // initializer to the existing type.
2033     if (!type)
2034       type = initializer->getType();
2035     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2036       type = mergedType;
2037     else if (failed(convertExpressionTo(initializer, type)))
2038       return failure();
2039 
2040     // Otherwise, if there is no initializer check that the type has already
2041     // been resolved from the constraint list.
2042   } else if (!type) {
2043     return emitErrorAndNote(
2044         loc, "unable to infer type for variable `" + name + "`", loc,
2045         "the type of a variable must be inferable from the constraint "
2046         "list or the initializer");
2047   }
2048 
2049   // Constraint types cannot be used when defining variables.
2050   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2051     return emitError(
2052         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2053   }
2054 
2055   // Try to define a variable with the given name.
2056   FailureOr<ast::VariableDecl *> varDecl =
2057       defineVariableDecl(name, loc, type, initializer, constraints);
2058   if (failed(varDecl))
2059     return failure();
2060 
2061   return *varDecl;
2062 }
2063 
2064 FailureOr<ast::VariableDecl *>
2065 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2066                                       const ast::ConstraintRef &constraint) {
2067   // Constraint arguments may apply more complex constraints via the arguments.
2068   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2069   ast::Type argType;
2070   if (failed(validateVariableConstraint(constraint, argType,
2071                                         allowNonCoreConstraints)))
2072     return failure();
2073   return defineVariableDecl(name, loc, argType, constraint);
2074 }
2075 
2076 LogicalResult
2077 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2078                                     ast::Type &inferredType) {
2079   for (const ast::ConstraintRef &ref : constraints)
2080     if (failed(validateVariableConstraint(ref, inferredType)))
2081       return failure();
2082   return success();
2083 }
2084 
2085 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2086                                                  ast::Type &inferredType,
2087                                                  bool allowNonCoreConstraints) {
2088   ast::Type constraintType;
2089   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2090     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2091       if (failed(validateTypeConstraintExpr(typeExpr)))
2092         return failure();
2093     }
2094     constraintType = ast::AttributeType::get(ctx);
2095   } else if (const auto *cst =
2096                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2097     constraintType = ast::OperationType::get(ctx, cst->getName());
2098   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2099     constraintType = typeTy;
2100   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2101     constraintType = typeRangeTy;
2102   } else if (const auto *cst =
2103                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2104     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2105       if (failed(validateTypeConstraintExpr(typeExpr)))
2106         return failure();
2107     }
2108     constraintType = valueTy;
2109   } else if (const auto *cst =
2110                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2111     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2112       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2113         return failure();
2114     }
2115     constraintType = valueRangeTy;
2116   } else if (const auto *cst =
2117                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2118     if (!allowNonCoreConstraints) {
2119       return emitError(ref.referenceLoc,
2120                        "`Rewrite` arguments and results are only permitted to "
2121                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2122                        "`TypeRange`, `Value`, `ValueRange`");
2123     }
2124 
2125     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2126     if (inputs.size() != 1) {
2127       return emitErrorAndNote(ref.referenceLoc,
2128                               "`Constraint`s applied via a variable constraint "
2129                               "list must take a single input, but got " +
2130                                   Twine(inputs.size()),
2131                               cst->getLoc(),
2132                               "see definition of constraint here");
2133     }
2134     constraintType = inputs.front()->getType();
2135   } else {
2136     llvm_unreachable("unknown constraint type");
2137   }
2138 
2139   // Check that the constraint type is compatible with the current inferred
2140   // type.
2141   if (!inferredType) {
2142     inferredType = constraintType;
2143   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2144     inferredType = mergedTy;
2145   } else {
2146     return emitError(ref.referenceLoc,
2147                      llvm::formatv("constraint type `{0}` is incompatible "
2148                                    "with the previously inferred type `{1}`",
2149                                    constraintType, inferredType));
2150   }
2151   return success();
2152 }
2153 
2154 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2155   ast::Type typeExprType = typeExpr->getType();
2156   if (typeExprType != typeTy) {
2157     return emitError(typeExpr->getLoc(),
2158                      "expected expression of `Type` in type constraint");
2159   }
2160   return success();
2161 }
2162 
2163 LogicalResult
2164 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2165   ast::Type typeExprType = typeExpr->getType();
2166   if (typeExprType != typeRangeTy) {
2167     return emitError(typeExpr->getLoc(),
2168                      "expected expression of `TypeRange` in type constraint");
2169   }
2170   return success();
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // Exprs
2175 
2176 FailureOr<ast::CallExpr *>
2177 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2178                        MutableArrayRef<ast::Expr *> arguments) {
2179   ast::Type parentType = parentExpr->getType();
2180 
2181   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2182   if (!callableDecl) {
2183     return emitError(loc,
2184                      llvm::formatv("expected a reference to a callable "
2185                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2186                                    parentType));
2187   }
2188   if (parserContext == ParserContext::Rewrite) {
2189     if (isa<ast::UserConstraintDecl>(callableDecl))
2190       return emitError(
2191           loc, "unable to invoke `Constraint` within a rewrite section");
2192   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2193     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2194   }
2195 
2196   // Verify the arguments of the call.
2197   /// Handle size mismatch.
2198   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2199   if (callArgs.size() != arguments.size()) {
2200     return emitErrorAndNote(
2201         loc,
2202         llvm::formatv("invalid number of arguments for {0} call; expected "
2203                       "{1}, but got {2}",
2204                       callableDecl->getCallableType(), callArgs.size(),
2205                       arguments.size()),
2206         callableDecl->getLoc(),
2207         llvm::formatv("see the definition of {0} here",
2208                       callableDecl->getName()->getName()));
2209   }
2210 
2211   /// Handle argument type mismatch.
2212   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2213     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2214                                   callableDecl->getName()->getName()),
2215                     callableDecl->getLoc());
2216   };
2217   for (auto it : llvm::zip(callArgs, arguments)) {
2218     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2219                                    attachDiagFn)))
2220       return failure();
2221   }
2222 
2223   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2224                                callableDecl->getResultType());
2225 }
2226 
2227 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2228                                                         ast::Decl *decl) {
2229   // Check the type of decl being referenced.
2230   ast::Type declType;
2231   if (isa<ast::ConstraintDecl>(decl))
2232     declType = ast::ConstraintType::get(ctx);
2233   else if (isa<ast::UserRewriteDecl>(decl))
2234     declType = ast::RewriteType::get(ctx);
2235   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2236     declType = varDecl->getType();
2237   else
2238     return emitError(loc, "invalid reference to `" +
2239                               decl->getName()->getName() + "`");
2240 
2241   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2242 }
2243 
2244 FailureOr<ast::DeclRefExpr *>
2245 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2246                                  ArrayRef<ast::ConstraintRef> constraints) {
2247   FailureOr<ast::VariableDecl *> decl =
2248       defineVariableDecl(name, loc, type, constraints);
2249   if (failed(decl))
2250     return failure();
2251   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2252 }
2253 
2254 FailureOr<ast::MemberAccessExpr *>
2255 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2256                                SMRange loc) {
2257   // Validate the member name for the given parent expression.
2258   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2259   if (failed(memberType))
2260     return failure();
2261 
2262   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2263 }
2264 
2265 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2266                                                   StringRef name, SMRange loc) {
2267   ast::Type parentType = parentExpr->getType();
2268   if (parentType.isa<ast::OperationType>()) {
2269     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2270       return valueRangeTy;
2271   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2272     // Handle indexed results.
2273     unsigned index = 0;
2274     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2275         index < tupleType.size()) {
2276       return tupleType.getElementTypes()[index];
2277     }
2278 
2279     // Handle named results.
2280     auto elementNames = tupleType.getElementNames();
2281     const auto *it = llvm::find(elementNames, name);
2282     if (it != elementNames.end())
2283       return tupleType.getElementTypes()[it - elementNames.begin()];
2284   }
2285   return emitError(
2286       loc,
2287       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2288                     name, parentType));
2289 }
2290 
2291 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2292     SMRange loc, const ast::OpNameDecl *name,
2293     MutableArrayRef<ast::Expr *> operands,
2294     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2295     MutableArrayRef<ast::Expr *> results) {
2296   Optional<StringRef> opNameRef = name->getName();
2297 
2298   // Verify the inputs operands.
2299   if (failed(validateOperationOperands(loc, opNameRef, operands)))
2300     return failure();
2301 
2302   // Verify the attribute list.
2303   for (ast::NamedAttributeDecl *attr : attributes) {
2304     // Check for an attribute type, or a type awaiting resolution.
2305     ast::Type attrType = attr->getValue()->getType();
2306     if (!attrType.isa<ast::AttributeType>()) {
2307       return emitError(
2308           attr->getValue()->getLoc(),
2309           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2310     }
2311   }
2312 
2313   // Verify the result types.
2314   if (failed(validateOperationResults(loc, opNameRef, results)))
2315     return failure();
2316 
2317   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2318                                     attributes);
2319 }
2320 
2321 LogicalResult
2322 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2323                                   MutableArrayRef<ast::Expr *> operands) {
2324   return validateOperationOperandsOrResults(loc, name, operands, valueTy,
2325                                             valueRangeTy);
2326 }
2327 
2328 LogicalResult
2329 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2330                                  MutableArrayRef<ast::Expr *> results) {
2331   return validateOperationOperandsOrResults(loc, name, results, typeTy,
2332                                             typeRangeTy);
2333 }
2334 
2335 LogicalResult Parser::validateOperationOperandsOrResults(
2336     SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2337     ast::Type singleTy, ast::Type rangeTy) {
2338   // All operation types accept a single range parameter.
2339   if (values.size() == 1) {
2340     if (failed(convertExpressionTo(values[0], rangeTy)))
2341       return failure();
2342     return success();
2343   }
2344 
2345   // Otherwise, accept the value groups as they have been defined and just
2346   // ensure they are one of the expected types.
2347   for (ast::Expr *&valueExpr : values) {
2348     ast::Type valueExprType = valueExpr->getType();
2349 
2350     // Check if this is one of the expected types.
2351     if (valueExprType == rangeTy || valueExprType == singleTy)
2352       continue;
2353 
2354     // If the operand is an Operation, allow converting to a Value or
2355     // ValueRange. This situations arises quite often with nested operation
2356     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2357     if (singleTy == valueTy) {
2358       if (valueExprType.isa<ast::OperationType>()) {
2359         valueExpr = convertOpToValue(valueExpr);
2360         continue;
2361       }
2362     }
2363 
2364     return emitError(
2365         valueExpr->getLoc(),
2366         llvm::formatv(
2367             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2368             singleTy, rangeTy, valueExprType));
2369   }
2370   return success();
2371 }
2372 
2373 FailureOr<ast::TupleExpr *>
2374 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2375                         ArrayRef<StringRef> elementNames) {
2376   for (const ast::Expr *element : elements) {
2377     ast::Type eleTy = element->getType();
2378     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2379       return emitError(
2380           element->getLoc(),
2381           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2382     }
2383   }
2384   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2385 }
2386 
2387 //===----------------------------------------------------------------------===//
2388 // Stmts
2389 
2390 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2391                                                     ast::Expr *rootOp) {
2392   // Check that root is an Operation.
2393   ast::Type rootType = rootOp->getType();
2394   if (!rootType.isa<ast::OperationType>())
2395     return emitError(rootOp->getLoc(), "expected `Op` expression");
2396 
2397   return ast::EraseStmt::create(ctx, loc, rootOp);
2398 }
2399 
2400 FailureOr<ast::ReplaceStmt *>
2401 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2402                           MutableArrayRef<ast::Expr *> replValues) {
2403   // Check that root is an Operation.
2404   ast::Type rootType = rootOp->getType();
2405   if (!rootType.isa<ast::OperationType>()) {
2406     return emitError(
2407         rootOp->getLoc(),
2408         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2409   }
2410 
2411   // If there are multiple replacement values, we implicitly convert any Op
2412   // expressions to the value form.
2413   bool shouldConvertOpToValues = replValues.size() > 1;
2414   for (ast::Expr *&replExpr : replValues) {
2415     ast::Type replType = replExpr->getType();
2416 
2417     // Check that replExpr is an Operation, Value, or ValueRange.
2418     if (replType.isa<ast::OperationType>()) {
2419       if (shouldConvertOpToValues)
2420         replExpr = convertOpToValue(replExpr);
2421       continue;
2422     }
2423 
2424     if (replType != valueTy && replType != valueRangeTy) {
2425       return emitError(replExpr->getLoc(),
2426                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2427                                      "expression, but got `{0}`",
2428                                      replType));
2429     }
2430   }
2431 
2432   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2433 }
2434 
2435 FailureOr<ast::RewriteStmt *>
2436 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2437                           ast::CompoundStmt *rewriteBody) {
2438   // Check that root is an Operation.
2439   ast::Type rootType = rootOp->getType();
2440   if (!rootType.isa<ast::OperationType>()) {
2441     return emitError(
2442         rootOp->getLoc(),
2443         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2444   }
2445 
2446   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2447 }
2448 
2449 //===----------------------------------------------------------------------===//
2450 // Parser
2451 //===----------------------------------------------------------------------===//
2452 
2453 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
2454                                                  llvm::SourceMgr &sourceMgr) {
2455   Parser parser(ctx, sourceMgr);
2456   return parser.parseModule();
2457 }
2458