1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/TableGen/Argument.h"
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Constraint.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/AST/Types.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Operation.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ManagedStatic.h"
28 #include "llvm/Support/SaveAndRestore.h"
29 #include "llvm/Support/ScopedPrinter.h"
30 #include "llvm/TableGen/Error.h"
31 #include "llvm/TableGen/Parser.h"
32 #include <string>
33 
34 using namespace mlir;
35 using namespace mlir::pdll;
36 
37 //===----------------------------------------------------------------------===//
38 // Parser
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 class Parser {
43 public:
44   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
45       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
46         curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
47         valueRangeTy(ast::ValueRangeType::get(ctx)),
48         typeTy(ast::TypeType::get(ctx)),
49         typeRangeTy(ast::TypeRangeType::get(ctx)),
50         attrTy(ast::AttributeType::get(ctx)) {}
51 
52   /// Try to parse a new module. Returns nullptr in the case of failure.
53   FailureOr<ast::Module *> parseModule();
54 
55 private:
56   /// The current context of the parser. It allows for the parser to know a bit
57   /// about the construct it is nested within during parsing. This is used
58   /// specifically to provide additional verification during parsing, e.g. to
59   /// prevent using rewrites within a match context, matcher constraints within
60   /// a rewrite section, etc.
61   enum class ParserContext {
62     /// The parser is in the global context.
63     Global,
64     /// The parser is currently within a Constraint, which disallows all types
65     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
66     Constraint,
67     /// The parser is currently within the matcher portion of a Pattern, which
68     /// is allows a terminal operation rewrite statement but no other rewrite
69     /// transformations.
70     PatternMatch,
71     /// The parser is currently within a Rewrite, which disallows calls to
72     /// constraints, requires operation expressions to have names, etc.
73     Rewrite,
74   };
75 
76   //===--------------------------------------------------------------------===//
77   // Parsing
78   //===--------------------------------------------------------------------===//
79 
80   /// Push a new decl scope onto the lexer.
81   ast::DeclScope *pushDeclScope() {
82     ast::DeclScope *newScope =
83         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
84     return (curDeclScope = newScope);
85   }
86   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
87 
88   /// Pop the last decl scope from the lexer.
89   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
90 
91   /// Parse the body of an AST module.
92   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
93 
94   /// Try to convert the given expression to `type`. Returns failure and emits
95   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
96   /// invoked to attach notes to the emitted error diagnostic. On success,
97   /// `expr` is updated to the expression used to convert to `type`.
98   LogicalResult convertExpressionTo(
99       ast::Expr *&expr, ast::Type type,
100       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
101 
102   /// Given an operation expression, convert it to a Value or ValueRange
103   /// typed expression.
104   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
105 
106   /// Lookup ODS information for the given operation, returns nullptr if no
107   /// information is found.
108   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
109     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
110   }
111 
112   //===--------------------------------------------------------------------===//
113   // Directives
114 
115   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
116   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
117   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
118                                SmallVectorImpl<ast::Decl *> &decls);
119 
120   /// Process the records of a parsed tablegen include file.
121   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
122                                SmallVectorImpl<ast::Decl *> &decls);
123 
124   /// Create a user defined native constraint for a constraint imported from
125   /// ODS.
126   template <typename ConstraintT>
127   ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
128                                                StringRef codeBlock, SMRange loc,
129                                                ast::Type type);
130   template <typename ConstraintT>
131   ast::Decl *
132   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
133                                     SMRange loc, ast::Type type);
134 
135   //===--------------------------------------------------------------------===//
136   // Decls
137 
138   /// This structure contains the set of pattern metadata that may be parsed.
139   struct ParsedPatternMetadata {
140     Optional<uint16_t> benefit;
141     bool hasBoundedRecursion = false;
142   };
143 
144   FailureOr<ast::Decl *> parseTopLevelDecl();
145   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
146 
147   /// Parse an argument variable as part of the signature of a
148   /// UserConstraintDecl or UserRewriteDecl.
149   FailureOr<ast::VariableDecl *> parseArgumentDecl();
150 
151   /// Parse a result variable as part of the signature of a UserConstraintDecl
152   /// or UserRewriteDecl.
153   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
154 
155   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
156   /// defined in a non-global context.
157   FailureOr<ast::UserConstraintDecl *>
158   parseUserConstraintDecl(bool isInline = false);
159 
160   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
161   /// non-global context, such as within a Pattern/Constraint/etc.
162   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
163 
164   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
165   /// PDLL constructs.
166   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
167       const ast::Name &name, bool isInline,
168       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
169       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
170 
171   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
172   /// defined in a non-global context.
173   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
174 
175   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
176   /// non-global context, such as within a Pattern/Rewrite/etc.
177   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
178 
179   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
180   /// PDLL constructs.
181   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
182       const ast::Name &name, bool isInline,
183       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
184       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
185 
186   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
187   /// effectively the same syntax, and only differ on slight semantics (given
188   /// the different parsing contexts).
189   template <typename T, typename ParseUserPDLLDeclFnT>
190   FailureOr<T *> parseUserConstraintOrRewriteDecl(
191       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
192       StringRef anonymousNamePrefix, bool isInline);
193 
194   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
195   /// These decls have effectively the same syntax.
196   template <typename T>
197   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
198       const ast::Name &name, bool isInline,
199       ArrayRef<ast::VariableDecl *> arguments,
200       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
201 
202   /// Parse the functional signature (i.e. the arguments and results) of a
203   /// UserConstraintDecl or UserRewriteDecl.
204   LogicalResult parseUserConstraintOrRewriteSignature(
205       SmallVectorImpl<ast::VariableDecl *> &arguments,
206       SmallVectorImpl<ast::VariableDecl *> &results,
207       ast::DeclScope *&argumentScope, ast::Type &resultType);
208 
209   /// Validate the return (which if present is specified by bodyIt) of a
210   /// UserConstraintDecl or UserRewriteDecl.
211   LogicalResult validateUserConstraintOrRewriteReturn(
212       StringRef declType, ast::CompoundStmt *body,
213       ArrayRef<ast::Stmt *>::iterator bodyIt,
214       ArrayRef<ast::Stmt *>::iterator bodyE,
215       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
216 
217   FailureOr<ast::CompoundStmt *>
218   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
219                   bool expectTerminalSemicolon = true);
220   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
221   FailureOr<ast::Decl *> parsePatternDecl();
222   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
223 
224   /// Check to see if a decl has already been defined with the given name, if
225   /// one has emit and error and return failure. Returns success otherwise.
226   LogicalResult checkDefineNamedDecl(const ast::Name &name);
227 
228   /// Try to define a variable decl with the given components, returns the
229   /// variable on success.
230   FailureOr<ast::VariableDecl *>
231   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
232                      ast::Expr *initExpr,
233                      ArrayRef<ast::ConstraintRef> constraints);
234   FailureOr<ast::VariableDecl *>
235   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
236                      ArrayRef<ast::ConstraintRef> constraints);
237 
238   /// Parse the constraint reference list for a variable decl.
239   LogicalResult parseVariableDeclConstraintList(
240       SmallVectorImpl<ast::ConstraintRef> &constraints);
241 
242   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
243   FailureOr<ast::Expr *> parseTypeConstraintExpr();
244 
245   /// Try to parse a single reference to a constraint. `typeConstraint` is the
246   /// location of a previously parsed type constraint for the entity that will
247   /// be constrained by the parsed constraint. `existingConstraints` are any
248   /// existing constraints that have already been parsed for the same entity
249   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
250   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
251   FailureOr<ast::ConstraintRef>
252   parseConstraint(Optional<SMRange> &typeConstraint,
253                   ArrayRef<ast::ConstraintRef> existingConstraints,
254                   bool allowInlineTypeConstraints);
255 
256   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
257   /// argument or result variable. The constraints for these variables do not
258   /// allow inline type constraints, and only permit a single constraint.
259   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
260 
261   //===--------------------------------------------------------------------===//
262   // Exprs
263 
264   FailureOr<ast::Expr *> parseExpr();
265 
266   /// Identifier expressions.
267   FailureOr<ast::Expr *> parseAttributeExpr();
268   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
269   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
270   FailureOr<ast::Expr *> parseIdentifierExpr();
271   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
272   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
273   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
274   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
275   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
276   FailureOr<ast::Expr *> parseOperationExpr();
277   FailureOr<ast::Expr *> parseTupleExpr();
278   FailureOr<ast::Expr *> parseTypeExpr();
279   FailureOr<ast::Expr *> parseUnderscoreExpr();
280 
281   //===--------------------------------------------------------------------===//
282   // Stmts
283 
284   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
285   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
286   FailureOr<ast::EraseStmt *> parseEraseStmt();
287   FailureOr<ast::LetStmt *> parseLetStmt();
288   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
289   FailureOr<ast::ReturnStmt *> parseReturnStmt();
290   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
291 
292   //===--------------------------------------------------------------------===//
293   // Creation+Analysis
294   //===--------------------------------------------------------------------===//
295 
296   //===--------------------------------------------------------------------===//
297   // Decls
298 
299   /// Try to extract a callable from the given AST node. Returns nullptr on
300   /// failure.
301   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
302 
303   /// Try to create a pattern decl with the given components, returning the
304   /// Pattern on success.
305   FailureOr<ast::PatternDecl *>
306   createPatternDecl(SMRange loc, const ast::Name *name,
307                     const ParsedPatternMetadata &metadata,
308                     ast::CompoundStmt *body);
309 
310   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
311   /// of results, defined as part of the signature.
312   ast::Type
313   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
314 
315   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
316   template <typename T>
317   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
318       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
319       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
320       ast::CompoundStmt *body);
321 
322   /// Try to create a variable decl with the given components, returning the
323   /// Variable on success.
324   FailureOr<ast::VariableDecl *>
325   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
326                      ArrayRef<ast::ConstraintRef> constraints);
327 
328   /// Create a variable for an argument or result defined as part of the
329   /// signature of a UserConstraintDecl/UserRewriteDecl.
330   FailureOr<ast::VariableDecl *>
331   createArgOrResultVariableDecl(StringRef name, SMRange loc,
332                                 const ast::ConstraintRef &constraint);
333 
334   /// Validate the constraints used to constraint a variable decl.
335   /// `inferredType` is the type of the variable inferred by the constraints
336   /// within the list, and is updated to the most refined type as determined by
337   /// the constraints. Returns success if the constraint list is valid, failure
338   /// otherwise.
339   LogicalResult
340   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
341                               ast::Type &inferredType);
342   /// Validate a single reference to a constraint. `inferredType` contains the
343   /// currently inferred variabled type and is refined within the type defined
344   /// by the constraint. Returns success if the constraint is valid, failure
345   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
346   /// defined constraints) may be used with the variable.
347   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
348                                            ast::Type &inferredType,
349                                            bool allowNonCoreConstraints = true);
350   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
351   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
352 
353   //===--------------------------------------------------------------------===//
354   // Exprs
355 
356   FailureOr<ast::CallExpr *>
357   createCallExpr(SMRange loc, ast::Expr *parentExpr,
358                  MutableArrayRef<ast::Expr *> arguments);
359   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
360   FailureOr<ast::DeclRefExpr *>
361   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
362                            ArrayRef<ast::ConstraintRef> constraints);
363   FailureOr<ast::MemberAccessExpr *>
364   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
365 
366   /// Validate the member access `name` into the given parent expression. On
367   /// success, this also returns the type of the member accessed.
368   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
369                                             StringRef name, SMRange loc);
370   FailureOr<ast::OperationExpr *>
371   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
372                       MutableArrayRef<ast::Expr *> operands,
373                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
374                       MutableArrayRef<ast::Expr *> results);
375   LogicalResult
376   validateOperationOperands(SMRange loc, Optional<StringRef> name,
377                             const ods::Operation *odsOp,
378                             MutableArrayRef<ast::Expr *> operands);
379   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
380                                          const ods::Operation *odsOp,
381                                          MutableArrayRef<ast::Expr *> results);
382   LogicalResult validateOperationOperandsOrResults(
383       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
384       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
385       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
386       ast::Type rangeTy);
387   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
388                                               ArrayRef<ast::Expr *> elements,
389                                               ArrayRef<StringRef> elementNames);
390 
391   //===--------------------------------------------------------------------===//
392   // Stmts
393 
394   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
395   FailureOr<ast::ReplaceStmt *>
396   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
397                     MutableArrayRef<ast::Expr *> replValues);
398   FailureOr<ast::RewriteStmt *>
399   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
400                     ast::CompoundStmt *rewriteBody);
401 
402   //===--------------------------------------------------------------------===//
403   // Lexer Utilities
404   //===--------------------------------------------------------------------===//
405 
406   /// If the current token has the specified kind, consume it and return true.
407   /// If not, return false.
408   bool consumeIf(Token::Kind kind) {
409     if (curToken.isNot(kind))
410       return false;
411     consumeToken(kind);
412     return true;
413   }
414 
415   /// Advance the current lexer onto the next token.
416   void consumeToken() {
417     assert(curToken.isNot(Token::eof, Token::error) &&
418            "shouldn't advance past EOF or errors");
419     curToken = lexer.lexToken();
420   }
421 
422   /// Advance the current lexer onto the next token, asserting what the expected
423   /// current token is. This is preferred to the above method because it leads
424   /// to more self-documenting code with better checking.
425   void consumeToken(Token::Kind kind) {
426     assert(curToken.is(kind) && "consumed an unexpected token");
427     consumeToken();
428   }
429 
430   /// Reset the lexer to the location at the given position.
431   void resetToken(SMRange tokLoc) {
432     lexer.resetPointer(tokLoc.Start.getPointer());
433     curToken = lexer.lexToken();
434   }
435 
436   /// Consume the specified token if present and return success. On failure,
437   /// output a diagnostic and return failure.
438   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
439     if (curToken.getKind() != kind)
440       return emitError(curToken.getLoc(), msg);
441     consumeToken();
442     return success();
443   }
444   LogicalResult emitError(SMRange loc, const Twine &msg) {
445     lexer.emitError(loc, msg);
446     return failure();
447   }
448   LogicalResult emitError(const Twine &msg) {
449     return emitError(curToken.getLoc(), msg);
450   }
451   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
452                                  const Twine &note) {
453     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
454     return failure();
455   }
456 
457   //===--------------------------------------------------------------------===//
458   // Fields
459   //===--------------------------------------------------------------------===//
460 
461   /// The owning AST context.
462   ast::Context &ctx;
463 
464   /// The lexer of this parser.
465   Lexer lexer;
466 
467   /// The current token within the lexer.
468   Token curToken;
469 
470   /// The most recently defined decl scope.
471   ast::DeclScope *curDeclScope = nullptr;
472   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
473 
474   /// The current context of the parser.
475   ParserContext parserContext = ParserContext::Global;
476 
477   /// Cached types to simplify verification and expression creation.
478   ast::Type valueTy, valueRangeTy;
479   ast::Type typeTy, typeRangeTy;
480   ast::Type attrTy;
481 
482   /// A counter used when naming anonymous constraints and rewrites.
483   unsigned anonymousDeclNameCounter = 0;
484 };
485 } // namespace
486 
487 FailureOr<ast::Module *> Parser::parseModule() {
488   SMLoc moduleLoc = curToken.getStartLoc();
489   pushDeclScope();
490 
491   // Parse the top-level decls of the module.
492   SmallVector<ast::Decl *> decls;
493   if (failed(parseModuleBody(decls)))
494     return popDeclScope(), failure();
495 
496   popDeclScope();
497   return ast::Module::create(ctx, moduleLoc, decls);
498 }
499 
500 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
501   while (curToken.isNot(Token::eof)) {
502     if (curToken.is(Token::directive)) {
503       if (failed(parseDirective(decls)))
504         return failure();
505       continue;
506     }
507 
508     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
509     if (failed(decl))
510       return failure();
511     decls.push_back(*decl);
512   }
513   return success();
514 }
515 
516 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
517   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
518                                                  valueRangeTy);
519 }
520 
521 LogicalResult Parser::convertExpressionTo(
522     ast::Expr *&expr, ast::Type type,
523     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
524   ast::Type exprType = expr->getType();
525   if (exprType == type)
526     return success();
527 
528   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
529     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
530         expr->getLoc(), llvm::formatv("unable to convert expression of type "
531                                       "`{0}` to the expected type of "
532                                       "`{1}`",
533                                       exprType, type));
534     if (noteAttachFn)
535       noteAttachFn(*diag);
536     return diag;
537   };
538 
539   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
540     // Two operation types are compatible if they have the same name, or if the
541     // expected type is more general.
542     if (auto opType = type.dyn_cast<ast::OperationType>()) {
543       if (opType.getName())
544         return emitConvertError();
545       return success();
546     }
547 
548     // An operation can always convert to a ValueRange.
549     if (type == valueRangeTy) {
550       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
551                                                      valueRangeTy);
552       return success();
553     }
554 
555     // Allow conversion to a single value by constraining the result range.
556     if (type == valueTy) {
557       // If the operation is registered, we can verify if it can ever have a
558       // single result.
559       Optional<StringRef> opName = exprOpType.getName();
560       if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
561         if (odsOp->getResults().empty()) {
562           return emitConvertError()->attachNote(
563               llvm::formatv("see the definition of `{0}`, which was defined "
564                             "with zero results",
565                             odsOp->getName()),
566               odsOp->getLoc());
567         }
568 
569         unsigned numSingleResults = llvm::count_if(
570             odsOp->getResults(), [](const ods::OperandOrResult &result) {
571               return result.getVariableLengthKind() ==
572                      ods::VariableLengthKind::Single;
573             });
574         if (numSingleResults > 1) {
575           return emitConvertError()->attachNote(
576               llvm::formatv("see the definition of `{0}`, which was defined "
577                             "with at least {1} results",
578                             odsOp->getName(), numSingleResults),
579               odsOp->getLoc());
580         }
581       }
582 
583       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
584                                                      valueTy);
585       return success();
586     }
587     return emitConvertError();
588   }
589 
590   // FIXME: Decide how to allow/support converting a single result to multiple,
591   // and multiple to a single result. For now, we just allow Single->Range,
592   // but this isn't something really supported in the PDL dialect. We should
593   // figure out some way to support both.
594   if ((exprType == valueTy || exprType == valueRangeTy) &&
595       (type == valueTy || type == valueRangeTy))
596     return success();
597   if ((exprType == typeTy || exprType == typeRangeTy) &&
598       (type == typeTy || type == typeRangeTy))
599     return success();
600 
601   // Handle tuple types.
602   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
603     auto tupleType = type.dyn_cast<ast::TupleType>();
604     if (!tupleType || tupleType.size() != exprTupleType.size())
605       return emitConvertError();
606 
607     // Build a new tuple expression using each of the elements of the current
608     // tuple.
609     SmallVector<ast::Expr *> newExprs;
610     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
611       newExprs.push_back(ast::MemberAccessExpr::create(
612           ctx, expr->getLoc(), expr, llvm::to_string(i),
613           exprTupleType.getElementTypes()[i]));
614 
615       auto diagFn = [&](ast::Diagnostic &diag) {
616         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
617                                       i, exprTupleType));
618         if (noteAttachFn)
619           noteAttachFn(diag);
620       };
621       if (failed(convertExpressionTo(newExprs.back(),
622                                      tupleType.getElementTypes()[i], diagFn)))
623         return failure();
624     }
625     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
626                                   tupleType.getElementNames());
627     return success();
628   }
629 
630   return emitConvertError();
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // Directives
635 
636 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
637   StringRef directive = curToken.getSpelling();
638   if (directive == "#include")
639     return parseInclude(decls);
640 
641   return emitError("unknown directive `" + directive + "`");
642 }
643 
644 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
645   SMRange loc = curToken.getLoc();
646   consumeToken(Token::directive);
647 
648   // Parse the file being included.
649   if (!curToken.isString())
650     return emitError(loc,
651                      "expected string file name after `include` directive");
652   SMRange fileLoc = curToken.getLoc();
653   std::string filenameStr = curToken.getStringValue();
654   StringRef filename = filenameStr;
655   consumeToken();
656 
657   // Check the type of include. If ending with `.pdll`, this is another pdl file
658   // to be parsed along with the current module.
659   if (filename.endswith(".pdll")) {
660     if (failed(lexer.pushInclude(filename)))
661       return emitError(fileLoc,
662                        "unable to open include file `" + filename + "`");
663 
664     // If we added the include successfully, parse it into the current module.
665     // Make sure to save the current token so that we can restore it when we
666     // finish parsing the nested file.
667     Token oldToken = curToken;
668     curToken = lexer.lexToken();
669     LogicalResult result = parseModuleBody(decls);
670     curToken = oldToken;
671     return result;
672   }
673 
674   // Otherwise, this must be a `.td` include.
675   if (filename.endswith(".td"))
676     return parseTdInclude(filename, fileLoc, decls);
677 
678   return emitError(fileLoc,
679                    "expected include filename to end with `.pdll` or `.td`");
680 }
681 
682 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
683                                      SmallVectorImpl<ast::Decl *> &decls) {
684   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
685 
686   // This class provides a context argument for the llvm::SourceMgr diagnostic
687   // handler.
688   struct DiagHandlerContext {
689     Parser &parser;
690     StringRef filename;
691     llvm::SMRange loc;
692   } handlerContext{*this, filename, fileLoc};
693 
694   // Set the diagnostic handler for the tablegen source manager.
695   llvm::SrcMgr.setDiagHandler(
696       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
697         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
698         (void)ctx->parser.emitError(
699             ctx->loc,
700             llvm::formatv("error while processing include file `{0}`: {1}",
701                           ctx->filename, diag.getMessage()));
702       },
703       &handlerContext);
704 
705   // Use the source manager to open the file, but don't yet add it.
706   std::string includedFile;
707   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
708       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
709   if (!includeBuffer)
710     return emitError(fileLoc, "unable to open include file `" + filename + "`");
711 
712   auto processFn = [&](llvm::RecordKeeper &records) {
713     processTdIncludeRecords(records, decls);
714 
715     // After we are done processing, move all of the tablegen source buffers to
716     // the main parser source mgr. This allows for directly using source
717     // locations from the .td files without needing to remap them.
718     parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr);
719     return false;
720   };
721   if (llvm::TableGenParseFile(std::move(*includeBuffer),
722                               parserSrcMgr.getIncludeDirs(), processFn))
723     return failure();
724 
725   return success();
726 }
727 
728 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
729                                      SmallVectorImpl<ast::Decl *> &decls) {
730   // Return the length kind of the given value.
731   auto getLengthKind = [](const auto &value) {
732     if (value.isOptional())
733       return ods::VariableLengthKind::Optional;
734     return value.isVariadic() ? ods::VariableLengthKind::Variadic
735                               : ods::VariableLengthKind::Single;
736   };
737 
738   // Insert a type constraint into the ODS context.
739   ods::Context &odsContext = ctx.getODSContext();
740   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
741       -> const ods::TypeConstraint & {
742     return odsContext.insertTypeConstraint(cst.constraint.getDefName(),
743                                            cst.constraint.getSummary(),
744                                            cst.constraint.getCPPClassName());
745   };
746   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
747     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
748   };
749 
750   // Process the parsed tablegen records to build ODS information.
751   /// Operations.
752   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
753     tblgen::Operator op(def);
754 
755     bool inserted = false;
756     ods::Operation *odsOp = nullptr;
757     std::tie(odsOp, inserted) =
758         odsContext.insertOperation(op.getOperationName(), op.getSummary(),
759                                    op.getDescription(), op.getLoc().front());
760 
761     // Ignore operations that have already been added.
762     if (!inserted)
763       continue;
764 
765     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
766       odsOp->appendAttribute(
767           attr.name, attr.attr.isOptional(),
768           odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(),
769                                                attr.attr.getSummary(),
770                                                attr.attr.getStorageType()));
771     }
772     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
773       odsOp->appendOperand(operand.name, getLengthKind(operand),
774                            addTypeConstraint(operand));
775     }
776     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
777       odsOp->appendResult(result.name, getLengthKind(result),
778                           addTypeConstraint(result));
779     }
780   }
781   /// Attr constraints.
782   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
783     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
784       decls.push_back(
785           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
786               tblgen::AttrConstraint(def),
787               convertLocToRange(def->getLoc().front()), attrTy));
788     }
789   }
790   /// Type constraints.
791   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
792     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
793       decls.push_back(
794           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
795               tblgen::TypeConstraint(def),
796               convertLocToRange(def->getLoc().front()), typeTy));
797     }
798   }
799   /// Interfaces.
800   ast::Type opTy = ast::OperationType::get(ctx);
801   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
802     StringRef name = def->getName();
803     if (def->isAnonymous() || curDeclScope->lookup(name) ||
804         def->isSubClassOf("DeclareInterfaceMethods"))
805       continue;
806     SMRange loc = convertLocToRange(def->getLoc().front());
807 
808     StringRef className = def->getValueAsString("cppClassName");
809     StringRef cppNamespace = def->getValueAsString("cppNamespace");
810     std::string codeBlock =
811         llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className)
812             .str();
813 
814     if (def->isSubClassOf("OpInterface")) {
815       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
816           name, codeBlock, loc, opTy));
817     } else if (def->isSubClassOf("AttrInterface")) {
818       decls.push_back(
819           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
820               name, codeBlock, loc, attrTy));
821     } else if (def->isSubClassOf("TypeInterface")) {
822       decls.push_back(
823           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
824               name, codeBlock, loc, typeTy));
825     }
826   }
827 }
828 
829 template <typename ConstraintT>
830 ast::Decl *
831 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
832                                           SMRange loc, ast::Type type) {
833   // Build the single input parameter.
834   ast::DeclScope *argScope = pushDeclScope();
835   auto *paramVar = ast::VariableDecl::create(
836       ctx, ast::Name::create(ctx, "self", loc), type,
837       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
838   argScope->add(paramVar);
839   popDeclScope();
840 
841   // Build the native constraint.
842   auto *constraintDecl = ast::UserConstraintDecl::createNative(
843       ctx, ast::Name::create(ctx, name, loc), paramVar,
844       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
845   curDeclScope->add(constraintDecl);
846   return constraintDecl;
847 }
848 
849 template <typename ConstraintT>
850 ast::Decl *
851 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
852                                           SMRange loc, ast::Type type) {
853   // Format the condition template.
854   tblgen::FmtContext fmtContext;
855   fmtContext.withSelf("self");
856   std::string codeBlock =
857       tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext);
858 
859   return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(),
860                                                         codeBlock, loc, type);
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // Decls
865 
866 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
867   FailureOr<ast::Decl *> decl;
868   switch (curToken.getKind()) {
869   case Token::kw_Constraint:
870     decl = parseUserConstraintDecl();
871     break;
872   case Token::kw_Pattern:
873     decl = parsePatternDecl();
874     break;
875   case Token::kw_Rewrite:
876     decl = parseUserRewriteDecl();
877     break;
878   default:
879     return emitError("expected top-level declaration, such as a `Pattern`");
880   }
881   if (failed(decl))
882     return failure();
883 
884   // If the decl has a name, add it to the current scope.
885   if (const ast::Name *name = (*decl)->getName()) {
886     if (failed(checkDefineNamedDecl(*name)))
887       return failure();
888     curDeclScope->add(*decl);
889   }
890   return decl;
891 }
892 
893 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
894   std::string attrNameStr;
895   if (curToken.isString())
896     attrNameStr = curToken.getStringValue();
897   else if (curToken.is(Token::identifier) || curToken.isKeyword())
898     attrNameStr = curToken.getSpelling().str();
899   else
900     return emitError("expected identifier or string attribute name");
901   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
902   consumeToken();
903 
904   // Check for a value of the attribute.
905   ast::Expr *attrValue = nullptr;
906   if (consumeIf(Token::equal)) {
907     FailureOr<ast::Expr *> attrExpr = parseExpr();
908     if (failed(attrExpr))
909       return failure();
910     attrValue = *attrExpr;
911   } else {
912     // If there isn't a concrete value, create an expression representing a
913     // UnitAttr.
914     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
915   }
916 
917   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
918 }
919 
920 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
921     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
922     bool expectTerminalSemicolon) {
923   consumeToken(Token::equal_arrow);
924 
925   // Parse the single statement of the lambda body.
926   SMLoc bodyStartLoc = curToken.getStartLoc();
927   pushDeclScope();
928   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
929   bool failedToParse =
930       failed(singleStatement) || failed(processStatementFn(*singleStatement));
931   popDeclScope();
932   if (failedToParse)
933     return failure();
934 
935   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
936   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
937 }
938 
939 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
940   // Ensure that the argument is named.
941   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
942     return emitError("expected identifier argument name");
943 
944   // Parse the argument similarly to a normal variable.
945   StringRef name = curToken.getSpelling();
946   SMRange nameLoc = curToken.getLoc();
947   consumeToken();
948 
949   if (failed(
950           parseToken(Token::colon, "expected `:` before argument constraint")))
951     return failure();
952 
953   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
954   if (failed(cst))
955     return failure();
956 
957   return createArgOrResultVariableDecl(name, nameLoc, *cst);
958 }
959 
960 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
961   // Check to see if this result is named.
962   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
963     // Check to see if this name actually refers to a Constraint.
964     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
965     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
966       // If yes, and this is a Rewrite, give a nice error message as non-Core
967       // constraints are not supported on Rewrite results.
968       if (parserContext == ParserContext::Rewrite) {
969         return emitError(
970             "`Rewrite` results are only permitted to use core constraints, "
971             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
972       }
973 
974       // Otherwise, parse this as an unnamed result variable.
975     } else {
976       // If it wasn't a constraint, parse the result similarly to a variable. If
977       // there is already an existing decl, we will emit an error when defining
978       // this variable later.
979       StringRef name = curToken.getSpelling();
980       SMRange nameLoc = curToken.getLoc();
981       consumeToken();
982 
983       if (failed(parseToken(Token::colon,
984                             "expected `:` before result constraint")))
985         return failure();
986 
987       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
988       if (failed(cst))
989         return failure();
990 
991       return createArgOrResultVariableDecl(name, nameLoc, *cst);
992     }
993   }
994 
995   // If it isn't named, we parse the constraint directly and create an unnamed
996   // result variable.
997   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
998   if (failed(cst))
999     return failure();
1000 
1001   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1002 }
1003 
1004 FailureOr<ast::UserConstraintDecl *>
1005 Parser::parseUserConstraintDecl(bool isInline) {
1006   // Constraints and rewrites have very similar formats, dispatch to a shared
1007   // interface for parsing.
1008   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1009       [&](auto &&...args) {
1010         return this->parseUserPDLLConstraintDecl(args...);
1011       },
1012       ParserContext::Constraint, "constraint", isInline);
1013 }
1014 
1015 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1016   FailureOr<ast::UserConstraintDecl *> decl =
1017       parseUserConstraintDecl(/*isInline=*/true);
1018   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1019     return failure();
1020 
1021   curDeclScope->add(*decl);
1022   return decl;
1023 }
1024 
1025 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1026     const ast::Name &name, bool isInline,
1027     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1028     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1029   // Push the argument scope back onto the list, so that the body can
1030   // reference arguments.
1031   pushDeclScope(argumentScope);
1032 
1033   // Parse the body of the constraint. The body is either defined as a compound
1034   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1035   ast::CompoundStmt *body;
1036   if (curToken.is(Token::equal_arrow)) {
1037     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1038         [&](ast::Stmt *&stmt) -> LogicalResult {
1039           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1040           if (!stmtExpr) {
1041             return emitError(stmt->getLoc(),
1042                              "expected `Constraint` lambda body to contain a "
1043                              "single expression");
1044           }
1045           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1046           return success();
1047         },
1048         /*expectTerminalSemicolon=*/!isInline);
1049     if (failed(bodyResult))
1050       return failure();
1051     body = *bodyResult;
1052   } else {
1053     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1054     if (failed(bodyResult))
1055       return failure();
1056     body = *bodyResult;
1057 
1058     // Verify the structure of the body.
1059     auto bodyIt = body->begin(), bodyE = body->end();
1060     for (; bodyIt != bodyE; ++bodyIt)
1061       if (isa<ast::ReturnStmt>(*bodyIt))
1062         break;
1063     if (failed(validateUserConstraintOrRewriteReturn(
1064             "Constraint", body, bodyIt, bodyE, results, resultType)))
1065       return failure();
1066   }
1067   popDeclScope();
1068 
1069   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1070       name, arguments, results, resultType, body);
1071 }
1072 
1073 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1074   // Constraints and rewrites have very similar formats, dispatch to a shared
1075   // interface for parsing.
1076   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1077       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1078       ParserContext::Rewrite, "rewrite", isInline);
1079 }
1080 
1081 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1082   FailureOr<ast::UserRewriteDecl *> decl =
1083       parseUserRewriteDecl(/*isInline=*/true);
1084   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1085     return failure();
1086 
1087   curDeclScope->add(*decl);
1088   return decl;
1089 }
1090 
1091 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1092     const ast::Name &name, bool isInline,
1093     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1094     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1095   // Push the argument scope back onto the list, so that the body can
1096   // reference arguments.
1097   curDeclScope = argumentScope;
1098   ast::CompoundStmt *body;
1099   if (curToken.is(Token::equal_arrow)) {
1100     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1101         [&](ast::Stmt *&statement) -> LogicalResult {
1102           if (isa<ast::OpRewriteStmt>(statement))
1103             return success();
1104 
1105           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1106           if (!statementExpr) {
1107             return emitError(
1108                 statement->getLoc(),
1109                 "expected `Rewrite` lambda body to contain a single expression "
1110                 "or an operation rewrite statement; such as `erase`, "
1111                 "`replace`, or `rewrite`");
1112           }
1113           statement =
1114               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1115           return success();
1116         },
1117         /*expectTerminalSemicolon=*/!isInline);
1118     if (failed(bodyResult))
1119       return failure();
1120     body = *bodyResult;
1121   } else {
1122     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1123     if (failed(bodyResult))
1124       return failure();
1125     body = *bodyResult;
1126   }
1127   popDeclScope();
1128 
1129   // Verify the structure of the body.
1130   auto bodyIt = body->begin(), bodyE = body->end();
1131   for (; bodyIt != bodyE; ++bodyIt)
1132     if (isa<ast::ReturnStmt>(*bodyIt))
1133       break;
1134   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1135                                                    bodyE, results, resultType)))
1136     return failure();
1137   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1138       name, arguments, results, resultType, body);
1139 }
1140 
1141 template <typename T, typename ParseUserPDLLDeclFnT>
1142 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1143     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1144     StringRef anonymousNamePrefix, bool isInline) {
1145   SMRange loc = curToken.getLoc();
1146   consumeToken();
1147   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1148 
1149   // Parse the name of the decl.
1150   const ast::Name *name = nullptr;
1151   if (curToken.isNot(Token::identifier)) {
1152     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1153     // in C++, so being unnamed is fine.
1154     if (!isInline)
1155       return emitError("expected identifier name");
1156 
1157     // Create a unique anonymous name to use, as the name for this decl is not
1158     // important.
1159     std::string anonName =
1160         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1161                       anonymousDeclNameCounter++)
1162             .str();
1163     name = &ast::Name::create(ctx, anonName, loc);
1164   } else {
1165     // If a name was provided, we can use it directly.
1166     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1167     consumeToken(Token::identifier);
1168   }
1169 
1170   // Parse the functional signature of the decl.
1171   SmallVector<ast::VariableDecl *> arguments, results;
1172   ast::DeclScope *argumentScope;
1173   ast::Type resultType;
1174   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1175                                                    argumentScope, resultType)))
1176     return failure();
1177 
1178   // Check to see which type of constraint this is. If the constraint contains a
1179   // compound body, this is a PDLL decl.
1180   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1181     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1182                            resultType);
1183 
1184   // Otherwise, this is a native decl.
1185   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1186                                                    results, resultType);
1187 }
1188 
1189 template <typename T>
1190 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1191     const ast::Name &name, bool isInline,
1192     ArrayRef<ast::VariableDecl *> arguments,
1193     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1194   // If followed by a string, the native code body has also been specified.
1195   std::string codeStrStorage;
1196   Optional<StringRef> optCodeStr;
1197   if (curToken.isString()) {
1198     codeStrStorage = curToken.getStringValue();
1199     optCodeStr = codeStrStorage;
1200     consumeToken();
1201   } else if (isInline) {
1202     return emitError(name.getLoc(),
1203                      "external declarations must be declared in global scope");
1204   }
1205   if (failed(parseToken(Token::semicolon,
1206                         "expected `;` after native declaration")))
1207     return failure();
1208   // TODO: PDL should be able to support constraint results in certain
1209   // situations, we should revise this.
1210   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1211     return emitError(
1212         "native Constraints currently do not support returning results");
1213   }
1214   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1215 }
1216 
1217 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1218     SmallVectorImpl<ast::VariableDecl *> &arguments,
1219     SmallVectorImpl<ast::VariableDecl *> &results,
1220     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1221   // Parse the argument list of the decl.
1222   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1223     return failure();
1224 
1225   argumentScope = pushDeclScope();
1226   if (curToken.isNot(Token::r_paren)) {
1227     do {
1228       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1229       if (failed(argument))
1230         return failure();
1231       arguments.emplace_back(*argument);
1232     } while (consumeIf(Token::comma));
1233   }
1234   popDeclScope();
1235   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1236     return failure();
1237 
1238   // Parse the results of the decl.
1239   pushDeclScope();
1240   if (consumeIf(Token::arrow)) {
1241     auto parseResultFn = [&]() -> LogicalResult {
1242       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1243       if (failed(result))
1244         return failure();
1245       results.emplace_back(*result);
1246       return success();
1247     };
1248 
1249     // Check for a list of results.
1250     if (consumeIf(Token::l_paren)) {
1251       do {
1252         if (failed(parseResultFn()))
1253           return failure();
1254       } while (consumeIf(Token::comma));
1255       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1256         return failure();
1257 
1258       // Otherwise, there is only one result.
1259     } else if (failed(parseResultFn())) {
1260       return failure();
1261     }
1262   }
1263   popDeclScope();
1264 
1265   // Compute the result type of the decl.
1266   resultType = createUserConstraintRewriteResultType(results);
1267 
1268   // Verify that results are only named if there are more than one.
1269   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1270     return emitError(
1271         results.front()->getLoc(),
1272         "cannot create a single-element tuple with an element label");
1273   }
1274   return success();
1275 }
1276 
1277 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1278     StringRef declType, ast::CompoundStmt *body,
1279     ArrayRef<ast::Stmt *>::iterator bodyIt,
1280     ArrayRef<ast::Stmt *>::iterator bodyE,
1281     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1282   // Handle if a `return` was provided.
1283   if (bodyIt != bodyE) {
1284     // Emit an error if we have trailing statements after the return.
1285     if (std::next(bodyIt) != bodyE) {
1286       return emitError(
1287           (*std::next(bodyIt))->getLoc(),
1288           llvm::formatv("`return` terminated the `{0}` body, but found "
1289                         "trailing statements afterwards",
1290                         declType));
1291     }
1292 
1293     // Otherwise if a return wasn't provided, check that no results are
1294     // expected.
1295   } else if (!results.empty()) {
1296     return emitError(
1297         {body->getLoc().End, body->getLoc().End},
1298         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1299                       declType, resultType));
1300   }
1301   return success();
1302 }
1303 
1304 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1305   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1306     if (isa<ast::OpRewriteStmt>(statement))
1307       return success();
1308     return emitError(
1309         statement->getLoc(),
1310         "expected Pattern lambda body to contain a single operation "
1311         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1312   });
1313 }
1314 
1315 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1316   SMRange loc = curToken.getLoc();
1317   consumeToken(Token::kw_Pattern);
1318   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1319                                               ParserContext::PatternMatch);
1320 
1321   // Check for an optional identifier for the pattern name.
1322   const ast::Name *name = nullptr;
1323   if (curToken.is(Token::identifier)) {
1324     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1325     consumeToken(Token::identifier);
1326   }
1327 
1328   // Parse any pattern metadata.
1329   ParsedPatternMetadata metadata;
1330   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1331     return failure();
1332 
1333   // Parse the pattern body.
1334   ast::CompoundStmt *body;
1335 
1336   // Handle a lambda body.
1337   if (curToken.is(Token::equal_arrow)) {
1338     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1339     if (failed(bodyResult))
1340       return failure();
1341     body = *bodyResult;
1342   } else {
1343     if (curToken.isNot(Token::l_brace))
1344       return emitError("expected `{` or `=>` to start pattern body");
1345     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1346     if (failed(bodyResult))
1347       return failure();
1348     body = *bodyResult;
1349 
1350     // Verify the body of the pattern.
1351     auto bodyIt = body->begin(), bodyE = body->end();
1352     for (; bodyIt != bodyE; ++bodyIt) {
1353       if (isa<ast::ReturnStmt>(*bodyIt)) {
1354         return emitError((*bodyIt)->getLoc(),
1355                          "`return` statements are only permitted within a "
1356                          "`Constraint` or `Rewrite` body");
1357       }
1358       // Break when we've found the rewrite statement.
1359       if (isa<ast::OpRewriteStmt>(*bodyIt))
1360         break;
1361     }
1362     if (bodyIt == bodyE) {
1363       return emitError(loc,
1364                        "expected Pattern body to terminate with an operation "
1365                        "rewrite statement, such as `erase`");
1366     }
1367     if (std::next(bodyIt) != bodyE) {
1368       return emitError((*std::next(bodyIt))->getLoc(),
1369                        "Pattern body was terminated by an operation "
1370                        "rewrite statement, but found trailing statements");
1371     }
1372   }
1373 
1374   return createPatternDecl(loc, name, metadata, body);
1375 }
1376 
1377 LogicalResult
1378 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1379   Optional<SMRange> benefitLoc;
1380   Optional<SMRange> hasBoundedRecursionLoc;
1381 
1382   do {
1383     if (curToken.isNot(Token::identifier))
1384       return emitError("expected pattern metadata identifier");
1385     StringRef metadataStr = curToken.getSpelling();
1386     SMRange metadataLoc = curToken.getLoc();
1387     consumeToken(Token::identifier);
1388 
1389     // Parse the benefit metadata: benefit(<integer-value>)
1390     if (metadataStr == "benefit") {
1391       if (benefitLoc) {
1392         return emitErrorAndNote(metadataLoc,
1393                                 "pattern benefit has already been specified",
1394                                 *benefitLoc, "see previous definition here");
1395       }
1396       if (failed(parseToken(Token::l_paren,
1397                             "expected `(` before pattern benefit")))
1398         return failure();
1399 
1400       uint16_t benefitValue = 0;
1401       if (curToken.isNot(Token::integer))
1402         return emitError("expected integral pattern benefit");
1403       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1404         return emitError(
1405             "expected pattern benefit to fit within a 16-bit integer");
1406       consumeToken(Token::integer);
1407 
1408       metadata.benefit = benefitValue;
1409       benefitLoc = metadataLoc;
1410 
1411       if (failed(
1412               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1413         return failure();
1414       continue;
1415     }
1416 
1417     // Parse the bounded recursion metadata: recursion
1418     if (metadataStr == "recursion") {
1419       if (hasBoundedRecursionLoc) {
1420         return emitErrorAndNote(
1421             metadataLoc,
1422             "pattern recursion metadata has already been specified",
1423             *hasBoundedRecursionLoc, "see previous definition here");
1424       }
1425       metadata.hasBoundedRecursion = true;
1426       hasBoundedRecursionLoc = metadataLoc;
1427       continue;
1428     }
1429 
1430     return emitError(metadataLoc, "unknown pattern metadata");
1431   } while (consumeIf(Token::comma));
1432 
1433   return success();
1434 }
1435 
1436 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1437   consumeToken(Token::less);
1438 
1439   FailureOr<ast::Expr *> typeExpr = parseExpr();
1440   if (failed(typeExpr) ||
1441       failed(parseToken(Token::greater,
1442                         "expected `>` after variable type constraint")))
1443     return failure();
1444   return typeExpr;
1445 }
1446 
1447 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1448   assert(curDeclScope && "defining decl outside of a decl scope");
1449   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1450     return emitErrorAndNote(
1451         name.getLoc(), "`" + name.getName() + "` has already been defined",
1452         lastDecl->getName()->getLoc(), "see previous definition here");
1453   }
1454   return success();
1455 }
1456 
1457 FailureOr<ast::VariableDecl *>
1458 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1459                            ast::Expr *initExpr,
1460                            ArrayRef<ast::ConstraintRef> constraints) {
1461   assert(curDeclScope && "defining variable outside of decl scope");
1462   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1463 
1464   // If the name of the variable indicates a special variable, we don't add it
1465   // to the scope. This variable is local to the definition point.
1466   if (name.empty() || name == "_") {
1467     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1468                                      constraints);
1469   }
1470   if (failed(checkDefineNamedDecl(nameDecl)))
1471     return failure();
1472 
1473   auto *varDecl =
1474       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1475   curDeclScope->add(varDecl);
1476   return varDecl;
1477 }
1478 
1479 FailureOr<ast::VariableDecl *>
1480 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1481                            ArrayRef<ast::ConstraintRef> constraints) {
1482   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1483                             constraints);
1484 }
1485 
1486 LogicalResult Parser::parseVariableDeclConstraintList(
1487     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1488   Optional<SMRange> typeConstraint;
1489   auto parseSingleConstraint = [&] {
1490     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1491         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1492     if (failed(constraint))
1493       return failure();
1494     constraints.push_back(*constraint);
1495     return success();
1496   };
1497 
1498   // Check to see if this is a single constraint, or a list.
1499   if (!consumeIf(Token::l_square))
1500     return parseSingleConstraint();
1501 
1502   do {
1503     if (failed(parseSingleConstraint()))
1504       return failure();
1505   } while (consumeIf(Token::comma));
1506   return parseToken(Token::r_square, "expected `]` after constraint list");
1507 }
1508 
1509 FailureOr<ast::ConstraintRef>
1510 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1511                         ArrayRef<ast::ConstraintRef> existingConstraints,
1512                         bool allowInlineTypeConstraints) {
1513   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1514     if (!allowInlineTypeConstraints) {
1515       return emitError(
1516           curToken.getLoc(),
1517           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1518           "permitted on arguments or results");
1519     }
1520     if (typeConstraint)
1521       return emitErrorAndNote(
1522           curToken.getLoc(),
1523           "the type of this variable has already been constrained",
1524           *typeConstraint, "see previous constraint location here");
1525     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1526     if (failed(constraintExpr))
1527       return failure();
1528     typeExpr = *constraintExpr;
1529     typeConstraint = typeExpr->getLoc();
1530     return success();
1531   };
1532 
1533   SMRange loc = curToken.getLoc();
1534   switch (curToken.getKind()) {
1535   case Token::kw_Attr: {
1536     consumeToken(Token::kw_Attr);
1537 
1538     // Check for a type constraint.
1539     ast::Expr *typeExpr = nullptr;
1540     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1541       return failure();
1542     return ast::ConstraintRef(
1543         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1544   }
1545   case Token::kw_Op: {
1546     consumeToken(Token::kw_Op);
1547 
1548     // Parse an optional operation name. If the name isn't provided, this refers
1549     // to "any" operation.
1550     FailureOr<ast::OpNameDecl *> opName =
1551         parseWrappedOperationName(/*allowEmptyName=*/true);
1552     if (failed(opName))
1553       return failure();
1554 
1555     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1556                               loc);
1557   }
1558   case Token::kw_Type:
1559     consumeToken(Token::kw_Type);
1560     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1561   case Token::kw_TypeRange:
1562     consumeToken(Token::kw_TypeRange);
1563     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1564                               loc);
1565   case Token::kw_Value: {
1566     consumeToken(Token::kw_Value);
1567 
1568     // Check for a type constraint.
1569     ast::Expr *typeExpr = nullptr;
1570     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1571       return failure();
1572 
1573     return ast::ConstraintRef(
1574         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1575   }
1576   case Token::kw_ValueRange: {
1577     consumeToken(Token::kw_ValueRange);
1578 
1579     // Check for a type constraint.
1580     ast::Expr *typeExpr = nullptr;
1581     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1582       return failure();
1583 
1584     return ast::ConstraintRef(
1585         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1586   }
1587 
1588   case Token::kw_Constraint: {
1589     // Handle an inline constraint.
1590     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1591     if (failed(decl))
1592       return failure();
1593     return ast::ConstraintRef(*decl, loc);
1594   }
1595   case Token::identifier: {
1596     StringRef constraintName = curToken.getSpelling();
1597     consumeToken(Token::identifier);
1598 
1599     // Lookup the referenced constraint.
1600     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1601     if (!cstDecl) {
1602       return emitError(loc, "unknown reference to constraint `" +
1603                                 constraintName + "`");
1604     }
1605 
1606     // Handle a reference to a proper constraint.
1607     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1608       return ast::ConstraintRef(cst, loc);
1609 
1610     return emitErrorAndNote(
1611         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1612         "see the definition of `" + constraintName + "` here");
1613   }
1614   default:
1615     break;
1616   }
1617   return emitError(loc, "expected identifier constraint");
1618 }
1619 
1620 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1621   Optional<SMRange> typeConstraint;
1622   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1623                          /*allowInlineTypeConstraints=*/false);
1624 }
1625 
1626 //===----------------------------------------------------------------------===//
1627 // Exprs
1628 
1629 FailureOr<ast::Expr *> Parser::parseExpr() {
1630   if (curToken.is(Token::underscore))
1631     return parseUnderscoreExpr();
1632 
1633   // Parse the LHS expression.
1634   FailureOr<ast::Expr *> lhsExpr;
1635   switch (curToken.getKind()) {
1636   case Token::kw_attr:
1637     lhsExpr = parseAttributeExpr();
1638     break;
1639   case Token::kw_Constraint:
1640     lhsExpr = parseInlineConstraintLambdaExpr();
1641     break;
1642   case Token::identifier:
1643     lhsExpr = parseIdentifierExpr();
1644     break;
1645   case Token::kw_op:
1646     lhsExpr = parseOperationExpr();
1647     break;
1648   case Token::kw_Rewrite:
1649     lhsExpr = parseInlineRewriteLambdaExpr();
1650     break;
1651   case Token::kw_type:
1652     lhsExpr = parseTypeExpr();
1653     break;
1654   case Token::l_paren:
1655     lhsExpr = parseTupleExpr();
1656     break;
1657   default:
1658     return emitError("expected expression");
1659   }
1660   if (failed(lhsExpr))
1661     return failure();
1662 
1663   // Check for an operator expression.
1664   while (true) {
1665     switch (curToken.getKind()) {
1666     case Token::dot:
1667       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1668       break;
1669     case Token::l_paren:
1670       lhsExpr = parseCallExpr(*lhsExpr);
1671       break;
1672     default:
1673       return lhsExpr;
1674     }
1675     if (failed(lhsExpr))
1676       return failure();
1677   }
1678 }
1679 
1680 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1681   SMRange loc = curToken.getLoc();
1682   consumeToken(Token::kw_attr);
1683 
1684   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1685   // identifier.
1686   if (!consumeIf(Token::less)) {
1687     resetToken(loc);
1688     return parseIdentifierExpr();
1689   }
1690 
1691   if (!curToken.isString())
1692     return emitError("expected string literal containing MLIR attribute");
1693   std::string attrExpr = curToken.getStringValue();
1694   consumeToken();
1695 
1696   if (failed(
1697           parseToken(Token::greater, "expected `>` after attribute literal")))
1698     return failure();
1699   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1700 }
1701 
1702 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1703   SMRange loc = curToken.getLoc();
1704   consumeToken(Token::l_paren);
1705 
1706   // Parse the arguments of the call.
1707   SmallVector<ast::Expr *> arguments;
1708   if (curToken.isNot(Token::r_paren)) {
1709     do {
1710       FailureOr<ast::Expr *> argument = parseExpr();
1711       if (failed(argument))
1712         return failure();
1713       arguments.push_back(*argument);
1714     } while (consumeIf(Token::comma));
1715   }
1716   loc.End = curToken.getEndLoc();
1717   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1718     return failure();
1719 
1720   return createCallExpr(loc, parentExpr, arguments);
1721 }
1722 
1723 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1724   ast::Decl *decl = curDeclScope->lookup(name);
1725   if (!decl)
1726     return emitError(loc, "undefined reference to `" + name + "`");
1727 
1728   return createDeclRefExpr(loc, decl);
1729 }
1730 
1731 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1732   StringRef name = curToken.getSpelling();
1733   SMRange nameLoc = curToken.getLoc();
1734   consumeToken();
1735 
1736   // Check to see if this is a decl ref expression that defines a variable
1737   // inline.
1738   if (consumeIf(Token::colon)) {
1739     SmallVector<ast::ConstraintRef> constraints;
1740     if (failed(parseVariableDeclConstraintList(constraints)))
1741       return failure();
1742     ast::Type type;
1743     if (failed(validateVariableConstraints(constraints, type)))
1744       return failure();
1745     return createInlineVariableExpr(type, name, nameLoc, constraints);
1746   }
1747 
1748   return parseDeclRefExpr(name, nameLoc);
1749 }
1750 
1751 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1752   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1753   if (failed(decl))
1754     return failure();
1755 
1756   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1757                                   ast::ConstraintType::get(ctx));
1758 }
1759 
1760 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1761   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1762   if (failed(decl))
1763     return failure();
1764 
1765   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1766                                   ast::RewriteType::get(ctx));
1767 }
1768 
1769 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1770   SMRange loc = curToken.getLoc();
1771   consumeToken(Token::dot);
1772 
1773   // Parse the member name.
1774   Token memberNameTok = curToken;
1775   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1776       !memberNameTok.isKeyword())
1777     return emitError(loc, "expected identifier or numeric member name");
1778   StringRef memberName = memberNameTok.getSpelling();
1779   consumeToken();
1780 
1781   return createMemberAccessExpr(parentExpr, memberName, loc);
1782 }
1783 
1784 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1785   SMRange loc = curToken.getLoc();
1786 
1787   // Handle the case of an no operation name.
1788   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1789     if (allowEmptyName)
1790       return ast::OpNameDecl::create(ctx, SMRange());
1791     return emitError("expected dialect namespace");
1792   }
1793   StringRef name = curToken.getSpelling();
1794   consumeToken();
1795 
1796   // Otherwise, this is a literal operation name.
1797   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1798     return failure();
1799 
1800   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1801     return emitError("expected operation name after dialect namespace");
1802 
1803   name = StringRef(name.data(), name.size() + 1);
1804   do {
1805     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1806     loc.End = curToken.getEndLoc();
1807     consumeToken();
1808   } while (curToken.isAny(Token::identifier, Token::dot) ||
1809            curToken.isKeyword());
1810   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1811 }
1812 
1813 FailureOr<ast::OpNameDecl *>
1814 Parser::parseWrappedOperationName(bool allowEmptyName) {
1815   if (!consumeIf(Token::less))
1816     return ast::OpNameDecl::create(ctx, SMRange());
1817 
1818   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1819   if (failed(opNameDecl))
1820     return failure();
1821 
1822   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1823     return failure();
1824   return opNameDecl;
1825 }
1826 
1827 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1828   SMRange loc = curToken.getLoc();
1829   consumeToken(Token::kw_op);
1830 
1831   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1832   // identifier.
1833   if (curToken.isNot(Token::less)) {
1834     resetToken(loc);
1835     return parseIdentifierExpr();
1836   }
1837 
1838   // Parse the operation name. The name may be elided, in which case the
1839   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1840   // `Operation*`). Operation names within a rewrite context must be named.
1841   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1842   FailureOr<ast::OpNameDecl *> opNameDecl =
1843       parseWrappedOperationName(allowEmptyName);
1844   if (failed(opNameDecl))
1845     return failure();
1846 
1847   // Functor used to create an implicit range variable, used for implicit "all"
1848   // operand or results variables.
1849   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1850     FailureOr<ast::VariableDecl *> rangeVar =
1851         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1852     assert(succeeded(rangeVar) && "expected range variable to be valid");
1853     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1854   };
1855 
1856   // Check for the optional list of operands.
1857   SmallVector<ast::Expr *> operands;
1858   if (!consumeIf(Token::l_paren)) {
1859     // If the operand list isn't specified and we are in a match context, define
1860     // an inplace unconstrained operand range corresponding to all of the
1861     // operands of the operation. This avoids treating zero operands the same
1862     // way as "unconstrained operands".
1863     if (parserContext != ParserContext::Rewrite) {
1864       operands.push_back(createImplicitRangeVar(
1865           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1866     }
1867   } else if (!consumeIf(Token::r_paren)) {
1868     // If the operand list was specified and non-empty, parse the operands.
1869     do {
1870       FailureOr<ast::Expr *> operand = parseExpr();
1871       if (failed(operand))
1872         return failure();
1873       operands.push_back(*operand);
1874     } while (consumeIf(Token::comma));
1875 
1876     if (failed(parseToken(Token::r_paren,
1877                           "expected `)` after operation operand list")))
1878       return failure();
1879   }
1880 
1881   // Check for the optional list of attributes.
1882   SmallVector<ast::NamedAttributeDecl *> attributes;
1883   if (consumeIf(Token::l_brace)) {
1884     do {
1885       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1886       if (failed(decl))
1887         return failure();
1888       attributes.emplace_back(*decl);
1889     } while (consumeIf(Token::comma));
1890 
1891     if (failed(parseToken(Token::r_brace,
1892                           "expected `}` after operation attribute list")))
1893       return failure();
1894   }
1895 
1896   // Check for the optional list of result types.
1897   SmallVector<ast::Expr *> resultTypes;
1898   if (consumeIf(Token::arrow)) {
1899     if (failed(parseToken(Token::l_paren,
1900                           "expected `(` before operation result type list")))
1901       return failure();
1902 
1903     // Handle the case of an empty result list.
1904     if (!consumeIf(Token::r_paren)) {
1905       do {
1906         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1907         if (failed(resultTypeExpr))
1908           return failure();
1909         resultTypes.push_back(*resultTypeExpr);
1910       } while (consumeIf(Token::comma));
1911 
1912       if (failed(parseToken(Token::r_paren,
1913                             "expected `)` after operation result type list")))
1914         return failure();
1915     }
1916   } else if (parserContext != ParserContext::Rewrite) {
1917     // If the result list isn't specified and we are in a match context, define
1918     // an inplace unconstrained result range corresponding to all of the results
1919     // of the operation. This avoids treating zero results the same way as
1920     // "unconstrained results".
1921     resultTypes.push_back(createImplicitRangeVar(
1922         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
1923   }
1924 
1925   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1926                              resultTypes);
1927 }
1928 
1929 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1930   SMRange loc = curToken.getLoc();
1931   consumeToken(Token::l_paren);
1932 
1933   DenseMap<StringRef, SMRange> usedNames;
1934   SmallVector<StringRef> elementNames;
1935   SmallVector<ast::Expr *> elements;
1936   if (curToken.isNot(Token::r_paren)) {
1937     do {
1938       // Check for the optional element name assignment before the value.
1939       StringRef elementName;
1940       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1941         Token elementNameTok = curToken;
1942         consumeToken();
1943 
1944         // The element name is only present if followed by an `=`.
1945         if (consumeIf(Token::equal)) {
1946           elementName = elementNameTok.getSpelling();
1947 
1948           // Check to see if this name is already used.
1949           auto elementNameIt =
1950               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1951           if (!elementNameIt.second) {
1952             return emitErrorAndNote(
1953                 elementNameTok.getLoc(),
1954                 llvm::formatv("duplicate tuple element label `{0}`",
1955                               elementName),
1956                 elementNameIt.first->getSecond(),
1957                 "see previous label use here");
1958           }
1959         } else {
1960           // Otherwise, we treat this as part of an expression so reset the
1961           // lexer.
1962           resetToken(elementNameTok.getLoc());
1963         }
1964       }
1965       elementNames.push_back(elementName);
1966 
1967       // Parse the tuple element value.
1968       FailureOr<ast::Expr *> element = parseExpr();
1969       if (failed(element))
1970         return failure();
1971       elements.push_back(*element);
1972     } while (consumeIf(Token::comma));
1973   }
1974   loc.End = curToken.getEndLoc();
1975   if (failed(
1976           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1977     return failure();
1978   return createTupleExpr(loc, elements, elementNames);
1979 }
1980 
1981 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1982   SMRange loc = curToken.getLoc();
1983   consumeToken(Token::kw_type);
1984 
1985   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1986   // identifier.
1987   if (!consumeIf(Token::less)) {
1988     resetToken(loc);
1989     return parseIdentifierExpr();
1990   }
1991 
1992   if (!curToken.isString())
1993     return emitError("expected string literal containing MLIR type");
1994   std::string attrExpr = curToken.getStringValue();
1995   consumeToken();
1996 
1997   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1998     return failure();
1999   return ast::TypeExpr::create(ctx, loc, attrExpr);
2000 }
2001 
2002 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2003   StringRef name = curToken.getSpelling();
2004   SMRange nameLoc = curToken.getLoc();
2005   consumeToken(Token::underscore);
2006 
2007   // Underscore expressions require a constraint list.
2008   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2009     return failure();
2010 
2011   // Parse the constraints for the expression.
2012   SmallVector<ast::ConstraintRef> constraints;
2013   if (failed(parseVariableDeclConstraintList(constraints)))
2014     return failure();
2015 
2016   ast::Type type;
2017   if (failed(validateVariableConstraints(constraints, type)))
2018     return failure();
2019   return createInlineVariableExpr(type, name, nameLoc, constraints);
2020 }
2021 
2022 //===----------------------------------------------------------------------===//
2023 // Stmts
2024 
2025 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2026   FailureOr<ast::Stmt *> stmt;
2027   switch (curToken.getKind()) {
2028   case Token::kw_erase:
2029     stmt = parseEraseStmt();
2030     break;
2031   case Token::kw_let:
2032     stmt = parseLetStmt();
2033     break;
2034   case Token::kw_replace:
2035     stmt = parseReplaceStmt();
2036     break;
2037   case Token::kw_return:
2038     stmt = parseReturnStmt();
2039     break;
2040   case Token::kw_rewrite:
2041     stmt = parseRewriteStmt();
2042     break;
2043   default:
2044     stmt = parseExpr();
2045     break;
2046   }
2047   if (failed(stmt) ||
2048       (expectTerminalSemicolon &&
2049        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2050     return failure();
2051   return stmt;
2052 }
2053 
2054 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2055   SMLoc startLoc = curToken.getStartLoc();
2056   consumeToken(Token::l_brace);
2057 
2058   // Push a new block scope and parse any nested statements.
2059   pushDeclScope();
2060   SmallVector<ast::Stmt *> statements;
2061   while (curToken.isNot(Token::r_brace)) {
2062     FailureOr<ast::Stmt *> statement = parseStmt();
2063     if (failed(statement))
2064       return popDeclScope(), failure();
2065     statements.push_back(*statement);
2066   }
2067   popDeclScope();
2068 
2069   // Consume the end brace.
2070   SMRange location(startLoc, curToken.getEndLoc());
2071   consumeToken(Token::r_brace);
2072 
2073   return ast::CompoundStmt::create(ctx, location, statements);
2074 }
2075 
2076 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2077   if (parserContext == ParserContext::Constraint)
2078     return emitError("`erase` cannot be used within a Constraint");
2079   SMRange loc = curToken.getLoc();
2080   consumeToken(Token::kw_erase);
2081 
2082   // Parse the root operation expression.
2083   FailureOr<ast::Expr *> rootOp = parseExpr();
2084   if (failed(rootOp))
2085     return failure();
2086 
2087   return createEraseStmt(loc, *rootOp);
2088 }
2089 
2090 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2091   SMRange loc = curToken.getLoc();
2092   consumeToken(Token::kw_let);
2093 
2094   // Parse the name of the new variable.
2095   SMRange varLoc = curToken.getLoc();
2096   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2097     // `_` is a reserved variable name.
2098     if (curToken.is(Token::underscore)) {
2099       return emitError(varLoc,
2100                        "`_` may only be used to define \"inline\" variables");
2101     }
2102     return emitError(varLoc,
2103                      "expected identifier after `let` to name a new variable");
2104   }
2105   StringRef varName = curToken.getSpelling();
2106   consumeToken();
2107 
2108   // Parse the optional set of constraints.
2109   SmallVector<ast::ConstraintRef> constraints;
2110   if (consumeIf(Token::colon) &&
2111       failed(parseVariableDeclConstraintList(constraints)))
2112     return failure();
2113 
2114   // Parse the optional initializer expression.
2115   ast::Expr *initializer = nullptr;
2116   if (consumeIf(Token::equal)) {
2117     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2118     if (failed(initOrFailure))
2119       return failure();
2120     initializer = *initOrFailure;
2121 
2122     // Check that the constraints are compatible with having an initializer,
2123     // e.g. type constraints cannot be used with initializers.
2124     for (ast::ConstraintRef constraint : constraints) {
2125       LogicalResult result =
2126           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2127               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2128                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2129                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2130                   return this->emitError(
2131                       constraint.referenceLoc,
2132                       "type constraints are not permitted on variables with "
2133                       "initializers");
2134                 }
2135                 return success();
2136               })
2137               .Default(success());
2138       if (failed(result))
2139         return failure();
2140     }
2141   }
2142 
2143   FailureOr<ast::VariableDecl *> varDecl =
2144       createVariableDecl(varName, varLoc, initializer, constraints);
2145   if (failed(varDecl))
2146     return failure();
2147   return ast::LetStmt::create(ctx, loc, *varDecl);
2148 }
2149 
2150 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2151   if (parserContext == ParserContext::Constraint)
2152     return emitError("`replace` cannot be used within a Constraint");
2153   SMRange loc = curToken.getLoc();
2154   consumeToken(Token::kw_replace);
2155 
2156   // Parse the root operation expression.
2157   FailureOr<ast::Expr *> rootOp = parseExpr();
2158   if (failed(rootOp))
2159     return failure();
2160 
2161   if (failed(
2162           parseToken(Token::kw_with, "expected `with` after root operation")))
2163     return failure();
2164 
2165   // The replacement portion of this statement is within a rewrite context.
2166   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2167                                               ParserContext::Rewrite);
2168 
2169   // Parse the replacement values.
2170   SmallVector<ast::Expr *> replValues;
2171   if (consumeIf(Token::l_paren)) {
2172     if (consumeIf(Token::r_paren)) {
2173       return emitError(
2174           loc, "expected at least one replacement value, consider using "
2175                "`erase` if no replacement values are desired");
2176     }
2177 
2178     do {
2179       FailureOr<ast::Expr *> replExpr = parseExpr();
2180       if (failed(replExpr))
2181         return failure();
2182       replValues.emplace_back(*replExpr);
2183     } while (consumeIf(Token::comma));
2184 
2185     if (failed(parseToken(Token::r_paren,
2186                           "expected `)` after replacement values")))
2187       return failure();
2188   } else {
2189     FailureOr<ast::Expr *> replExpr = parseExpr();
2190     if (failed(replExpr))
2191       return failure();
2192     replValues.emplace_back(*replExpr);
2193   }
2194 
2195   return createReplaceStmt(loc, *rootOp, replValues);
2196 }
2197 
2198 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2199   SMRange loc = curToken.getLoc();
2200   consumeToken(Token::kw_return);
2201 
2202   // Parse the result value.
2203   FailureOr<ast::Expr *> resultExpr = parseExpr();
2204   if (failed(resultExpr))
2205     return failure();
2206 
2207   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2208 }
2209 
2210 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2211   if (parserContext == ParserContext::Constraint)
2212     return emitError("`rewrite` cannot be used within a Constraint");
2213   SMRange loc = curToken.getLoc();
2214   consumeToken(Token::kw_rewrite);
2215 
2216   // Parse the root operation.
2217   FailureOr<ast::Expr *> rootOp = parseExpr();
2218   if (failed(rootOp))
2219     return failure();
2220 
2221   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2222     return failure();
2223 
2224   if (curToken.isNot(Token::l_brace))
2225     return emitError("expected `{` to start rewrite body");
2226 
2227   // The rewrite body of this statement is within a rewrite context.
2228   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2229                                               ParserContext::Rewrite);
2230 
2231   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2232   if (failed(rewriteBody))
2233     return failure();
2234 
2235   // Verify the rewrite body.
2236   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2237     if (isa<ast::ReturnStmt>(stmt)) {
2238       return emitError(stmt->getLoc(),
2239                        "`return` statements are only permitted within a "
2240                        "`Constraint` or `Rewrite` body");
2241     }
2242   }
2243 
2244   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2245 }
2246 
2247 //===----------------------------------------------------------------------===//
2248 // Creation+Analysis
2249 //===----------------------------------------------------------------------===//
2250 
2251 //===----------------------------------------------------------------------===//
2252 // Decls
2253 
2254 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2255   // Unwrap reference expressions.
2256   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2257     node = init->getDecl();
2258   return dyn_cast<ast::CallableDecl>(node);
2259 }
2260 
2261 FailureOr<ast::PatternDecl *>
2262 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2263                           const ParsedPatternMetadata &metadata,
2264                           ast::CompoundStmt *body) {
2265   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2266                                   metadata.hasBoundedRecursion, body);
2267 }
2268 
2269 ast::Type Parser::createUserConstraintRewriteResultType(
2270     ArrayRef<ast::VariableDecl *> results) {
2271   // Single result decls use the type of the single result.
2272   if (results.size() == 1)
2273     return results[0]->getType();
2274 
2275   // Multiple results use a tuple type, with the types and names grabbed from
2276   // the result variable decls.
2277   auto resultTypes = llvm::map_range(
2278       results, [&](const auto *result) { return result->getType(); });
2279   auto resultNames = llvm::map_range(
2280       results, [&](const auto *result) { return result->getName().getName(); });
2281   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2282                              llvm::to_vector(resultNames));
2283 }
2284 
2285 template <typename T>
2286 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2287     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2288     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2289     ast::CompoundStmt *body) {
2290   if (!body->getChildren().empty()) {
2291     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2292       ast::Expr *resultExpr = retStmt->getResultExpr();
2293 
2294       // Process the result of the decl. If no explicit signature results
2295       // were provided, check for return type inference. Otherwise, check that
2296       // the return expression can be converted to the expected type.
2297       if (results.empty())
2298         resultType = resultExpr->getType();
2299       else if (failed(convertExpressionTo(resultExpr, resultType)))
2300         return failure();
2301       else
2302         retStmt->setResultExpr(resultExpr);
2303     }
2304   }
2305   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2306 }
2307 
2308 FailureOr<ast::VariableDecl *>
2309 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2310                            ArrayRef<ast::ConstraintRef> constraints) {
2311   // The type of the variable, which is expected to be inferred by either a
2312   // constraint or an initializer expression.
2313   ast::Type type;
2314   if (failed(validateVariableConstraints(constraints, type)))
2315     return failure();
2316 
2317   if (initializer) {
2318     // Update the variable type based on the initializer, or try to convert the
2319     // initializer to the existing type.
2320     if (!type)
2321       type = initializer->getType();
2322     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2323       type = mergedType;
2324     else if (failed(convertExpressionTo(initializer, type)))
2325       return failure();
2326 
2327     // Otherwise, if there is no initializer check that the type has already
2328     // been resolved from the constraint list.
2329   } else if (!type) {
2330     return emitErrorAndNote(
2331         loc, "unable to infer type for variable `" + name + "`", loc,
2332         "the type of a variable must be inferable from the constraint "
2333         "list or the initializer");
2334   }
2335 
2336   // Constraint types cannot be used when defining variables.
2337   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2338     return emitError(
2339         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2340   }
2341 
2342   // Try to define a variable with the given name.
2343   FailureOr<ast::VariableDecl *> varDecl =
2344       defineVariableDecl(name, loc, type, initializer, constraints);
2345   if (failed(varDecl))
2346     return failure();
2347 
2348   return *varDecl;
2349 }
2350 
2351 FailureOr<ast::VariableDecl *>
2352 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2353                                       const ast::ConstraintRef &constraint) {
2354   // Constraint arguments may apply more complex constraints via the arguments.
2355   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2356   ast::Type argType;
2357   if (failed(validateVariableConstraint(constraint, argType,
2358                                         allowNonCoreConstraints)))
2359     return failure();
2360   return defineVariableDecl(name, loc, argType, constraint);
2361 }
2362 
2363 LogicalResult
2364 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2365                                     ast::Type &inferredType) {
2366   for (const ast::ConstraintRef &ref : constraints)
2367     if (failed(validateVariableConstraint(ref, inferredType)))
2368       return failure();
2369   return success();
2370 }
2371 
2372 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2373                                                  ast::Type &inferredType,
2374                                                  bool allowNonCoreConstraints) {
2375   ast::Type constraintType;
2376   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2377     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2378       if (failed(validateTypeConstraintExpr(typeExpr)))
2379         return failure();
2380     }
2381     constraintType = ast::AttributeType::get(ctx);
2382   } else if (const auto *cst =
2383                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2384     constraintType = ast::OperationType::get(ctx, cst->getName());
2385   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2386     constraintType = typeTy;
2387   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2388     constraintType = typeRangeTy;
2389   } else if (const auto *cst =
2390                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2391     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2392       if (failed(validateTypeConstraintExpr(typeExpr)))
2393         return failure();
2394     }
2395     constraintType = valueTy;
2396   } else if (const auto *cst =
2397                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2398     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2399       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2400         return failure();
2401     }
2402     constraintType = valueRangeTy;
2403   } else if (const auto *cst =
2404                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2405     if (!allowNonCoreConstraints) {
2406       return emitError(ref.referenceLoc,
2407                        "`Rewrite` arguments and results are only permitted to "
2408                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2409                        "`TypeRange`, `Value`, `ValueRange`");
2410     }
2411 
2412     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2413     if (inputs.size() != 1) {
2414       return emitErrorAndNote(ref.referenceLoc,
2415                               "`Constraint`s applied via a variable constraint "
2416                               "list must take a single input, but got " +
2417                                   Twine(inputs.size()),
2418                               cst->getLoc(),
2419                               "see definition of constraint here");
2420     }
2421     constraintType = inputs.front()->getType();
2422   } else {
2423     llvm_unreachable("unknown constraint type");
2424   }
2425 
2426   // Check that the constraint type is compatible with the current inferred
2427   // type.
2428   if (!inferredType) {
2429     inferredType = constraintType;
2430   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2431     inferredType = mergedTy;
2432   } else {
2433     return emitError(ref.referenceLoc,
2434                      llvm::formatv("constraint type `{0}` is incompatible "
2435                                    "with the previously inferred type `{1}`",
2436                                    constraintType, inferredType));
2437   }
2438   return success();
2439 }
2440 
2441 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2442   ast::Type typeExprType = typeExpr->getType();
2443   if (typeExprType != typeTy) {
2444     return emitError(typeExpr->getLoc(),
2445                      "expected expression of `Type` in type constraint");
2446   }
2447   return success();
2448 }
2449 
2450 LogicalResult
2451 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2452   ast::Type typeExprType = typeExpr->getType();
2453   if (typeExprType != typeRangeTy) {
2454     return emitError(typeExpr->getLoc(),
2455                      "expected expression of `TypeRange` in type constraint");
2456   }
2457   return success();
2458 }
2459 
2460 //===----------------------------------------------------------------------===//
2461 // Exprs
2462 
2463 FailureOr<ast::CallExpr *>
2464 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2465                        MutableArrayRef<ast::Expr *> arguments) {
2466   ast::Type parentType = parentExpr->getType();
2467 
2468   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2469   if (!callableDecl) {
2470     return emitError(loc,
2471                      llvm::formatv("expected a reference to a callable "
2472                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2473                                    parentType));
2474   }
2475   if (parserContext == ParserContext::Rewrite) {
2476     if (isa<ast::UserConstraintDecl>(callableDecl))
2477       return emitError(
2478           loc, "unable to invoke `Constraint` within a rewrite section");
2479   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2480     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2481   }
2482 
2483   // Verify the arguments of the call.
2484   /// Handle size mismatch.
2485   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2486   if (callArgs.size() != arguments.size()) {
2487     return emitErrorAndNote(
2488         loc,
2489         llvm::formatv("invalid number of arguments for {0} call; expected "
2490                       "{1}, but got {2}",
2491                       callableDecl->getCallableType(), callArgs.size(),
2492                       arguments.size()),
2493         callableDecl->getLoc(),
2494         llvm::formatv("see the definition of {0} here",
2495                       callableDecl->getName()->getName()));
2496   }
2497 
2498   /// Handle argument type mismatch.
2499   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2500     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2501                                   callableDecl->getName()->getName()),
2502                     callableDecl->getLoc());
2503   };
2504   for (auto it : llvm::zip(callArgs, arguments)) {
2505     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2506                                    attachDiagFn)))
2507       return failure();
2508   }
2509 
2510   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2511                                callableDecl->getResultType());
2512 }
2513 
2514 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2515                                                         ast::Decl *decl) {
2516   // Check the type of decl being referenced.
2517   ast::Type declType;
2518   if (isa<ast::ConstraintDecl>(decl))
2519     declType = ast::ConstraintType::get(ctx);
2520   else if (isa<ast::UserRewriteDecl>(decl))
2521     declType = ast::RewriteType::get(ctx);
2522   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2523     declType = varDecl->getType();
2524   else
2525     return emitError(loc, "invalid reference to `" +
2526                               decl->getName()->getName() + "`");
2527 
2528   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2529 }
2530 
2531 FailureOr<ast::DeclRefExpr *>
2532 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2533                                  ArrayRef<ast::ConstraintRef> constraints) {
2534   FailureOr<ast::VariableDecl *> decl =
2535       defineVariableDecl(name, loc, type, constraints);
2536   if (failed(decl))
2537     return failure();
2538   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2539 }
2540 
2541 FailureOr<ast::MemberAccessExpr *>
2542 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2543                                SMRange loc) {
2544   // Validate the member name for the given parent expression.
2545   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2546   if (failed(memberType))
2547     return failure();
2548 
2549   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2550 }
2551 
2552 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2553                                                   StringRef name, SMRange loc) {
2554   ast::Type parentType = parentExpr->getType();
2555   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2556     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2557       return valueRangeTy;
2558 
2559     // Verify member access based on the operation type.
2560     if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
2561       auto results = odsOp->getResults();
2562 
2563       // Handle indexed results.
2564       unsigned index = 0;
2565       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2566           index < results.size()) {
2567         return results[index].isVariadic() ? valueRangeTy : valueTy;
2568       }
2569 
2570       // Handle named results.
2571       const auto *it = llvm::find_if(results, [&](const auto &result) {
2572         return result.getName() == name;
2573       });
2574       if (it != results.end())
2575         return it->isVariadic() ? valueRangeTy : valueTy;
2576     }
2577 
2578   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2579     // Handle indexed results.
2580     unsigned index = 0;
2581     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2582         index < tupleType.size()) {
2583       return tupleType.getElementTypes()[index];
2584     }
2585 
2586     // Handle named results.
2587     auto elementNames = tupleType.getElementNames();
2588     const auto *it = llvm::find(elementNames, name);
2589     if (it != elementNames.end())
2590       return tupleType.getElementTypes()[it - elementNames.begin()];
2591   }
2592   return emitError(
2593       loc,
2594       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2595                     name, parentType));
2596 }
2597 
2598 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2599     SMRange loc, const ast::OpNameDecl *name,
2600     MutableArrayRef<ast::Expr *> operands,
2601     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2602     MutableArrayRef<ast::Expr *> results) {
2603   Optional<StringRef> opNameRef = name->getName();
2604   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2605 
2606   // Verify the inputs operands.
2607   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2608     return failure();
2609 
2610   // Verify the attribute list.
2611   for (ast::NamedAttributeDecl *attr : attributes) {
2612     // Check for an attribute type, or a type awaiting resolution.
2613     ast::Type attrType = attr->getValue()->getType();
2614     if (!attrType.isa<ast::AttributeType>()) {
2615       return emitError(
2616           attr->getValue()->getLoc(),
2617           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2618     }
2619   }
2620 
2621   // Verify the result types.
2622   if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2623     return failure();
2624 
2625   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2626                                     attributes);
2627 }
2628 
2629 LogicalResult
2630 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2631                                   const ods::Operation *odsOp,
2632                                   MutableArrayRef<ast::Expr *> operands) {
2633   return validateOperationOperandsOrResults(
2634       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2635       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2636       valueRangeTy);
2637 }
2638 
2639 LogicalResult
2640 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2641                                  const ods::Operation *odsOp,
2642                                  MutableArrayRef<ast::Expr *> results) {
2643   return validateOperationOperandsOrResults(
2644       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2645       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2646 }
2647 
2648 LogicalResult Parser::validateOperationOperandsOrResults(
2649     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2650     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2651     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2652     ast::Type rangeTy) {
2653   // All operation types accept a single range parameter.
2654   if (values.size() == 1) {
2655     if (failed(convertExpressionTo(values[0], rangeTy)))
2656       return failure();
2657     return success();
2658   }
2659 
2660   /// If the operation has ODS information, we can more accurately verify the
2661   /// values.
2662   if (odsOpLoc) {
2663     if (odsValues.size() != values.size()) {
2664       return emitErrorAndNote(
2665           loc,
2666           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2667                         "{2}, but got {3}",
2668                         groupName, *name, odsValues.size(), values.size()),
2669           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2670     }
2671     auto diagFn = [&](ast::Diagnostic &diag) {
2672       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2673                       *odsOpLoc);
2674     };
2675     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2676       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2677       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2678         return failure();
2679     }
2680     return success();
2681   }
2682 
2683   // Otherwise, accept the value groups as they have been defined and just
2684   // ensure they are one of the expected types.
2685   for (ast::Expr *&valueExpr : values) {
2686     ast::Type valueExprType = valueExpr->getType();
2687 
2688     // Check if this is one of the expected types.
2689     if (valueExprType == rangeTy || valueExprType == singleTy)
2690       continue;
2691 
2692     // If the operand is an Operation, allow converting to a Value or
2693     // ValueRange. This situations arises quite often with nested operation
2694     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2695     if (singleTy == valueTy) {
2696       if (valueExprType.isa<ast::OperationType>()) {
2697         valueExpr = convertOpToValue(valueExpr);
2698         continue;
2699       }
2700     }
2701 
2702     return emitError(
2703         valueExpr->getLoc(),
2704         llvm::formatv(
2705             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2706             singleTy, rangeTy, valueExprType));
2707   }
2708   return success();
2709 }
2710 
2711 FailureOr<ast::TupleExpr *>
2712 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2713                         ArrayRef<StringRef> elementNames) {
2714   for (const ast::Expr *element : elements) {
2715     ast::Type eleTy = element->getType();
2716     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2717       return emitError(
2718           element->getLoc(),
2719           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2720     }
2721   }
2722   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2723 }
2724 
2725 //===----------------------------------------------------------------------===//
2726 // Stmts
2727 
2728 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2729                                                     ast::Expr *rootOp) {
2730   // Check that root is an Operation.
2731   ast::Type rootType = rootOp->getType();
2732   if (!rootType.isa<ast::OperationType>())
2733     return emitError(rootOp->getLoc(), "expected `Op` expression");
2734 
2735   return ast::EraseStmt::create(ctx, loc, rootOp);
2736 }
2737 
2738 FailureOr<ast::ReplaceStmt *>
2739 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2740                           MutableArrayRef<ast::Expr *> replValues) {
2741   // Check that root is an Operation.
2742   ast::Type rootType = rootOp->getType();
2743   if (!rootType.isa<ast::OperationType>()) {
2744     return emitError(
2745         rootOp->getLoc(),
2746         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2747   }
2748 
2749   // If there are multiple replacement values, we implicitly convert any Op
2750   // expressions to the value form.
2751   bool shouldConvertOpToValues = replValues.size() > 1;
2752   for (ast::Expr *&replExpr : replValues) {
2753     ast::Type replType = replExpr->getType();
2754 
2755     // Check that replExpr is an Operation, Value, or ValueRange.
2756     if (replType.isa<ast::OperationType>()) {
2757       if (shouldConvertOpToValues)
2758         replExpr = convertOpToValue(replExpr);
2759       continue;
2760     }
2761 
2762     if (replType != valueTy && replType != valueRangeTy) {
2763       return emitError(replExpr->getLoc(),
2764                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2765                                      "expression, but got `{0}`",
2766                                      replType));
2767     }
2768   }
2769 
2770   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2771 }
2772 
2773 FailureOr<ast::RewriteStmt *>
2774 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2775                           ast::CompoundStmt *rewriteBody) {
2776   // Check that root is an Operation.
2777   ast::Type rootType = rootOp->getType();
2778   if (!rootType.isa<ast::OperationType>()) {
2779     return emitError(
2780         rootOp->getLoc(),
2781         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2782   }
2783 
2784   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2785 }
2786 
2787 //===----------------------------------------------------------------------===//
2788 // Parser
2789 //===----------------------------------------------------------------------===//
2790 
2791 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
2792                                                  llvm::SourceMgr &sourceMgr) {
2793   Parser parser(ctx, sourceMgr);
2794   return parser.parseModule();
2795 }
2796