1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/Tools/PDLL/AST/Context.h"
13 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
14 #include "mlir/Tools/PDLL/AST/Nodes.h"
15 #include "mlir/Tools/PDLL/AST/Types.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/SaveAndRestore.h"
20 #include "llvm/Support/ScopedPrinter.h"
21 #include <string>
22 
23 using namespace mlir;
24 using namespace mlir::pdll;
25 
26 //===----------------------------------------------------------------------===//
27 // Parser
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 class Parser {
32 public:
33   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
34       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
35         curToken(lexer.lexToken()), curDeclScope(nullptr),
36         valueTy(ast::ValueType::get(ctx)),
37         valueRangeTy(ast::ValueRangeType::get(ctx)),
38         typeTy(ast::TypeType::get(ctx)),
39         typeRangeTy(ast::TypeRangeType::get(ctx)) {}
40 
41   /// Try to parse a new module. Returns nullptr in the case of failure.
42   FailureOr<ast::Module *> parseModule();
43 
44 private:
45   /// The current context of the parser. It allows for the parser to know a bit
46   /// about the construct it is nested within during parsing. This is used
47   /// specifically to provide additional verification during parsing, e.g. to
48   /// prevent using rewrites within a match context, matcher constraints within
49   /// a rewrite section, etc.
50   enum class ParserContext {
51     /// The parser is in the global context.
52     Global,
53     /// The parser is currently within a Constraint, which disallows all types
54     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
55     Constraint,
56     /// The parser is currently within the matcher portion of a Pattern, which
57     /// is allows a terminal operation rewrite statement but no other rewrite
58     /// transformations.
59     PatternMatch,
60     /// The parser is currently within a Rewrite, which disallows calls to
61     /// constraints, requires operation expressions to have names, etc.
62     Rewrite,
63   };
64 
65   //===--------------------------------------------------------------------===//
66   // Parsing
67   //===--------------------------------------------------------------------===//
68 
69   /// Push a new decl scope onto the lexer.
70   ast::DeclScope *pushDeclScope() {
71     ast::DeclScope *newScope =
72         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
73     return (curDeclScope = newScope);
74   }
75   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
76 
77   /// Pop the last decl scope from the lexer.
78   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
79 
80   /// Parse the body of an AST module.
81   LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
82 
83   /// Try to convert the given expression to `type`. Returns failure and emits
84   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
85   /// invoked to attach notes to the emitted error diagnostic. On success,
86   /// `expr` is updated to the expression used to convert to `type`.
87   LogicalResult convertExpressionTo(
88       ast::Expr *&expr, ast::Type type,
89       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
90 
91   /// Given an operation expression, convert it to a Value or ValueRange
92   /// typed expression.
93   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
94 
95   //===--------------------------------------------------------------------===//
96   // Directives
97 
98   LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
99   LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
100 
101   //===--------------------------------------------------------------------===//
102   // Decls
103 
104   /// This structure contains the set of pattern metadata that may be parsed.
105   struct ParsedPatternMetadata {
106     Optional<uint16_t> benefit;
107     bool hasBoundedRecursion = false;
108   };
109 
110   FailureOr<ast::Decl *> parseTopLevelDecl();
111   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
112 
113   /// Parse an argument variable as part of the signature of a
114   /// UserConstraintDecl or UserRewriteDecl.
115   FailureOr<ast::VariableDecl *> parseArgumentDecl();
116 
117   /// Parse a result variable as part of the signature of a UserConstraintDecl
118   /// or UserRewriteDecl.
119   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
120 
121   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
122   /// defined in a non-global context.
123   FailureOr<ast::UserConstraintDecl *>
124   parseUserConstraintDecl(bool isInline = false);
125 
126   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
127   /// non-global context, such as within a Pattern/Constraint/etc.
128   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
129 
130   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
131   /// PDLL constructs.
132   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
133       const ast::Name &name, bool isInline,
134       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
135       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
136 
137   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
138   /// defined in a non-global context.
139   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
140 
141   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
142   /// non-global context, such as within a Pattern/Rewrite/etc.
143   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
144 
145   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
146   /// PDLL constructs.
147   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
148       const ast::Name &name, bool isInline,
149       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
150       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
151 
152   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
153   /// effectively the same syntax, and only differ on slight semantics (given
154   /// the different parsing contexts).
155   template <typename T, typename ParseUserPDLLDeclFnT>
156   FailureOr<T *> parseUserConstraintOrRewriteDecl(
157       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
158       StringRef anonymousNamePrefix, bool isInline);
159 
160   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
161   /// These decls have effectively the same syntax.
162   template <typename T>
163   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
164       const ast::Name &name, bool isInline,
165       ArrayRef<ast::VariableDecl *> arguments,
166       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
167 
168   /// Parse the functional signature (i.e. the arguments and results) of a
169   /// UserConstraintDecl or UserRewriteDecl.
170   LogicalResult parseUserConstraintOrRewriteSignature(
171       SmallVectorImpl<ast::VariableDecl *> &arguments,
172       SmallVectorImpl<ast::VariableDecl *> &results,
173       ast::DeclScope *&argumentScope, ast::Type &resultType);
174 
175   /// Validate the return (which if present is specified by bodyIt) of a
176   /// UserConstraintDecl or UserRewriteDecl.
177   LogicalResult validateUserConstraintOrRewriteReturn(
178       StringRef declType, ast::CompoundStmt *body,
179       ArrayRef<ast::Stmt *>::iterator bodyIt,
180       ArrayRef<ast::Stmt *>::iterator bodyE,
181       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
182 
183   FailureOr<ast::CompoundStmt *>
184   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
185                   bool expectTerminalSemicolon = true);
186   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
187   FailureOr<ast::Decl *> parsePatternDecl();
188   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
189 
190   /// Check to see if a decl has already been defined with the given name, if
191   /// one has emit and error and return failure. Returns success otherwise.
192   LogicalResult checkDefineNamedDecl(const ast::Name &name);
193 
194   /// Try to define a variable decl with the given components, returns the
195   /// variable on success.
196   FailureOr<ast::VariableDecl *>
197   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
198                      ast::Expr *initExpr,
199                      ArrayRef<ast::ConstraintRef> constraints);
200   FailureOr<ast::VariableDecl *>
201   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
202                      ArrayRef<ast::ConstraintRef> constraints);
203 
204   /// Parse the constraint reference list for a variable decl.
205   LogicalResult parseVariableDeclConstraintList(
206       SmallVectorImpl<ast::ConstraintRef> &constraints);
207 
208   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
209   FailureOr<ast::Expr *> parseTypeConstraintExpr();
210 
211   /// Try to parse a single reference to a constraint. `typeConstraint` is the
212   /// location of a previously parsed type constraint for the entity that will
213   /// be constrained by the parsed constraint. `existingConstraints` are any
214   /// existing constraints that have already been parsed for the same entity
215   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
216   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
217   FailureOr<ast::ConstraintRef>
218   parseConstraint(Optional<SMRange> &typeConstraint,
219                   ArrayRef<ast::ConstraintRef> existingConstraints,
220                   bool allowInlineTypeConstraints);
221 
222   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
223   /// argument or result variable. The constraints for these variables do not
224   /// allow inline type constraints, and only permit a single constraint.
225   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
226 
227   //===--------------------------------------------------------------------===//
228   // Exprs
229 
230   FailureOr<ast::Expr *> parseExpr();
231 
232   /// Identifier expressions.
233   FailureOr<ast::Expr *> parseAttributeExpr();
234   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
235   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
236   FailureOr<ast::Expr *> parseIdentifierExpr();
237   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
238   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
239   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
240   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
241   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
242   FailureOr<ast::Expr *> parseOperationExpr();
243   FailureOr<ast::Expr *> parseTupleExpr();
244   FailureOr<ast::Expr *> parseTypeExpr();
245   FailureOr<ast::Expr *> parseUnderscoreExpr();
246 
247   //===--------------------------------------------------------------------===//
248   // Stmts
249 
250   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
251   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
252   FailureOr<ast::EraseStmt *> parseEraseStmt();
253   FailureOr<ast::LetStmt *> parseLetStmt();
254   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
255   FailureOr<ast::ReturnStmt *> parseReturnStmt();
256   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
257 
258   //===--------------------------------------------------------------------===//
259   // Creation+Analysis
260   //===--------------------------------------------------------------------===//
261 
262   //===--------------------------------------------------------------------===//
263   // Decls
264 
265   /// Try to extract a callable from the given AST node. Returns nullptr on
266   /// failure.
267   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
268 
269   /// Try to create a pattern decl with the given components, returning the
270   /// Pattern on success.
271   FailureOr<ast::PatternDecl *>
272   createPatternDecl(SMRange loc, const ast::Name *name,
273                     const ParsedPatternMetadata &metadata,
274                     ast::CompoundStmt *body);
275 
276   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
277   /// of results, defined as part of the signature.
278   ast::Type
279   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
280 
281   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
282   template <typename T>
283   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
284       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
285       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
286       ast::CompoundStmt *body);
287 
288   /// Try to create a variable decl with the given components, returning the
289   /// Variable on success.
290   FailureOr<ast::VariableDecl *>
291   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
292                      ArrayRef<ast::ConstraintRef> constraints);
293 
294   /// Create a variable for an argument or result defined as part of the
295   /// signature of a UserConstraintDecl/UserRewriteDecl.
296   FailureOr<ast::VariableDecl *>
297   createArgOrResultVariableDecl(StringRef name, SMRange loc,
298                                 const ast::ConstraintRef &constraint);
299 
300   /// Validate the constraints used to constraint a variable decl.
301   /// `inferredType` is the type of the variable inferred by the constraints
302   /// within the list, and is updated to the most refined type as determined by
303   /// the constraints. Returns success if the constraint list is valid, failure
304   /// otherwise.
305   LogicalResult
306   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
307                               ast::Type &inferredType);
308   /// Validate a single reference to a constraint. `inferredType` contains the
309   /// currently inferred variabled type and is refined within the type defined
310   /// by the constraint. Returns success if the constraint is valid, failure
311   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
312   /// defined constraints) may be used with the variable.
313   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
314                                            ast::Type &inferredType,
315                                            bool allowNonCoreConstraints = true);
316   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
317   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
318 
319   //===--------------------------------------------------------------------===//
320   // Exprs
321 
322   FailureOr<ast::CallExpr *>
323   createCallExpr(SMRange loc, ast::Expr *parentExpr,
324                  MutableArrayRef<ast::Expr *> arguments);
325   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
326   FailureOr<ast::DeclRefExpr *>
327   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
328                            ArrayRef<ast::ConstraintRef> constraints);
329   FailureOr<ast::MemberAccessExpr *>
330   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
331 
332   /// Validate the member access `name` into the given parent expression. On
333   /// success, this also returns the type of the member accessed.
334   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
335                                             StringRef name, SMRange loc);
336   FailureOr<ast::OperationExpr *>
337   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
338                       MutableArrayRef<ast::Expr *> operands,
339                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
340                       MutableArrayRef<ast::Expr *> results);
341   LogicalResult
342   validateOperationOperands(SMRange loc, Optional<StringRef> name,
343                             MutableArrayRef<ast::Expr *> operands);
344   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
345                                          MutableArrayRef<ast::Expr *> results);
346   LogicalResult
347   validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
348                                      MutableArrayRef<ast::Expr *> values,
349                                      ast::Type singleTy, ast::Type rangeTy);
350   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
351                                               ArrayRef<ast::Expr *> elements,
352                                               ArrayRef<StringRef> elementNames);
353 
354   //===--------------------------------------------------------------------===//
355   // Stmts
356 
357   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
358   FailureOr<ast::ReplaceStmt *>
359   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
360                     MutableArrayRef<ast::Expr *> replValues);
361   FailureOr<ast::RewriteStmt *>
362   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
363                     ast::CompoundStmt *rewriteBody);
364 
365   //===--------------------------------------------------------------------===//
366   // Lexer Utilities
367   //===--------------------------------------------------------------------===//
368 
369   /// If the current token has the specified kind, consume it and return true.
370   /// If not, return false.
371   bool consumeIf(Token::Kind kind) {
372     if (curToken.isNot(kind))
373       return false;
374     consumeToken(kind);
375     return true;
376   }
377 
378   /// Advance the current lexer onto the next token.
379   void consumeToken() {
380     assert(curToken.isNot(Token::eof, Token::error) &&
381            "shouldn't advance past EOF or errors");
382     curToken = lexer.lexToken();
383   }
384 
385   /// Advance the current lexer onto the next token, asserting what the expected
386   /// current token is. This is preferred to the above method because it leads
387   /// to more self-documenting code with better checking.
388   void consumeToken(Token::Kind kind) {
389     assert(curToken.is(kind) && "consumed an unexpected token");
390     consumeToken();
391   }
392 
393   /// Reset the lexer to the location at the given position.
394   void resetToken(SMRange tokLoc) {
395     lexer.resetPointer(tokLoc.Start.getPointer());
396     curToken = lexer.lexToken();
397   }
398 
399   /// Consume the specified token if present and return success. On failure,
400   /// output a diagnostic and return failure.
401   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
402     if (curToken.getKind() != kind)
403       return emitError(curToken.getLoc(), msg);
404     consumeToken();
405     return success();
406   }
407   LogicalResult emitError(SMRange loc, const Twine &msg) {
408     lexer.emitError(loc, msg);
409     return failure();
410   }
411   LogicalResult emitError(const Twine &msg) {
412     return emitError(curToken.getLoc(), msg);
413   }
414   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
415                                  const Twine &note) {
416     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
417     return failure();
418   }
419 
420   //===--------------------------------------------------------------------===//
421   // Fields
422   //===--------------------------------------------------------------------===//
423 
424   /// The owning AST context.
425   ast::Context &ctx;
426 
427   /// The lexer of this parser.
428   Lexer lexer;
429 
430   /// The current token within the lexer.
431   Token curToken;
432 
433   /// The most recently defined decl scope.
434   ast::DeclScope *curDeclScope;
435   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
436 
437   /// The current context of the parser.
438   ParserContext parserContext = ParserContext::Global;
439 
440   /// Cached types to simplify verification and expression creation.
441   ast::Type valueTy, valueRangeTy;
442   ast::Type typeTy, typeRangeTy;
443 
444   /// A counter used when naming anonymous constraints and rewrites.
445   unsigned anonymousDeclNameCounter = 0;
446 };
447 } // namespace
448 
449 FailureOr<ast::Module *> Parser::parseModule() {
450   SMLoc moduleLoc = curToken.getStartLoc();
451   pushDeclScope();
452 
453   // Parse the top-level decls of the module.
454   SmallVector<ast::Decl *> decls;
455   if (failed(parseModuleBody(decls)))
456     return popDeclScope(), failure();
457 
458   popDeclScope();
459   return ast::Module::create(ctx, moduleLoc, decls);
460 }
461 
462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
463   while (curToken.isNot(Token::eof)) {
464     if (curToken.is(Token::directive)) {
465       if (failed(parseDirective(decls)))
466         return failure();
467       continue;
468     }
469 
470     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
471     if (failed(decl))
472       return failure();
473     decls.push_back(*decl);
474   }
475   return success();
476 }
477 
478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
479   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
480                                                  valueRangeTy);
481 }
482 
483 LogicalResult Parser::convertExpressionTo(
484     ast::Expr *&expr, ast::Type type,
485     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
486   ast::Type exprType = expr->getType();
487   if (exprType == type)
488     return success();
489 
490   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
491     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
492         expr->getLoc(), llvm::formatv("unable to convert expression of type "
493                                       "`{0}` to the expected type of "
494                                       "`{1}`",
495                                       exprType, type));
496     if (noteAttachFn)
497       noteAttachFn(*diag);
498     return diag;
499   };
500 
501   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
502     // Two operation types are compatible if they have the same name, or if the
503     // expected type is more general.
504     if (auto opType = type.dyn_cast<ast::OperationType>()) {
505       if (opType.getName())
506         return emitConvertError();
507       return success();
508     }
509 
510     // An operation can always convert to a ValueRange.
511     if (type == valueRangeTy) {
512       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
513                                                      valueRangeTy);
514       return success();
515     }
516 
517     // Allow conversion to a single value by constraining the result range.
518     if (type == valueTy) {
519       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
520                                                      valueTy);
521       return success();
522     }
523     return emitConvertError();
524   }
525 
526   // FIXME: Decide how to allow/support converting a single result to multiple,
527   // and multiple to a single result. For now, we just allow Single->Range,
528   // but this isn't something really supported in the PDL dialect. We should
529   // figure out some way to support both.
530   if ((exprType == valueTy || exprType == valueRangeTy) &&
531       (type == valueTy || type == valueRangeTy))
532     return success();
533   if ((exprType == typeTy || exprType == typeRangeTy) &&
534       (type == typeTy || type == typeRangeTy))
535     return success();
536 
537   // Handle tuple types.
538   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
539     auto tupleType = type.dyn_cast<ast::TupleType>();
540     if (!tupleType || tupleType.size() != exprTupleType.size())
541       return emitConvertError();
542 
543     // Build a new tuple expression using each of the elements of the current
544     // tuple.
545     SmallVector<ast::Expr *> newExprs;
546     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
547       newExprs.push_back(ast::MemberAccessExpr::create(
548           ctx, expr->getLoc(), expr, llvm::to_string(i),
549           exprTupleType.getElementTypes()[i]));
550 
551       auto diagFn = [&](ast::Diagnostic &diag) {
552         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
553                                       i, exprTupleType));
554         if (noteAttachFn)
555           noteAttachFn(diag);
556       };
557       if (failed(convertExpressionTo(newExprs.back(),
558                                      tupleType.getElementTypes()[i], diagFn)))
559         return failure();
560     }
561     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
562                                   tupleType.getElementNames());
563     return success();
564   }
565 
566   return emitConvertError();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // Directives
571 
572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
573   StringRef directive = curToken.getSpelling();
574   if (directive == "#include")
575     return parseInclude(decls);
576 
577   return emitError("unknown directive `" + directive + "`");
578 }
579 
580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
581   SMRange loc = curToken.getLoc();
582   consumeToken(Token::directive);
583 
584   // Parse the file being included.
585   if (!curToken.isString())
586     return emitError(loc,
587                      "expected string file name after `include` directive");
588   SMRange fileLoc = curToken.getLoc();
589   std::string filenameStr = curToken.getStringValue();
590   StringRef filename = filenameStr;
591   consumeToken();
592 
593   // Check the type of include. If ending with `.pdll`, this is another pdl file
594   // to be parsed along with the current module.
595   if (filename.endswith(".pdll")) {
596     if (failed(lexer.pushInclude(filename)))
597       return emitError(fileLoc,
598                        "unable to open include file `" + filename + "`");
599 
600     // If we added the include successfully, parse it into the current module.
601     // Make sure to save the current token so that we can restore it when we
602     // finish parsing the nested file.
603     Token oldToken = curToken;
604     curToken = lexer.lexToken();
605     LogicalResult result = parseModuleBody(decls);
606     curToken = oldToken;
607     return result;
608   }
609 
610   return emitError(fileLoc, "expected include filename to end with `.pdll`");
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // Decls
615 
616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
617   FailureOr<ast::Decl *> decl;
618   switch (curToken.getKind()) {
619   case Token::kw_Constraint:
620     decl = parseUserConstraintDecl();
621     break;
622   case Token::kw_Pattern:
623     decl = parsePatternDecl();
624     break;
625   case Token::kw_Rewrite:
626     decl = parseUserRewriteDecl();
627     break;
628   default:
629     return emitError("expected top-level declaration, such as a `Pattern`");
630   }
631   if (failed(decl))
632     return failure();
633 
634   // If the decl has a name, add it to the current scope.
635   if (const ast::Name *name = (*decl)->getName()) {
636     if (failed(checkDefineNamedDecl(*name)))
637       return failure();
638     curDeclScope->add(*decl);
639   }
640   return decl;
641 }
642 
643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
644   std::string attrNameStr;
645   if (curToken.isString())
646     attrNameStr = curToken.getStringValue();
647   else if (curToken.is(Token::identifier) || curToken.isKeyword())
648     attrNameStr = curToken.getSpelling().str();
649   else
650     return emitError("expected identifier or string attribute name");
651   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
652   consumeToken();
653 
654   // Check for a value of the attribute.
655   ast::Expr *attrValue = nullptr;
656   if (consumeIf(Token::equal)) {
657     FailureOr<ast::Expr *> attrExpr = parseExpr();
658     if (failed(attrExpr))
659       return failure();
660     attrValue = *attrExpr;
661   } else {
662     // If there isn't a concrete value, create an expression representing a
663     // UnitAttr.
664     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
665   }
666 
667   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
668 }
669 
670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
671     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
672     bool expectTerminalSemicolon) {
673   consumeToken(Token::equal_arrow);
674 
675   // Parse the single statement of the lambda body.
676   SMLoc bodyStartLoc = curToken.getStartLoc();
677   pushDeclScope();
678   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
679   bool failedToParse =
680       failed(singleStatement) || failed(processStatementFn(*singleStatement));
681   popDeclScope();
682   if (failedToParse)
683     return failure();
684 
685   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
686   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
687 }
688 
689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
690   // Ensure that the argument is named.
691   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
692     return emitError("expected identifier argument name");
693 
694   // Parse the argument similarly to a normal variable.
695   StringRef name = curToken.getSpelling();
696   SMRange nameLoc = curToken.getLoc();
697   consumeToken();
698 
699   if (failed(
700           parseToken(Token::colon, "expected `:` before argument constraint")))
701     return failure();
702 
703   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
704   if (failed(cst))
705     return failure();
706 
707   return createArgOrResultVariableDecl(name, nameLoc, *cst);
708 }
709 
710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
711   // Check to see if this result is named.
712   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
713     // Check to see if this name actually refers to a Constraint.
714     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
715     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
716       // If yes, and this is a Rewrite, give a nice error message as non-Core
717       // constraints are not supported on Rewrite results.
718       if (parserContext == ParserContext::Rewrite) {
719         return emitError(
720             "`Rewrite` results are only permitted to use core constraints, "
721             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
722       }
723 
724       // Otherwise, parse this as an unnamed result variable.
725     } else {
726       // If it wasn't a constraint, parse the result similarly to a variable. If
727       // there is already an existing decl, we will emit an error when defining
728       // this variable later.
729       StringRef name = curToken.getSpelling();
730       SMRange nameLoc = curToken.getLoc();
731       consumeToken();
732 
733       if (failed(parseToken(Token::colon,
734                             "expected `:` before result constraint")))
735         return failure();
736 
737       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
738       if (failed(cst))
739         return failure();
740 
741       return createArgOrResultVariableDecl(name, nameLoc, *cst);
742     }
743   }
744 
745   // If it isn't named, we parse the constraint directly and create an unnamed
746   // result variable.
747   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
748   if (failed(cst))
749     return failure();
750 
751   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
752 }
753 
754 FailureOr<ast::UserConstraintDecl *>
755 Parser::parseUserConstraintDecl(bool isInline) {
756   // Constraints and rewrites have very similar formats, dispatch to a shared
757   // interface for parsing.
758   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
759       [&](auto &&...args) {
760         return this->parseUserPDLLConstraintDecl(args...);
761       },
762       ParserContext::Constraint, "constraint", isInline);
763 }
764 
765 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
766   FailureOr<ast::UserConstraintDecl *> decl =
767       parseUserConstraintDecl(/*isInline=*/true);
768   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
769     return failure();
770 
771   curDeclScope->add(*decl);
772   return decl;
773 }
774 
775 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
776     const ast::Name &name, bool isInline,
777     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
778     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
779   // Push the argument scope back onto the list, so that the body can
780   // reference arguments.
781   pushDeclScope(argumentScope);
782 
783   // Parse the body of the constraint. The body is either defined as a compound
784   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
785   ast::CompoundStmt *body;
786   if (curToken.is(Token::equal_arrow)) {
787     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
788         [&](ast::Stmt *&stmt) -> LogicalResult {
789           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
790           if (!stmtExpr) {
791             return emitError(stmt->getLoc(),
792                              "expected `Constraint` lambda body to contain a "
793                              "single expression");
794           }
795           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
796           return success();
797         },
798         /*expectTerminalSemicolon=*/!isInline);
799     if (failed(bodyResult))
800       return failure();
801     body = *bodyResult;
802   } else {
803     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
804     if (failed(bodyResult))
805       return failure();
806     body = *bodyResult;
807 
808     // Verify the structure of the body.
809     auto bodyIt = body->begin(), bodyE = body->end();
810     for (; bodyIt != bodyE; ++bodyIt)
811       if (isa<ast::ReturnStmt>(*bodyIt))
812         break;
813     if (failed(validateUserConstraintOrRewriteReturn(
814             "Constraint", body, bodyIt, bodyE, results, resultType)))
815       return failure();
816   }
817   popDeclScope();
818 
819   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
820       name, arguments, results, resultType, body);
821 }
822 
823 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
824   // Constraints and rewrites have very similar formats, dispatch to a shared
825   // interface for parsing.
826   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
827       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
828       ParserContext::Rewrite, "rewrite", isInline);
829 }
830 
831 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
832   FailureOr<ast::UserRewriteDecl *> decl =
833       parseUserRewriteDecl(/*isInline=*/true);
834   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
835     return failure();
836 
837   curDeclScope->add(*decl);
838   return decl;
839 }
840 
841 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
842     const ast::Name &name, bool isInline,
843     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
844     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
845   // Push the argument scope back onto the list, so that the body can
846   // reference arguments.
847   curDeclScope = argumentScope;
848   ast::CompoundStmt *body;
849   if (curToken.is(Token::equal_arrow)) {
850     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
851         [&](ast::Stmt *&statement) -> LogicalResult {
852           if (isa<ast::OpRewriteStmt>(statement))
853             return success();
854 
855           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
856           if (!statementExpr) {
857             return emitError(
858                 statement->getLoc(),
859                 "expected `Rewrite` lambda body to contain a single expression "
860                 "or an operation rewrite statement; such as `erase`, "
861                 "`replace`, or `rewrite`");
862           }
863           statement =
864               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
865           return success();
866         },
867         /*expectTerminalSemicolon=*/!isInline);
868     if (failed(bodyResult))
869       return failure();
870     body = *bodyResult;
871   } else {
872     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
873     if (failed(bodyResult))
874       return failure();
875     body = *bodyResult;
876   }
877   popDeclScope();
878 
879   // Verify the structure of the body.
880   auto bodyIt = body->begin(), bodyE = body->end();
881   for (; bodyIt != bodyE; ++bodyIt)
882     if (isa<ast::ReturnStmt>(*bodyIt))
883       break;
884   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
885                                                    bodyE, results, resultType)))
886     return failure();
887   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
888       name, arguments, results, resultType, body);
889 }
890 
891 template <typename T, typename ParseUserPDLLDeclFnT>
892 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
893     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
894     StringRef anonymousNamePrefix, bool isInline) {
895   SMRange loc = curToken.getLoc();
896   consumeToken();
897   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
898 
899   // Parse the name of the decl.
900   const ast::Name *name = nullptr;
901   if (curToken.isNot(Token::identifier)) {
902     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
903     // in C++, so being unnamed is fine.
904     if (!isInline)
905       return emitError("expected identifier name");
906 
907     // Create a unique anonymous name to use, as the name for this decl is not
908     // important.
909     std::string anonName =
910         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
911                       anonymousDeclNameCounter++)
912             .str();
913     name = &ast::Name::create(ctx, anonName, loc);
914   } else {
915     // If a name was provided, we can use it directly.
916     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
917     consumeToken(Token::identifier);
918   }
919 
920   // Parse the functional signature of the decl.
921   SmallVector<ast::VariableDecl *> arguments, results;
922   ast::DeclScope *argumentScope;
923   ast::Type resultType;
924   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
925                                                    argumentScope, resultType)))
926     return failure();
927 
928   // Check to see which type of constraint this is. If the constraint contains a
929   // compound body, this is a PDLL decl.
930   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
931     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
932                            resultType);
933 
934   // Otherwise, this is a native decl.
935   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
936                                                    results, resultType);
937 }
938 
939 template <typename T>
940 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
941     const ast::Name &name, bool isInline,
942     ArrayRef<ast::VariableDecl *> arguments,
943     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
944   // If followed by a string, the native code body has also been specified.
945   std::string codeStrStorage;
946   Optional<StringRef> optCodeStr;
947   if (curToken.isString()) {
948     codeStrStorage = curToken.getStringValue();
949     optCodeStr = codeStrStorage;
950     consumeToken();
951   } else if (isInline) {
952     return emitError(name.getLoc(),
953                      "external declarations must be declared in global scope");
954   }
955   if (failed(parseToken(Token::semicolon,
956                         "expected `;` after native declaration")))
957     return failure();
958   // TODO: PDL should be able to support constraint results in certain
959   // situations, we should revise this.
960   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
961     return emitError(
962         "native Constraints currently do not support returning results");
963   }
964   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
965 }
966 
967 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
968     SmallVectorImpl<ast::VariableDecl *> &arguments,
969     SmallVectorImpl<ast::VariableDecl *> &results,
970     ast::DeclScope *&argumentScope, ast::Type &resultType) {
971   // Parse the argument list of the decl.
972   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
973     return failure();
974 
975   argumentScope = pushDeclScope();
976   if (curToken.isNot(Token::r_paren)) {
977     do {
978       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
979       if (failed(argument))
980         return failure();
981       arguments.emplace_back(*argument);
982     } while (consumeIf(Token::comma));
983   }
984   popDeclScope();
985   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
986     return failure();
987 
988   // Parse the results of the decl.
989   pushDeclScope();
990   if (consumeIf(Token::arrow)) {
991     auto parseResultFn = [&]() -> LogicalResult {
992       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
993       if (failed(result))
994         return failure();
995       results.emplace_back(*result);
996       return success();
997     };
998 
999     // Check for a list of results.
1000     if (consumeIf(Token::l_paren)) {
1001       do {
1002         if (failed(parseResultFn()))
1003           return failure();
1004       } while (consumeIf(Token::comma));
1005       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1006         return failure();
1007 
1008       // Otherwise, there is only one result.
1009     } else if (failed(parseResultFn())) {
1010       return failure();
1011     }
1012   }
1013   popDeclScope();
1014 
1015   // Compute the result type of the decl.
1016   resultType = createUserConstraintRewriteResultType(results);
1017 
1018   // Verify that results are only named if there are more than one.
1019   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1020     return emitError(
1021         results.front()->getLoc(),
1022         "cannot create a single-element tuple with an element label");
1023   }
1024   return success();
1025 }
1026 
1027 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1028     StringRef declType, ast::CompoundStmt *body,
1029     ArrayRef<ast::Stmt *>::iterator bodyIt,
1030     ArrayRef<ast::Stmt *>::iterator bodyE,
1031     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1032   // Handle if a `return` was provided.
1033   if (bodyIt != bodyE) {
1034     // Emit an error if we have trailing statements after the return.
1035     if (std::next(bodyIt) != bodyE) {
1036       return emitError(
1037           (*std::next(bodyIt))->getLoc(),
1038           llvm::formatv("`return` terminated the `{0}` body, but found "
1039                         "trailing statements afterwards",
1040                         declType));
1041     }
1042 
1043     // Otherwise if a return wasn't provided, check that no results are
1044     // expected.
1045   } else if (!results.empty()) {
1046     return emitError(
1047         {body->getLoc().End, body->getLoc().End},
1048         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1049                       declType, resultType));
1050   }
1051   return success();
1052 }
1053 
1054 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1055   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1056     if (isa<ast::OpRewriteStmt>(statement))
1057       return success();
1058     return emitError(
1059         statement->getLoc(),
1060         "expected Pattern lambda body to contain a single operation "
1061         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1062   });
1063 }
1064 
1065 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1066   SMRange loc = curToken.getLoc();
1067   consumeToken(Token::kw_Pattern);
1068   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1069                                               ParserContext::PatternMatch);
1070 
1071   // Check for an optional identifier for the pattern name.
1072   const ast::Name *name = nullptr;
1073   if (curToken.is(Token::identifier)) {
1074     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1075     consumeToken(Token::identifier);
1076   }
1077 
1078   // Parse any pattern metadata.
1079   ParsedPatternMetadata metadata;
1080   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1081     return failure();
1082 
1083   // Parse the pattern body.
1084   ast::CompoundStmt *body;
1085 
1086   // Handle a lambda body.
1087   if (curToken.is(Token::equal_arrow)) {
1088     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1089     if (failed(bodyResult))
1090       return failure();
1091     body = *bodyResult;
1092   } else {
1093     if (curToken.isNot(Token::l_brace))
1094       return emitError("expected `{` or `=>` to start pattern body");
1095     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1096     if (failed(bodyResult))
1097       return failure();
1098     body = *bodyResult;
1099 
1100     // Verify the body of the pattern.
1101     auto bodyIt = body->begin(), bodyE = body->end();
1102     for (; bodyIt != bodyE; ++bodyIt) {
1103       if (isa<ast::ReturnStmt>(*bodyIt)) {
1104         return emitError((*bodyIt)->getLoc(),
1105                          "`return` statements are only permitted within a "
1106                          "`Constraint` or `Rewrite` body");
1107       }
1108       // Break when we've found the rewrite statement.
1109       if (isa<ast::OpRewriteStmt>(*bodyIt))
1110         break;
1111     }
1112     if (bodyIt == bodyE) {
1113       return emitError(loc,
1114                        "expected Pattern body to terminate with an operation "
1115                        "rewrite statement, such as `erase`");
1116     }
1117     if (std::next(bodyIt) != bodyE) {
1118       return emitError((*std::next(bodyIt))->getLoc(),
1119                        "Pattern body was terminated by an operation "
1120                        "rewrite statement, but found trailing statements");
1121     }
1122   }
1123 
1124   return createPatternDecl(loc, name, metadata, body);
1125 }
1126 
1127 LogicalResult
1128 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1129   Optional<SMRange> benefitLoc;
1130   Optional<SMRange> hasBoundedRecursionLoc;
1131 
1132   do {
1133     if (curToken.isNot(Token::identifier))
1134       return emitError("expected pattern metadata identifier");
1135     StringRef metadataStr = curToken.getSpelling();
1136     SMRange metadataLoc = curToken.getLoc();
1137     consumeToken(Token::identifier);
1138 
1139     // Parse the benefit metadata: benefit(<integer-value>)
1140     if (metadataStr == "benefit") {
1141       if (benefitLoc) {
1142         return emitErrorAndNote(metadataLoc,
1143                                 "pattern benefit has already been specified",
1144                                 *benefitLoc, "see previous definition here");
1145       }
1146       if (failed(parseToken(Token::l_paren,
1147                             "expected `(` before pattern benefit")))
1148         return failure();
1149 
1150       uint16_t benefitValue = 0;
1151       if (curToken.isNot(Token::integer))
1152         return emitError("expected integral pattern benefit");
1153       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1154         return emitError(
1155             "expected pattern benefit to fit within a 16-bit integer");
1156       consumeToken(Token::integer);
1157 
1158       metadata.benefit = benefitValue;
1159       benefitLoc = metadataLoc;
1160 
1161       if (failed(
1162               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1163         return failure();
1164       continue;
1165     }
1166 
1167     // Parse the bounded recursion metadata: recursion
1168     if (metadataStr == "recursion") {
1169       if (hasBoundedRecursionLoc) {
1170         return emitErrorAndNote(
1171             metadataLoc,
1172             "pattern recursion metadata has already been specified",
1173             *hasBoundedRecursionLoc, "see previous definition here");
1174       }
1175       metadata.hasBoundedRecursion = true;
1176       hasBoundedRecursionLoc = metadataLoc;
1177       continue;
1178     }
1179 
1180     return emitError(metadataLoc, "unknown pattern metadata");
1181   } while (consumeIf(Token::comma));
1182 
1183   return success();
1184 }
1185 
1186 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1187   consumeToken(Token::less);
1188 
1189   FailureOr<ast::Expr *> typeExpr = parseExpr();
1190   if (failed(typeExpr) ||
1191       failed(parseToken(Token::greater,
1192                         "expected `>` after variable type constraint")))
1193     return failure();
1194   return typeExpr;
1195 }
1196 
1197 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1198   assert(curDeclScope && "defining decl outside of a decl scope");
1199   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1200     return emitErrorAndNote(
1201         name.getLoc(), "`" + name.getName() + "` has already been defined",
1202         lastDecl->getName()->getLoc(), "see previous definition here");
1203   }
1204   return success();
1205 }
1206 
1207 FailureOr<ast::VariableDecl *>
1208 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1209                            ast::Expr *initExpr,
1210                            ArrayRef<ast::ConstraintRef> constraints) {
1211   assert(curDeclScope && "defining variable outside of decl scope");
1212   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1213 
1214   // If the name of the variable indicates a special variable, we don't add it
1215   // to the scope. This variable is local to the definition point.
1216   if (name.empty() || name == "_") {
1217     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1218                                      constraints);
1219   }
1220   if (failed(checkDefineNamedDecl(nameDecl)))
1221     return failure();
1222 
1223   auto *varDecl =
1224       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1225   curDeclScope->add(varDecl);
1226   return varDecl;
1227 }
1228 
1229 FailureOr<ast::VariableDecl *>
1230 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1231                            ArrayRef<ast::ConstraintRef> constraints) {
1232   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1233                             constraints);
1234 }
1235 
1236 LogicalResult Parser::parseVariableDeclConstraintList(
1237     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1238   Optional<SMRange> typeConstraint;
1239   auto parseSingleConstraint = [&] {
1240     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1241         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1242     if (failed(constraint))
1243       return failure();
1244     constraints.push_back(*constraint);
1245     return success();
1246   };
1247 
1248   // Check to see if this is a single constraint, or a list.
1249   if (!consumeIf(Token::l_square))
1250     return parseSingleConstraint();
1251 
1252   do {
1253     if (failed(parseSingleConstraint()))
1254       return failure();
1255   } while (consumeIf(Token::comma));
1256   return parseToken(Token::r_square, "expected `]` after constraint list");
1257 }
1258 
1259 FailureOr<ast::ConstraintRef>
1260 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1261                         ArrayRef<ast::ConstraintRef> existingConstraints,
1262                         bool allowInlineTypeConstraints) {
1263   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1264     if (!allowInlineTypeConstraints) {
1265       return emitError(
1266           curToken.getLoc(),
1267           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1268           "permitted on arguments or results");
1269     }
1270     if (typeConstraint)
1271       return emitErrorAndNote(
1272           curToken.getLoc(),
1273           "the type of this variable has already been constrained",
1274           *typeConstraint, "see previous constraint location here");
1275     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1276     if (failed(constraintExpr))
1277       return failure();
1278     typeExpr = *constraintExpr;
1279     typeConstraint = typeExpr->getLoc();
1280     return success();
1281   };
1282 
1283   SMRange loc = curToken.getLoc();
1284   switch (curToken.getKind()) {
1285   case Token::kw_Attr: {
1286     consumeToken(Token::kw_Attr);
1287 
1288     // Check for a type constraint.
1289     ast::Expr *typeExpr = nullptr;
1290     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1291       return failure();
1292     return ast::ConstraintRef(
1293         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1294   }
1295   case Token::kw_Op: {
1296     consumeToken(Token::kw_Op);
1297 
1298     // Parse an optional operation name. If the name isn't provided, this refers
1299     // to "any" operation.
1300     FailureOr<ast::OpNameDecl *> opName =
1301         parseWrappedOperationName(/*allowEmptyName=*/true);
1302     if (failed(opName))
1303       return failure();
1304 
1305     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1306                               loc);
1307   }
1308   case Token::kw_Type:
1309     consumeToken(Token::kw_Type);
1310     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1311   case Token::kw_TypeRange:
1312     consumeToken(Token::kw_TypeRange);
1313     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1314                               loc);
1315   case Token::kw_Value: {
1316     consumeToken(Token::kw_Value);
1317 
1318     // Check for a type constraint.
1319     ast::Expr *typeExpr = nullptr;
1320     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1321       return failure();
1322 
1323     return ast::ConstraintRef(
1324         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1325   }
1326   case Token::kw_ValueRange: {
1327     consumeToken(Token::kw_ValueRange);
1328 
1329     // Check for a type constraint.
1330     ast::Expr *typeExpr = nullptr;
1331     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1332       return failure();
1333 
1334     return ast::ConstraintRef(
1335         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1336   }
1337 
1338   case Token::kw_Constraint: {
1339     // Handle an inline constraint.
1340     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1341     if (failed(decl))
1342       return failure();
1343     return ast::ConstraintRef(*decl, loc);
1344   }
1345   case Token::identifier: {
1346     StringRef constraintName = curToken.getSpelling();
1347     consumeToken(Token::identifier);
1348 
1349     // Lookup the referenced constraint.
1350     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1351     if (!cstDecl) {
1352       return emitError(loc, "unknown reference to constraint `" +
1353                                 constraintName + "`");
1354     }
1355 
1356     // Handle a reference to a proper constraint.
1357     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1358       return ast::ConstraintRef(cst, loc);
1359 
1360     return emitErrorAndNote(
1361         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1362         "see the definition of `" + constraintName + "` here");
1363   }
1364   default:
1365     break;
1366   }
1367   return emitError(loc, "expected identifier constraint");
1368 }
1369 
1370 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1371   Optional<SMRange> typeConstraint;
1372   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1373                          /*allowInlineTypeConstraints=*/false);
1374 }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // Exprs
1378 
1379 FailureOr<ast::Expr *> Parser::parseExpr() {
1380   if (curToken.is(Token::underscore))
1381     return parseUnderscoreExpr();
1382 
1383   // Parse the LHS expression.
1384   FailureOr<ast::Expr *> lhsExpr;
1385   switch (curToken.getKind()) {
1386   case Token::kw_attr:
1387     lhsExpr = parseAttributeExpr();
1388     break;
1389   case Token::kw_Constraint:
1390     lhsExpr = parseInlineConstraintLambdaExpr();
1391     break;
1392   case Token::identifier:
1393     lhsExpr = parseIdentifierExpr();
1394     break;
1395   case Token::kw_op:
1396     lhsExpr = parseOperationExpr();
1397     break;
1398   case Token::kw_Rewrite:
1399     lhsExpr = parseInlineRewriteLambdaExpr();
1400     break;
1401   case Token::kw_type:
1402     lhsExpr = parseTypeExpr();
1403     break;
1404   case Token::l_paren:
1405     lhsExpr = parseTupleExpr();
1406     break;
1407   default:
1408     return emitError("expected expression");
1409   }
1410   if (failed(lhsExpr))
1411     return failure();
1412 
1413   // Check for an operator expression.
1414   while (true) {
1415     switch (curToken.getKind()) {
1416     case Token::dot:
1417       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1418       break;
1419     case Token::l_paren:
1420       lhsExpr = parseCallExpr(*lhsExpr);
1421       break;
1422     default:
1423       return lhsExpr;
1424     }
1425     if (failed(lhsExpr))
1426       return failure();
1427   }
1428 }
1429 
1430 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1431   SMRange loc = curToken.getLoc();
1432   consumeToken(Token::kw_attr);
1433 
1434   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1435   // identifier.
1436   if (!consumeIf(Token::less)) {
1437     resetToken(loc);
1438     return parseIdentifierExpr();
1439   }
1440 
1441   if (!curToken.isString())
1442     return emitError("expected string literal containing MLIR attribute");
1443   std::string attrExpr = curToken.getStringValue();
1444   consumeToken();
1445 
1446   if (failed(
1447           parseToken(Token::greater, "expected `>` after attribute literal")))
1448     return failure();
1449   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1450 }
1451 
1452 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1453   SMRange loc = curToken.getLoc();
1454   consumeToken(Token::l_paren);
1455 
1456   // Parse the arguments of the call.
1457   SmallVector<ast::Expr *> arguments;
1458   if (curToken.isNot(Token::r_paren)) {
1459     do {
1460       FailureOr<ast::Expr *> argument = parseExpr();
1461       if (failed(argument))
1462         return failure();
1463       arguments.push_back(*argument);
1464     } while (consumeIf(Token::comma));
1465   }
1466   loc.End = curToken.getEndLoc();
1467   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1468     return failure();
1469 
1470   return createCallExpr(loc, parentExpr, arguments);
1471 }
1472 
1473 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1474   ast::Decl *decl = curDeclScope->lookup(name);
1475   if (!decl)
1476     return emitError(loc, "undefined reference to `" + name + "`");
1477 
1478   return createDeclRefExpr(loc, decl);
1479 }
1480 
1481 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1482   StringRef name = curToken.getSpelling();
1483   SMRange nameLoc = curToken.getLoc();
1484   consumeToken();
1485 
1486   // Check to see if this is a decl ref expression that defines a variable
1487   // inline.
1488   if (consumeIf(Token::colon)) {
1489     SmallVector<ast::ConstraintRef> constraints;
1490     if (failed(parseVariableDeclConstraintList(constraints)))
1491       return failure();
1492     ast::Type type;
1493     if (failed(validateVariableConstraints(constraints, type)))
1494       return failure();
1495     return createInlineVariableExpr(type, name, nameLoc, constraints);
1496   }
1497 
1498   return parseDeclRefExpr(name, nameLoc);
1499 }
1500 
1501 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1502   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1503   if (failed(decl))
1504     return failure();
1505 
1506   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1507                                   ast::ConstraintType::get(ctx));
1508 }
1509 
1510 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1511   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1512   if (failed(decl))
1513     return failure();
1514 
1515   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1516                                   ast::RewriteType::get(ctx));
1517 }
1518 
1519 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1520   SMRange loc = curToken.getLoc();
1521   consumeToken(Token::dot);
1522 
1523   // Parse the member name.
1524   Token memberNameTok = curToken;
1525   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1526       !memberNameTok.isKeyword())
1527     return emitError(loc, "expected identifier or numeric member name");
1528   StringRef memberName = memberNameTok.getSpelling();
1529   consumeToken();
1530 
1531   return createMemberAccessExpr(parentExpr, memberName, loc);
1532 }
1533 
1534 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1535   SMRange loc = curToken.getLoc();
1536 
1537   // Handle the case of an no operation name.
1538   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1539     if (allowEmptyName)
1540       return ast::OpNameDecl::create(ctx, SMRange());
1541     return emitError("expected dialect namespace");
1542   }
1543   StringRef name = curToken.getSpelling();
1544   consumeToken();
1545 
1546   // Otherwise, this is a literal operation name.
1547   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1548     return failure();
1549 
1550   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1551     return emitError("expected operation name after dialect namespace");
1552 
1553   name = StringRef(name.data(), name.size() + 1);
1554   do {
1555     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1556     loc.End = curToken.getEndLoc();
1557     consumeToken();
1558   } while (curToken.isAny(Token::identifier, Token::dot) ||
1559            curToken.isKeyword());
1560   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1561 }
1562 
1563 FailureOr<ast::OpNameDecl *>
1564 Parser::parseWrappedOperationName(bool allowEmptyName) {
1565   if (!consumeIf(Token::less))
1566     return ast::OpNameDecl::create(ctx, SMRange());
1567 
1568   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1569   if (failed(opNameDecl))
1570     return failure();
1571 
1572   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1573     return failure();
1574   return opNameDecl;
1575 }
1576 
1577 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1578   SMRange loc = curToken.getLoc();
1579   consumeToken(Token::kw_op);
1580 
1581   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1582   // identifier.
1583   if (curToken.isNot(Token::less)) {
1584     resetToken(loc);
1585     return parseIdentifierExpr();
1586   }
1587 
1588   // Parse the operation name. The name may be elided, in which case the
1589   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1590   // `Operation*`). Operation names within a rewrite context must be named.
1591   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1592   FailureOr<ast::OpNameDecl *> opNameDecl =
1593       parseWrappedOperationName(allowEmptyName);
1594   if (failed(opNameDecl))
1595     return failure();
1596 
1597   // Functor used to create an implicit range variable, used for implicit "all"
1598   // operand or results variables.
1599   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1600     FailureOr<ast::VariableDecl *> rangeVar =
1601         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1602     assert(succeeded(rangeVar) && "expected range variable to be valid");
1603     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1604   };
1605 
1606   // Check for the optional list of operands.
1607   SmallVector<ast::Expr *> operands;
1608   if (!consumeIf(Token::l_paren)) {
1609     // If the operand list isn't specified and we are in a match context, define
1610     // an inplace unconstrained operand range corresponding to all of the
1611     // operands of the operation. This avoids treating zero operands the same
1612     // way as "unconstrained operands".
1613     if (parserContext != ParserContext::Rewrite) {
1614       operands.push_back(createImplicitRangeVar(
1615           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1616     }
1617   } else if (!consumeIf(Token::r_paren)) {
1618     // If the operand list was specified and non-empty, parse the operands.
1619     do {
1620       FailureOr<ast::Expr *> operand = parseExpr();
1621       if (failed(operand))
1622         return failure();
1623       operands.push_back(*operand);
1624     } while (consumeIf(Token::comma));
1625 
1626     if (failed(parseToken(Token::r_paren,
1627                           "expected `)` after operation operand list")))
1628       return failure();
1629   }
1630 
1631   // Check for the optional list of attributes.
1632   SmallVector<ast::NamedAttributeDecl *> attributes;
1633   if (consumeIf(Token::l_brace)) {
1634     do {
1635       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1636       if (failed(decl))
1637         return failure();
1638       attributes.emplace_back(*decl);
1639     } while (consumeIf(Token::comma));
1640 
1641     if (failed(parseToken(Token::r_brace,
1642                           "expected `}` after operation attribute list")))
1643       return failure();
1644   }
1645 
1646   // Check for the optional list of result types.
1647   SmallVector<ast::Expr *> resultTypes;
1648   if (consumeIf(Token::arrow)) {
1649     if (failed(parseToken(Token::l_paren,
1650                           "expected `(` before operation result type list")))
1651       return failure();
1652 
1653     // Handle the case of an empty result list.
1654     if (!consumeIf(Token::r_paren)) {
1655       do {
1656         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1657         if (failed(resultTypeExpr))
1658           return failure();
1659         resultTypes.push_back(*resultTypeExpr);
1660       } while (consumeIf(Token::comma));
1661 
1662       if (failed(parseToken(Token::r_paren,
1663                             "expected `)` after operation result type list")))
1664         return failure();
1665     }
1666   } else if (parserContext != ParserContext::Rewrite) {
1667     // If the result list isn't specified and we are in a match context, define
1668     // an inplace unconstrained result range corresponding to all of the results
1669     // of the operation. This avoids treating zero results the same way as
1670     // "unconstrained results".
1671     resultTypes.push_back(createImplicitRangeVar(
1672         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
1673   }
1674 
1675   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1676                              resultTypes);
1677 }
1678 
1679 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1680   SMRange loc = curToken.getLoc();
1681   consumeToken(Token::l_paren);
1682 
1683   DenseMap<StringRef, SMRange> usedNames;
1684   SmallVector<StringRef> elementNames;
1685   SmallVector<ast::Expr *> elements;
1686   if (curToken.isNot(Token::r_paren)) {
1687     do {
1688       // Check for the optional element name assignment before the value.
1689       StringRef elementName;
1690       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1691         Token elementNameTok = curToken;
1692         consumeToken();
1693 
1694         // The element name is only present if followed by an `=`.
1695         if (consumeIf(Token::equal)) {
1696           elementName = elementNameTok.getSpelling();
1697 
1698           // Check to see if this name is already used.
1699           auto elementNameIt =
1700               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1701           if (!elementNameIt.second) {
1702             return emitErrorAndNote(
1703                 elementNameTok.getLoc(),
1704                 llvm::formatv("duplicate tuple element label `{0}`",
1705                               elementName),
1706                 elementNameIt.first->getSecond(),
1707                 "see previous label use here");
1708           }
1709         } else {
1710           // Otherwise, we treat this as part of an expression so reset the
1711           // lexer.
1712           resetToken(elementNameTok.getLoc());
1713         }
1714       }
1715       elementNames.push_back(elementName);
1716 
1717       // Parse the tuple element value.
1718       FailureOr<ast::Expr *> element = parseExpr();
1719       if (failed(element))
1720         return failure();
1721       elements.push_back(*element);
1722     } while (consumeIf(Token::comma));
1723   }
1724   loc.End = curToken.getEndLoc();
1725   if (failed(
1726           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1727     return failure();
1728   return createTupleExpr(loc, elements, elementNames);
1729 }
1730 
1731 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1732   SMRange loc = curToken.getLoc();
1733   consumeToken(Token::kw_type);
1734 
1735   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1736   // identifier.
1737   if (!consumeIf(Token::less)) {
1738     resetToken(loc);
1739     return parseIdentifierExpr();
1740   }
1741 
1742   if (!curToken.isString())
1743     return emitError("expected string literal containing MLIR type");
1744   std::string attrExpr = curToken.getStringValue();
1745   consumeToken();
1746 
1747   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1748     return failure();
1749   return ast::TypeExpr::create(ctx, loc, attrExpr);
1750 }
1751 
1752 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
1753   StringRef name = curToken.getSpelling();
1754   SMRange nameLoc = curToken.getLoc();
1755   consumeToken(Token::underscore);
1756 
1757   // Underscore expressions require a constraint list.
1758   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
1759     return failure();
1760 
1761   // Parse the constraints for the expression.
1762   SmallVector<ast::ConstraintRef> constraints;
1763   if (failed(parseVariableDeclConstraintList(constraints)))
1764     return failure();
1765 
1766   ast::Type type;
1767   if (failed(validateVariableConstraints(constraints, type)))
1768     return failure();
1769   return createInlineVariableExpr(type, name, nameLoc, constraints);
1770 }
1771 
1772 //===----------------------------------------------------------------------===//
1773 // Stmts
1774 
1775 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
1776   FailureOr<ast::Stmt *> stmt;
1777   switch (curToken.getKind()) {
1778   case Token::kw_erase:
1779     stmt = parseEraseStmt();
1780     break;
1781   case Token::kw_let:
1782     stmt = parseLetStmt();
1783     break;
1784   case Token::kw_replace:
1785     stmt = parseReplaceStmt();
1786     break;
1787   case Token::kw_return:
1788     stmt = parseReturnStmt();
1789     break;
1790   case Token::kw_rewrite:
1791     stmt = parseRewriteStmt();
1792     break;
1793   default:
1794     stmt = parseExpr();
1795     break;
1796   }
1797   if (failed(stmt) ||
1798       (expectTerminalSemicolon &&
1799        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
1800     return failure();
1801   return stmt;
1802 }
1803 
1804 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
1805   SMLoc startLoc = curToken.getStartLoc();
1806   consumeToken(Token::l_brace);
1807 
1808   // Push a new block scope and parse any nested statements.
1809   pushDeclScope();
1810   SmallVector<ast::Stmt *> statements;
1811   while (curToken.isNot(Token::r_brace)) {
1812     FailureOr<ast::Stmt *> statement = parseStmt();
1813     if (failed(statement))
1814       return popDeclScope(), failure();
1815     statements.push_back(*statement);
1816   }
1817   popDeclScope();
1818 
1819   // Consume the end brace.
1820   SMRange location(startLoc, curToken.getEndLoc());
1821   consumeToken(Token::r_brace);
1822 
1823   return ast::CompoundStmt::create(ctx, location, statements);
1824 }
1825 
1826 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
1827   if (parserContext == ParserContext::Constraint)
1828     return emitError("`erase` cannot be used within a Constraint");
1829   SMRange loc = curToken.getLoc();
1830   consumeToken(Token::kw_erase);
1831 
1832   // Parse the root operation expression.
1833   FailureOr<ast::Expr *> rootOp = parseExpr();
1834   if (failed(rootOp))
1835     return failure();
1836 
1837   return createEraseStmt(loc, *rootOp);
1838 }
1839 
1840 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
1841   SMRange loc = curToken.getLoc();
1842   consumeToken(Token::kw_let);
1843 
1844   // Parse the name of the new variable.
1845   SMRange varLoc = curToken.getLoc();
1846   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
1847     // `_` is a reserved variable name.
1848     if (curToken.is(Token::underscore)) {
1849       return emitError(varLoc,
1850                        "`_` may only be used to define \"inline\" variables");
1851     }
1852     return emitError(varLoc,
1853                      "expected identifier after `let` to name a new variable");
1854   }
1855   StringRef varName = curToken.getSpelling();
1856   consumeToken();
1857 
1858   // Parse the optional set of constraints.
1859   SmallVector<ast::ConstraintRef> constraints;
1860   if (consumeIf(Token::colon) &&
1861       failed(parseVariableDeclConstraintList(constraints)))
1862     return failure();
1863 
1864   // Parse the optional initializer expression.
1865   ast::Expr *initializer = nullptr;
1866   if (consumeIf(Token::equal)) {
1867     FailureOr<ast::Expr *> initOrFailure = parseExpr();
1868     if (failed(initOrFailure))
1869       return failure();
1870     initializer = *initOrFailure;
1871 
1872     // Check that the constraints are compatible with having an initializer,
1873     // e.g. type constraints cannot be used with initializers.
1874     for (ast::ConstraintRef constraint : constraints) {
1875       LogicalResult result =
1876           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
1877               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
1878                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
1879                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
1880                   return this->emitError(
1881                       constraint.referenceLoc,
1882                       "type constraints are not permitted on variables with "
1883                       "initializers");
1884                 }
1885                 return success();
1886               })
1887               .Default(success());
1888       if (failed(result))
1889         return failure();
1890     }
1891   }
1892 
1893   FailureOr<ast::VariableDecl *> varDecl =
1894       createVariableDecl(varName, varLoc, initializer, constraints);
1895   if (failed(varDecl))
1896     return failure();
1897   return ast::LetStmt::create(ctx, loc, *varDecl);
1898 }
1899 
1900 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
1901   if (parserContext == ParserContext::Constraint)
1902     return emitError("`replace` cannot be used within a Constraint");
1903   SMRange loc = curToken.getLoc();
1904   consumeToken(Token::kw_replace);
1905 
1906   // Parse the root operation expression.
1907   FailureOr<ast::Expr *> rootOp = parseExpr();
1908   if (failed(rootOp))
1909     return failure();
1910 
1911   if (failed(
1912           parseToken(Token::kw_with, "expected `with` after root operation")))
1913     return failure();
1914 
1915   // The replacement portion of this statement is within a rewrite context.
1916   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1917                                               ParserContext::Rewrite);
1918 
1919   // Parse the replacement values.
1920   SmallVector<ast::Expr *> replValues;
1921   if (consumeIf(Token::l_paren)) {
1922     if (consumeIf(Token::r_paren)) {
1923       return emitError(
1924           loc, "expected at least one replacement value, consider using "
1925                "`erase` if no replacement values are desired");
1926     }
1927 
1928     do {
1929       FailureOr<ast::Expr *> replExpr = parseExpr();
1930       if (failed(replExpr))
1931         return failure();
1932       replValues.emplace_back(*replExpr);
1933     } while (consumeIf(Token::comma));
1934 
1935     if (failed(parseToken(Token::r_paren,
1936                           "expected `)` after replacement values")))
1937       return failure();
1938   } else {
1939     FailureOr<ast::Expr *> replExpr = parseExpr();
1940     if (failed(replExpr))
1941       return failure();
1942     replValues.emplace_back(*replExpr);
1943   }
1944 
1945   return createReplaceStmt(loc, *rootOp, replValues);
1946 }
1947 
1948 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
1949   SMRange loc = curToken.getLoc();
1950   consumeToken(Token::kw_return);
1951 
1952   // Parse the result value.
1953   FailureOr<ast::Expr *> resultExpr = parseExpr();
1954   if (failed(resultExpr))
1955     return failure();
1956 
1957   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
1958 }
1959 
1960 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
1961   if (parserContext == ParserContext::Constraint)
1962     return emitError("`rewrite` cannot be used within a Constraint");
1963   SMRange loc = curToken.getLoc();
1964   consumeToken(Token::kw_rewrite);
1965 
1966   // Parse the root operation.
1967   FailureOr<ast::Expr *> rootOp = parseExpr();
1968   if (failed(rootOp))
1969     return failure();
1970 
1971   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
1972     return failure();
1973 
1974   if (curToken.isNot(Token::l_brace))
1975     return emitError("expected `{` to start rewrite body");
1976 
1977   // The rewrite body of this statement is within a rewrite context.
1978   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1979                                               ParserContext::Rewrite);
1980 
1981   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
1982   if (failed(rewriteBody))
1983     return failure();
1984 
1985   // Verify the rewrite body.
1986   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
1987     if (isa<ast::ReturnStmt>(stmt)) {
1988       return emitError(stmt->getLoc(),
1989                        "`return` statements are only permitted within a "
1990                        "`Constraint` or `Rewrite` body");
1991     }
1992   }
1993 
1994   return createRewriteStmt(loc, *rootOp, *rewriteBody);
1995 }
1996 
1997 //===----------------------------------------------------------------------===//
1998 // Creation+Analysis
1999 //===----------------------------------------------------------------------===//
2000 
2001 //===----------------------------------------------------------------------===//
2002 // Decls
2003 
2004 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2005   // Unwrap reference expressions.
2006   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2007     node = init->getDecl();
2008   return dyn_cast<ast::CallableDecl>(node);
2009 }
2010 
2011 FailureOr<ast::PatternDecl *>
2012 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2013                           const ParsedPatternMetadata &metadata,
2014                           ast::CompoundStmt *body) {
2015   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2016                                   metadata.hasBoundedRecursion, body);
2017 }
2018 
2019 ast::Type Parser::createUserConstraintRewriteResultType(
2020     ArrayRef<ast::VariableDecl *> results) {
2021   // Single result decls use the type of the single result.
2022   if (results.size() == 1)
2023     return results[0]->getType();
2024 
2025   // Multiple results use a tuple type, with the types and names grabbed from
2026   // the result variable decls.
2027   auto resultTypes = llvm::map_range(
2028       results, [&](const auto *result) { return result->getType(); });
2029   auto resultNames = llvm::map_range(
2030       results, [&](const auto *result) { return result->getName().getName(); });
2031   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2032                              llvm::to_vector(resultNames));
2033 }
2034 
2035 template <typename T>
2036 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2037     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2038     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2039     ast::CompoundStmt *body) {
2040   if (!body->getChildren().empty()) {
2041     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2042       ast::Expr *resultExpr = retStmt->getResultExpr();
2043 
2044       // Process the result of the decl. If no explicit signature results
2045       // were provided, check for return type inference. Otherwise, check that
2046       // the return expression can be converted to the expected type.
2047       if (results.empty())
2048         resultType = resultExpr->getType();
2049       else if (failed(convertExpressionTo(resultExpr, resultType)))
2050         return failure();
2051       else
2052         retStmt->setResultExpr(resultExpr);
2053     }
2054   }
2055   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2056 }
2057 
2058 FailureOr<ast::VariableDecl *>
2059 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2060                            ArrayRef<ast::ConstraintRef> constraints) {
2061   // The type of the variable, which is expected to be inferred by either a
2062   // constraint or an initializer expression.
2063   ast::Type type;
2064   if (failed(validateVariableConstraints(constraints, type)))
2065     return failure();
2066 
2067   if (initializer) {
2068     // Update the variable type based on the initializer, or try to convert the
2069     // initializer to the existing type.
2070     if (!type)
2071       type = initializer->getType();
2072     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2073       type = mergedType;
2074     else if (failed(convertExpressionTo(initializer, type)))
2075       return failure();
2076 
2077     // Otherwise, if there is no initializer check that the type has already
2078     // been resolved from the constraint list.
2079   } else if (!type) {
2080     return emitErrorAndNote(
2081         loc, "unable to infer type for variable `" + name + "`", loc,
2082         "the type of a variable must be inferable from the constraint "
2083         "list or the initializer");
2084   }
2085 
2086   // Constraint types cannot be used when defining variables.
2087   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2088     return emitError(
2089         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2090   }
2091 
2092   // Try to define a variable with the given name.
2093   FailureOr<ast::VariableDecl *> varDecl =
2094       defineVariableDecl(name, loc, type, initializer, constraints);
2095   if (failed(varDecl))
2096     return failure();
2097 
2098   return *varDecl;
2099 }
2100 
2101 FailureOr<ast::VariableDecl *>
2102 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2103                                       const ast::ConstraintRef &constraint) {
2104   // Constraint arguments may apply more complex constraints via the arguments.
2105   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2106   ast::Type argType;
2107   if (failed(validateVariableConstraint(constraint, argType,
2108                                         allowNonCoreConstraints)))
2109     return failure();
2110   return defineVariableDecl(name, loc, argType, constraint);
2111 }
2112 
2113 LogicalResult
2114 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2115                                     ast::Type &inferredType) {
2116   for (const ast::ConstraintRef &ref : constraints)
2117     if (failed(validateVariableConstraint(ref, inferredType)))
2118       return failure();
2119   return success();
2120 }
2121 
2122 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2123                                                  ast::Type &inferredType,
2124                                                  bool allowNonCoreConstraints) {
2125   ast::Type constraintType;
2126   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2127     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2128       if (failed(validateTypeConstraintExpr(typeExpr)))
2129         return failure();
2130     }
2131     constraintType = ast::AttributeType::get(ctx);
2132   } else if (const auto *cst =
2133                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2134     constraintType = ast::OperationType::get(ctx, cst->getName());
2135   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2136     constraintType = typeTy;
2137   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2138     constraintType = typeRangeTy;
2139   } else if (const auto *cst =
2140                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2141     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2142       if (failed(validateTypeConstraintExpr(typeExpr)))
2143         return failure();
2144     }
2145     constraintType = valueTy;
2146   } else if (const auto *cst =
2147                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2148     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2149       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2150         return failure();
2151     }
2152     constraintType = valueRangeTy;
2153   } else if (const auto *cst =
2154                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2155     if (!allowNonCoreConstraints) {
2156       return emitError(ref.referenceLoc,
2157                        "`Rewrite` arguments and results are only permitted to "
2158                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2159                        "`TypeRange`, `Value`, `ValueRange`");
2160     }
2161 
2162     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2163     if (inputs.size() != 1) {
2164       return emitErrorAndNote(ref.referenceLoc,
2165                               "`Constraint`s applied via a variable constraint "
2166                               "list must take a single input, but got " +
2167                                   Twine(inputs.size()),
2168                               cst->getLoc(),
2169                               "see definition of constraint here");
2170     }
2171     constraintType = inputs.front()->getType();
2172   } else {
2173     llvm_unreachable("unknown constraint type");
2174   }
2175 
2176   // Check that the constraint type is compatible with the current inferred
2177   // type.
2178   if (!inferredType) {
2179     inferredType = constraintType;
2180   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2181     inferredType = mergedTy;
2182   } else {
2183     return emitError(ref.referenceLoc,
2184                      llvm::formatv("constraint type `{0}` is incompatible "
2185                                    "with the previously inferred type `{1}`",
2186                                    constraintType, inferredType));
2187   }
2188   return success();
2189 }
2190 
2191 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2192   ast::Type typeExprType = typeExpr->getType();
2193   if (typeExprType != typeTy) {
2194     return emitError(typeExpr->getLoc(),
2195                      "expected expression of `Type` in type constraint");
2196   }
2197   return success();
2198 }
2199 
2200 LogicalResult
2201 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2202   ast::Type typeExprType = typeExpr->getType();
2203   if (typeExprType != typeRangeTy) {
2204     return emitError(typeExpr->getLoc(),
2205                      "expected expression of `TypeRange` in type constraint");
2206   }
2207   return success();
2208 }
2209 
2210 //===----------------------------------------------------------------------===//
2211 // Exprs
2212 
2213 FailureOr<ast::CallExpr *>
2214 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2215                        MutableArrayRef<ast::Expr *> arguments) {
2216   ast::Type parentType = parentExpr->getType();
2217 
2218   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2219   if (!callableDecl) {
2220     return emitError(loc,
2221                      llvm::formatv("expected a reference to a callable "
2222                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2223                                    parentType));
2224   }
2225   if (parserContext == ParserContext::Rewrite) {
2226     if (isa<ast::UserConstraintDecl>(callableDecl))
2227       return emitError(
2228           loc, "unable to invoke `Constraint` within a rewrite section");
2229   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2230     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2231   }
2232 
2233   // Verify the arguments of the call.
2234   /// Handle size mismatch.
2235   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2236   if (callArgs.size() != arguments.size()) {
2237     return emitErrorAndNote(
2238         loc,
2239         llvm::formatv("invalid number of arguments for {0} call; expected "
2240                       "{1}, but got {2}",
2241                       callableDecl->getCallableType(), callArgs.size(),
2242                       arguments.size()),
2243         callableDecl->getLoc(),
2244         llvm::formatv("see the definition of {0} here",
2245                       callableDecl->getName()->getName()));
2246   }
2247 
2248   /// Handle argument type mismatch.
2249   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2250     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2251                                   callableDecl->getName()->getName()),
2252                     callableDecl->getLoc());
2253   };
2254   for (auto it : llvm::zip(callArgs, arguments)) {
2255     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2256                                    attachDiagFn)))
2257       return failure();
2258   }
2259 
2260   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2261                                callableDecl->getResultType());
2262 }
2263 
2264 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2265                                                         ast::Decl *decl) {
2266   // Check the type of decl being referenced.
2267   ast::Type declType;
2268   if (isa<ast::ConstraintDecl>(decl))
2269     declType = ast::ConstraintType::get(ctx);
2270   else if (isa<ast::UserRewriteDecl>(decl))
2271     declType = ast::RewriteType::get(ctx);
2272   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2273     declType = varDecl->getType();
2274   else
2275     return emitError(loc, "invalid reference to `" +
2276                               decl->getName()->getName() + "`");
2277 
2278   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2279 }
2280 
2281 FailureOr<ast::DeclRefExpr *>
2282 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2283                                  ArrayRef<ast::ConstraintRef> constraints) {
2284   FailureOr<ast::VariableDecl *> decl =
2285       defineVariableDecl(name, loc, type, constraints);
2286   if (failed(decl))
2287     return failure();
2288   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2289 }
2290 
2291 FailureOr<ast::MemberAccessExpr *>
2292 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2293                                SMRange loc) {
2294   // Validate the member name for the given parent expression.
2295   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2296   if (failed(memberType))
2297     return failure();
2298 
2299   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2300 }
2301 
2302 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2303                                                   StringRef name, SMRange loc) {
2304   ast::Type parentType = parentExpr->getType();
2305   if (parentType.isa<ast::OperationType>()) {
2306     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2307       return valueRangeTy;
2308   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2309     // Handle indexed results.
2310     unsigned index = 0;
2311     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2312         index < tupleType.size()) {
2313       return tupleType.getElementTypes()[index];
2314     }
2315 
2316     // Handle named results.
2317     auto elementNames = tupleType.getElementNames();
2318     const auto *it = llvm::find(elementNames, name);
2319     if (it != elementNames.end())
2320       return tupleType.getElementTypes()[it - elementNames.begin()];
2321   }
2322   return emitError(
2323       loc,
2324       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2325                     name, parentType));
2326 }
2327 
2328 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2329     SMRange loc, const ast::OpNameDecl *name,
2330     MutableArrayRef<ast::Expr *> operands,
2331     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2332     MutableArrayRef<ast::Expr *> results) {
2333   Optional<StringRef> opNameRef = name->getName();
2334 
2335   // Verify the inputs operands.
2336   if (failed(validateOperationOperands(loc, opNameRef, operands)))
2337     return failure();
2338 
2339   // Verify the attribute list.
2340   for (ast::NamedAttributeDecl *attr : attributes) {
2341     // Check for an attribute type, or a type awaiting resolution.
2342     ast::Type attrType = attr->getValue()->getType();
2343     if (!attrType.isa<ast::AttributeType>()) {
2344       return emitError(
2345           attr->getValue()->getLoc(),
2346           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2347     }
2348   }
2349 
2350   // Verify the result types.
2351   if (failed(validateOperationResults(loc, opNameRef, results)))
2352     return failure();
2353 
2354   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2355                                     attributes);
2356 }
2357 
2358 LogicalResult
2359 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2360                                   MutableArrayRef<ast::Expr *> operands) {
2361   return validateOperationOperandsOrResults(loc, name, operands, valueTy,
2362                                             valueRangeTy);
2363 }
2364 
2365 LogicalResult
2366 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2367                                  MutableArrayRef<ast::Expr *> results) {
2368   return validateOperationOperandsOrResults(loc, name, results, typeTy,
2369                                             typeRangeTy);
2370 }
2371 
2372 LogicalResult Parser::validateOperationOperandsOrResults(
2373     SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2374     ast::Type singleTy, ast::Type rangeTy) {
2375   // All operation types accept a single range parameter.
2376   if (values.size() == 1) {
2377     if (failed(convertExpressionTo(values[0], rangeTy)))
2378       return failure();
2379     return success();
2380   }
2381 
2382   // Otherwise, accept the value groups as they have been defined and just
2383   // ensure they are one of the expected types.
2384   for (ast::Expr *&valueExpr : values) {
2385     ast::Type valueExprType = valueExpr->getType();
2386 
2387     // Check if this is one of the expected types.
2388     if (valueExprType == rangeTy || valueExprType == singleTy)
2389       continue;
2390 
2391     // If the operand is an Operation, allow converting to a Value or
2392     // ValueRange. This situations arises quite often with nested operation
2393     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2394     if (singleTy == valueTy) {
2395       if (valueExprType.isa<ast::OperationType>()) {
2396         valueExpr = convertOpToValue(valueExpr);
2397         continue;
2398       }
2399     }
2400 
2401     return emitError(
2402         valueExpr->getLoc(),
2403         llvm::formatv(
2404             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2405             singleTy, rangeTy, valueExprType));
2406   }
2407   return success();
2408 }
2409 
2410 FailureOr<ast::TupleExpr *>
2411 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2412                         ArrayRef<StringRef> elementNames) {
2413   for (const ast::Expr *element : elements) {
2414     ast::Type eleTy = element->getType();
2415     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2416       return emitError(
2417           element->getLoc(),
2418           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2419     }
2420   }
2421   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2422 }
2423 
2424 //===----------------------------------------------------------------------===//
2425 // Stmts
2426 
2427 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2428                                                     ast::Expr *rootOp) {
2429   // Check that root is an Operation.
2430   ast::Type rootType = rootOp->getType();
2431   if (!rootType.isa<ast::OperationType>())
2432     return emitError(rootOp->getLoc(), "expected `Op` expression");
2433 
2434   return ast::EraseStmt::create(ctx, loc, rootOp);
2435 }
2436 
2437 FailureOr<ast::ReplaceStmt *>
2438 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2439                           MutableArrayRef<ast::Expr *> replValues) {
2440   // Check that root is an Operation.
2441   ast::Type rootType = rootOp->getType();
2442   if (!rootType.isa<ast::OperationType>()) {
2443     return emitError(
2444         rootOp->getLoc(),
2445         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2446   }
2447 
2448   // If there are multiple replacement values, we implicitly convert any Op
2449   // expressions to the value form.
2450   bool shouldConvertOpToValues = replValues.size() > 1;
2451   for (ast::Expr *&replExpr : replValues) {
2452     ast::Type replType = replExpr->getType();
2453 
2454     // Check that replExpr is an Operation, Value, or ValueRange.
2455     if (replType.isa<ast::OperationType>()) {
2456       if (shouldConvertOpToValues)
2457         replExpr = convertOpToValue(replExpr);
2458       continue;
2459     }
2460 
2461     if (replType != valueTy && replType != valueRangeTy) {
2462       return emitError(replExpr->getLoc(),
2463                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2464                                      "expression, but got `{0}`",
2465                                      replType));
2466     }
2467   }
2468 
2469   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2470 }
2471 
2472 FailureOr<ast::RewriteStmt *>
2473 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2474                           ast::CompoundStmt *rewriteBody) {
2475   // Check that root is an Operation.
2476   ast::Type rootType = rootOp->getType();
2477   if (!rootType.isa<ast::OperationType>()) {
2478     return emitError(
2479         rootOp->getLoc(),
2480         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2481   }
2482 
2483   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // Parser
2488 //===----------------------------------------------------------------------===//
2489 
2490 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
2491                                                  llvm::SourceMgr &sourceMgr) {
2492   Parser parser(ctx, sourceMgr);
2493   return parser.parseModule();
2494 }
2495