1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/Tools/PDLL/AST/Context.h"
13 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
14 #include "mlir/Tools/PDLL/AST/Nodes.h"
15 #include "mlir/Tools/PDLL/AST/Types.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/SaveAndRestore.h"
20 #include "llvm/Support/ScopedPrinter.h"
21 #include <string>
22 
23 using namespace mlir;
24 using namespace mlir::pdll;
25 
26 //===----------------------------------------------------------------------===//
27 // Parser
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 class Parser {
32 public:
33   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
34       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
35         curToken(lexer.lexToken()), curDeclScope(nullptr),
36         valueTy(ast::ValueType::get(ctx)),
37         valueRangeTy(ast::ValueRangeType::get(ctx)),
38         typeTy(ast::TypeType::get(ctx)),
39         typeRangeTy(ast::TypeRangeType::get(ctx)) {}
40 
41   /// Try to parse a new module. Returns nullptr in the case of failure.
42   FailureOr<ast::Module *> parseModule();
43 
44 private:
45   /// The current context of the parser. It allows for the parser to know a bit
46   /// about the construct it is nested within during parsing. This is used
47   /// specifically to provide additional verification during parsing, e.g. to
48   /// prevent using rewrites within a match context, matcher constraints within
49   /// a rewrite section, etc.
50   enum class ParserContext {
51     /// The parser is in the global context.
52     Global,
53     /// The parser is currently within a Constraint, which disallows all types
54     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
55     Constraint,
56     /// The parser is currently within the matcher portion of a Pattern, which
57     /// is allows a terminal operation rewrite statement but no other rewrite
58     /// transformations.
59     PatternMatch,
60     /// The parser is currently within a Rewrite, which disallows calls to
61     /// constraints, requires operation expressions to have names, etc.
62     Rewrite,
63   };
64 
65   //===--------------------------------------------------------------------===//
66   // Parsing
67   //===--------------------------------------------------------------------===//
68 
69   /// Push a new decl scope onto the lexer.
70   ast::DeclScope *pushDeclScope() {
71     ast::DeclScope *newScope =
72         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
73     return (curDeclScope = newScope);
74   }
75   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
76 
77   /// Pop the last decl scope from the lexer.
78   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
79 
80   /// Parse the body of an AST module.
81   LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
82 
83   /// Try to convert the given expression to `type`. Returns failure and emits
84   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
85   /// invoked to attach notes to the emitted error diagnostic. On success,
86   /// `expr` is updated to the expression used to convert to `type`.
87   LogicalResult convertExpressionTo(
88       ast::Expr *&expr, ast::Type type,
89       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
90 
91   /// Given an operation expression, convert it to a Value or ValueRange
92   /// typed expression.
93   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
94 
95   //===--------------------------------------------------------------------===//
96   // Directives
97 
98   LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
99   LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
100 
101   //===--------------------------------------------------------------------===//
102   // Decls
103 
104   /// This structure contains the set of pattern metadata that may be parsed.
105   struct ParsedPatternMetadata {
106     Optional<uint16_t> benefit;
107     bool hasBoundedRecursion = false;
108   };
109 
110   FailureOr<ast::Decl *> parseTopLevelDecl();
111   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
112 
113   /// Parse an argument variable as part of the signature of a
114   /// UserConstraintDecl or UserRewriteDecl.
115   FailureOr<ast::VariableDecl *> parseArgumentDecl();
116 
117   /// Parse a result variable as part of the signature of a UserConstraintDecl
118   /// or UserRewriteDecl.
119   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
120 
121   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
122   /// defined in a non-global context.
123   FailureOr<ast::UserConstraintDecl *>
124   parseUserConstraintDecl(bool isInline = false);
125 
126   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
127   /// non-global context, such as within a Pattern/Constraint/etc.
128   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
129 
130   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
131   /// PDLL constructs.
132   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
133       const ast::Name &name, bool isInline,
134       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
135       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
136 
137   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
138   /// defined in a non-global context.
139   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
140 
141   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
142   /// non-global context, such as within a Pattern/Rewrite/etc.
143   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
144 
145   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
146   /// PDLL constructs.
147   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
148       const ast::Name &name, bool isInline,
149       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
150       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
151 
152   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
153   /// effectively the same syntax, and only differ on slight semantics (given
154   /// the different parsing contexts).
155   template <typename T, typename ParseUserPDLLDeclFnT>
156   FailureOr<T *> parseUserConstraintOrRewriteDecl(
157       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
158       StringRef anonymousNamePrefix, bool isInline);
159 
160   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
161   /// These decls have effectively the same syntax.
162   template <typename T>
163   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
164       const ast::Name &name, bool isInline,
165       ArrayRef<ast::VariableDecl *> arguments,
166       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
167 
168   /// Parse the functional signature (i.e. the arguments and results) of a
169   /// UserConstraintDecl or UserRewriteDecl.
170   LogicalResult parseUserConstraintOrRewriteSignature(
171       SmallVectorImpl<ast::VariableDecl *> &arguments,
172       SmallVectorImpl<ast::VariableDecl *> &results,
173       ast::DeclScope *&argumentScope, ast::Type &resultType);
174 
175   /// Validate the return (which if present is specified by bodyIt) of a
176   /// UserConstraintDecl or UserRewriteDecl.
177   LogicalResult validateUserConstraintOrRewriteReturn(
178       StringRef declType, ast::CompoundStmt *body,
179       ArrayRef<ast::Stmt *>::iterator bodyIt,
180       ArrayRef<ast::Stmt *>::iterator bodyE,
181       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
182 
183   FailureOr<ast::CompoundStmt *>
184   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
185                   bool expectTerminalSemicolon = true);
186   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
187   FailureOr<ast::Decl *> parsePatternDecl();
188   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
189 
190   /// Check to see if a decl has already been defined with the given name, if
191   /// one has emit and error and return failure. Returns success otherwise.
192   LogicalResult checkDefineNamedDecl(const ast::Name &name);
193 
194   /// Try to define a variable decl with the given components, returns the
195   /// variable on success.
196   FailureOr<ast::VariableDecl *>
197   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
198                      ast::Expr *initExpr,
199                      ArrayRef<ast::ConstraintRef> constraints);
200   FailureOr<ast::VariableDecl *>
201   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
202                      ArrayRef<ast::ConstraintRef> constraints);
203 
204   /// Parse the constraint reference list for a variable decl.
205   LogicalResult parseVariableDeclConstraintList(
206       SmallVectorImpl<ast::ConstraintRef> &constraints);
207 
208   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
209   FailureOr<ast::Expr *> parseTypeConstraintExpr();
210 
211   /// Try to parse a single reference to a constraint. `typeConstraint` is the
212   /// location of a previously parsed type constraint for the entity that will
213   /// be constrained by the parsed constraint. `existingConstraints` are any
214   /// existing constraints that have already been parsed for the same entity
215   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
216   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
217   FailureOr<ast::ConstraintRef>
218   parseConstraint(Optional<SMRange> &typeConstraint,
219                   ArrayRef<ast::ConstraintRef> existingConstraints,
220                   bool allowInlineTypeConstraints);
221 
222   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
223   /// argument or result variable. The constraints for these variables do not
224   /// allow inline type constraints, and only permit a single constraint.
225   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
226 
227   //===--------------------------------------------------------------------===//
228   // Exprs
229 
230   FailureOr<ast::Expr *> parseExpr();
231 
232   /// Identifier expressions.
233   FailureOr<ast::Expr *> parseAttributeExpr();
234   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
235   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
236   FailureOr<ast::Expr *> parseIdentifierExpr();
237   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
238   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
239   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
240   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
241   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
242   FailureOr<ast::Expr *> parseOperationExpr();
243   FailureOr<ast::Expr *> parseTupleExpr();
244   FailureOr<ast::Expr *> parseTypeExpr();
245   FailureOr<ast::Expr *> parseUnderscoreExpr();
246 
247   //===--------------------------------------------------------------------===//
248   // Stmts
249 
250   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
251   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
252   FailureOr<ast::EraseStmt *> parseEraseStmt();
253   FailureOr<ast::LetStmt *> parseLetStmt();
254   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
255   FailureOr<ast::ReturnStmt *> parseReturnStmt();
256   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
257 
258   //===--------------------------------------------------------------------===//
259   // Creation+Analysis
260   //===--------------------------------------------------------------------===//
261 
262   //===--------------------------------------------------------------------===//
263   // Decls
264 
265   /// Try to extract a callable from the given AST node. Returns nullptr on
266   /// failure.
267   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
268 
269   /// Try to create a pattern decl with the given components, returning the
270   /// Pattern on success.
271   FailureOr<ast::PatternDecl *>
272   createPatternDecl(SMRange loc, const ast::Name *name,
273                     const ParsedPatternMetadata &metadata,
274                     ast::CompoundStmt *body);
275 
276   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
277   /// of results, defined as part of the signature.
278   ast::Type
279   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
280 
281   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
282   template <typename T>
283   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
284       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
285       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
286       ast::CompoundStmt *body);
287 
288   /// Try to create a variable decl with the given components, returning the
289   /// Variable on success.
290   FailureOr<ast::VariableDecl *>
291   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
292                      ArrayRef<ast::ConstraintRef> constraints);
293 
294   /// Create a variable for an argument or result defined as part of the
295   /// signature of a UserConstraintDecl/UserRewriteDecl.
296   FailureOr<ast::VariableDecl *>
297   createArgOrResultVariableDecl(StringRef name, SMRange loc,
298                                 const ast::ConstraintRef &constraint);
299 
300   /// Validate the constraints used to constraint a variable decl.
301   /// `inferredType` is the type of the variable inferred by the constraints
302   /// within the list, and is updated to the most refined type as determined by
303   /// the constraints. Returns success if the constraint list is valid, failure
304   /// otherwise.
305   LogicalResult
306   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
307                               ast::Type &inferredType);
308   /// Validate a single reference to a constraint. `inferredType` contains the
309   /// currently inferred variabled type and is refined within the type defined
310   /// by the constraint. Returns success if the constraint is valid, failure
311   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
312   /// defined constraints) may be used with the variable.
313   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
314                                            ast::Type &inferredType,
315                                            bool allowNonCoreConstraints = true);
316   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
317   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
318 
319   //===--------------------------------------------------------------------===//
320   // Exprs
321 
322   FailureOr<ast::CallExpr *>
323   createCallExpr(SMRange loc, ast::Expr *parentExpr,
324                  MutableArrayRef<ast::Expr *> arguments);
325   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
326   FailureOr<ast::DeclRefExpr *>
327   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
328                            ArrayRef<ast::ConstraintRef> constraints);
329   FailureOr<ast::MemberAccessExpr *>
330   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
331 
332   /// Validate the member access `name` into the given parent expression. On
333   /// success, this also returns the type of the member accessed.
334   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
335                                             StringRef name, SMRange loc);
336   FailureOr<ast::OperationExpr *>
337   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
338                       MutableArrayRef<ast::Expr *> operands,
339                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
340                       MutableArrayRef<ast::Expr *> results);
341   LogicalResult
342   validateOperationOperands(SMRange loc, Optional<StringRef> name,
343                             MutableArrayRef<ast::Expr *> operands);
344   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
345                                          MutableArrayRef<ast::Expr *> results);
346   LogicalResult
347   validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
348                                      MutableArrayRef<ast::Expr *> values,
349                                      ast::Type singleTy, ast::Type rangeTy);
350   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
351                                               ArrayRef<ast::Expr *> elements,
352                                               ArrayRef<StringRef> elementNames);
353 
354   //===--------------------------------------------------------------------===//
355   // Stmts
356 
357   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
358   FailureOr<ast::ReplaceStmt *>
359   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
360                     MutableArrayRef<ast::Expr *> replValues);
361   FailureOr<ast::RewriteStmt *>
362   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
363                     ast::CompoundStmt *rewriteBody);
364 
365   //===--------------------------------------------------------------------===//
366   // Lexer Utilities
367   //===--------------------------------------------------------------------===//
368 
369   /// If the current token has the specified kind, consume it and return true.
370   /// If not, return false.
371   bool consumeIf(Token::Kind kind) {
372     if (curToken.isNot(kind))
373       return false;
374     consumeToken(kind);
375     return true;
376   }
377 
378   /// Advance the current lexer onto the next token.
379   void consumeToken() {
380     assert(curToken.isNot(Token::eof, Token::error) &&
381            "shouldn't advance past EOF or errors");
382     curToken = lexer.lexToken();
383   }
384 
385   /// Advance the current lexer onto the next token, asserting what the expected
386   /// current token is. This is preferred to the above method because it leads
387   /// to more self-documenting code with better checking.
388   void consumeToken(Token::Kind kind) {
389     assert(curToken.is(kind) && "consumed an unexpected token");
390     consumeToken();
391   }
392 
393   /// Reset the lexer to the location at the given position.
394   void resetToken(SMRange tokLoc) {
395     lexer.resetPointer(tokLoc.Start.getPointer());
396     curToken = lexer.lexToken();
397   }
398 
399   /// Consume the specified token if present and return success. On failure,
400   /// output a diagnostic and return failure.
401   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
402     if (curToken.getKind() != kind)
403       return emitError(curToken.getLoc(), msg);
404     consumeToken();
405     return success();
406   }
407   LogicalResult emitError(SMRange loc, const Twine &msg) {
408     lexer.emitError(loc, msg);
409     return failure();
410   }
411   LogicalResult emitError(const Twine &msg) {
412     return emitError(curToken.getLoc(), msg);
413   }
414   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
415                                  const Twine &note) {
416     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
417     return failure();
418   }
419 
420   //===--------------------------------------------------------------------===//
421   // Fields
422   //===--------------------------------------------------------------------===//
423 
424   /// The owning AST context.
425   ast::Context &ctx;
426 
427   /// The lexer of this parser.
428   Lexer lexer;
429 
430   /// The current token within the lexer.
431   Token curToken;
432 
433   /// The most recently defined decl scope.
434   ast::DeclScope *curDeclScope;
435   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
436 
437   /// The current context of the parser.
438   ParserContext parserContext = ParserContext::Global;
439 
440   /// Cached types to simplify verification and expression creation.
441   ast::Type valueTy, valueRangeTy;
442   ast::Type typeTy, typeRangeTy;
443 
444   /// A counter used when naming anonymous constraints and rewrites.
445   unsigned anonymousDeclNameCounter = 0;
446 };
447 } // namespace
448 
449 FailureOr<ast::Module *> Parser::parseModule() {
450   SMLoc moduleLoc = curToken.getStartLoc();
451   pushDeclScope();
452 
453   // Parse the top-level decls of the module.
454   SmallVector<ast::Decl *> decls;
455   if (failed(parseModuleBody(decls)))
456     return popDeclScope(), failure();
457 
458   popDeclScope();
459   return ast::Module::create(ctx, moduleLoc, decls);
460 }
461 
462 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
463   while (curToken.isNot(Token::eof)) {
464     if (curToken.is(Token::directive)) {
465       if (failed(parseDirective(decls)))
466         return failure();
467       continue;
468     }
469 
470     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
471     if (failed(decl))
472       return failure();
473     decls.push_back(*decl);
474   }
475   return success();
476 }
477 
478 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
479   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
480                                                  valueRangeTy);
481 }
482 
483 LogicalResult Parser::convertExpressionTo(
484     ast::Expr *&expr, ast::Type type,
485     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
486   ast::Type exprType = expr->getType();
487   if (exprType == type)
488     return success();
489 
490   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
491     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
492         expr->getLoc(), llvm::formatv("unable to convert expression of type "
493                                       "`{0}` to the expected type of "
494                                       "`{1}`",
495                                       exprType, type));
496     if (noteAttachFn)
497       noteAttachFn(*diag);
498     return diag;
499   };
500 
501   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
502     // Two operation types are compatible if they have the same name, or if the
503     // expected type is more general.
504     if (auto opType = type.dyn_cast<ast::OperationType>()) {
505       if (opType.getName())
506         return emitConvertError();
507       return success();
508     }
509 
510     // An operation can always convert to a ValueRange.
511     if (type == valueRangeTy) {
512       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
513                                                      valueRangeTy);
514       return success();
515     }
516 
517     // Allow conversion to a single value by constraining the result range.
518     if (type == valueTy) {
519       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
520                                                      valueTy);
521       return success();
522     }
523     return emitConvertError();
524   }
525 
526   // FIXME: Decide how to allow/support converting a single result to multiple,
527   // and multiple to a single result. For now, we just allow Single->Range,
528   // but this isn't something really supported in the PDL dialect. We should
529   // figure out some way to support both.
530   if ((exprType == valueTy || exprType == valueRangeTy) &&
531       (type == valueTy || type == valueRangeTy))
532     return success();
533   if ((exprType == typeTy || exprType == typeRangeTy) &&
534       (type == typeTy || type == typeRangeTy))
535     return success();
536 
537   // Handle tuple types.
538   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
539     auto tupleType = type.dyn_cast<ast::TupleType>();
540     if (!tupleType || tupleType.size() != exprTupleType.size())
541       return emitConvertError();
542 
543     // Build a new tuple expression using each of the elements of the current
544     // tuple.
545     SmallVector<ast::Expr *> newExprs;
546     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
547       newExprs.push_back(ast::MemberAccessExpr::create(
548           ctx, expr->getLoc(), expr, llvm::to_string(i),
549           exprTupleType.getElementTypes()[i]));
550 
551       auto diagFn = [&](ast::Diagnostic &diag) {
552         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
553                                       i, exprTupleType));
554         if (noteAttachFn)
555           noteAttachFn(diag);
556       };
557       if (failed(convertExpressionTo(newExprs.back(),
558                                      tupleType.getElementTypes()[i], diagFn)))
559         return failure();
560     }
561     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
562                                   tupleType.getElementNames());
563     return success();
564   }
565 
566   return emitConvertError();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // Directives
571 
572 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
573   StringRef directive = curToken.getSpelling();
574   if (directive == "#include")
575     return parseInclude(decls);
576 
577   return emitError("unknown directive `" + directive + "`");
578 }
579 
580 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
581   SMRange loc = curToken.getLoc();
582   consumeToken(Token::directive);
583 
584   // Parse the file being included.
585   if (!curToken.isString())
586     return emitError(loc,
587                      "expected string file name after `include` directive");
588   SMRange fileLoc = curToken.getLoc();
589   std::string filenameStr = curToken.getStringValue();
590   StringRef filename = filenameStr;
591   consumeToken();
592 
593   // Check the type of include. If ending with `.pdll`, this is another pdl file
594   // to be parsed along with the current module.
595   if (filename.endswith(".pdll")) {
596     if (failed(lexer.pushInclude(filename)))
597       return emitError(fileLoc,
598                        "unable to open include file `" + filename + "`");
599 
600     // If we added the include successfully, parse it into the current module.
601     // Make sure to save the current token so that we can restore it when we
602     // finish parsing the nested file.
603     Token oldToken = curToken;
604     curToken = lexer.lexToken();
605     LogicalResult result = parseModuleBody(decls);
606     curToken = oldToken;
607     return result;
608   }
609 
610   return emitError(fileLoc, "expected include filename to end with `.pdll`");
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // Decls
615 
616 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
617   FailureOr<ast::Decl *> decl;
618   switch (curToken.getKind()) {
619   case Token::kw_Constraint:
620     decl = parseUserConstraintDecl();
621     break;
622   case Token::kw_Pattern:
623     decl = parsePatternDecl();
624     break;
625   case Token::kw_Rewrite:
626     decl = parseUserRewriteDecl();
627     break;
628   default:
629     return emitError("expected top-level declaration, such as a `Pattern`");
630   }
631   if (failed(decl))
632     return failure();
633 
634   // If the decl has a name, add it to the current scope.
635   if (const ast::Name *name = (*decl)->getName()) {
636     if (failed(checkDefineNamedDecl(*name)))
637       return failure();
638     curDeclScope->add(*decl);
639   }
640   return decl;
641 }
642 
643 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
644   std::string attrNameStr;
645   if (curToken.isString())
646     attrNameStr = curToken.getStringValue();
647   else if (curToken.is(Token::identifier) || curToken.isKeyword())
648     attrNameStr = curToken.getSpelling().str();
649   else
650     return emitError("expected identifier or string attribute name");
651   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
652   consumeToken();
653 
654   // Check for a value of the attribute.
655   ast::Expr *attrValue = nullptr;
656   if (consumeIf(Token::equal)) {
657     FailureOr<ast::Expr *> attrExpr = parseExpr();
658     if (failed(attrExpr))
659       return failure();
660     attrValue = *attrExpr;
661   } else {
662     // If there isn't a concrete value, create an expression representing a
663     // UnitAttr.
664     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
665   }
666 
667   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
668 }
669 
670 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
671     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
672     bool expectTerminalSemicolon) {
673   consumeToken(Token::equal_arrow);
674 
675   // Parse the single statement of the lambda body.
676   SMLoc bodyStartLoc = curToken.getStartLoc();
677   pushDeclScope();
678   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
679   bool failedToParse =
680       failed(singleStatement) || failed(processStatementFn(*singleStatement));
681   popDeclScope();
682   if (failedToParse)
683     return failure();
684 
685   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
686   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
687 }
688 
689 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
690   // Ensure that the argument is named.
691   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
692     return emitError("expected identifier argument name");
693 
694   // Parse the argument similarly to a normal variable.
695   StringRef name = curToken.getSpelling();
696   SMRange nameLoc = curToken.getLoc();
697   consumeToken();
698 
699   if (failed(
700           parseToken(Token::colon, "expected `:` before argument constraint")))
701     return failure();
702 
703   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
704   if (failed(cst))
705     return failure();
706 
707   return createArgOrResultVariableDecl(name, nameLoc, *cst);
708 }
709 
710 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
711   // Check to see if this result is named.
712   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
713     // Check to see if this name actually refers to a Constraint.
714     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
715     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
716       // If yes, and this is a Rewrite, give a nice error message as non-Core
717       // constraints are not supported on Rewrite results.
718       if (parserContext == ParserContext::Rewrite) {
719         return emitError(
720             "`Rewrite` results are only permitted to use core constraints, "
721             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
722       }
723 
724       // Otherwise, parse this as an unnamed result variable.
725     } else {
726       // If it wasn't a constraint, parse the result similarly to a variable. If
727       // there is already an existing decl, we will emit an error when defining
728       // this variable later.
729       StringRef name = curToken.getSpelling();
730       SMRange nameLoc = curToken.getLoc();
731       consumeToken();
732 
733       if (failed(parseToken(Token::colon,
734                             "expected `:` before result constraint")))
735         return failure();
736 
737       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
738       if (failed(cst))
739         return failure();
740 
741       return createArgOrResultVariableDecl(name, nameLoc, *cst);
742     }
743   }
744 
745   // If it isn't named, we parse the constraint directly and create an unnamed
746   // result variable.
747   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
748   if (failed(cst))
749     return failure();
750 
751   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
752 }
753 
754 FailureOr<ast::UserConstraintDecl *>
755 Parser::parseUserConstraintDecl(bool isInline) {
756   // Constraints and rewrites have very similar formats, dispatch to a shared
757   // interface for parsing.
758   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
759       [&](auto &&...args) {
760         return this->parseUserPDLLConstraintDecl(args...);
761       },
762       ParserContext::Constraint, "constraint", isInline);
763 }
764 
765 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
766   FailureOr<ast::UserConstraintDecl *> decl =
767       parseUserConstraintDecl(/*isInline=*/true);
768   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
769     return failure();
770 
771   curDeclScope->add(*decl);
772   return decl;
773 }
774 
775 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
776     const ast::Name &name, bool isInline,
777     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
778     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
779   // Push the argument scope back onto the list, so that the body can
780   // reference arguments.
781   pushDeclScope(argumentScope);
782 
783   // Parse the body of the constraint. The body is either defined as a compound
784   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
785   ast::CompoundStmt *body;
786   if (curToken.is(Token::equal_arrow)) {
787     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
788         [&](ast::Stmt *&stmt) -> LogicalResult {
789           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
790           if (!stmtExpr) {
791             return emitError(stmt->getLoc(),
792                              "expected `Constraint` lambda body to contain a "
793                              "single expression");
794           }
795           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
796           return success();
797         },
798         /*expectTerminalSemicolon=*/!isInline);
799     if (failed(bodyResult))
800       return failure();
801     body = *bodyResult;
802   } else {
803     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
804     if (failed(bodyResult))
805       return failure();
806     body = *bodyResult;
807 
808     // Verify the structure of the body.
809     auto bodyIt = body->begin(), bodyE = body->end();
810     for (; bodyIt != bodyE; ++bodyIt)
811       if (isa<ast::ReturnStmt>(*bodyIt))
812         break;
813     if (failed(validateUserConstraintOrRewriteReturn(
814             "Constraint", body, bodyIt, bodyE, results, resultType)))
815       return failure();
816   }
817   popDeclScope();
818 
819   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
820       name, arguments, results, resultType, body);
821 }
822 
823 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
824   // Constraints and rewrites have very similar formats, dispatch to a shared
825   // interface for parsing.
826   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
827       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
828       ParserContext::Rewrite, "rewrite", isInline);
829 }
830 
831 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
832   FailureOr<ast::UserRewriteDecl *> decl =
833       parseUserRewriteDecl(/*isInline=*/true);
834   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
835     return failure();
836 
837   curDeclScope->add(*decl);
838   return decl;
839 }
840 
841 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
842     const ast::Name &name, bool isInline,
843     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
844     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
845   // Push the argument scope back onto the list, so that the body can
846   // reference arguments.
847   curDeclScope = argumentScope;
848   ast::CompoundStmt *body;
849   if (curToken.is(Token::equal_arrow)) {
850     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
851         [&](ast::Stmt *&statement) -> LogicalResult {
852           if (isa<ast::OpRewriteStmt>(statement))
853             return success();
854 
855           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
856           if (!statementExpr) {
857             return emitError(
858                 statement->getLoc(),
859                 "expected `Rewrite` lambda body to contain a single expression "
860                 "or an operation rewrite statement; such as `erase`, "
861                 "`replace`, or `rewrite`");
862           }
863           statement =
864               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
865           return success();
866         },
867         /*expectTerminalSemicolon=*/!isInline);
868     if (failed(bodyResult))
869       return failure();
870     body = *bodyResult;
871   } else {
872     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
873     if (failed(bodyResult))
874       return failure();
875     body = *bodyResult;
876   }
877   popDeclScope();
878 
879   // Verify the structure of the body.
880   auto bodyIt = body->begin(), bodyE = body->end();
881   for (; bodyIt != bodyE; ++bodyIt)
882     if (isa<ast::ReturnStmt>(*bodyIt))
883       break;
884   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
885                                                    bodyE, results, resultType)))
886     return failure();
887   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
888       name, arguments, results, resultType, body);
889 }
890 
891 template <typename T, typename ParseUserPDLLDeclFnT>
892 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
893     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
894     StringRef anonymousNamePrefix, bool isInline) {
895   SMRange loc = curToken.getLoc();
896   consumeToken();
897   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
898 
899   // Parse the name of the decl.
900   const ast::Name *name = nullptr;
901   if (curToken.isNot(Token::identifier)) {
902     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
903     // in C++, so being unnamed is fine.
904     if (!isInline)
905       return emitError("expected identifier name");
906 
907     // Create a unique anonymous name to use, as the name for this decl is not
908     // important.
909     std::string anonName =
910         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
911                       anonymousDeclNameCounter++)
912             .str();
913     name = &ast::Name::create(ctx, anonName, loc);
914   } else {
915     // If a name was provided, we can use it directly.
916     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
917     consumeToken(Token::identifier);
918   }
919 
920   // Parse the functional signature of the decl.
921   SmallVector<ast::VariableDecl *> arguments, results;
922   ast::DeclScope *argumentScope;
923   ast::Type resultType;
924   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
925                                                    argumentScope, resultType)))
926     return failure();
927 
928   // Check to see which type of constraint this is. If the constraint contains a
929   // compound body, this is a PDLL decl.
930   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
931     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
932                            resultType);
933 
934   // Otherwise, this is a native decl.
935   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
936                                                    results, resultType);
937 }
938 
939 template <typename T>
940 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
941     const ast::Name &name, bool isInline,
942     ArrayRef<ast::VariableDecl *> arguments,
943     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
944   // If followed by a string, the native code body has also been specified.
945   std::string codeStrStorage;
946   Optional<StringRef> optCodeStr;
947   if (curToken.isString()) {
948     codeStrStorage = curToken.getStringValue();
949     optCodeStr = codeStrStorage;
950     consumeToken();
951   } else if (isInline) {
952     return emitError(name.getLoc(),
953                      "external declarations must be declared in global scope");
954   }
955   if (failed(parseToken(Token::semicolon,
956                         "expected `;` after native declaration")))
957     return failure();
958   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
959 }
960 
961 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
962     SmallVectorImpl<ast::VariableDecl *> &arguments,
963     SmallVectorImpl<ast::VariableDecl *> &results,
964     ast::DeclScope *&argumentScope, ast::Type &resultType) {
965   // Parse the argument list of the decl.
966   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
967     return failure();
968 
969   argumentScope = pushDeclScope();
970   if (curToken.isNot(Token::r_paren)) {
971     do {
972       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
973       if (failed(argument))
974         return failure();
975       arguments.emplace_back(*argument);
976     } while (consumeIf(Token::comma));
977   }
978   popDeclScope();
979   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
980     return failure();
981 
982   // Parse the results of the decl.
983   pushDeclScope();
984   if (consumeIf(Token::arrow)) {
985     auto parseResultFn = [&]() -> LogicalResult {
986       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
987       if (failed(result))
988         return failure();
989       results.emplace_back(*result);
990       return success();
991     };
992 
993     // Check for a list of results.
994     if (consumeIf(Token::l_paren)) {
995       do {
996         if (failed(parseResultFn()))
997           return failure();
998       } while (consumeIf(Token::comma));
999       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1000         return failure();
1001 
1002       // Otherwise, there is only one result.
1003     } else if (failed(parseResultFn())) {
1004       return failure();
1005     }
1006   }
1007   popDeclScope();
1008 
1009   // Compute the result type of the decl.
1010   resultType = createUserConstraintRewriteResultType(results);
1011 
1012   // Verify that results are only named if there are more than one.
1013   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1014     return emitError(
1015         results.front()->getLoc(),
1016         "cannot create a single-element tuple with an element label");
1017   }
1018   return success();
1019 }
1020 
1021 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1022     StringRef declType, ast::CompoundStmt *body,
1023     ArrayRef<ast::Stmt *>::iterator bodyIt,
1024     ArrayRef<ast::Stmt *>::iterator bodyE,
1025     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1026   // Handle if a `return` was provided.
1027   if (bodyIt != bodyE) {
1028     // Emit an error if we have trailing statements after the return.
1029     if (std::next(bodyIt) != bodyE) {
1030       return emitError(
1031           (*std::next(bodyIt))->getLoc(),
1032           llvm::formatv("`return` terminated the `{0}` body, but found "
1033                         "trailing statements afterwards",
1034                         declType));
1035     }
1036 
1037     // Otherwise if a return wasn't provided, check that no results are
1038     // expected.
1039   } else if (!results.empty()) {
1040     return emitError(
1041         {body->getLoc().End, body->getLoc().End},
1042         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1043                       declType, resultType));
1044   }
1045   return success();
1046 }
1047 
1048 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1049   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1050     if (isa<ast::OpRewriteStmt>(statement))
1051       return success();
1052     return emitError(
1053         statement->getLoc(),
1054         "expected Pattern lambda body to contain a single operation "
1055         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1056   });
1057 }
1058 
1059 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1060   SMRange loc = curToken.getLoc();
1061   consumeToken(Token::kw_Pattern);
1062   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1063                                               ParserContext::PatternMatch);
1064 
1065   // Check for an optional identifier for the pattern name.
1066   const ast::Name *name = nullptr;
1067   if (curToken.is(Token::identifier)) {
1068     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1069     consumeToken(Token::identifier);
1070   }
1071 
1072   // Parse any pattern metadata.
1073   ParsedPatternMetadata metadata;
1074   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1075     return failure();
1076 
1077   // Parse the pattern body.
1078   ast::CompoundStmt *body;
1079 
1080   // Handle a lambda body.
1081   if (curToken.is(Token::equal_arrow)) {
1082     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1083     if (failed(bodyResult))
1084       return failure();
1085     body = *bodyResult;
1086   } else {
1087     if (curToken.isNot(Token::l_brace))
1088       return emitError("expected `{` or `=>` to start pattern body");
1089     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1090     if (failed(bodyResult))
1091       return failure();
1092     body = *bodyResult;
1093 
1094     // Verify the body of the pattern.
1095     auto bodyIt = body->begin(), bodyE = body->end();
1096     for (; bodyIt != bodyE; ++bodyIt) {
1097       if (isa<ast::ReturnStmt>(*bodyIt)) {
1098         return emitError((*bodyIt)->getLoc(),
1099                          "`return` statements are only permitted within a "
1100                          "`Constraint` or `Rewrite` body");
1101       }
1102       // Break when we've found the rewrite statement.
1103       if (isa<ast::OpRewriteStmt>(*bodyIt))
1104         break;
1105     }
1106     if (bodyIt == bodyE) {
1107       return emitError(loc,
1108                        "expected Pattern body to terminate with an operation "
1109                        "rewrite statement, such as `erase`");
1110     }
1111     if (std::next(bodyIt) != bodyE) {
1112       return emitError((*std::next(bodyIt))->getLoc(),
1113                        "Pattern body was terminated by an operation "
1114                        "rewrite statement, but found trailing statements");
1115     }
1116   }
1117 
1118   return createPatternDecl(loc, name, metadata, body);
1119 }
1120 
1121 LogicalResult
1122 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1123   Optional<SMRange> benefitLoc;
1124   Optional<SMRange> hasBoundedRecursionLoc;
1125 
1126   do {
1127     if (curToken.isNot(Token::identifier))
1128       return emitError("expected pattern metadata identifier");
1129     StringRef metadataStr = curToken.getSpelling();
1130     SMRange metadataLoc = curToken.getLoc();
1131     consumeToken(Token::identifier);
1132 
1133     // Parse the benefit metadata: benefit(<integer-value>)
1134     if (metadataStr == "benefit") {
1135       if (benefitLoc) {
1136         return emitErrorAndNote(metadataLoc,
1137                                 "pattern benefit has already been specified",
1138                                 *benefitLoc, "see previous definition here");
1139       }
1140       if (failed(parseToken(Token::l_paren,
1141                             "expected `(` before pattern benefit")))
1142         return failure();
1143 
1144       uint16_t benefitValue = 0;
1145       if (curToken.isNot(Token::integer))
1146         return emitError("expected integral pattern benefit");
1147       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1148         return emitError(
1149             "expected pattern benefit to fit within a 16-bit integer");
1150       consumeToken(Token::integer);
1151 
1152       metadata.benefit = benefitValue;
1153       benefitLoc = metadataLoc;
1154 
1155       if (failed(
1156               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1157         return failure();
1158       continue;
1159     }
1160 
1161     // Parse the bounded recursion metadata: recursion
1162     if (metadataStr == "recursion") {
1163       if (hasBoundedRecursionLoc) {
1164         return emitErrorAndNote(
1165             metadataLoc,
1166             "pattern recursion metadata has already been specified",
1167             *hasBoundedRecursionLoc, "see previous definition here");
1168       }
1169       metadata.hasBoundedRecursion = true;
1170       hasBoundedRecursionLoc = metadataLoc;
1171       continue;
1172     }
1173 
1174     return emitError(metadataLoc, "unknown pattern metadata");
1175   } while (consumeIf(Token::comma));
1176 
1177   return success();
1178 }
1179 
1180 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1181   consumeToken(Token::less);
1182 
1183   FailureOr<ast::Expr *> typeExpr = parseExpr();
1184   if (failed(typeExpr) ||
1185       failed(parseToken(Token::greater,
1186                         "expected `>` after variable type constraint")))
1187     return failure();
1188   return typeExpr;
1189 }
1190 
1191 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1192   assert(curDeclScope && "defining decl outside of a decl scope");
1193   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1194     return emitErrorAndNote(
1195         name.getLoc(), "`" + name.getName() + "` has already been defined",
1196         lastDecl->getName()->getLoc(), "see previous definition here");
1197   }
1198   return success();
1199 }
1200 
1201 FailureOr<ast::VariableDecl *>
1202 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1203                            ast::Expr *initExpr,
1204                            ArrayRef<ast::ConstraintRef> constraints) {
1205   assert(curDeclScope && "defining variable outside of decl scope");
1206   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1207 
1208   // If the name of the variable indicates a special variable, we don't add it
1209   // to the scope. This variable is local to the definition point.
1210   if (name.empty() || name == "_") {
1211     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1212                                      constraints);
1213   }
1214   if (failed(checkDefineNamedDecl(nameDecl)))
1215     return failure();
1216 
1217   auto *varDecl =
1218       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1219   curDeclScope->add(varDecl);
1220   return varDecl;
1221 }
1222 
1223 FailureOr<ast::VariableDecl *>
1224 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1225                            ArrayRef<ast::ConstraintRef> constraints) {
1226   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1227                             constraints);
1228 }
1229 
1230 LogicalResult Parser::parseVariableDeclConstraintList(
1231     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1232   Optional<SMRange> typeConstraint;
1233   auto parseSingleConstraint = [&] {
1234     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1235         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1236     if (failed(constraint))
1237       return failure();
1238     constraints.push_back(*constraint);
1239     return success();
1240   };
1241 
1242   // Check to see if this is a single constraint, or a list.
1243   if (!consumeIf(Token::l_square))
1244     return parseSingleConstraint();
1245 
1246   do {
1247     if (failed(parseSingleConstraint()))
1248       return failure();
1249   } while (consumeIf(Token::comma));
1250   return parseToken(Token::r_square, "expected `]` after constraint list");
1251 }
1252 
1253 FailureOr<ast::ConstraintRef>
1254 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1255                         ArrayRef<ast::ConstraintRef> existingConstraints,
1256                         bool allowInlineTypeConstraints) {
1257   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1258     if (!allowInlineTypeConstraints) {
1259       return emitError(
1260           curToken.getLoc(),
1261           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1262           "permitted on arguments or results");
1263     }
1264     if (typeConstraint)
1265       return emitErrorAndNote(
1266           curToken.getLoc(),
1267           "the type of this variable has already been constrained",
1268           *typeConstraint, "see previous constraint location here");
1269     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1270     if (failed(constraintExpr))
1271       return failure();
1272     typeExpr = *constraintExpr;
1273     typeConstraint = typeExpr->getLoc();
1274     return success();
1275   };
1276 
1277   SMRange loc = curToken.getLoc();
1278   switch (curToken.getKind()) {
1279   case Token::kw_Attr: {
1280     consumeToken(Token::kw_Attr);
1281 
1282     // Check for a type constraint.
1283     ast::Expr *typeExpr = nullptr;
1284     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1285       return failure();
1286     return ast::ConstraintRef(
1287         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1288   }
1289   case Token::kw_Op: {
1290     consumeToken(Token::kw_Op);
1291 
1292     // Parse an optional operation name. If the name isn't provided, this refers
1293     // to "any" operation.
1294     FailureOr<ast::OpNameDecl *> opName =
1295         parseWrappedOperationName(/*allowEmptyName=*/true);
1296     if (failed(opName))
1297       return failure();
1298 
1299     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1300                               loc);
1301   }
1302   case Token::kw_Type:
1303     consumeToken(Token::kw_Type);
1304     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1305   case Token::kw_TypeRange:
1306     consumeToken(Token::kw_TypeRange);
1307     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1308                               loc);
1309   case Token::kw_Value: {
1310     consumeToken(Token::kw_Value);
1311 
1312     // Check for a type constraint.
1313     ast::Expr *typeExpr = nullptr;
1314     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1315       return failure();
1316 
1317     return ast::ConstraintRef(
1318         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1319   }
1320   case Token::kw_ValueRange: {
1321     consumeToken(Token::kw_ValueRange);
1322 
1323     // Check for a type constraint.
1324     ast::Expr *typeExpr = nullptr;
1325     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1326       return failure();
1327 
1328     return ast::ConstraintRef(
1329         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1330   }
1331 
1332   case Token::kw_Constraint: {
1333     // Handle an inline constraint.
1334     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1335     if (failed(decl))
1336       return failure();
1337     return ast::ConstraintRef(*decl, loc);
1338   }
1339   case Token::identifier: {
1340     StringRef constraintName = curToken.getSpelling();
1341     consumeToken(Token::identifier);
1342 
1343     // Lookup the referenced constraint.
1344     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1345     if (!cstDecl) {
1346       return emitError(loc, "unknown reference to constraint `" +
1347                                 constraintName + "`");
1348     }
1349 
1350     // Handle a reference to a proper constraint.
1351     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1352       return ast::ConstraintRef(cst, loc);
1353 
1354     return emitErrorAndNote(
1355         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1356         "see the definition of `" + constraintName + "` here");
1357   }
1358   default:
1359     break;
1360   }
1361   return emitError(loc, "expected identifier constraint");
1362 }
1363 
1364 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1365   Optional<SMRange> typeConstraint;
1366   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1367                          /*allowInlineTypeConstraints=*/false);
1368 }
1369 
1370 //===----------------------------------------------------------------------===//
1371 // Exprs
1372 
1373 FailureOr<ast::Expr *> Parser::parseExpr() {
1374   if (curToken.is(Token::underscore))
1375     return parseUnderscoreExpr();
1376 
1377   // Parse the LHS expression.
1378   FailureOr<ast::Expr *> lhsExpr;
1379   switch (curToken.getKind()) {
1380   case Token::kw_attr:
1381     lhsExpr = parseAttributeExpr();
1382     break;
1383   case Token::kw_Constraint:
1384     lhsExpr = parseInlineConstraintLambdaExpr();
1385     break;
1386   case Token::identifier:
1387     lhsExpr = parseIdentifierExpr();
1388     break;
1389   case Token::kw_op:
1390     lhsExpr = parseOperationExpr();
1391     break;
1392   case Token::kw_Rewrite:
1393     lhsExpr = parseInlineRewriteLambdaExpr();
1394     break;
1395   case Token::kw_type:
1396     lhsExpr = parseTypeExpr();
1397     break;
1398   case Token::l_paren:
1399     lhsExpr = parseTupleExpr();
1400     break;
1401   default:
1402     return emitError("expected expression");
1403   }
1404   if (failed(lhsExpr))
1405     return failure();
1406 
1407   // Check for an operator expression.
1408   while (true) {
1409     switch (curToken.getKind()) {
1410     case Token::dot:
1411       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1412       break;
1413     case Token::l_paren:
1414       lhsExpr = parseCallExpr(*lhsExpr);
1415       break;
1416     default:
1417       return lhsExpr;
1418     }
1419     if (failed(lhsExpr))
1420       return failure();
1421   }
1422 }
1423 
1424 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1425   SMRange loc = curToken.getLoc();
1426   consumeToken(Token::kw_attr);
1427 
1428   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1429   // identifier.
1430   if (!consumeIf(Token::less)) {
1431     resetToken(loc);
1432     return parseIdentifierExpr();
1433   }
1434 
1435   if (!curToken.isString())
1436     return emitError("expected string literal containing MLIR attribute");
1437   std::string attrExpr = curToken.getStringValue();
1438   consumeToken();
1439 
1440   if (failed(
1441           parseToken(Token::greater, "expected `>` after attribute literal")))
1442     return failure();
1443   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1444 }
1445 
1446 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1447   SMRange loc = curToken.getLoc();
1448   consumeToken(Token::l_paren);
1449 
1450   // Parse the arguments of the call.
1451   SmallVector<ast::Expr *> arguments;
1452   if (curToken.isNot(Token::r_paren)) {
1453     do {
1454       FailureOr<ast::Expr *> argument = parseExpr();
1455       if (failed(argument))
1456         return failure();
1457       arguments.push_back(*argument);
1458     } while (consumeIf(Token::comma));
1459   }
1460   loc.End = curToken.getEndLoc();
1461   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1462     return failure();
1463 
1464   return createCallExpr(loc, parentExpr, arguments);
1465 }
1466 
1467 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1468   ast::Decl *decl = curDeclScope->lookup(name);
1469   if (!decl)
1470     return emitError(loc, "undefined reference to `" + name + "`");
1471 
1472   return createDeclRefExpr(loc, decl);
1473 }
1474 
1475 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1476   StringRef name = curToken.getSpelling();
1477   SMRange nameLoc = curToken.getLoc();
1478   consumeToken();
1479 
1480   // Check to see if this is a decl ref expression that defines a variable
1481   // inline.
1482   if (consumeIf(Token::colon)) {
1483     SmallVector<ast::ConstraintRef> constraints;
1484     if (failed(parseVariableDeclConstraintList(constraints)))
1485       return failure();
1486     ast::Type type;
1487     if (failed(validateVariableConstraints(constraints, type)))
1488       return failure();
1489     return createInlineVariableExpr(type, name, nameLoc, constraints);
1490   }
1491 
1492   return parseDeclRefExpr(name, nameLoc);
1493 }
1494 
1495 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1496   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1497   if (failed(decl))
1498     return failure();
1499 
1500   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1501                                   ast::ConstraintType::get(ctx));
1502 }
1503 
1504 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1505   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1506   if (failed(decl))
1507     return failure();
1508 
1509   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1510                                   ast::RewriteType::get(ctx));
1511 }
1512 
1513 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1514   SMRange loc = curToken.getLoc();
1515   consumeToken(Token::dot);
1516 
1517   // Parse the member name.
1518   Token memberNameTok = curToken;
1519   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1520       !memberNameTok.isKeyword())
1521     return emitError(loc, "expected identifier or numeric member name");
1522   StringRef memberName = memberNameTok.getSpelling();
1523   consumeToken();
1524 
1525   return createMemberAccessExpr(parentExpr, memberName, loc);
1526 }
1527 
1528 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1529   SMRange loc = curToken.getLoc();
1530 
1531   // Handle the case of an no operation name.
1532   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1533     if (allowEmptyName)
1534       return ast::OpNameDecl::create(ctx, SMRange());
1535     return emitError("expected dialect namespace");
1536   }
1537   StringRef name = curToken.getSpelling();
1538   consumeToken();
1539 
1540   // Otherwise, this is a literal operation name.
1541   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1542     return failure();
1543 
1544   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1545     return emitError("expected operation name after dialect namespace");
1546 
1547   name = StringRef(name.data(), name.size() + 1);
1548   do {
1549     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1550     loc.End = curToken.getEndLoc();
1551     consumeToken();
1552   } while (curToken.isAny(Token::identifier, Token::dot) ||
1553            curToken.isKeyword());
1554   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1555 }
1556 
1557 FailureOr<ast::OpNameDecl *>
1558 Parser::parseWrappedOperationName(bool allowEmptyName) {
1559   if (!consumeIf(Token::less))
1560     return ast::OpNameDecl::create(ctx, SMRange());
1561 
1562   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1563   if (failed(opNameDecl))
1564     return failure();
1565 
1566   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1567     return failure();
1568   return opNameDecl;
1569 }
1570 
1571 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1572   SMRange loc = curToken.getLoc();
1573   consumeToken(Token::kw_op);
1574 
1575   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1576   // identifier.
1577   if (curToken.isNot(Token::less)) {
1578     resetToken(loc);
1579     return parseIdentifierExpr();
1580   }
1581 
1582   // Parse the operation name. The name may be elided, in which case the
1583   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1584   // `Operation*`). Operation names within a rewrite context must be named.
1585   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1586   FailureOr<ast::OpNameDecl *> opNameDecl =
1587       parseWrappedOperationName(allowEmptyName);
1588   if (failed(opNameDecl))
1589     return failure();
1590 
1591   // Functor used to create an implicit range variable, used for implicit "all"
1592   // operand or results variables.
1593   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1594     FailureOr<ast::VariableDecl *> rangeVar =
1595         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1596     assert(succeeded(rangeVar) && "expected range variable to be valid");
1597     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1598   };
1599 
1600   // Check for the optional list of operands.
1601   SmallVector<ast::Expr *> operands;
1602   if (!consumeIf(Token::l_paren)) {
1603     // If the operand list isn't specified and we are in a match context, define
1604     // an inplace unconstrained operand range corresponding to all of the
1605     // operands of the operation. This avoids treating zero operands the same
1606     // way as "unconstrained operands".
1607     if (parserContext != ParserContext::Rewrite) {
1608       operands.push_back(createImplicitRangeVar(
1609           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1610     }
1611   } else if (!consumeIf(Token::r_paren)) {
1612     // If the operand list was specified and non-empty, parse the operands.
1613     do {
1614       FailureOr<ast::Expr *> operand = parseExpr();
1615       if (failed(operand))
1616         return failure();
1617       operands.push_back(*operand);
1618     } while (consumeIf(Token::comma));
1619 
1620     if (failed(parseToken(Token::r_paren,
1621                           "expected `)` after operation operand list")))
1622       return failure();
1623   }
1624 
1625   // Check for the optional list of attributes.
1626   SmallVector<ast::NamedAttributeDecl *> attributes;
1627   if (consumeIf(Token::l_brace)) {
1628     do {
1629       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1630       if (failed(decl))
1631         return failure();
1632       attributes.emplace_back(*decl);
1633     } while (consumeIf(Token::comma));
1634 
1635     if (failed(parseToken(Token::r_brace,
1636                           "expected `}` after operation attribute list")))
1637       return failure();
1638   }
1639 
1640   // Check for the optional list of result types.
1641   SmallVector<ast::Expr *> resultTypes;
1642   if (consumeIf(Token::arrow)) {
1643     if (failed(parseToken(Token::l_paren,
1644                           "expected `(` before operation result type list")))
1645       return failure();
1646 
1647     // Handle the case of an empty result list.
1648     if (!consumeIf(Token::r_paren)) {
1649       do {
1650         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1651         if (failed(resultTypeExpr))
1652           return failure();
1653         resultTypes.push_back(*resultTypeExpr);
1654       } while (consumeIf(Token::comma));
1655 
1656       if (failed(parseToken(Token::r_paren,
1657                             "expected `)` after operation result type list")))
1658         return failure();
1659     }
1660   } else if (parserContext != ParserContext::Rewrite) {
1661     // If the result list isn't specified and we are in a match context, define
1662     // an inplace unconstrained result range corresponding to all of the results
1663     // of the operation. This avoids treating zero results the same way as
1664     // "unconstrained results".
1665     resultTypes.push_back(createImplicitRangeVar(
1666         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
1667   }
1668 
1669   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1670                              resultTypes);
1671 }
1672 
1673 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1674   SMRange loc = curToken.getLoc();
1675   consumeToken(Token::l_paren);
1676 
1677   DenseMap<StringRef, SMRange> usedNames;
1678   SmallVector<StringRef> elementNames;
1679   SmallVector<ast::Expr *> elements;
1680   if (curToken.isNot(Token::r_paren)) {
1681     do {
1682       // Check for the optional element name assignment before the value.
1683       StringRef elementName;
1684       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1685         Token elementNameTok = curToken;
1686         consumeToken();
1687 
1688         // The element name is only present if followed by an `=`.
1689         if (consumeIf(Token::equal)) {
1690           elementName = elementNameTok.getSpelling();
1691 
1692           // Check to see if this name is already used.
1693           auto elementNameIt =
1694               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1695           if (!elementNameIt.second) {
1696             return emitErrorAndNote(
1697                 elementNameTok.getLoc(),
1698                 llvm::formatv("duplicate tuple element label `{0}`",
1699                               elementName),
1700                 elementNameIt.first->getSecond(),
1701                 "see previous label use here");
1702           }
1703         } else {
1704           // Otherwise, we treat this as part of an expression so reset the
1705           // lexer.
1706           resetToken(elementNameTok.getLoc());
1707         }
1708       }
1709       elementNames.push_back(elementName);
1710 
1711       // Parse the tuple element value.
1712       FailureOr<ast::Expr *> element = parseExpr();
1713       if (failed(element))
1714         return failure();
1715       elements.push_back(*element);
1716     } while (consumeIf(Token::comma));
1717   }
1718   loc.End = curToken.getEndLoc();
1719   if (failed(
1720           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1721     return failure();
1722   return createTupleExpr(loc, elements, elementNames);
1723 }
1724 
1725 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1726   SMRange loc = curToken.getLoc();
1727   consumeToken(Token::kw_type);
1728 
1729   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1730   // identifier.
1731   if (!consumeIf(Token::less)) {
1732     resetToken(loc);
1733     return parseIdentifierExpr();
1734   }
1735 
1736   if (!curToken.isString())
1737     return emitError("expected string literal containing MLIR type");
1738   std::string attrExpr = curToken.getStringValue();
1739   consumeToken();
1740 
1741   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1742     return failure();
1743   return ast::TypeExpr::create(ctx, loc, attrExpr);
1744 }
1745 
1746 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
1747   StringRef name = curToken.getSpelling();
1748   SMRange nameLoc = curToken.getLoc();
1749   consumeToken(Token::underscore);
1750 
1751   // Underscore expressions require a constraint list.
1752   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
1753     return failure();
1754 
1755   // Parse the constraints for the expression.
1756   SmallVector<ast::ConstraintRef> constraints;
1757   if (failed(parseVariableDeclConstraintList(constraints)))
1758     return failure();
1759 
1760   ast::Type type;
1761   if (failed(validateVariableConstraints(constraints, type)))
1762     return failure();
1763   return createInlineVariableExpr(type, name, nameLoc, constraints);
1764 }
1765 
1766 //===----------------------------------------------------------------------===//
1767 // Stmts
1768 
1769 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
1770   FailureOr<ast::Stmt *> stmt;
1771   switch (curToken.getKind()) {
1772   case Token::kw_erase:
1773     stmt = parseEraseStmt();
1774     break;
1775   case Token::kw_let:
1776     stmt = parseLetStmt();
1777     break;
1778   case Token::kw_replace:
1779     stmt = parseReplaceStmt();
1780     break;
1781   case Token::kw_return:
1782     stmt = parseReturnStmt();
1783     break;
1784   case Token::kw_rewrite:
1785     stmt = parseRewriteStmt();
1786     break;
1787   default:
1788     stmt = parseExpr();
1789     break;
1790   }
1791   if (failed(stmt) ||
1792       (expectTerminalSemicolon &&
1793        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
1794     return failure();
1795   return stmt;
1796 }
1797 
1798 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
1799   SMLoc startLoc = curToken.getStartLoc();
1800   consumeToken(Token::l_brace);
1801 
1802   // Push a new block scope and parse any nested statements.
1803   pushDeclScope();
1804   SmallVector<ast::Stmt *> statements;
1805   while (curToken.isNot(Token::r_brace)) {
1806     FailureOr<ast::Stmt *> statement = parseStmt();
1807     if (failed(statement))
1808       return popDeclScope(), failure();
1809     statements.push_back(*statement);
1810   }
1811   popDeclScope();
1812 
1813   // Consume the end brace.
1814   SMRange location(startLoc, curToken.getEndLoc());
1815   consumeToken(Token::r_brace);
1816 
1817   return ast::CompoundStmt::create(ctx, location, statements);
1818 }
1819 
1820 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
1821   if (parserContext == ParserContext::Constraint)
1822     return emitError("`erase` cannot be used within a Constraint");
1823   SMRange loc = curToken.getLoc();
1824   consumeToken(Token::kw_erase);
1825 
1826   // Parse the root operation expression.
1827   FailureOr<ast::Expr *> rootOp = parseExpr();
1828   if (failed(rootOp))
1829     return failure();
1830 
1831   return createEraseStmt(loc, *rootOp);
1832 }
1833 
1834 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
1835   SMRange loc = curToken.getLoc();
1836   consumeToken(Token::kw_let);
1837 
1838   // Parse the name of the new variable.
1839   SMRange varLoc = curToken.getLoc();
1840   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
1841     // `_` is a reserved variable name.
1842     if (curToken.is(Token::underscore)) {
1843       return emitError(varLoc,
1844                        "`_` may only be used to define \"inline\" variables");
1845     }
1846     return emitError(varLoc,
1847                      "expected identifier after `let` to name a new variable");
1848   }
1849   StringRef varName = curToken.getSpelling();
1850   consumeToken();
1851 
1852   // Parse the optional set of constraints.
1853   SmallVector<ast::ConstraintRef> constraints;
1854   if (consumeIf(Token::colon) &&
1855       failed(parseVariableDeclConstraintList(constraints)))
1856     return failure();
1857 
1858   // Parse the optional initializer expression.
1859   ast::Expr *initializer = nullptr;
1860   if (consumeIf(Token::equal)) {
1861     FailureOr<ast::Expr *> initOrFailure = parseExpr();
1862     if (failed(initOrFailure))
1863       return failure();
1864     initializer = *initOrFailure;
1865 
1866     // Check that the constraints are compatible with having an initializer,
1867     // e.g. type constraints cannot be used with initializers.
1868     for (ast::ConstraintRef constraint : constraints) {
1869       LogicalResult result =
1870           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
1871               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
1872                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
1873                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
1874                   return this->emitError(
1875                       constraint.referenceLoc,
1876                       "type constraints are not permitted on variables with "
1877                       "initializers");
1878                 }
1879                 return success();
1880               })
1881               .Default(success());
1882       if (failed(result))
1883         return failure();
1884     }
1885   }
1886 
1887   FailureOr<ast::VariableDecl *> varDecl =
1888       createVariableDecl(varName, varLoc, initializer, constraints);
1889   if (failed(varDecl))
1890     return failure();
1891   return ast::LetStmt::create(ctx, loc, *varDecl);
1892 }
1893 
1894 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
1895   if (parserContext == ParserContext::Constraint)
1896     return emitError("`replace` cannot be used within a Constraint");
1897   SMRange loc = curToken.getLoc();
1898   consumeToken(Token::kw_replace);
1899 
1900   // Parse the root operation expression.
1901   FailureOr<ast::Expr *> rootOp = parseExpr();
1902   if (failed(rootOp))
1903     return failure();
1904 
1905   if (failed(
1906           parseToken(Token::kw_with, "expected `with` after root operation")))
1907     return failure();
1908 
1909   // The replacement portion of this statement is within a rewrite context.
1910   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1911                                               ParserContext::Rewrite);
1912 
1913   // Parse the replacement values.
1914   SmallVector<ast::Expr *> replValues;
1915   if (consumeIf(Token::l_paren)) {
1916     if (consumeIf(Token::r_paren)) {
1917       return emitError(
1918           loc, "expected at least one replacement value, consider using "
1919                "`erase` if no replacement values are desired");
1920     }
1921 
1922     do {
1923       FailureOr<ast::Expr *> replExpr = parseExpr();
1924       if (failed(replExpr))
1925         return failure();
1926       replValues.emplace_back(*replExpr);
1927     } while (consumeIf(Token::comma));
1928 
1929     if (failed(parseToken(Token::r_paren,
1930                           "expected `)` after replacement values")))
1931       return failure();
1932   } else {
1933     FailureOr<ast::Expr *> replExpr = parseExpr();
1934     if (failed(replExpr))
1935       return failure();
1936     replValues.emplace_back(*replExpr);
1937   }
1938 
1939   return createReplaceStmt(loc, *rootOp, replValues);
1940 }
1941 
1942 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
1943   SMRange loc = curToken.getLoc();
1944   consumeToken(Token::kw_return);
1945 
1946   // Parse the result value.
1947   FailureOr<ast::Expr *> resultExpr = parseExpr();
1948   if (failed(resultExpr))
1949     return failure();
1950 
1951   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
1952 }
1953 
1954 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
1955   if (parserContext == ParserContext::Constraint)
1956     return emitError("`rewrite` cannot be used within a Constraint");
1957   SMRange loc = curToken.getLoc();
1958   consumeToken(Token::kw_rewrite);
1959 
1960   // Parse the root operation.
1961   FailureOr<ast::Expr *> rootOp = parseExpr();
1962   if (failed(rootOp))
1963     return failure();
1964 
1965   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
1966     return failure();
1967 
1968   if (curToken.isNot(Token::l_brace))
1969     return emitError("expected `{` to start rewrite body");
1970 
1971   // The rewrite body of this statement is within a rewrite context.
1972   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1973                                               ParserContext::Rewrite);
1974 
1975   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
1976   if (failed(rewriteBody))
1977     return failure();
1978 
1979   // Verify the rewrite body.
1980   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
1981     if (isa<ast::ReturnStmt>(stmt)) {
1982       return emitError(stmt->getLoc(),
1983                        "`return` statements are only permitted within a "
1984                        "`Constraint` or `Rewrite` body");
1985     }
1986   }
1987 
1988   return createRewriteStmt(loc, *rootOp, *rewriteBody);
1989 }
1990 
1991 //===----------------------------------------------------------------------===//
1992 // Creation+Analysis
1993 //===----------------------------------------------------------------------===//
1994 
1995 //===----------------------------------------------------------------------===//
1996 // Decls
1997 
1998 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
1999   // Unwrap reference expressions.
2000   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2001     node = init->getDecl();
2002   return dyn_cast<ast::CallableDecl>(node);
2003 }
2004 
2005 FailureOr<ast::PatternDecl *>
2006 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2007                           const ParsedPatternMetadata &metadata,
2008                           ast::CompoundStmt *body) {
2009   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2010                                   metadata.hasBoundedRecursion, body);
2011 }
2012 
2013 ast::Type Parser::createUserConstraintRewriteResultType(
2014     ArrayRef<ast::VariableDecl *> results) {
2015   // Single result decls use the type of the single result.
2016   if (results.size() == 1)
2017     return results[0]->getType();
2018 
2019   // Multiple results use a tuple type, with the types and names grabbed from
2020   // the result variable decls.
2021   auto resultTypes = llvm::map_range(
2022       results, [&](const auto *result) { return result->getType(); });
2023   auto resultNames = llvm::map_range(
2024       results, [&](const auto *result) { return result->getName().getName(); });
2025   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2026                              llvm::to_vector(resultNames));
2027 }
2028 
2029 template <typename T>
2030 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2031     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2032     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2033     ast::CompoundStmt *body) {
2034   if (!body->getChildren().empty()) {
2035     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2036       ast::Expr *resultExpr = retStmt->getResultExpr();
2037 
2038       // Process the result of the decl. If no explicit signature results
2039       // were provided, check for return type inference. Otherwise, check that
2040       // the return expression can be converted to the expected type.
2041       if (results.empty())
2042         resultType = resultExpr->getType();
2043       else if (failed(convertExpressionTo(resultExpr, resultType)))
2044         return failure();
2045       else
2046         retStmt->setResultExpr(resultExpr);
2047     }
2048   }
2049   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2050 }
2051 
2052 FailureOr<ast::VariableDecl *>
2053 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2054                            ArrayRef<ast::ConstraintRef> constraints) {
2055   // The type of the variable, which is expected to be inferred by either a
2056   // constraint or an initializer expression.
2057   ast::Type type;
2058   if (failed(validateVariableConstraints(constraints, type)))
2059     return failure();
2060 
2061   if (initializer) {
2062     // Update the variable type based on the initializer, or try to convert the
2063     // initializer to the existing type.
2064     if (!type)
2065       type = initializer->getType();
2066     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2067       type = mergedType;
2068     else if (failed(convertExpressionTo(initializer, type)))
2069       return failure();
2070 
2071     // Otherwise, if there is no initializer check that the type has already
2072     // been resolved from the constraint list.
2073   } else if (!type) {
2074     return emitErrorAndNote(
2075         loc, "unable to infer type for variable `" + name + "`", loc,
2076         "the type of a variable must be inferable from the constraint "
2077         "list or the initializer");
2078   }
2079 
2080   // Constraint types cannot be used when defining variables.
2081   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2082     return emitError(
2083         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2084   }
2085 
2086   // Try to define a variable with the given name.
2087   FailureOr<ast::VariableDecl *> varDecl =
2088       defineVariableDecl(name, loc, type, initializer, constraints);
2089   if (failed(varDecl))
2090     return failure();
2091 
2092   return *varDecl;
2093 }
2094 
2095 FailureOr<ast::VariableDecl *>
2096 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2097                                       const ast::ConstraintRef &constraint) {
2098   // Constraint arguments may apply more complex constraints via the arguments.
2099   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2100   ast::Type argType;
2101   if (failed(validateVariableConstraint(constraint, argType,
2102                                         allowNonCoreConstraints)))
2103     return failure();
2104   return defineVariableDecl(name, loc, argType, constraint);
2105 }
2106 
2107 LogicalResult
2108 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2109                                     ast::Type &inferredType) {
2110   for (const ast::ConstraintRef &ref : constraints)
2111     if (failed(validateVariableConstraint(ref, inferredType)))
2112       return failure();
2113   return success();
2114 }
2115 
2116 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2117                                                  ast::Type &inferredType,
2118                                                  bool allowNonCoreConstraints) {
2119   ast::Type constraintType;
2120   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2121     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2122       if (failed(validateTypeConstraintExpr(typeExpr)))
2123         return failure();
2124     }
2125     constraintType = ast::AttributeType::get(ctx);
2126   } else if (const auto *cst =
2127                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2128     constraintType = ast::OperationType::get(ctx, cst->getName());
2129   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2130     constraintType = typeTy;
2131   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2132     constraintType = typeRangeTy;
2133   } else if (const auto *cst =
2134                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2135     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2136       if (failed(validateTypeConstraintExpr(typeExpr)))
2137         return failure();
2138     }
2139     constraintType = valueTy;
2140   } else if (const auto *cst =
2141                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2142     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2143       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2144         return failure();
2145     }
2146     constraintType = valueRangeTy;
2147   } else if (const auto *cst =
2148                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2149     if (!allowNonCoreConstraints) {
2150       return emitError(ref.referenceLoc,
2151                        "`Rewrite` arguments and results are only permitted to "
2152                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2153                        "`TypeRange`, `Value`, `ValueRange`");
2154     }
2155 
2156     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2157     if (inputs.size() != 1) {
2158       return emitErrorAndNote(ref.referenceLoc,
2159                               "`Constraint`s applied via a variable constraint "
2160                               "list must take a single input, but got " +
2161                                   Twine(inputs.size()),
2162                               cst->getLoc(),
2163                               "see definition of constraint here");
2164     }
2165     constraintType = inputs.front()->getType();
2166   } else {
2167     llvm_unreachable("unknown constraint type");
2168   }
2169 
2170   // Check that the constraint type is compatible with the current inferred
2171   // type.
2172   if (!inferredType) {
2173     inferredType = constraintType;
2174   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2175     inferredType = mergedTy;
2176   } else {
2177     return emitError(ref.referenceLoc,
2178                      llvm::formatv("constraint type `{0}` is incompatible "
2179                                    "with the previously inferred type `{1}`",
2180                                    constraintType, inferredType));
2181   }
2182   return success();
2183 }
2184 
2185 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2186   ast::Type typeExprType = typeExpr->getType();
2187   if (typeExprType != typeTy) {
2188     return emitError(typeExpr->getLoc(),
2189                      "expected expression of `Type` in type constraint");
2190   }
2191   return success();
2192 }
2193 
2194 LogicalResult
2195 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2196   ast::Type typeExprType = typeExpr->getType();
2197   if (typeExprType != typeRangeTy) {
2198     return emitError(typeExpr->getLoc(),
2199                      "expected expression of `TypeRange` in type constraint");
2200   }
2201   return success();
2202 }
2203 
2204 //===----------------------------------------------------------------------===//
2205 // Exprs
2206 
2207 FailureOr<ast::CallExpr *>
2208 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2209                        MutableArrayRef<ast::Expr *> arguments) {
2210   ast::Type parentType = parentExpr->getType();
2211 
2212   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2213   if (!callableDecl) {
2214     return emitError(loc,
2215                      llvm::formatv("expected a reference to a callable "
2216                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2217                                    parentType));
2218   }
2219   if (parserContext == ParserContext::Rewrite) {
2220     if (isa<ast::UserConstraintDecl>(callableDecl))
2221       return emitError(
2222           loc, "unable to invoke `Constraint` within a rewrite section");
2223   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2224     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2225   }
2226 
2227   // Verify the arguments of the call.
2228   /// Handle size mismatch.
2229   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2230   if (callArgs.size() != arguments.size()) {
2231     return emitErrorAndNote(
2232         loc,
2233         llvm::formatv("invalid number of arguments for {0} call; expected "
2234                       "{1}, but got {2}",
2235                       callableDecl->getCallableType(), callArgs.size(),
2236                       arguments.size()),
2237         callableDecl->getLoc(),
2238         llvm::formatv("see the definition of {0} here",
2239                       callableDecl->getName()->getName()));
2240   }
2241 
2242   /// Handle argument type mismatch.
2243   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2244     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2245                                   callableDecl->getName()->getName()),
2246                     callableDecl->getLoc());
2247   };
2248   for (auto it : llvm::zip(callArgs, arguments)) {
2249     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2250                                    attachDiagFn)))
2251       return failure();
2252   }
2253 
2254   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2255                                callableDecl->getResultType());
2256 }
2257 
2258 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2259                                                         ast::Decl *decl) {
2260   // Check the type of decl being referenced.
2261   ast::Type declType;
2262   if (isa<ast::ConstraintDecl>(decl))
2263     declType = ast::ConstraintType::get(ctx);
2264   else if (isa<ast::UserRewriteDecl>(decl))
2265     declType = ast::RewriteType::get(ctx);
2266   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2267     declType = varDecl->getType();
2268   else
2269     return emitError(loc, "invalid reference to `" +
2270                               decl->getName()->getName() + "`");
2271 
2272   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2273 }
2274 
2275 FailureOr<ast::DeclRefExpr *>
2276 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2277                                  ArrayRef<ast::ConstraintRef> constraints) {
2278   FailureOr<ast::VariableDecl *> decl =
2279       defineVariableDecl(name, loc, type, constraints);
2280   if (failed(decl))
2281     return failure();
2282   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2283 }
2284 
2285 FailureOr<ast::MemberAccessExpr *>
2286 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2287                                SMRange loc) {
2288   // Validate the member name for the given parent expression.
2289   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2290   if (failed(memberType))
2291     return failure();
2292 
2293   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2294 }
2295 
2296 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2297                                                   StringRef name, SMRange loc) {
2298   ast::Type parentType = parentExpr->getType();
2299   if (parentType.isa<ast::OperationType>()) {
2300     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2301       return valueRangeTy;
2302   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2303     // Handle indexed results.
2304     unsigned index = 0;
2305     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2306         index < tupleType.size()) {
2307       return tupleType.getElementTypes()[index];
2308     }
2309 
2310     // Handle named results.
2311     auto elementNames = tupleType.getElementNames();
2312     const auto *it = llvm::find(elementNames, name);
2313     if (it != elementNames.end())
2314       return tupleType.getElementTypes()[it - elementNames.begin()];
2315   }
2316   return emitError(
2317       loc,
2318       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2319                     name, parentType));
2320 }
2321 
2322 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2323     SMRange loc, const ast::OpNameDecl *name,
2324     MutableArrayRef<ast::Expr *> operands,
2325     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2326     MutableArrayRef<ast::Expr *> results) {
2327   Optional<StringRef> opNameRef = name->getName();
2328 
2329   // Verify the inputs operands.
2330   if (failed(validateOperationOperands(loc, opNameRef, operands)))
2331     return failure();
2332 
2333   // Verify the attribute list.
2334   for (ast::NamedAttributeDecl *attr : attributes) {
2335     // Check for an attribute type, or a type awaiting resolution.
2336     ast::Type attrType = attr->getValue()->getType();
2337     if (!attrType.isa<ast::AttributeType>()) {
2338       return emitError(
2339           attr->getValue()->getLoc(),
2340           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2341     }
2342   }
2343 
2344   // Verify the result types.
2345   if (failed(validateOperationResults(loc, opNameRef, results)))
2346     return failure();
2347 
2348   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2349                                     attributes);
2350 }
2351 
2352 LogicalResult
2353 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2354                                   MutableArrayRef<ast::Expr *> operands) {
2355   return validateOperationOperandsOrResults(loc, name, operands, valueTy,
2356                                             valueRangeTy);
2357 }
2358 
2359 LogicalResult
2360 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2361                                  MutableArrayRef<ast::Expr *> results) {
2362   return validateOperationOperandsOrResults(loc, name, results, typeTy,
2363                                             typeRangeTy);
2364 }
2365 
2366 LogicalResult Parser::validateOperationOperandsOrResults(
2367     SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2368     ast::Type singleTy, ast::Type rangeTy) {
2369   // All operation types accept a single range parameter.
2370   if (values.size() == 1) {
2371     if (failed(convertExpressionTo(values[0], rangeTy)))
2372       return failure();
2373     return success();
2374   }
2375 
2376   // Otherwise, accept the value groups as they have been defined and just
2377   // ensure they are one of the expected types.
2378   for (ast::Expr *&valueExpr : values) {
2379     ast::Type valueExprType = valueExpr->getType();
2380 
2381     // Check if this is one of the expected types.
2382     if (valueExprType == rangeTy || valueExprType == singleTy)
2383       continue;
2384 
2385     // If the operand is an Operation, allow converting to a Value or
2386     // ValueRange. This situations arises quite often with nested operation
2387     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2388     if (singleTy == valueTy) {
2389       if (valueExprType.isa<ast::OperationType>()) {
2390         valueExpr = convertOpToValue(valueExpr);
2391         continue;
2392       }
2393     }
2394 
2395     return emitError(
2396         valueExpr->getLoc(),
2397         llvm::formatv(
2398             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2399             singleTy, rangeTy, valueExprType));
2400   }
2401   return success();
2402 }
2403 
2404 FailureOr<ast::TupleExpr *>
2405 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2406                         ArrayRef<StringRef> elementNames) {
2407   for (const ast::Expr *element : elements) {
2408     ast::Type eleTy = element->getType();
2409     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2410       return emitError(
2411           element->getLoc(),
2412           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2413     }
2414   }
2415   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2416 }
2417 
2418 //===----------------------------------------------------------------------===//
2419 // Stmts
2420 
2421 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2422                                                     ast::Expr *rootOp) {
2423   // Check that root is an Operation.
2424   ast::Type rootType = rootOp->getType();
2425   if (!rootType.isa<ast::OperationType>())
2426     return emitError(rootOp->getLoc(), "expected `Op` expression");
2427 
2428   return ast::EraseStmt::create(ctx, loc, rootOp);
2429 }
2430 
2431 FailureOr<ast::ReplaceStmt *>
2432 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2433                           MutableArrayRef<ast::Expr *> replValues) {
2434   // Check that root is an Operation.
2435   ast::Type rootType = rootOp->getType();
2436   if (!rootType.isa<ast::OperationType>()) {
2437     return emitError(
2438         rootOp->getLoc(),
2439         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2440   }
2441 
2442   // If there are multiple replacement values, we implicitly convert any Op
2443   // expressions to the value form.
2444   bool shouldConvertOpToValues = replValues.size() > 1;
2445   for (ast::Expr *&replExpr : replValues) {
2446     ast::Type replType = replExpr->getType();
2447 
2448     // Check that replExpr is an Operation, Value, or ValueRange.
2449     if (replType.isa<ast::OperationType>()) {
2450       if (shouldConvertOpToValues)
2451         replExpr = convertOpToValue(replExpr);
2452       continue;
2453     }
2454 
2455     if (replType != valueTy && replType != valueRangeTy) {
2456       return emitError(replExpr->getLoc(),
2457                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2458                                      "expression, but got `{0}`",
2459                                      replType));
2460     }
2461   }
2462 
2463   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2464 }
2465 
2466 FailureOr<ast::RewriteStmt *>
2467 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2468                           ast::CompoundStmt *rewriteBody) {
2469   // Check that root is an Operation.
2470   ast::Type rootType = rootOp->getType();
2471   if (!rootType.isa<ast::OperationType>()) {
2472     return emitError(
2473         rootOp->getLoc(),
2474         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2475   }
2476 
2477   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2478 }
2479 
2480 //===----------------------------------------------------------------------===//
2481 // Parser
2482 //===----------------------------------------------------------------------===//
2483 
2484 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
2485                                                  llvm::SourceMgr &sourceMgr) {
2486   Parser parser(ctx, sourceMgr);
2487   return parser.parseModule();
2488 }
2489