1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/TableGen/Argument.h"
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Constraint.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/AST/Types.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Operation.h"
24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <string>
34 
35 using namespace mlir;
36 using namespace mlir::pdll;
37 
38 //===----------------------------------------------------------------------===//
39 // Parser
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 class Parser {
44 public:
45   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
46          CodeCompleteContext *codeCompleteContext)
47       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
48         curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
49         valueRangeTy(ast::ValueRangeType::get(ctx)),
50         typeTy(ast::TypeType::get(ctx)),
51         typeRangeTy(ast::TypeRangeType::get(ctx)),
52         attrTy(ast::AttributeType::get(ctx)),
53         codeCompleteContext(codeCompleteContext) {}
54 
55   /// Try to parse a new module. Returns nullptr in the case of failure.
56   FailureOr<ast::Module *> parseModule();
57 
58 private:
59   /// The current context of the parser. It allows for the parser to know a bit
60   /// about the construct it is nested within during parsing. This is used
61   /// specifically to provide additional verification during parsing, e.g. to
62   /// prevent using rewrites within a match context, matcher constraints within
63   /// a rewrite section, etc.
64   enum class ParserContext {
65     /// The parser is in the global context.
66     Global,
67     /// The parser is currently within a Constraint, which disallows all types
68     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
69     Constraint,
70     /// The parser is currently within the matcher portion of a Pattern, which
71     /// is allows a terminal operation rewrite statement but no other rewrite
72     /// transformations.
73     PatternMatch,
74     /// The parser is currently within a Rewrite, which disallows calls to
75     /// constraints, requires operation expressions to have names, etc.
76     Rewrite,
77   };
78 
79   /// The current specification context of an operations result type. This
80   /// indicates how the result types of an operation may be inferred.
81   enum class OpResultTypeContext {
82     /// The result types of the operation are not known to be inferred.
83     Explicit,
84     /// The result types of the operation are inferred from the root input of a
85     /// `replace` statement.
86     Replacement,
87     /// The result types of the operation are inferred by using the
88     /// `InferTypeOpInterface` interface provided by the operation.
89     Interface,
90   };
91 
92   //===--------------------------------------------------------------------===//
93   // Parsing
94   //===--------------------------------------------------------------------===//
95 
96   /// Push a new decl scope onto the lexer.
97   ast::DeclScope *pushDeclScope() {
98     ast::DeclScope *newScope =
99         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
100     return (curDeclScope = newScope);
101   }
102   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
103 
104   /// Pop the last decl scope from the lexer.
105   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
106 
107   /// Parse the body of an AST module.
108   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
109 
110   /// Try to convert the given expression to `type`. Returns failure and emits
111   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
112   /// invoked to attach notes to the emitted error diagnostic. On success,
113   /// `expr` is updated to the expression used to convert to `type`.
114   LogicalResult convertExpressionTo(
115       ast::Expr *&expr, ast::Type type,
116       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
117 
118   /// Given an operation expression, convert it to a Value or ValueRange
119   /// typed expression.
120   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
121 
122   /// Lookup ODS information for the given operation, returns nullptr if no
123   /// information is found.
124   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
125     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
126   }
127 
128   //===--------------------------------------------------------------------===//
129   // Directives
130 
131   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
132   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
133   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
134                                SmallVectorImpl<ast::Decl *> &decls);
135 
136   /// Process the records of a parsed tablegen include file.
137   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
138                                SmallVectorImpl<ast::Decl *> &decls);
139 
140   /// Create a user defined native constraint for a constraint imported from
141   /// ODS.
142   template <typename ConstraintT>
143   ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
144                                                StringRef codeBlock, SMRange loc,
145                                                ast::Type type,
146                                                StringRef nativeType);
147   template <typename ConstraintT>
148   ast::Decl *
149   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
150                                     SMRange loc, ast::Type type,
151                                     StringRef nativeType);
152 
153   //===--------------------------------------------------------------------===//
154   // Decls
155 
156   /// This structure contains the set of pattern metadata that may be parsed.
157   struct ParsedPatternMetadata {
158     Optional<uint16_t> benefit;
159     bool hasBoundedRecursion = false;
160   };
161 
162   FailureOr<ast::Decl *> parseTopLevelDecl();
163   FailureOr<ast::NamedAttributeDecl *>
164   parseNamedAttributeDecl(Optional<StringRef> parentOpName);
165 
166   /// Parse an argument variable as part of the signature of a
167   /// UserConstraintDecl or UserRewriteDecl.
168   FailureOr<ast::VariableDecl *> parseArgumentDecl();
169 
170   /// Parse a result variable as part of the signature of a UserConstraintDecl
171   /// or UserRewriteDecl.
172   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
173 
174   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
175   /// defined in a non-global context.
176   FailureOr<ast::UserConstraintDecl *>
177   parseUserConstraintDecl(bool isInline = false);
178 
179   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
180   /// non-global context, such as within a Pattern/Constraint/etc.
181   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
182 
183   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
184   /// PDLL constructs.
185   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
186       const ast::Name &name, bool isInline,
187       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
188       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
189 
190   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
191   /// defined in a non-global context.
192   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
193 
194   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
195   /// non-global context, such as within a Pattern/Rewrite/etc.
196   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
197 
198   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
199   /// PDLL constructs.
200   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
201       const ast::Name &name, bool isInline,
202       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
203       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
204 
205   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
206   /// effectively the same syntax, and only differ on slight semantics (given
207   /// the different parsing contexts).
208   template <typename T, typename ParseUserPDLLDeclFnT>
209   FailureOr<T *> parseUserConstraintOrRewriteDecl(
210       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
211       StringRef anonymousNamePrefix, bool isInline);
212 
213   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
214   /// These decls have effectively the same syntax.
215   template <typename T>
216   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
217       const ast::Name &name, bool isInline,
218       ArrayRef<ast::VariableDecl *> arguments,
219       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
220 
221   /// Parse the functional signature (i.e. the arguments and results) of a
222   /// UserConstraintDecl or UserRewriteDecl.
223   LogicalResult parseUserConstraintOrRewriteSignature(
224       SmallVectorImpl<ast::VariableDecl *> &arguments,
225       SmallVectorImpl<ast::VariableDecl *> &results,
226       ast::DeclScope *&argumentScope, ast::Type &resultType);
227 
228   /// Validate the return (which if present is specified by bodyIt) of a
229   /// UserConstraintDecl or UserRewriteDecl.
230   LogicalResult validateUserConstraintOrRewriteReturn(
231       StringRef declType, ast::CompoundStmt *body,
232       ArrayRef<ast::Stmt *>::iterator bodyIt,
233       ArrayRef<ast::Stmt *>::iterator bodyE,
234       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
235 
236   FailureOr<ast::CompoundStmt *>
237   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
238                   bool expectTerminalSemicolon = true);
239   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
240   FailureOr<ast::Decl *> parsePatternDecl();
241   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
242 
243   /// Check to see if a decl has already been defined with the given name, if
244   /// one has emit and error and return failure. Returns success otherwise.
245   LogicalResult checkDefineNamedDecl(const ast::Name &name);
246 
247   /// Try to define a variable decl with the given components, returns the
248   /// variable on success.
249   FailureOr<ast::VariableDecl *>
250   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
251                      ast::Expr *initExpr,
252                      ArrayRef<ast::ConstraintRef> constraints);
253   FailureOr<ast::VariableDecl *>
254   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
255                      ArrayRef<ast::ConstraintRef> constraints);
256 
257   /// Parse the constraint reference list for a variable decl.
258   LogicalResult parseVariableDeclConstraintList(
259       SmallVectorImpl<ast::ConstraintRef> &constraints);
260 
261   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
262   FailureOr<ast::Expr *> parseTypeConstraintExpr();
263 
264   /// Try to parse a single reference to a constraint. `typeConstraint` is the
265   /// location of a previously parsed type constraint for the entity that will
266   /// be constrained by the parsed constraint. `existingConstraints` are any
267   /// existing constraints that have already been parsed for the same entity
268   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
269   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
270   /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
271   /// constraints) may be used with the variable.
272   FailureOr<ast::ConstraintRef>
273   parseConstraint(Optional<SMRange> &typeConstraint,
274                   ArrayRef<ast::ConstraintRef> existingConstraints,
275                   bool allowInlineTypeConstraints,
276                   bool allowNonCoreConstraints);
277 
278   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
279   /// argument or result variable. The constraints for these variables do not
280   /// allow inline type constraints, and only permit a single constraint.
281   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
282 
283   //===--------------------------------------------------------------------===//
284   // Exprs
285 
286   FailureOr<ast::Expr *> parseExpr();
287 
288   /// Identifier expressions.
289   FailureOr<ast::Expr *> parseAttributeExpr();
290   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
291   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
292   FailureOr<ast::Expr *> parseIdentifierExpr();
293   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
294   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
295   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
296   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
297   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
298   FailureOr<ast::Expr *>
299   parseOperationExpr(OpResultTypeContext inputResultTypeContext =
300                          OpResultTypeContext::Explicit);
301   FailureOr<ast::Expr *> parseTupleExpr();
302   FailureOr<ast::Expr *> parseTypeExpr();
303   FailureOr<ast::Expr *> parseUnderscoreExpr();
304 
305   //===--------------------------------------------------------------------===//
306   // Stmts
307 
308   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
309   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
310   FailureOr<ast::EraseStmt *> parseEraseStmt();
311   FailureOr<ast::LetStmt *> parseLetStmt();
312   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
313   FailureOr<ast::ReturnStmt *> parseReturnStmt();
314   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
315 
316   //===--------------------------------------------------------------------===//
317   // Creation+Analysis
318   //===--------------------------------------------------------------------===//
319 
320   //===--------------------------------------------------------------------===//
321   // Decls
322 
323   /// Try to extract a callable from the given AST node. Returns nullptr on
324   /// failure.
325   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
326 
327   /// Try to create a pattern decl with the given components, returning the
328   /// Pattern on success.
329   FailureOr<ast::PatternDecl *>
330   createPatternDecl(SMRange loc, const ast::Name *name,
331                     const ParsedPatternMetadata &metadata,
332                     ast::CompoundStmt *body);
333 
334   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
335   /// of results, defined as part of the signature.
336   ast::Type
337   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
338 
339   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
340   template <typename T>
341   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
342       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
343       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
344       ast::CompoundStmt *body);
345 
346   /// Try to create a variable decl with the given components, returning the
347   /// Variable on success.
348   FailureOr<ast::VariableDecl *>
349   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
350                      ArrayRef<ast::ConstraintRef> constraints);
351 
352   /// Create a variable for an argument or result defined as part of the
353   /// signature of a UserConstraintDecl/UserRewriteDecl.
354   FailureOr<ast::VariableDecl *>
355   createArgOrResultVariableDecl(StringRef name, SMRange loc,
356                                 const ast::ConstraintRef &constraint);
357 
358   /// Validate the constraints used to constraint a variable decl.
359   /// `inferredType` is the type of the variable inferred by the constraints
360   /// within the list, and is updated to the most refined type as determined by
361   /// the constraints. Returns success if the constraint list is valid, failure
362   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
363   /// defined constraints) may be used with the variable.
364   LogicalResult
365   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
366                               ast::Type &inferredType,
367                               bool allowNonCoreConstraints = true);
368   /// Validate a single reference to a constraint. `inferredType` contains the
369   /// currently inferred variabled type and is refined within the type defined
370   /// by the constraint. Returns success if the constraint is valid, failure
371   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
372   /// defined constraints) may be used with the variable.
373   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
374                                            ast::Type &inferredType,
375                                            bool allowNonCoreConstraints = true);
376   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
377   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
378 
379   //===--------------------------------------------------------------------===//
380   // Exprs
381 
382   FailureOr<ast::CallExpr *>
383   createCallExpr(SMRange loc, ast::Expr *parentExpr,
384                  MutableArrayRef<ast::Expr *> arguments);
385   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
386   FailureOr<ast::DeclRefExpr *>
387   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
388                            ArrayRef<ast::ConstraintRef> constraints);
389   FailureOr<ast::MemberAccessExpr *>
390   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
391 
392   /// Validate the member access `name` into the given parent expression. On
393   /// success, this also returns the type of the member accessed.
394   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
395                                             StringRef name, SMRange loc);
396   FailureOr<ast::OperationExpr *>
397   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
398                       OpResultTypeContext resultTypeContext,
399                       MutableArrayRef<ast::Expr *> operands,
400                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
401                       MutableArrayRef<ast::Expr *> results);
402   LogicalResult
403   validateOperationOperands(SMRange loc, Optional<StringRef> name,
404                             const ods::Operation *odsOp,
405                             MutableArrayRef<ast::Expr *> operands);
406   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
407                                          const ods::Operation *odsOp,
408                                          MutableArrayRef<ast::Expr *> results);
409   void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
410                                           const ods::Operation *odsOp);
411   LogicalResult validateOperationOperandsOrResults(
412       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
413       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
414       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
415       ast::Type rangeTy);
416   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
417                                               ArrayRef<ast::Expr *> elements,
418                                               ArrayRef<StringRef> elementNames);
419 
420   //===--------------------------------------------------------------------===//
421   // Stmts
422 
423   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
424   FailureOr<ast::ReplaceStmt *>
425   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
426                     MutableArrayRef<ast::Expr *> replValues);
427   FailureOr<ast::RewriteStmt *>
428   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
429                     ast::CompoundStmt *rewriteBody);
430 
431   //===--------------------------------------------------------------------===//
432   // Code Completion
433   //===--------------------------------------------------------------------===//
434 
435   /// The set of various code completion methods. Every completion method
436   /// returns `failure` to stop the parsing process after providing completion
437   /// results.
438 
439   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
440   LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
441   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
442                                            bool allowNonCoreConstraints,
443                                            bool allowInlineTypeConstraints);
444   LogicalResult codeCompleteDialectName();
445   LogicalResult codeCompleteOperationName(StringRef dialectName);
446   LogicalResult codeCompletePatternMetadata();
447   LogicalResult codeCompleteIncludeFilename(StringRef curPath);
448 
449   void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
450   void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
451                                               unsigned currentNumOperands);
452   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
453                                              unsigned currentNumResults);
454 
455   //===--------------------------------------------------------------------===//
456   // Lexer Utilities
457   //===--------------------------------------------------------------------===//
458 
459   /// If the current token has the specified kind, consume it and return true.
460   /// If not, return false.
461   bool consumeIf(Token::Kind kind) {
462     if (curToken.isNot(kind))
463       return false;
464     consumeToken(kind);
465     return true;
466   }
467 
468   /// Advance the current lexer onto the next token.
469   void consumeToken() {
470     assert(curToken.isNot(Token::eof, Token::error) &&
471            "shouldn't advance past EOF or errors");
472     curToken = lexer.lexToken();
473   }
474 
475   /// Advance the current lexer onto the next token, asserting what the expected
476   /// current token is. This is preferred to the above method because it leads
477   /// to more self-documenting code with better checking.
478   void consumeToken(Token::Kind kind) {
479     assert(curToken.is(kind) && "consumed an unexpected token");
480     consumeToken();
481   }
482 
483   /// Reset the lexer to the location at the given position.
484   void resetToken(SMRange tokLoc) {
485     lexer.resetPointer(tokLoc.Start.getPointer());
486     curToken = lexer.lexToken();
487   }
488 
489   /// Consume the specified token if present and return success. On failure,
490   /// output a diagnostic and return failure.
491   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
492     if (curToken.getKind() != kind)
493       return emitError(curToken.getLoc(), msg);
494     consumeToken();
495     return success();
496   }
497   LogicalResult emitError(SMRange loc, const Twine &msg) {
498     lexer.emitError(loc, msg);
499     return failure();
500   }
501   LogicalResult emitError(const Twine &msg) {
502     return emitError(curToken.getLoc(), msg);
503   }
504   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
505                                  const Twine &note) {
506     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
507     return failure();
508   }
509 
510   //===--------------------------------------------------------------------===//
511   // Fields
512   //===--------------------------------------------------------------------===//
513 
514   /// The owning AST context.
515   ast::Context &ctx;
516 
517   /// The lexer of this parser.
518   Lexer lexer;
519 
520   /// The current token within the lexer.
521   Token curToken;
522 
523   /// The most recently defined decl scope.
524   ast::DeclScope *curDeclScope = nullptr;
525   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
526 
527   /// The current context of the parser.
528   ParserContext parserContext = ParserContext::Global;
529 
530   /// Cached types to simplify verification and expression creation.
531   ast::Type valueTy, valueRangeTy;
532   ast::Type typeTy, typeRangeTy;
533   ast::Type attrTy;
534 
535   /// A counter used when naming anonymous constraints and rewrites.
536   unsigned anonymousDeclNameCounter = 0;
537 
538   /// The optional code completion context.
539   CodeCompleteContext *codeCompleteContext;
540 };
541 } // namespace
542 
543 FailureOr<ast::Module *> Parser::parseModule() {
544   SMLoc moduleLoc = curToken.getStartLoc();
545   pushDeclScope();
546 
547   // Parse the top-level decls of the module.
548   SmallVector<ast::Decl *> decls;
549   if (failed(parseModuleBody(decls)))
550     return popDeclScope(), failure();
551 
552   popDeclScope();
553   return ast::Module::create(ctx, moduleLoc, decls);
554 }
555 
556 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
557   while (curToken.isNot(Token::eof)) {
558     if (curToken.is(Token::directive)) {
559       if (failed(parseDirective(decls)))
560         return failure();
561       continue;
562     }
563 
564     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
565     if (failed(decl))
566       return failure();
567     decls.push_back(*decl);
568   }
569   return success();
570 }
571 
572 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
573   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
574                                                  valueRangeTy);
575 }
576 
577 LogicalResult Parser::convertExpressionTo(
578     ast::Expr *&expr, ast::Type type,
579     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
580   ast::Type exprType = expr->getType();
581   if (exprType == type)
582     return success();
583 
584   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
585     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
586         expr->getLoc(), llvm::formatv("unable to convert expression of type "
587                                       "`{0}` to the expected type of "
588                                       "`{1}`",
589                                       exprType, type));
590     if (noteAttachFn)
591       noteAttachFn(*diag);
592     return diag;
593   };
594 
595   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
596     // Two operation types are compatible if they have the same name, or if the
597     // expected type is more general.
598     if (auto opType = type.dyn_cast<ast::OperationType>()) {
599       if (opType.getName())
600         return emitConvertError();
601       return success();
602     }
603 
604     // An operation can always convert to a ValueRange.
605     if (type == valueRangeTy) {
606       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
607                                                      valueRangeTy);
608       return success();
609     }
610 
611     // Allow conversion to a single value by constraining the result range.
612     if (type == valueTy) {
613       // If the operation is registered, we can verify if it can ever have a
614       // single result.
615       if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
616         if (odsOp->getResults().empty()) {
617           return emitConvertError()->attachNote(
618               llvm::formatv("see the definition of `{0}`, which was defined "
619                             "with zero results",
620                             odsOp->getName()),
621               odsOp->getLoc());
622         }
623 
624         unsigned numSingleResults = llvm::count_if(
625             odsOp->getResults(), [](const ods::OperandOrResult &result) {
626               return result.getVariableLengthKind() ==
627                      ods::VariableLengthKind::Single;
628             });
629         if (numSingleResults > 1) {
630           return emitConvertError()->attachNote(
631               llvm::formatv("see the definition of `{0}`, which was defined "
632                             "with at least {1} results",
633                             odsOp->getName(), numSingleResults),
634               odsOp->getLoc());
635         }
636       }
637 
638       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
639                                                      valueTy);
640       return success();
641     }
642     return emitConvertError();
643   }
644 
645   // FIXME: Decide how to allow/support converting a single result to multiple,
646   // and multiple to a single result. For now, we just allow Single->Range,
647   // but this isn't something really supported in the PDL dialect. We should
648   // figure out some way to support both.
649   if ((exprType == valueTy || exprType == valueRangeTy) &&
650       (type == valueTy || type == valueRangeTy))
651     return success();
652   if ((exprType == typeTy || exprType == typeRangeTy) &&
653       (type == typeTy || type == typeRangeTy))
654     return success();
655 
656   // Handle tuple types.
657   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
658     auto tupleType = type.dyn_cast<ast::TupleType>();
659     if (!tupleType || tupleType.size() != exprTupleType.size())
660       return emitConvertError();
661 
662     // Build a new tuple expression using each of the elements of the current
663     // tuple.
664     SmallVector<ast::Expr *> newExprs;
665     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
666       newExprs.push_back(ast::MemberAccessExpr::create(
667           ctx, expr->getLoc(), expr, llvm::to_string(i),
668           exprTupleType.getElementTypes()[i]));
669 
670       auto diagFn = [&](ast::Diagnostic &diag) {
671         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
672                                       i, exprTupleType));
673         if (noteAttachFn)
674           noteAttachFn(diag);
675       };
676       if (failed(convertExpressionTo(newExprs.back(),
677                                      tupleType.getElementTypes()[i], diagFn)))
678         return failure();
679     }
680     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
681                                   tupleType.getElementNames());
682     return success();
683   }
684 
685   return emitConvertError();
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // Directives
690 
691 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
692   StringRef directive = curToken.getSpelling();
693   if (directive == "#include")
694     return parseInclude(decls);
695 
696   return emitError("unknown directive `" + directive + "`");
697 }
698 
699 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
700   SMRange loc = curToken.getLoc();
701   consumeToken(Token::directive);
702 
703   // Handle code completion of the include file path.
704   if (curToken.is(Token::code_complete_string))
705     return codeCompleteIncludeFilename(curToken.getStringValue());
706 
707   // Parse the file being included.
708   if (!curToken.isString())
709     return emitError(loc,
710                      "expected string file name after `include` directive");
711   SMRange fileLoc = curToken.getLoc();
712   std::string filenameStr = curToken.getStringValue();
713   StringRef filename = filenameStr;
714   consumeToken();
715 
716   // Check the type of include. If ending with `.pdll`, this is another pdl file
717   // to be parsed along with the current module.
718   if (filename.endswith(".pdll")) {
719     if (failed(lexer.pushInclude(filename, fileLoc)))
720       return emitError(fileLoc,
721                        "unable to open include file `" + filename + "`");
722 
723     // If we added the include successfully, parse it into the current module.
724     // Make sure to update to the next token after we finish parsing the nested
725     // file.
726     curToken = lexer.lexToken();
727     LogicalResult result = parseModuleBody(decls);
728     curToken = lexer.lexToken();
729     return result;
730   }
731 
732   // Otherwise, this must be a `.td` include.
733   if (filename.endswith(".td"))
734     return parseTdInclude(filename, fileLoc, decls);
735 
736   return emitError(fileLoc,
737                    "expected include filename to end with `.pdll` or `.td`");
738 }
739 
740 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
741                                      SmallVectorImpl<ast::Decl *> &decls) {
742   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
743 
744   // Use the source manager to open the file, but don't yet add it.
745   std::string includedFile;
746   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
747       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
748   if (!includeBuffer)
749     return emitError(fileLoc, "unable to open include file `" + filename + "`");
750 
751   // Setup the source manager for parsing the tablegen file.
752   llvm::SourceMgr tdSrcMgr;
753   tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
754   tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
755 
756   // This class provides a context argument for the llvm::SourceMgr diagnostic
757   // handler.
758   struct DiagHandlerContext {
759     Parser &parser;
760     StringRef filename;
761     llvm::SMRange loc;
762   } handlerContext{*this, filename, fileLoc};
763 
764   // Set the diagnostic handler for the tablegen source manager.
765   tdSrcMgr.setDiagHandler(
766       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
767         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
768         (void)ctx->parser.emitError(
769             ctx->loc,
770             llvm::formatv("error while processing include file `{0}`: {1}",
771                           ctx->filename, diag.getMessage()));
772       },
773       &handlerContext);
774 
775   // Parse the tablegen file.
776   llvm::RecordKeeper tdRecords;
777   if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
778     return failure();
779 
780   // Process the parsed records.
781   processTdIncludeRecords(tdRecords, decls);
782 
783   // After we are done processing, move all of the tablegen source buffers to
784   // the main parser source mgr. This allows for directly using source locations
785   // from the .td files without needing to remap them.
786   parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
787   return success();
788 }
789 
790 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
791                                      SmallVectorImpl<ast::Decl *> &decls) {
792   // Return the length kind of the given value.
793   auto getLengthKind = [](const auto &value) {
794     if (value.isOptional())
795       return ods::VariableLengthKind::Optional;
796     return value.isVariadic() ? ods::VariableLengthKind::Variadic
797                               : ods::VariableLengthKind::Single;
798   };
799 
800   // Insert a type constraint into the ODS context.
801   ods::Context &odsContext = ctx.getODSContext();
802   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
803       -> const ods::TypeConstraint & {
804     return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(),
805                                            cst.constraint.getSummary(),
806                                            cst.constraint.getCPPClassName());
807   };
808   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
809     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
810   };
811 
812   // Process the parsed tablegen records to build ODS information.
813   /// Operations.
814   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
815     tblgen::Operator op(def);
816 
817     // Check to see if this operation is known to support type inferrence.
818     bool supportsResultTypeInferrence =
819         op.getTrait("::mlir::InferTypeOpInterface::Trait");
820 
821     bool inserted = false;
822     ods::Operation *odsOp = nullptr;
823     std::tie(odsOp, inserted) = odsContext.insertOperation(
824         op.getOperationName(), op.getSummary(), op.getDescription(),
825         op.getQualCppClassName(), supportsResultTypeInferrence,
826         op.getLoc().front());
827 
828     // Ignore operations that have already been added.
829     if (!inserted)
830       continue;
831 
832     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
833       odsOp->appendAttribute(
834           attr.name, attr.attr.isOptional(),
835           odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(),
836                                                attr.attr.getSummary(),
837                                                attr.attr.getStorageType()));
838     }
839     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
840       odsOp->appendOperand(operand.name, getLengthKind(operand),
841                            addTypeConstraint(operand));
842     }
843     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
844       odsOp->appendResult(result.name, getLengthKind(result),
845                           addTypeConstraint(result));
846     }
847   }
848   /// Attr constraints.
849   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
850     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
851       tblgen::Attribute constraint(def);
852       decls.push_back(
853           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
854               constraint, convertLocToRange(def->getLoc().front()), attrTy,
855               constraint.getStorageType()));
856     }
857   }
858   /// Type constraints.
859   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
860     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
861       tblgen::TypeConstraint constraint(def);
862       decls.push_back(
863           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
864               constraint, convertLocToRange(def->getLoc().front()), typeTy,
865               constraint.getCPPClassName()));
866     }
867   }
868   /// Interfaces.
869   ast::Type opTy = ast::OperationType::get(ctx);
870   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
871     StringRef name = def->getName();
872     if (def->isAnonymous() || curDeclScope->lookup(name) ||
873         def->isSubClassOf("DeclareInterfaceMethods"))
874       continue;
875     SMRange loc = convertLocToRange(def->getLoc().front());
876 
877     std::string cppClassName =
878         llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
879                       def->getValueAsString("cppClassName"))
880             .str();
881     std::string codeBlock =
882         llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
883                       cppClassName)
884             .str();
885 
886     if (def->isSubClassOf("OpInterface")) {
887       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
888           name, codeBlock, loc, opTy, cppClassName));
889     } else if (def->isSubClassOf("AttrInterface")) {
890       decls.push_back(
891           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
892               name, codeBlock, loc, attrTy, cppClassName));
893     } else if (def->isSubClassOf("TypeInterface")) {
894       decls.push_back(
895           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
896               name, codeBlock, loc, typeTy, cppClassName));
897     }
898   }
899 }
900 
901 template <typename ConstraintT>
902 ast::Decl *
903 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
904                                           SMRange loc, ast::Type type,
905                                           StringRef nativeType) {
906   // Build the single input parameter.
907   ast::DeclScope *argScope = pushDeclScope();
908   auto *paramVar = ast::VariableDecl::create(
909       ctx, ast::Name::create(ctx, "self", loc), type,
910       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
911   argScope->add(paramVar);
912   popDeclScope();
913 
914   // Build the native constraint.
915   auto *constraintDecl = ast::UserConstraintDecl::createNative(
916       ctx, ast::Name::create(ctx, name, loc), paramVar,
917       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
918   curDeclScope->add(constraintDecl);
919   return constraintDecl;
920 }
921 
922 template <typename ConstraintT>
923 ast::Decl *
924 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
925                                           SMRange loc, ast::Type type,
926                                           StringRef nativeType) {
927   // Format the condition template.
928   tblgen::FmtContext fmtContext;
929   fmtContext.withSelf("self");
930   std::string codeBlock = tblgen::tgfmt(
931       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
932       &fmtContext);
933 
934   return createODSNativePDLLConstraintDecl<ConstraintT>(
935       constraint.getUniqueDefName(), codeBlock, loc, type, nativeType);
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // Decls
940 
941 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
942   FailureOr<ast::Decl *> decl;
943   switch (curToken.getKind()) {
944   case Token::kw_Constraint:
945     decl = parseUserConstraintDecl();
946     break;
947   case Token::kw_Pattern:
948     decl = parsePatternDecl();
949     break;
950   case Token::kw_Rewrite:
951     decl = parseUserRewriteDecl();
952     break;
953   default:
954     return emitError("expected top-level declaration, such as a `Pattern`");
955   }
956   if (failed(decl))
957     return failure();
958 
959   // If the decl has a name, add it to the current scope.
960   if (const ast::Name *name = (*decl)->getName()) {
961     if (failed(checkDefineNamedDecl(*name)))
962       return failure();
963     curDeclScope->add(*decl);
964   }
965   return decl;
966 }
967 
968 FailureOr<ast::NamedAttributeDecl *>
969 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
970   // Check for name code completion.
971   if (curToken.is(Token::code_complete))
972     return codeCompleteAttributeName(parentOpName);
973 
974   std::string attrNameStr;
975   if (curToken.isString())
976     attrNameStr = curToken.getStringValue();
977   else if (curToken.is(Token::identifier) || curToken.isKeyword())
978     attrNameStr = curToken.getSpelling().str();
979   else
980     return emitError("expected identifier or string attribute name");
981   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
982   consumeToken();
983 
984   // Check for a value of the attribute.
985   ast::Expr *attrValue = nullptr;
986   if (consumeIf(Token::equal)) {
987     FailureOr<ast::Expr *> attrExpr = parseExpr();
988     if (failed(attrExpr))
989       return failure();
990     attrValue = *attrExpr;
991   } else {
992     // If there isn't a concrete value, create an expression representing a
993     // UnitAttr.
994     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
995   }
996 
997   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
998 }
999 
1000 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1001     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1002     bool expectTerminalSemicolon) {
1003   consumeToken(Token::equal_arrow);
1004 
1005   // Parse the single statement of the lambda body.
1006   SMLoc bodyStartLoc = curToken.getStartLoc();
1007   pushDeclScope();
1008   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1009   bool failedToParse =
1010       failed(singleStatement) || failed(processStatementFn(*singleStatement));
1011   popDeclScope();
1012   if (failedToParse)
1013     return failure();
1014 
1015   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1016   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1017 }
1018 
1019 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1020   // Ensure that the argument is named.
1021   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1022     return emitError("expected identifier argument name");
1023 
1024   // Parse the argument similarly to a normal variable.
1025   StringRef name = curToken.getSpelling();
1026   SMRange nameLoc = curToken.getLoc();
1027   consumeToken();
1028 
1029   if (failed(
1030           parseToken(Token::colon, "expected `:` before argument constraint")))
1031     return failure();
1032 
1033   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1034   if (failed(cst))
1035     return failure();
1036 
1037   return createArgOrResultVariableDecl(name, nameLoc, *cst);
1038 }
1039 
1040 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1041   // Check to see if this result is named.
1042   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1043     // Check to see if this name actually refers to a Constraint.
1044     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1045     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1046       // If yes, and this is a Rewrite, give a nice error message as non-Core
1047       // constraints are not supported on Rewrite results.
1048       if (parserContext == ParserContext::Rewrite) {
1049         return emitError(
1050             "`Rewrite` results are only permitted to use core constraints, "
1051             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1052       }
1053 
1054       // Otherwise, parse this as an unnamed result variable.
1055     } else {
1056       // If it wasn't a constraint, parse the result similarly to a variable. If
1057       // there is already an existing decl, we will emit an error when defining
1058       // this variable later.
1059       StringRef name = curToken.getSpelling();
1060       SMRange nameLoc = curToken.getLoc();
1061       consumeToken();
1062 
1063       if (failed(parseToken(Token::colon,
1064                             "expected `:` before result constraint")))
1065         return failure();
1066 
1067       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1068       if (failed(cst))
1069         return failure();
1070 
1071       return createArgOrResultVariableDecl(name, nameLoc, *cst);
1072     }
1073   }
1074 
1075   // If it isn't named, we parse the constraint directly and create an unnamed
1076   // result variable.
1077   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1078   if (failed(cst))
1079     return failure();
1080 
1081   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1082 }
1083 
1084 FailureOr<ast::UserConstraintDecl *>
1085 Parser::parseUserConstraintDecl(bool isInline) {
1086   // Constraints and rewrites have very similar formats, dispatch to a shared
1087   // interface for parsing.
1088   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1089       [&](auto &&...args) {
1090         return this->parseUserPDLLConstraintDecl(args...);
1091       },
1092       ParserContext::Constraint, "constraint", isInline);
1093 }
1094 
1095 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1096   FailureOr<ast::UserConstraintDecl *> decl =
1097       parseUserConstraintDecl(/*isInline=*/true);
1098   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1099     return failure();
1100 
1101   curDeclScope->add(*decl);
1102   return decl;
1103 }
1104 
1105 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1106     const ast::Name &name, bool isInline,
1107     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1108     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1109   // Push the argument scope back onto the list, so that the body can
1110   // reference arguments.
1111   pushDeclScope(argumentScope);
1112 
1113   // Parse the body of the constraint. The body is either defined as a compound
1114   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1115   ast::CompoundStmt *body;
1116   if (curToken.is(Token::equal_arrow)) {
1117     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1118         [&](ast::Stmt *&stmt) -> LogicalResult {
1119           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1120           if (!stmtExpr) {
1121             return emitError(stmt->getLoc(),
1122                              "expected `Constraint` lambda body to contain a "
1123                              "single expression");
1124           }
1125           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1126           return success();
1127         },
1128         /*expectTerminalSemicolon=*/!isInline);
1129     if (failed(bodyResult))
1130       return failure();
1131     body = *bodyResult;
1132   } else {
1133     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1134     if (failed(bodyResult))
1135       return failure();
1136     body = *bodyResult;
1137 
1138     // Verify the structure of the body.
1139     auto bodyIt = body->begin(), bodyE = body->end();
1140     for (; bodyIt != bodyE; ++bodyIt)
1141       if (isa<ast::ReturnStmt>(*bodyIt))
1142         break;
1143     if (failed(validateUserConstraintOrRewriteReturn(
1144             "Constraint", body, bodyIt, bodyE, results, resultType)))
1145       return failure();
1146   }
1147   popDeclScope();
1148 
1149   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1150       name, arguments, results, resultType, body);
1151 }
1152 
1153 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1154   // Constraints and rewrites have very similar formats, dispatch to a shared
1155   // interface for parsing.
1156   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1157       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1158       ParserContext::Rewrite, "rewrite", isInline);
1159 }
1160 
1161 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1162   FailureOr<ast::UserRewriteDecl *> decl =
1163       parseUserRewriteDecl(/*isInline=*/true);
1164   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1165     return failure();
1166 
1167   curDeclScope->add(*decl);
1168   return decl;
1169 }
1170 
1171 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1172     const ast::Name &name, bool isInline,
1173     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1174     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1175   // Push the argument scope back onto the list, so that the body can
1176   // reference arguments.
1177   curDeclScope = argumentScope;
1178   ast::CompoundStmt *body;
1179   if (curToken.is(Token::equal_arrow)) {
1180     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1181         [&](ast::Stmt *&statement) -> LogicalResult {
1182           if (isa<ast::OpRewriteStmt>(statement))
1183             return success();
1184 
1185           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1186           if (!statementExpr) {
1187             return emitError(
1188                 statement->getLoc(),
1189                 "expected `Rewrite` lambda body to contain a single expression "
1190                 "or an operation rewrite statement; such as `erase`, "
1191                 "`replace`, or `rewrite`");
1192           }
1193           statement =
1194               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1195           return success();
1196         },
1197         /*expectTerminalSemicolon=*/!isInline);
1198     if (failed(bodyResult))
1199       return failure();
1200     body = *bodyResult;
1201   } else {
1202     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1203     if (failed(bodyResult))
1204       return failure();
1205     body = *bodyResult;
1206   }
1207   popDeclScope();
1208 
1209   // Verify the structure of the body.
1210   auto bodyIt = body->begin(), bodyE = body->end();
1211   for (; bodyIt != bodyE; ++bodyIt)
1212     if (isa<ast::ReturnStmt>(*bodyIt))
1213       break;
1214   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1215                                                    bodyE, results, resultType)))
1216     return failure();
1217   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1218       name, arguments, results, resultType, body);
1219 }
1220 
1221 template <typename T, typename ParseUserPDLLDeclFnT>
1222 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1223     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1224     StringRef anonymousNamePrefix, bool isInline) {
1225   SMRange loc = curToken.getLoc();
1226   consumeToken();
1227   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1228 
1229   // Parse the name of the decl.
1230   const ast::Name *name = nullptr;
1231   if (curToken.isNot(Token::identifier)) {
1232     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1233     // in C++, so being unnamed is fine.
1234     if (!isInline)
1235       return emitError("expected identifier name");
1236 
1237     // Create a unique anonymous name to use, as the name for this decl is not
1238     // important.
1239     std::string anonName =
1240         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1241                       anonymousDeclNameCounter++)
1242             .str();
1243     name = &ast::Name::create(ctx, anonName, loc);
1244   } else {
1245     // If a name was provided, we can use it directly.
1246     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1247     consumeToken(Token::identifier);
1248   }
1249 
1250   // Parse the functional signature of the decl.
1251   SmallVector<ast::VariableDecl *> arguments, results;
1252   ast::DeclScope *argumentScope;
1253   ast::Type resultType;
1254   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1255                                                    argumentScope, resultType)))
1256     return failure();
1257 
1258   // Check to see which type of constraint this is. If the constraint contains a
1259   // compound body, this is a PDLL decl.
1260   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1261     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1262                            resultType);
1263 
1264   // Otherwise, this is a native decl.
1265   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1266                                                    results, resultType);
1267 }
1268 
1269 template <typename T>
1270 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1271     const ast::Name &name, bool isInline,
1272     ArrayRef<ast::VariableDecl *> arguments,
1273     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1274   // If followed by a string, the native code body has also been specified.
1275   std::string codeStrStorage;
1276   Optional<StringRef> optCodeStr;
1277   if (curToken.isString()) {
1278     codeStrStorage = curToken.getStringValue();
1279     optCodeStr = codeStrStorage;
1280     consumeToken();
1281   } else if (isInline) {
1282     return emitError(name.getLoc(),
1283                      "external declarations must be declared in global scope");
1284   } else if (curToken.is(Token::error)) {
1285     return failure();
1286   }
1287   if (failed(parseToken(Token::semicolon,
1288                         "expected `;` after native declaration")))
1289     return failure();
1290   // TODO: PDL should be able to support constraint results in certain
1291   // situations, we should revise this.
1292   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1293     return emitError(
1294         "native Constraints currently do not support returning results");
1295   }
1296   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1297 }
1298 
1299 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1300     SmallVectorImpl<ast::VariableDecl *> &arguments,
1301     SmallVectorImpl<ast::VariableDecl *> &results,
1302     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1303   // Parse the argument list of the decl.
1304   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1305     return failure();
1306 
1307   argumentScope = pushDeclScope();
1308   if (curToken.isNot(Token::r_paren)) {
1309     do {
1310       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1311       if (failed(argument))
1312         return failure();
1313       arguments.emplace_back(*argument);
1314     } while (consumeIf(Token::comma));
1315   }
1316   popDeclScope();
1317   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1318     return failure();
1319 
1320   // Parse the results of the decl.
1321   pushDeclScope();
1322   if (consumeIf(Token::arrow)) {
1323     auto parseResultFn = [&]() -> LogicalResult {
1324       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1325       if (failed(result))
1326         return failure();
1327       results.emplace_back(*result);
1328       return success();
1329     };
1330 
1331     // Check for a list of results.
1332     if (consumeIf(Token::l_paren)) {
1333       do {
1334         if (failed(parseResultFn()))
1335           return failure();
1336       } while (consumeIf(Token::comma));
1337       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1338         return failure();
1339 
1340       // Otherwise, there is only one result.
1341     } else if (failed(parseResultFn())) {
1342       return failure();
1343     }
1344   }
1345   popDeclScope();
1346 
1347   // Compute the result type of the decl.
1348   resultType = createUserConstraintRewriteResultType(results);
1349 
1350   // Verify that results are only named if there are more than one.
1351   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1352     return emitError(
1353         results.front()->getLoc(),
1354         "cannot create a single-element tuple with an element label");
1355   }
1356   return success();
1357 }
1358 
1359 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1360     StringRef declType, ast::CompoundStmt *body,
1361     ArrayRef<ast::Stmt *>::iterator bodyIt,
1362     ArrayRef<ast::Stmt *>::iterator bodyE,
1363     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1364   // Handle if a `return` was provided.
1365   if (bodyIt != bodyE) {
1366     // Emit an error if we have trailing statements after the return.
1367     if (std::next(bodyIt) != bodyE) {
1368       return emitError(
1369           (*std::next(bodyIt))->getLoc(),
1370           llvm::formatv("`return` terminated the `{0}` body, but found "
1371                         "trailing statements afterwards",
1372                         declType));
1373     }
1374 
1375     // Otherwise if a return wasn't provided, check that no results are
1376     // expected.
1377   } else if (!results.empty()) {
1378     return emitError(
1379         {body->getLoc().End, body->getLoc().End},
1380         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1381                       declType, resultType));
1382   }
1383   return success();
1384 }
1385 
1386 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1387   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1388     if (isa<ast::OpRewriteStmt>(statement))
1389       return success();
1390     return emitError(
1391         statement->getLoc(),
1392         "expected Pattern lambda body to contain a single operation "
1393         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1394   });
1395 }
1396 
1397 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1398   SMRange loc = curToken.getLoc();
1399   consumeToken(Token::kw_Pattern);
1400   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1401                                               ParserContext::PatternMatch);
1402 
1403   // Check for an optional identifier for the pattern name.
1404   const ast::Name *name = nullptr;
1405   if (curToken.is(Token::identifier)) {
1406     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1407     consumeToken(Token::identifier);
1408   }
1409 
1410   // Parse any pattern metadata.
1411   ParsedPatternMetadata metadata;
1412   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1413     return failure();
1414 
1415   // Parse the pattern body.
1416   ast::CompoundStmt *body;
1417 
1418   // Handle a lambda body.
1419   if (curToken.is(Token::equal_arrow)) {
1420     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1421     if (failed(bodyResult))
1422       return failure();
1423     body = *bodyResult;
1424   } else {
1425     if (curToken.isNot(Token::l_brace))
1426       return emitError("expected `{` or `=>` to start pattern body");
1427     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1428     if (failed(bodyResult))
1429       return failure();
1430     body = *bodyResult;
1431 
1432     // Verify the body of the pattern.
1433     auto bodyIt = body->begin(), bodyE = body->end();
1434     for (; bodyIt != bodyE; ++bodyIt) {
1435       if (isa<ast::ReturnStmt>(*bodyIt)) {
1436         return emitError((*bodyIt)->getLoc(),
1437                          "`return` statements are only permitted within a "
1438                          "`Constraint` or `Rewrite` body");
1439       }
1440       // Break when we've found the rewrite statement.
1441       if (isa<ast::OpRewriteStmt>(*bodyIt))
1442         break;
1443     }
1444     if (bodyIt == bodyE) {
1445       return emitError(loc,
1446                        "expected Pattern body to terminate with an operation "
1447                        "rewrite statement, such as `erase`");
1448     }
1449     if (std::next(bodyIt) != bodyE) {
1450       return emitError((*std::next(bodyIt))->getLoc(),
1451                        "Pattern body was terminated by an operation "
1452                        "rewrite statement, but found trailing statements");
1453     }
1454   }
1455 
1456   return createPatternDecl(loc, name, metadata, body);
1457 }
1458 
1459 LogicalResult
1460 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1461   Optional<SMRange> benefitLoc;
1462   Optional<SMRange> hasBoundedRecursionLoc;
1463 
1464   do {
1465     // Handle metadata code completion.
1466     if (curToken.is(Token::code_complete))
1467       return codeCompletePatternMetadata();
1468 
1469     if (curToken.isNot(Token::identifier))
1470       return emitError("expected pattern metadata identifier");
1471     StringRef metadataStr = curToken.getSpelling();
1472     SMRange metadataLoc = curToken.getLoc();
1473     consumeToken(Token::identifier);
1474 
1475     // Parse the benefit metadata: benefit(<integer-value>)
1476     if (metadataStr == "benefit") {
1477       if (benefitLoc) {
1478         return emitErrorAndNote(metadataLoc,
1479                                 "pattern benefit has already been specified",
1480                                 *benefitLoc, "see previous definition here");
1481       }
1482       if (failed(parseToken(Token::l_paren,
1483                             "expected `(` before pattern benefit")))
1484         return failure();
1485 
1486       uint16_t benefitValue = 0;
1487       if (curToken.isNot(Token::integer))
1488         return emitError("expected integral pattern benefit");
1489       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1490         return emitError(
1491             "expected pattern benefit to fit within a 16-bit integer");
1492       consumeToken(Token::integer);
1493 
1494       metadata.benefit = benefitValue;
1495       benefitLoc = metadataLoc;
1496 
1497       if (failed(
1498               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1499         return failure();
1500       continue;
1501     }
1502 
1503     // Parse the bounded recursion metadata: recursion
1504     if (metadataStr == "recursion") {
1505       if (hasBoundedRecursionLoc) {
1506         return emitErrorAndNote(
1507             metadataLoc,
1508             "pattern recursion metadata has already been specified",
1509             *hasBoundedRecursionLoc, "see previous definition here");
1510       }
1511       metadata.hasBoundedRecursion = true;
1512       hasBoundedRecursionLoc = metadataLoc;
1513       continue;
1514     }
1515 
1516     return emitError(metadataLoc, "unknown pattern metadata");
1517   } while (consumeIf(Token::comma));
1518 
1519   return success();
1520 }
1521 
1522 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1523   consumeToken(Token::less);
1524 
1525   FailureOr<ast::Expr *> typeExpr = parseExpr();
1526   if (failed(typeExpr) ||
1527       failed(parseToken(Token::greater,
1528                         "expected `>` after variable type constraint")))
1529     return failure();
1530   return typeExpr;
1531 }
1532 
1533 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1534   assert(curDeclScope && "defining decl outside of a decl scope");
1535   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1536     return emitErrorAndNote(
1537         name.getLoc(), "`" + name.getName() + "` has already been defined",
1538         lastDecl->getName()->getLoc(), "see previous definition here");
1539   }
1540   return success();
1541 }
1542 
1543 FailureOr<ast::VariableDecl *>
1544 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1545                            ast::Expr *initExpr,
1546                            ArrayRef<ast::ConstraintRef> constraints) {
1547   assert(curDeclScope && "defining variable outside of decl scope");
1548   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1549 
1550   // If the name of the variable indicates a special variable, we don't add it
1551   // to the scope. This variable is local to the definition point.
1552   if (name.empty() || name == "_") {
1553     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1554                                      constraints);
1555   }
1556   if (failed(checkDefineNamedDecl(nameDecl)))
1557     return failure();
1558 
1559   auto *varDecl =
1560       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1561   curDeclScope->add(varDecl);
1562   return varDecl;
1563 }
1564 
1565 FailureOr<ast::VariableDecl *>
1566 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1567                            ArrayRef<ast::ConstraintRef> constraints) {
1568   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1569                             constraints);
1570 }
1571 
1572 LogicalResult Parser::parseVariableDeclConstraintList(
1573     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1574   Optional<SMRange> typeConstraint;
1575   auto parseSingleConstraint = [&] {
1576     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1577         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1578         /*allowNonCoreConstraints=*/true);
1579     if (failed(constraint))
1580       return failure();
1581     constraints.push_back(*constraint);
1582     return success();
1583   };
1584 
1585   // Check to see if this is a single constraint, or a list.
1586   if (!consumeIf(Token::l_square))
1587     return parseSingleConstraint();
1588 
1589   do {
1590     if (failed(parseSingleConstraint()))
1591       return failure();
1592   } while (consumeIf(Token::comma));
1593   return parseToken(Token::r_square, "expected `]` after constraint list");
1594 }
1595 
1596 FailureOr<ast::ConstraintRef>
1597 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1598                         ArrayRef<ast::ConstraintRef> existingConstraints,
1599                         bool allowInlineTypeConstraints,
1600                         bool allowNonCoreConstraints) {
1601   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1602     if (!allowInlineTypeConstraints) {
1603       return emitError(
1604           curToken.getLoc(),
1605           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1606           "permitted on arguments or results");
1607     }
1608     if (typeConstraint)
1609       return emitErrorAndNote(
1610           curToken.getLoc(),
1611           "the type of this variable has already been constrained",
1612           *typeConstraint, "see previous constraint location here");
1613     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1614     if (failed(constraintExpr))
1615       return failure();
1616     typeExpr = *constraintExpr;
1617     typeConstraint = typeExpr->getLoc();
1618     return success();
1619   };
1620 
1621   SMRange loc = curToken.getLoc();
1622   switch (curToken.getKind()) {
1623   case Token::kw_Attr: {
1624     consumeToken(Token::kw_Attr);
1625 
1626     // Check for a type constraint.
1627     ast::Expr *typeExpr = nullptr;
1628     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1629       return failure();
1630     return ast::ConstraintRef(
1631         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1632   }
1633   case Token::kw_Op: {
1634     consumeToken(Token::kw_Op);
1635 
1636     // Parse an optional operation name. If the name isn't provided, this refers
1637     // to "any" operation.
1638     FailureOr<ast::OpNameDecl *> opName =
1639         parseWrappedOperationName(/*allowEmptyName=*/true);
1640     if (failed(opName))
1641       return failure();
1642 
1643     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1644                               loc);
1645   }
1646   case Token::kw_Type:
1647     consumeToken(Token::kw_Type);
1648     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1649   case Token::kw_TypeRange:
1650     consumeToken(Token::kw_TypeRange);
1651     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1652                               loc);
1653   case Token::kw_Value: {
1654     consumeToken(Token::kw_Value);
1655 
1656     // Check for a type constraint.
1657     ast::Expr *typeExpr = nullptr;
1658     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1659       return failure();
1660 
1661     return ast::ConstraintRef(
1662         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1663   }
1664   case Token::kw_ValueRange: {
1665     consumeToken(Token::kw_ValueRange);
1666 
1667     // Check for a type constraint.
1668     ast::Expr *typeExpr = nullptr;
1669     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1670       return failure();
1671 
1672     return ast::ConstraintRef(
1673         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1674   }
1675 
1676   case Token::kw_Constraint: {
1677     // Handle an inline constraint.
1678     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1679     if (failed(decl))
1680       return failure();
1681     return ast::ConstraintRef(*decl, loc);
1682   }
1683   case Token::identifier: {
1684     StringRef constraintName = curToken.getSpelling();
1685     consumeToken(Token::identifier);
1686 
1687     // Lookup the referenced constraint.
1688     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1689     if (!cstDecl) {
1690       return emitError(loc, "unknown reference to constraint `" +
1691                                 constraintName + "`");
1692     }
1693 
1694     // Handle a reference to a proper constraint.
1695     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1696       return ast::ConstraintRef(cst, loc);
1697 
1698     return emitErrorAndNote(
1699         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1700         "see the definition of `" + constraintName + "` here");
1701   }
1702     // Handle single entity constraint code completion.
1703   case Token::code_complete: {
1704     // Try to infer the current type for use by code completion.
1705     ast::Type inferredType;
1706     if (failed(validateVariableConstraints(existingConstraints, inferredType,
1707                                            allowNonCoreConstraints)))
1708       return failure();
1709 
1710     return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1711                                       allowInlineTypeConstraints);
1712   }
1713   default:
1714     break;
1715   }
1716   return emitError(loc, "expected identifier constraint");
1717 }
1718 
1719 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1720   // Constraint arguments may apply more complex constraints via the arguments.
1721   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1722 
1723   Optional<SMRange> typeConstraint;
1724   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1725                          /*allowInlineTypeConstraints=*/false,
1726                          allowNonCoreConstraints);
1727 }
1728 
1729 //===----------------------------------------------------------------------===//
1730 // Exprs
1731 
1732 FailureOr<ast::Expr *> Parser::parseExpr() {
1733   if (curToken.is(Token::underscore))
1734     return parseUnderscoreExpr();
1735 
1736   // Parse the LHS expression.
1737   FailureOr<ast::Expr *> lhsExpr;
1738   switch (curToken.getKind()) {
1739   case Token::kw_attr:
1740     lhsExpr = parseAttributeExpr();
1741     break;
1742   case Token::kw_Constraint:
1743     lhsExpr = parseInlineConstraintLambdaExpr();
1744     break;
1745   case Token::identifier:
1746     lhsExpr = parseIdentifierExpr();
1747     break;
1748   case Token::kw_op:
1749     lhsExpr = parseOperationExpr();
1750     break;
1751   case Token::kw_Rewrite:
1752     lhsExpr = parseInlineRewriteLambdaExpr();
1753     break;
1754   case Token::kw_type:
1755     lhsExpr = parseTypeExpr();
1756     break;
1757   case Token::l_paren:
1758     lhsExpr = parseTupleExpr();
1759     break;
1760   default:
1761     return emitError("expected expression");
1762   }
1763   if (failed(lhsExpr))
1764     return failure();
1765 
1766   // Check for an operator expression.
1767   while (true) {
1768     switch (curToken.getKind()) {
1769     case Token::dot:
1770       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1771       break;
1772     case Token::l_paren:
1773       lhsExpr = parseCallExpr(*lhsExpr);
1774       break;
1775     default:
1776       return lhsExpr;
1777     }
1778     if (failed(lhsExpr))
1779       return failure();
1780   }
1781 }
1782 
1783 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1784   SMRange loc = curToken.getLoc();
1785   consumeToken(Token::kw_attr);
1786 
1787   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1788   // identifier.
1789   if (!consumeIf(Token::less)) {
1790     resetToken(loc);
1791     return parseIdentifierExpr();
1792   }
1793 
1794   if (!curToken.isString())
1795     return emitError("expected string literal containing MLIR attribute");
1796   std::string attrExpr = curToken.getStringValue();
1797   consumeToken();
1798 
1799   loc.End = curToken.getEndLoc();
1800   if (failed(
1801           parseToken(Token::greater, "expected `>` after attribute literal")))
1802     return failure();
1803   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1804 }
1805 
1806 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1807   consumeToken(Token::l_paren);
1808 
1809   // Parse the arguments of the call.
1810   SmallVector<ast::Expr *> arguments;
1811   if (curToken.isNot(Token::r_paren)) {
1812     do {
1813       // Handle code completion for the call arguments.
1814       if (curToken.is(Token::code_complete)) {
1815         codeCompleteCallSignature(parentExpr, arguments.size());
1816         return failure();
1817       }
1818 
1819       FailureOr<ast::Expr *> argument = parseExpr();
1820       if (failed(argument))
1821         return failure();
1822       arguments.push_back(*argument);
1823     } while (consumeIf(Token::comma));
1824   }
1825 
1826   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1827   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1828     return failure();
1829 
1830   return createCallExpr(loc, parentExpr, arguments);
1831 }
1832 
1833 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1834   ast::Decl *decl = curDeclScope->lookup(name);
1835   if (!decl)
1836     return emitError(loc, "undefined reference to `" + name + "`");
1837 
1838   return createDeclRefExpr(loc, decl);
1839 }
1840 
1841 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1842   StringRef name = curToken.getSpelling();
1843   SMRange nameLoc = curToken.getLoc();
1844   consumeToken();
1845 
1846   // Check to see if this is a decl ref expression that defines a variable
1847   // inline.
1848   if (consumeIf(Token::colon)) {
1849     SmallVector<ast::ConstraintRef> constraints;
1850     if (failed(parseVariableDeclConstraintList(constraints)))
1851       return failure();
1852     ast::Type type;
1853     if (failed(validateVariableConstraints(constraints, type)))
1854       return failure();
1855     return createInlineVariableExpr(type, name, nameLoc, constraints);
1856   }
1857 
1858   return parseDeclRefExpr(name, nameLoc);
1859 }
1860 
1861 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1862   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1863   if (failed(decl))
1864     return failure();
1865 
1866   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1867                                   ast::ConstraintType::get(ctx));
1868 }
1869 
1870 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1871   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1872   if (failed(decl))
1873     return failure();
1874 
1875   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1876                                   ast::RewriteType::get(ctx));
1877 }
1878 
1879 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1880   SMRange dotLoc = curToken.getLoc();
1881   consumeToken(Token::dot);
1882 
1883   // Check for code completion of the member name.
1884   if (curToken.is(Token::code_complete))
1885     return codeCompleteMemberAccess(parentExpr);
1886 
1887   // Parse the member name.
1888   Token memberNameTok = curToken;
1889   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1890       !memberNameTok.isKeyword())
1891     return emitError(dotLoc, "expected identifier or numeric member name");
1892   StringRef memberName = memberNameTok.getSpelling();
1893   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1894   consumeToken();
1895 
1896   return createMemberAccessExpr(parentExpr, memberName, loc);
1897 }
1898 
1899 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1900   SMRange loc = curToken.getLoc();
1901 
1902   // Check for code completion for the dialect name.
1903   if (curToken.is(Token::code_complete))
1904     return codeCompleteDialectName();
1905 
1906   // Handle the case of an no operation name.
1907   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1908     if (allowEmptyName)
1909       return ast::OpNameDecl::create(ctx, SMRange());
1910     return emitError("expected dialect namespace");
1911   }
1912   StringRef name = curToken.getSpelling();
1913   consumeToken();
1914 
1915   // Otherwise, this is a literal operation name.
1916   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1917     return failure();
1918 
1919   // Check for code completion for the operation name.
1920   if (curToken.is(Token::code_complete))
1921     return codeCompleteOperationName(name);
1922 
1923   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1924     return emitError("expected operation name after dialect namespace");
1925 
1926   name = StringRef(name.data(), name.size() + 1);
1927   do {
1928     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1929     loc.End = curToken.getEndLoc();
1930     consumeToken();
1931   } while (curToken.isAny(Token::identifier, Token::dot) ||
1932            curToken.isKeyword());
1933   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1934 }
1935 
1936 FailureOr<ast::OpNameDecl *>
1937 Parser::parseWrappedOperationName(bool allowEmptyName) {
1938   if (!consumeIf(Token::less))
1939     return ast::OpNameDecl::create(ctx, SMRange());
1940 
1941   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1942   if (failed(opNameDecl))
1943     return failure();
1944 
1945   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1946     return failure();
1947   return opNameDecl;
1948 }
1949 
1950 FailureOr<ast::Expr *>
1951 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
1952   SMRange loc = curToken.getLoc();
1953   consumeToken(Token::kw_op);
1954 
1955   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1956   // identifier.
1957   if (curToken.isNot(Token::less)) {
1958     resetToken(loc);
1959     return parseIdentifierExpr();
1960   }
1961 
1962   // Parse the operation name. The name may be elided, in which case the
1963   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1964   // `Operation*`). Operation names within a rewrite context must be named.
1965   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1966   FailureOr<ast::OpNameDecl *> opNameDecl =
1967       parseWrappedOperationName(allowEmptyName);
1968   if (failed(opNameDecl))
1969     return failure();
1970   Optional<StringRef> opName = (*opNameDecl)->getName();
1971 
1972   // Functor used to create an implicit range variable, used for implicit "all"
1973   // operand or results variables.
1974   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1975     FailureOr<ast::VariableDecl *> rangeVar =
1976         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1977     assert(succeeded(rangeVar) && "expected range variable to be valid");
1978     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1979   };
1980 
1981   // Check for the optional list of operands.
1982   SmallVector<ast::Expr *> operands;
1983   if (!consumeIf(Token::l_paren)) {
1984     // If the operand list isn't specified and we are in a match context, define
1985     // an inplace unconstrained operand range corresponding to all of the
1986     // operands of the operation. This avoids treating zero operands the same
1987     // way as "unconstrained operands".
1988     if (parserContext != ParserContext::Rewrite) {
1989       operands.push_back(createImplicitRangeVar(
1990           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1991     }
1992   } else if (!consumeIf(Token::r_paren)) {
1993     // If the operand list was specified and non-empty, parse the operands.
1994     do {
1995       // Check for operand signature code completion.
1996       if (curToken.is(Token::code_complete)) {
1997         codeCompleteOperationOperandsSignature(opName, operands.size());
1998         return failure();
1999       }
2000 
2001       FailureOr<ast::Expr *> operand = parseExpr();
2002       if (failed(operand))
2003         return failure();
2004       operands.push_back(*operand);
2005     } while (consumeIf(Token::comma));
2006 
2007     if (failed(parseToken(Token::r_paren,
2008                           "expected `)` after operation operand list")))
2009       return failure();
2010   }
2011 
2012   // Check for the optional list of attributes.
2013   SmallVector<ast::NamedAttributeDecl *> attributes;
2014   if (consumeIf(Token::l_brace)) {
2015     do {
2016       FailureOr<ast::NamedAttributeDecl *> decl =
2017           parseNamedAttributeDecl(opName);
2018       if (failed(decl))
2019         return failure();
2020       attributes.emplace_back(*decl);
2021     } while (consumeIf(Token::comma));
2022 
2023     if (failed(parseToken(Token::r_brace,
2024                           "expected `}` after operation attribute list")))
2025       return failure();
2026   }
2027 
2028   // Handle the result types of the operation.
2029   SmallVector<ast::Expr *> resultTypes;
2030   OpResultTypeContext resultTypeContext = inputResultTypeContext;
2031 
2032   // Check for an explicit list of result types.
2033   if (consumeIf(Token::arrow)) {
2034     if (failed(parseToken(Token::l_paren,
2035                           "expected `(` before operation result type list")))
2036       return failure();
2037 
2038     // If result types are provided, initially assume that the operation does
2039     // not rely on type inferrence. We don't assert that it isn't, because we
2040     // may be inferring the value of some type/type range variables, but given
2041     // that these variables may be defined in calls we can't always discern when
2042     // this is the case.
2043     resultTypeContext = OpResultTypeContext::Explicit;
2044 
2045     // Handle the case of an empty result list.
2046     if (!consumeIf(Token::r_paren)) {
2047       do {
2048         // Check for result signature code completion.
2049         if (curToken.is(Token::code_complete)) {
2050           codeCompleteOperationResultsSignature(opName, resultTypes.size());
2051           return failure();
2052         }
2053 
2054         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2055         if (failed(resultTypeExpr))
2056           return failure();
2057         resultTypes.push_back(*resultTypeExpr);
2058       } while (consumeIf(Token::comma));
2059 
2060       if (failed(parseToken(Token::r_paren,
2061                             "expected `)` after operation result type list")))
2062         return failure();
2063     }
2064   } else if (parserContext != ParserContext::Rewrite) {
2065     // If the result list isn't specified and we are in a match context, define
2066     // an inplace unconstrained result range corresponding to all of the results
2067     // of the operation. This avoids treating zero results the same way as
2068     // "unconstrained results".
2069     resultTypes.push_back(createImplicitRangeVar(
2070         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2071   } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2072     // If the result list isn't specified and we are in a rewrite, try to infer
2073     // them at runtime instead.
2074     resultTypeContext = OpResultTypeContext::Interface;
2075   }
2076 
2077   return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2078                              attributes, resultTypes);
2079 }
2080 
2081 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2082   SMRange loc = curToken.getLoc();
2083   consumeToken(Token::l_paren);
2084 
2085   DenseMap<StringRef, SMRange> usedNames;
2086   SmallVector<StringRef> elementNames;
2087   SmallVector<ast::Expr *> elements;
2088   if (curToken.isNot(Token::r_paren)) {
2089     do {
2090       // Check for the optional element name assignment before the value.
2091       StringRef elementName;
2092       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2093         Token elementNameTok = curToken;
2094         consumeToken();
2095 
2096         // The element name is only present if followed by an `=`.
2097         if (consumeIf(Token::equal)) {
2098           elementName = elementNameTok.getSpelling();
2099 
2100           // Check to see if this name is already used.
2101           auto elementNameIt =
2102               usedNames.try_emplace(elementName, elementNameTok.getLoc());
2103           if (!elementNameIt.second) {
2104             return emitErrorAndNote(
2105                 elementNameTok.getLoc(),
2106                 llvm::formatv("duplicate tuple element label `{0}`",
2107                               elementName),
2108                 elementNameIt.first->getSecond(),
2109                 "see previous label use here");
2110           }
2111         } else {
2112           // Otherwise, we treat this as part of an expression so reset the
2113           // lexer.
2114           resetToken(elementNameTok.getLoc());
2115         }
2116       }
2117       elementNames.push_back(elementName);
2118 
2119       // Parse the tuple element value.
2120       FailureOr<ast::Expr *> element = parseExpr();
2121       if (failed(element))
2122         return failure();
2123       elements.push_back(*element);
2124     } while (consumeIf(Token::comma));
2125   }
2126   loc.End = curToken.getEndLoc();
2127   if (failed(
2128           parseToken(Token::r_paren, "expected `)` after tuple element list")))
2129     return failure();
2130   return createTupleExpr(loc, elements, elementNames);
2131 }
2132 
2133 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2134   SMRange loc = curToken.getLoc();
2135   consumeToken(Token::kw_type);
2136 
2137   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2138   // identifier.
2139   if (!consumeIf(Token::less)) {
2140     resetToken(loc);
2141     return parseIdentifierExpr();
2142   }
2143 
2144   if (!curToken.isString())
2145     return emitError("expected string literal containing MLIR type");
2146   std::string attrExpr = curToken.getStringValue();
2147   consumeToken();
2148 
2149   loc.End = curToken.getEndLoc();
2150   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2151     return failure();
2152   return ast::TypeExpr::create(ctx, loc, attrExpr);
2153 }
2154 
2155 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2156   StringRef name = curToken.getSpelling();
2157   SMRange nameLoc = curToken.getLoc();
2158   consumeToken(Token::underscore);
2159 
2160   // Underscore expressions require a constraint list.
2161   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2162     return failure();
2163 
2164   // Parse the constraints for the expression.
2165   SmallVector<ast::ConstraintRef> constraints;
2166   if (failed(parseVariableDeclConstraintList(constraints)))
2167     return failure();
2168 
2169   ast::Type type;
2170   if (failed(validateVariableConstraints(constraints, type)))
2171     return failure();
2172   return createInlineVariableExpr(type, name, nameLoc, constraints);
2173 }
2174 
2175 //===----------------------------------------------------------------------===//
2176 // Stmts
2177 
2178 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2179   FailureOr<ast::Stmt *> stmt;
2180   switch (curToken.getKind()) {
2181   case Token::kw_erase:
2182     stmt = parseEraseStmt();
2183     break;
2184   case Token::kw_let:
2185     stmt = parseLetStmt();
2186     break;
2187   case Token::kw_replace:
2188     stmt = parseReplaceStmt();
2189     break;
2190   case Token::kw_return:
2191     stmt = parseReturnStmt();
2192     break;
2193   case Token::kw_rewrite:
2194     stmt = parseRewriteStmt();
2195     break;
2196   default:
2197     stmt = parseExpr();
2198     break;
2199   }
2200   if (failed(stmt) ||
2201       (expectTerminalSemicolon &&
2202        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2203     return failure();
2204   return stmt;
2205 }
2206 
2207 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2208   SMLoc startLoc = curToken.getStartLoc();
2209   consumeToken(Token::l_brace);
2210 
2211   // Push a new block scope and parse any nested statements.
2212   pushDeclScope();
2213   SmallVector<ast::Stmt *> statements;
2214   while (curToken.isNot(Token::r_brace)) {
2215     FailureOr<ast::Stmt *> statement = parseStmt();
2216     if (failed(statement))
2217       return popDeclScope(), failure();
2218     statements.push_back(*statement);
2219   }
2220   popDeclScope();
2221 
2222   // Consume the end brace.
2223   SMRange location(startLoc, curToken.getEndLoc());
2224   consumeToken(Token::r_brace);
2225 
2226   return ast::CompoundStmt::create(ctx, location, statements);
2227 }
2228 
2229 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2230   if (parserContext == ParserContext::Constraint)
2231     return emitError("`erase` cannot be used within a Constraint");
2232   SMRange loc = curToken.getLoc();
2233   consumeToken(Token::kw_erase);
2234 
2235   // Parse the root operation expression.
2236   FailureOr<ast::Expr *> rootOp = parseExpr();
2237   if (failed(rootOp))
2238     return failure();
2239 
2240   return createEraseStmt(loc, *rootOp);
2241 }
2242 
2243 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2244   SMRange loc = curToken.getLoc();
2245   consumeToken(Token::kw_let);
2246 
2247   // Parse the name of the new variable.
2248   SMRange varLoc = curToken.getLoc();
2249   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2250     // `_` is a reserved variable name.
2251     if (curToken.is(Token::underscore)) {
2252       return emitError(varLoc,
2253                        "`_` may only be used to define \"inline\" variables");
2254     }
2255     return emitError(varLoc,
2256                      "expected identifier after `let` to name a new variable");
2257   }
2258   StringRef varName = curToken.getSpelling();
2259   consumeToken();
2260 
2261   // Parse the optional set of constraints.
2262   SmallVector<ast::ConstraintRef> constraints;
2263   if (consumeIf(Token::colon) &&
2264       failed(parseVariableDeclConstraintList(constraints)))
2265     return failure();
2266 
2267   // Parse the optional initializer expression.
2268   ast::Expr *initializer = nullptr;
2269   if (consumeIf(Token::equal)) {
2270     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2271     if (failed(initOrFailure))
2272       return failure();
2273     initializer = *initOrFailure;
2274 
2275     // Check that the constraints are compatible with having an initializer,
2276     // e.g. type constraints cannot be used with initializers.
2277     for (ast::ConstraintRef constraint : constraints) {
2278       LogicalResult result =
2279           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2280               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2281                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2282                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2283                   return this->emitError(
2284                       constraint.referenceLoc,
2285                       "type constraints are not permitted on variables with "
2286                       "initializers");
2287                 }
2288                 return success();
2289               })
2290               .Default(success());
2291       if (failed(result))
2292         return failure();
2293     }
2294   }
2295 
2296   FailureOr<ast::VariableDecl *> varDecl =
2297       createVariableDecl(varName, varLoc, initializer, constraints);
2298   if (failed(varDecl))
2299     return failure();
2300   return ast::LetStmt::create(ctx, loc, *varDecl);
2301 }
2302 
2303 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2304   if (parserContext == ParserContext::Constraint)
2305     return emitError("`replace` cannot be used within a Constraint");
2306   SMRange loc = curToken.getLoc();
2307   consumeToken(Token::kw_replace);
2308 
2309   // Parse the root operation expression.
2310   FailureOr<ast::Expr *> rootOp = parseExpr();
2311   if (failed(rootOp))
2312     return failure();
2313 
2314   if (failed(
2315           parseToken(Token::kw_with, "expected `with` after root operation")))
2316     return failure();
2317 
2318   // The replacement portion of this statement is within a rewrite context.
2319   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2320                                               ParserContext::Rewrite);
2321 
2322   // Parse the replacement values.
2323   SmallVector<ast::Expr *> replValues;
2324   if (consumeIf(Token::l_paren)) {
2325     if (consumeIf(Token::r_paren)) {
2326       return emitError(
2327           loc, "expected at least one replacement value, consider using "
2328                "`erase` if no replacement values are desired");
2329     }
2330 
2331     do {
2332       FailureOr<ast::Expr *> replExpr = parseExpr();
2333       if (failed(replExpr))
2334         return failure();
2335       replValues.emplace_back(*replExpr);
2336     } while (consumeIf(Token::comma));
2337 
2338     if (failed(parseToken(Token::r_paren,
2339                           "expected `)` after replacement values")))
2340       return failure();
2341   } else {
2342     // Handle replacement with an operation uniquely, as the replacement
2343     // operation supports type inferrence from the root operation.
2344     FailureOr<ast::Expr *> replExpr;
2345     if (curToken.is(Token::kw_op))
2346       replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2347     else
2348       replExpr = parseExpr();
2349     if (failed(replExpr))
2350       return failure();
2351     replValues.emplace_back(*replExpr);
2352   }
2353 
2354   return createReplaceStmt(loc, *rootOp, replValues);
2355 }
2356 
2357 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2358   SMRange loc = curToken.getLoc();
2359   consumeToken(Token::kw_return);
2360 
2361   // Parse the result value.
2362   FailureOr<ast::Expr *> resultExpr = parseExpr();
2363   if (failed(resultExpr))
2364     return failure();
2365 
2366   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2367 }
2368 
2369 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2370   if (parserContext == ParserContext::Constraint)
2371     return emitError("`rewrite` cannot be used within a Constraint");
2372   SMRange loc = curToken.getLoc();
2373   consumeToken(Token::kw_rewrite);
2374 
2375   // Parse the root operation.
2376   FailureOr<ast::Expr *> rootOp = parseExpr();
2377   if (failed(rootOp))
2378     return failure();
2379 
2380   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2381     return failure();
2382 
2383   if (curToken.isNot(Token::l_brace))
2384     return emitError("expected `{` to start rewrite body");
2385 
2386   // The rewrite body of this statement is within a rewrite context.
2387   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2388                                               ParserContext::Rewrite);
2389 
2390   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2391   if (failed(rewriteBody))
2392     return failure();
2393 
2394   // Verify the rewrite body.
2395   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2396     if (isa<ast::ReturnStmt>(stmt)) {
2397       return emitError(stmt->getLoc(),
2398                        "`return` statements are only permitted within a "
2399                        "`Constraint` or `Rewrite` body");
2400     }
2401   }
2402 
2403   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2404 }
2405 
2406 //===----------------------------------------------------------------------===//
2407 // Creation+Analysis
2408 //===----------------------------------------------------------------------===//
2409 
2410 //===----------------------------------------------------------------------===//
2411 // Decls
2412 
2413 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2414   // Unwrap reference expressions.
2415   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2416     node = init->getDecl();
2417   return dyn_cast<ast::CallableDecl>(node);
2418 }
2419 
2420 FailureOr<ast::PatternDecl *>
2421 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2422                           const ParsedPatternMetadata &metadata,
2423                           ast::CompoundStmt *body) {
2424   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2425                                   metadata.hasBoundedRecursion, body);
2426 }
2427 
2428 ast::Type Parser::createUserConstraintRewriteResultType(
2429     ArrayRef<ast::VariableDecl *> results) {
2430   // Single result decls use the type of the single result.
2431   if (results.size() == 1)
2432     return results[0]->getType();
2433 
2434   // Multiple results use a tuple type, with the types and names grabbed from
2435   // the result variable decls.
2436   auto resultTypes = llvm::map_range(
2437       results, [&](const auto *result) { return result->getType(); });
2438   auto resultNames = llvm::map_range(
2439       results, [&](const auto *result) { return result->getName().getName(); });
2440   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2441                              llvm::to_vector(resultNames));
2442 }
2443 
2444 template <typename T>
2445 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2446     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2447     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2448     ast::CompoundStmt *body) {
2449   if (!body->getChildren().empty()) {
2450     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2451       ast::Expr *resultExpr = retStmt->getResultExpr();
2452 
2453       // Process the result of the decl. If no explicit signature results
2454       // were provided, check for return type inference. Otherwise, check that
2455       // the return expression can be converted to the expected type.
2456       if (results.empty())
2457         resultType = resultExpr->getType();
2458       else if (failed(convertExpressionTo(resultExpr, resultType)))
2459         return failure();
2460       else
2461         retStmt->setResultExpr(resultExpr);
2462     }
2463   }
2464   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2465 }
2466 
2467 FailureOr<ast::VariableDecl *>
2468 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2469                            ArrayRef<ast::ConstraintRef> constraints) {
2470   // The type of the variable, which is expected to be inferred by either a
2471   // constraint or an initializer expression.
2472   ast::Type type;
2473   if (failed(validateVariableConstraints(constraints, type)))
2474     return failure();
2475 
2476   if (initializer) {
2477     // Update the variable type based on the initializer, or try to convert the
2478     // initializer to the existing type.
2479     if (!type)
2480       type = initializer->getType();
2481     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2482       type = mergedType;
2483     else if (failed(convertExpressionTo(initializer, type)))
2484       return failure();
2485 
2486     // Otherwise, if there is no initializer check that the type has already
2487     // been resolved from the constraint list.
2488   } else if (!type) {
2489     return emitErrorAndNote(
2490         loc, "unable to infer type for variable `" + name + "`", loc,
2491         "the type of a variable must be inferable from the constraint "
2492         "list or the initializer");
2493   }
2494 
2495   // Constraint types cannot be used when defining variables.
2496   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2497     return emitError(
2498         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2499   }
2500 
2501   // Try to define a variable with the given name.
2502   FailureOr<ast::VariableDecl *> varDecl =
2503       defineVariableDecl(name, loc, type, initializer, constraints);
2504   if (failed(varDecl))
2505     return failure();
2506 
2507   return *varDecl;
2508 }
2509 
2510 FailureOr<ast::VariableDecl *>
2511 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2512                                       const ast::ConstraintRef &constraint) {
2513   // Constraint arguments may apply more complex constraints via the arguments.
2514   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2515   ast::Type argType;
2516   if (failed(validateVariableConstraint(constraint, argType,
2517                                         allowNonCoreConstraints)))
2518     return failure();
2519   return defineVariableDecl(name, loc, argType, constraint);
2520 }
2521 
2522 LogicalResult
2523 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2524                                     ast::Type &inferredType,
2525                                     bool allowNonCoreConstraints) {
2526   for (const ast::ConstraintRef &ref : constraints)
2527     if (failed(validateVariableConstraint(ref, inferredType,
2528                                           allowNonCoreConstraints)))
2529       return failure();
2530   return success();
2531 }
2532 
2533 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2534                                                  ast::Type &inferredType,
2535                                                  bool allowNonCoreConstraints) {
2536   ast::Type constraintType;
2537   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2538     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2539       if (failed(validateTypeConstraintExpr(typeExpr)))
2540         return failure();
2541     }
2542     constraintType = ast::AttributeType::get(ctx);
2543   } else if (const auto *cst =
2544                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2545     constraintType = ast::OperationType::get(
2546         ctx, cst->getName(), lookupODSOperation(cst->getName()));
2547   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2548     constraintType = typeTy;
2549   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2550     constraintType = typeRangeTy;
2551   } else if (const auto *cst =
2552                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2553     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2554       if (failed(validateTypeConstraintExpr(typeExpr)))
2555         return failure();
2556     }
2557     constraintType = valueTy;
2558   } else if (const auto *cst =
2559                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2560     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2561       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2562         return failure();
2563     }
2564     constraintType = valueRangeTy;
2565   } else if (const auto *cst =
2566                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2567     if (!allowNonCoreConstraints) {
2568       return emitError(ref.referenceLoc,
2569                        "`Rewrite` arguments and results are only permitted to "
2570                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2571                        "`TypeRange`, `Value`, `ValueRange`");
2572     }
2573 
2574     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2575     if (inputs.size() != 1) {
2576       return emitErrorAndNote(ref.referenceLoc,
2577                               "`Constraint`s applied via a variable constraint "
2578                               "list must take a single input, but got " +
2579                                   Twine(inputs.size()),
2580                               cst->getLoc(),
2581                               "see definition of constraint here");
2582     }
2583     constraintType = inputs.front()->getType();
2584   } else {
2585     llvm_unreachable("unknown constraint type");
2586   }
2587 
2588   // Check that the constraint type is compatible with the current inferred
2589   // type.
2590   if (!inferredType) {
2591     inferredType = constraintType;
2592   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2593     inferredType = mergedTy;
2594   } else {
2595     return emitError(ref.referenceLoc,
2596                      llvm::formatv("constraint type `{0}` is incompatible "
2597                                    "with the previously inferred type `{1}`",
2598                                    constraintType, inferredType));
2599   }
2600   return success();
2601 }
2602 
2603 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2604   ast::Type typeExprType = typeExpr->getType();
2605   if (typeExprType != typeTy) {
2606     return emitError(typeExpr->getLoc(),
2607                      "expected expression of `Type` in type constraint");
2608   }
2609   return success();
2610 }
2611 
2612 LogicalResult
2613 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2614   ast::Type typeExprType = typeExpr->getType();
2615   if (typeExprType != typeRangeTy) {
2616     return emitError(typeExpr->getLoc(),
2617                      "expected expression of `TypeRange` in type constraint");
2618   }
2619   return success();
2620 }
2621 
2622 //===----------------------------------------------------------------------===//
2623 // Exprs
2624 
2625 FailureOr<ast::CallExpr *>
2626 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2627                        MutableArrayRef<ast::Expr *> arguments) {
2628   ast::Type parentType = parentExpr->getType();
2629 
2630   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2631   if (!callableDecl) {
2632     return emitError(loc,
2633                      llvm::formatv("expected a reference to a callable "
2634                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2635                                    parentType));
2636   }
2637   if (parserContext == ParserContext::Rewrite) {
2638     if (isa<ast::UserConstraintDecl>(callableDecl))
2639       return emitError(
2640           loc, "unable to invoke `Constraint` within a rewrite section");
2641   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2642     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2643   }
2644 
2645   // Verify the arguments of the call.
2646   /// Handle size mismatch.
2647   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2648   if (callArgs.size() != arguments.size()) {
2649     return emitErrorAndNote(
2650         loc,
2651         llvm::formatv("invalid number of arguments for {0} call; expected "
2652                       "{1}, but got {2}",
2653                       callableDecl->getCallableType(), callArgs.size(),
2654                       arguments.size()),
2655         callableDecl->getLoc(),
2656         llvm::formatv("see the definition of {0} here",
2657                       callableDecl->getName()->getName()));
2658   }
2659 
2660   /// Handle argument type mismatch.
2661   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2662     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2663                                   callableDecl->getName()->getName()),
2664                     callableDecl->getLoc());
2665   };
2666   for (auto it : llvm::zip(callArgs, arguments)) {
2667     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2668                                    attachDiagFn)))
2669       return failure();
2670   }
2671 
2672   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2673                                callableDecl->getResultType());
2674 }
2675 
2676 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2677                                                         ast::Decl *decl) {
2678   // Check the type of decl being referenced.
2679   ast::Type declType;
2680   if (isa<ast::ConstraintDecl>(decl))
2681     declType = ast::ConstraintType::get(ctx);
2682   else if (isa<ast::UserRewriteDecl>(decl))
2683     declType = ast::RewriteType::get(ctx);
2684   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2685     declType = varDecl->getType();
2686   else
2687     return emitError(loc, "invalid reference to `" +
2688                               decl->getName()->getName() + "`");
2689 
2690   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2691 }
2692 
2693 FailureOr<ast::DeclRefExpr *>
2694 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2695                                  ArrayRef<ast::ConstraintRef> constraints) {
2696   FailureOr<ast::VariableDecl *> decl =
2697       defineVariableDecl(name, loc, type, constraints);
2698   if (failed(decl))
2699     return failure();
2700   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2701 }
2702 
2703 FailureOr<ast::MemberAccessExpr *>
2704 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2705                                SMRange loc) {
2706   // Validate the member name for the given parent expression.
2707   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2708   if (failed(memberType))
2709     return failure();
2710 
2711   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2712 }
2713 
2714 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2715                                                   StringRef name, SMRange loc) {
2716   ast::Type parentType = parentExpr->getType();
2717   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2718     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2719       return valueRangeTy;
2720 
2721     // Verify member access based on the operation type.
2722     if (const ods::Operation *odsOp = opType.getODSOperation()) {
2723       auto results = odsOp->getResults();
2724 
2725       // Handle indexed results.
2726       unsigned index = 0;
2727       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2728           index < results.size()) {
2729         return results[index].isVariadic() ? valueRangeTy : valueTy;
2730       }
2731 
2732       // Handle named results.
2733       const auto *it = llvm::find_if(results, [&](const auto &result) {
2734         return result.getName() == name;
2735       });
2736       if (it != results.end())
2737         return it->isVariadic() ? valueRangeTy : valueTy;
2738     } else if (llvm::isDigit(name[0])) {
2739       // Allow unchecked numeric indexing of the results of unregistered
2740       // operations. It returns a single value.
2741       return valueTy;
2742     }
2743   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2744     // Handle indexed results.
2745     unsigned index = 0;
2746     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2747         index < tupleType.size()) {
2748       return tupleType.getElementTypes()[index];
2749     }
2750 
2751     // Handle named results.
2752     auto elementNames = tupleType.getElementNames();
2753     const auto *it = llvm::find(elementNames, name);
2754     if (it != elementNames.end())
2755       return tupleType.getElementTypes()[it - elementNames.begin()];
2756   }
2757   return emitError(
2758       loc,
2759       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2760                     name, parentType));
2761 }
2762 
2763 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2764     SMRange loc, const ast::OpNameDecl *name,
2765     OpResultTypeContext resultTypeContext,
2766     MutableArrayRef<ast::Expr *> operands,
2767     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2768     MutableArrayRef<ast::Expr *> results) {
2769   Optional<StringRef> opNameRef = name->getName();
2770   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2771 
2772   // Verify the inputs operands.
2773   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2774     return failure();
2775 
2776   // Verify the attribute list.
2777   for (ast::NamedAttributeDecl *attr : attributes) {
2778     // Check for an attribute type, or a type awaiting resolution.
2779     ast::Type attrType = attr->getValue()->getType();
2780     if (!attrType.isa<ast::AttributeType>()) {
2781       return emitError(
2782           attr->getValue()->getLoc(),
2783           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2784     }
2785   }
2786 
2787   assert(
2788       (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2789       "unexpected inferrence when results were explicitly specified");
2790 
2791   // If we aren't relying on type inferrence, or explicit results were provided,
2792   // validate them.
2793   if (resultTypeContext == OpResultTypeContext::Explicit) {
2794     if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2795       return failure();
2796 
2797     // Validate the use of interface based type inferrence for this operation.
2798   } else if (resultTypeContext == OpResultTypeContext::Interface) {
2799     assert(opNameRef &&
2800            "expected valid operation name when inferring operation results");
2801     checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2802   }
2803 
2804   return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2805                                     attributes);
2806 }
2807 
2808 LogicalResult
2809 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2810                                   const ods::Operation *odsOp,
2811                                   MutableArrayRef<ast::Expr *> operands) {
2812   return validateOperationOperandsOrResults(
2813       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2814       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2815       valueRangeTy);
2816 }
2817 
2818 LogicalResult
2819 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2820                                  const ods::Operation *odsOp,
2821                                  MutableArrayRef<ast::Expr *> results) {
2822   return validateOperationOperandsOrResults(
2823       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2824       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2825 }
2826 
2827 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2828                                                 const ods::Operation *odsOp) {
2829   // If the operation might not have inferrence support, emit a warning to the
2830   // user. We don't emit an error because the interface might be added to the
2831   // operation at runtime. It's rare, but it could still happen. We emit a
2832   // warning here instead.
2833 
2834   // Handle inferrence warnings for unknown operations.
2835   if (!odsOp) {
2836     ctx.getDiagEngine().emitWarning(
2837         loc, llvm::formatv(
2838                  "operation result types are marked to be inferred, but "
2839                  "`{0}` is unknown. Ensure that `{0}` supports zero "
2840                  "results or implements `InferTypeOpInterface`. Include "
2841                  "the ODS definition of this operation to remove this warning.",
2842                  opName));
2843     return;
2844   }
2845 
2846   // Handle inferrence warnings for known operations that expected at least one
2847   // result, but don't have inference support. An elided results list can mean
2848   // "zero-results", and we don't want to warn when that is the expected
2849   // behavior.
2850   bool requiresInferrence =
2851       llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2852         return !result.isVariableLength();
2853       });
2854   if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2855     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2856         loc,
2857         llvm::formatv("operation result types are marked to be inferred, but "
2858                       "`{0}` does not provide an implementation of "
2859                       "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2860                       "`InferTypeOpInterface` at runtime, or add support to "
2861                       "the ODS definition to remove this warning.",
2862                       opName));
2863     diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2864                      odsOp->getLoc());
2865     return;
2866   }
2867 }
2868 
2869 LogicalResult Parser::validateOperationOperandsOrResults(
2870     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2871     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2872     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2873     ast::Type rangeTy) {
2874   // All operation types accept a single range parameter.
2875   if (values.size() == 1) {
2876     if (failed(convertExpressionTo(values[0], rangeTy)))
2877       return failure();
2878     return success();
2879   }
2880 
2881   /// If the operation has ODS information, we can more accurately verify the
2882   /// values.
2883   if (odsOpLoc) {
2884     if (odsValues.size() != values.size()) {
2885       return emitErrorAndNote(
2886           loc,
2887           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2888                         "{2}, but got {3}",
2889                         groupName, *name, odsValues.size(), values.size()),
2890           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2891     }
2892     auto diagFn = [&](ast::Diagnostic &diag) {
2893       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2894                       *odsOpLoc);
2895     };
2896     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2897       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2898       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2899         return failure();
2900     }
2901     return success();
2902   }
2903 
2904   // Otherwise, accept the value groups as they have been defined and just
2905   // ensure they are one of the expected types.
2906   for (ast::Expr *&valueExpr : values) {
2907     ast::Type valueExprType = valueExpr->getType();
2908 
2909     // Check if this is one of the expected types.
2910     if (valueExprType == rangeTy || valueExprType == singleTy)
2911       continue;
2912 
2913     // If the operand is an Operation, allow converting to a Value or
2914     // ValueRange. This situations arises quite often with nested operation
2915     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2916     if (singleTy == valueTy) {
2917       if (valueExprType.isa<ast::OperationType>()) {
2918         valueExpr = convertOpToValue(valueExpr);
2919         continue;
2920       }
2921     }
2922 
2923     return emitError(
2924         valueExpr->getLoc(),
2925         llvm::formatv(
2926             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2927             singleTy, rangeTy, valueExprType));
2928   }
2929   return success();
2930 }
2931 
2932 FailureOr<ast::TupleExpr *>
2933 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2934                         ArrayRef<StringRef> elementNames) {
2935   for (const ast::Expr *element : elements) {
2936     ast::Type eleTy = element->getType();
2937     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2938       return emitError(
2939           element->getLoc(),
2940           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2941     }
2942   }
2943   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2944 }
2945 
2946 //===----------------------------------------------------------------------===//
2947 // Stmts
2948 
2949 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2950                                                     ast::Expr *rootOp) {
2951   // Check that root is an Operation.
2952   ast::Type rootType = rootOp->getType();
2953   if (!rootType.isa<ast::OperationType>())
2954     return emitError(rootOp->getLoc(), "expected `Op` expression");
2955 
2956   return ast::EraseStmt::create(ctx, loc, rootOp);
2957 }
2958 
2959 FailureOr<ast::ReplaceStmt *>
2960 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2961                           MutableArrayRef<ast::Expr *> replValues) {
2962   // Check that root is an Operation.
2963   ast::Type rootType = rootOp->getType();
2964   if (!rootType.isa<ast::OperationType>()) {
2965     return emitError(
2966         rootOp->getLoc(),
2967         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2968   }
2969 
2970   // If there are multiple replacement values, we implicitly convert any Op
2971   // expressions to the value form.
2972   bool shouldConvertOpToValues = replValues.size() > 1;
2973   for (ast::Expr *&replExpr : replValues) {
2974     ast::Type replType = replExpr->getType();
2975 
2976     // Check that replExpr is an Operation, Value, or ValueRange.
2977     if (replType.isa<ast::OperationType>()) {
2978       if (shouldConvertOpToValues)
2979         replExpr = convertOpToValue(replExpr);
2980       continue;
2981     }
2982 
2983     if (replType != valueTy && replType != valueRangeTy) {
2984       return emitError(replExpr->getLoc(),
2985                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2986                                      "expression, but got `{0}`",
2987                                      replType));
2988     }
2989   }
2990 
2991   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2992 }
2993 
2994 FailureOr<ast::RewriteStmt *>
2995 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2996                           ast::CompoundStmt *rewriteBody) {
2997   // Check that root is an Operation.
2998   ast::Type rootType = rootOp->getType();
2999   if (!rootType.isa<ast::OperationType>()) {
3000     return emitError(
3001         rootOp->getLoc(),
3002         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3003   }
3004 
3005   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3006 }
3007 
3008 //===----------------------------------------------------------------------===//
3009 // Code Completion
3010 //===----------------------------------------------------------------------===//
3011 
3012 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3013   ast::Type parentType = parentExpr->getType();
3014   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3015     codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3016   else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3017     codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3018   return failure();
3019 }
3020 
3021 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
3022   if (opName)
3023     codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3024   return failure();
3025 }
3026 
3027 LogicalResult
3028 Parser::codeCompleteConstraintName(ast::Type inferredType,
3029                                    bool allowNonCoreConstraints,
3030                                    bool allowInlineTypeConstraints) {
3031   codeCompleteContext->codeCompleteConstraintName(
3032       inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
3033       curDeclScope);
3034   return failure();
3035 }
3036 
3037 LogicalResult Parser::codeCompleteDialectName() {
3038   codeCompleteContext->codeCompleteDialectName();
3039   return failure();
3040 }
3041 
3042 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3043   codeCompleteContext->codeCompleteOperationName(dialectName);
3044   return failure();
3045 }
3046 
3047 LogicalResult Parser::codeCompletePatternMetadata() {
3048   codeCompleteContext->codeCompletePatternMetadata();
3049   return failure();
3050 }
3051 
3052 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3053   codeCompleteContext->codeCompleteIncludeFilename(curPath);
3054   return failure();
3055 }
3056 
3057 void Parser::codeCompleteCallSignature(ast::Node *parent,
3058                                        unsigned currentNumArgs) {
3059   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3060   if (!callableDecl)
3061     return;
3062 
3063   codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3064 }
3065 
3066 void Parser::codeCompleteOperationOperandsSignature(
3067     Optional<StringRef> opName, unsigned currentNumOperands) {
3068   codeCompleteContext->codeCompleteOperationOperandsSignature(
3069       opName, currentNumOperands);
3070 }
3071 
3072 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
3073                                                    unsigned currentNumResults) {
3074   codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3075                                                              currentNumResults);
3076 }
3077 
3078 //===----------------------------------------------------------------------===//
3079 // Parser
3080 //===----------------------------------------------------------------------===//
3081 
3082 FailureOr<ast::Module *>
3083 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3084                         CodeCompleteContext *codeCompleteContext) {
3085   Parser parser(ctx, sourceMgr, codeCompleteContext);
3086   return parser.parseModule();
3087 }
3088