1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/TableGen/Argument.h"
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Constraint.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/AST/Types.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Operation.h"
24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <string>
34 
35 using namespace mlir;
36 using namespace mlir::pdll;
37 
38 //===----------------------------------------------------------------------===//
39 // Parser
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 class Parser {
44 public:
45   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
46          CodeCompleteContext *codeCompleteContext)
47       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
48         curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
49         valueRangeTy(ast::ValueRangeType::get(ctx)),
50         typeTy(ast::TypeType::get(ctx)),
51         typeRangeTy(ast::TypeRangeType::get(ctx)),
52         attrTy(ast::AttributeType::get(ctx)),
53         codeCompleteContext(codeCompleteContext) {}
54 
55   /// Try to parse a new module. Returns nullptr in the case of failure.
56   FailureOr<ast::Module *> parseModule();
57 
58 private:
59   /// The current context of the parser. It allows for the parser to know a bit
60   /// about the construct it is nested within during parsing. This is used
61   /// specifically to provide additional verification during parsing, e.g. to
62   /// prevent using rewrites within a match context, matcher constraints within
63   /// a rewrite section, etc.
64   enum class ParserContext {
65     /// The parser is in the global context.
66     Global,
67     /// The parser is currently within a Constraint, which disallows all types
68     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
69     Constraint,
70     /// The parser is currently within the matcher portion of a Pattern, which
71     /// is allows a terminal operation rewrite statement but no other rewrite
72     /// transformations.
73     PatternMatch,
74     /// The parser is currently within a Rewrite, which disallows calls to
75     /// constraints, requires operation expressions to have names, etc.
76     Rewrite,
77   };
78 
79   //===--------------------------------------------------------------------===//
80   // Parsing
81   //===--------------------------------------------------------------------===//
82 
83   /// Push a new decl scope onto the lexer.
84   ast::DeclScope *pushDeclScope() {
85     ast::DeclScope *newScope =
86         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
87     return (curDeclScope = newScope);
88   }
89   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
90 
91   /// Pop the last decl scope from the lexer.
92   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
93 
94   /// Parse the body of an AST module.
95   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
96 
97   /// Try to convert the given expression to `type`. Returns failure and emits
98   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
99   /// invoked to attach notes to the emitted error diagnostic. On success,
100   /// `expr` is updated to the expression used to convert to `type`.
101   LogicalResult convertExpressionTo(
102       ast::Expr *&expr, ast::Type type,
103       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
104 
105   /// Given an operation expression, convert it to a Value or ValueRange
106   /// typed expression.
107   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
108 
109   /// Lookup ODS information for the given operation, returns nullptr if no
110   /// information is found.
111   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
112     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
113   }
114 
115   //===--------------------------------------------------------------------===//
116   // Directives
117 
118   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
119   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
120   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
121                                SmallVectorImpl<ast::Decl *> &decls);
122 
123   /// Process the records of a parsed tablegen include file.
124   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
125                                SmallVectorImpl<ast::Decl *> &decls);
126 
127   /// Create a user defined native constraint for a constraint imported from
128   /// ODS.
129   template <typename ConstraintT>
130   ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
131                                                StringRef codeBlock, SMRange loc,
132                                                ast::Type type);
133   template <typename ConstraintT>
134   ast::Decl *
135   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
136                                     SMRange loc, ast::Type type);
137 
138   //===--------------------------------------------------------------------===//
139   // Decls
140 
141   /// This structure contains the set of pattern metadata that may be parsed.
142   struct ParsedPatternMetadata {
143     Optional<uint16_t> benefit;
144     bool hasBoundedRecursion = false;
145   };
146 
147   FailureOr<ast::Decl *> parseTopLevelDecl();
148   FailureOr<ast::NamedAttributeDecl *>
149   parseNamedAttributeDecl(Optional<StringRef> parentOpName);
150 
151   /// Parse an argument variable as part of the signature of a
152   /// UserConstraintDecl or UserRewriteDecl.
153   FailureOr<ast::VariableDecl *> parseArgumentDecl();
154 
155   /// Parse a result variable as part of the signature of a UserConstraintDecl
156   /// or UserRewriteDecl.
157   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
158 
159   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
160   /// defined in a non-global context.
161   FailureOr<ast::UserConstraintDecl *>
162   parseUserConstraintDecl(bool isInline = false);
163 
164   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
165   /// non-global context, such as within a Pattern/Constraint/etc.
166   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
167 
168   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
169   /// PDLL constructs.
170   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
171       const ast::Name &name, bool isInline,
172       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
173       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
174 
175   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
176   /// defined in a non-global context.
177   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
178 
179   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
180   /// non-global context, such as within a Pattern/Rewrite/etc.
181   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
182 
183   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
184   /// PDLL constructs.
185   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
186       const ast::Name &name, bool isInline,
187       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
188       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
189 
190   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
191   /// effectively the same syntax, and only differ on slight semantics (given
192   /// the different parsing contexts).
193   template <typename T, typename ParseUserPDLLDeclFnT>
194   FailureOr<T *> parseUserConstraintOrRewriteDecl(
195       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
196       StringRef anonymousNamePrefix, bool isInline);
197 
198   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
199   /// These decls have effectively the same syntax.
200   template <typename T>
201   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
202       const ast::Name &name, bool isInline,
203       ArrayRef<ast::VariableDecl *> arguments,
204       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
205 
206   /// Parse the functional signature (i.e. the arguments and results) of a
207   /// UserConstraintDecl or UserRewriteDecl.
208   LogicalResult parseUserConstraintOrRewriteSignature(
209       SmallVectorImpl<ast::VariableDecl *> &arguments,
210       SmallVectorImpl<ast::VariableDecl *> &results,
211       ast::DeclScope *&argumentScope, ast::Type &resultType);
212 
213   /// Validate the return (which if present is specified by bodyIt) of a
214   /// UserConstraintDecl or UserRewriteDecl.
215   LogicalResult validateUserConstraintOrRewriteReturn(
216       StringRef declType, ast::CompoundStmt *body,
217       ArrayRef<ast::Stmt *>::iterator bodyIt,
218       ArrayRef<ast::Stmt *>::iterator bodyE,
219       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
220 
221   FailureOr<ast::CompoundStmt *>
222   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
223                   bool expectTerminalSemicolon = true);
224   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
225   FailureOr<ast::Decl *> parsePatternDecl();
226   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
227 
228   /// Check to see if a decl has already been defined with the given name, if
229   /// one has emit and error and return failure. Returns success otherwise.
230   LogicalResult checkDefineNamedDecl(const ast::Name &name);
231 
232   /// Try to define a variable decl with the given components, returns the
233   /// variable on success.
234   FailureOr<ast::VariableDecl *>
235   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
236                      ast::Expr *initExpr,
237                      ArrayRef<ast::ConstraintRef> constraints);
238   FailureOr<ast::VariableDecl *>
239   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
240                      ArrayRef<ast::ConstraintRef> constraints);
241 
242   /// Parse the constraint reference list for a variable decl.
243   LogicalResult parseVariableDeclConstraintList(
244       SmallVectorImpl<ast::ConstraintRef> &constraints);
245 
246   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
247   FailureOr<ast::Expr *> parseTypeConstraintExpr();
248 
249   /// Try to parse a single reference to a constraint. `typeConstraint` is the
250   /// location of a previously parsed type constraint for the entity that will
251   /// be constrained by the parsed constraint. `existingConstraints` are any
252   /// existing constraints that have already been parsed for the same entity
253   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
254   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
255   /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
256   /// constraints) may be used with the variable.
257   FailureOr<ast::ConstraintRef>
258   parseConstraint(Optional<SMRange> &typeConstraint,
259                   ArrayRef<ast::ConstraintRef> existingConstraints,
260                   bool allowInlineTypeConstraints,
261                   bool allowNonCoreConstraints);
262 
263   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
264   /// argument or result variable. The constraints for these variables do not
265   /// allow inline type constraints, and only permit a single constraint.
266   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
267 
268   //===--------------------------------------------------------------------===//
269   // Exprs
270 
271   FailureOr<ast::Expr *> parseExpr();
272 
273   /// Identifier expressions.
274   FailureOr<ast::Expr *> parseAttributeExpr();
275   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
276   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
277   FailureOr<ast::Expr *> parseIdentifierExpr();
278   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
279   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
280   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
281   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
282   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
283   FailureOr<ast::Expr *> parseOperationExpr();
284   FailureOr<ast::Expr *> parseTupleExpr();
285   FailureOr<ast::Expr *> parseTypeExpr();
286   FailureOr<ast::Expr *> parseUnderscoreExpr();
287 
288   //===--------------------------------------------------------------------===//
289   // Stmts
290 
291   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
292   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
293   FailureOr<ast::EraseStmt *> parseEraseStmt();
294   FailureOr<ast::LetStmt *> parseLetStmt();
295   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
296   FailureOr<ast::ReturnStmt *> parseReturnStmt();
297   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
298 
299   //===--------------------------------------------------------------------===//
300   // Creation+Analysis
301   //===--------------------------------------------------------------------===//
302 
303   //===--------------------------------------------------------------------===//
304   // Decls
305 
306   /// Try to extract a callable from the given AST node. Returns nullptr on
307   /// failure.
308   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
309 
310   /// Try to create a pattern decl with the given components, returning the
311   /// Pattern on success.
312   FailureOr<ast::PatternDecl *>
313   createPatternDecl(SMRange loc, const ast::Name *name,
314                     const ParsedPatternMetadata &metadata,
315                     ast::CompoundStmt *body);
316 
317   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
318   /// of results, defined as part of the signature.
319   ast::Type
320   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
321 
322   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
323   template <typename T>
324   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
325       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
326       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
327       ast::CompoundStmt *body);
328 
329   /// Try to create a variable decl with the given components, returning the
330   /// Variable on success.
331   FailureOr<ast::VariableDecl *>
332   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
333                      ArrayRef<ast::ConstraintRef> constraints);
334 
335   /// Create a variable for an argument or result defined as part of the
336   /// signature of a UserConstraintDecl/UserRewriteDecl.
337   FailureOr<ast::VariableDecl *>
338   createArgOrResultVariableDecl(StringRef name, SMRange loc,
339                                 const ast::ConstraintRef &constraint);
340 
341   /// Validate the constraints used to constraint a variable decl.
342   /// `inferredType` is the type of the variable inferred by the constraints
343   /// within the list, and is updated to the most refined type as determined by
344   /// the constraints. Returns success if the constraint list is valid, failure
345   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
346   /// defined constraints) may be used with the variable.
347   LogicalResult
348   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
349                               ast::Type &inferredType,
350                               bool allowNonCoreConstraints = true);
351   /// Validate a single reference to a constraint. `inferredType` contains the
352   /// currently inferred variabled type and is refined within the type defined
353   /// by the constraint. Returns success if the constraint is valid, failure
354   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
355   /// defined constraints) may be used with the variable.
356   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
357                                            ast::Type &inferredType,
358                                            bool allowNonCoreConstraints = true);
359   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
360   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
361 
362   //===--------------------------------------------------------------------===//
363   // Exprs
364 
365   FailureOr<ast::CallExpr *>
366   createCallExpr(SMRange loc, ast::Expr *parentExpr,
367                  MutableArrayRef<ast::Expr *> arguments);
368   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
369   FailureOr<ast::DeclRefExpr *>
370   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
371                            ArrayRef<ast::ConstraintRef> constraints);
372   FailureOr<ast::MemberAccessExpr *>
373   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
374 
375   /// Validate the member access `name` into the given parent expression. On
376   /// success, this also returns the type of the member accessed.
377   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
378                                             StringRef name, SMRange loc);
379   FailureOr<ast::OperationExpr *>
380   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
381                       MutableArrayRef<ast::Expr *> operands,
382                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
383                       MutableArrayRef<ast::Expr *> results);
384   LogicalResult
385   validateOperationOperands(SMRange loc, Optional<StringRef> name,
386                             const ods::Operation *odsOp,
387                             MutableArrayRef<ast::Expr *> operands);
388   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
389                                          const ods::Operation *odsOp,
390                                          MutableArrayRef<ast::Expr *> results);
391   LogicalResult validateOperationOperandsOrResults(
392       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
393       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
394       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
395       ast::Type rangeTy);
396   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
397                                               ArrayRef<ast::Expr *> elements,
398                                               ArrayRef<StringRef> elementNames);
399 
400   //===--------------------------------------------------------------------===//
401   // Stmts
402 
403   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
404   FailureOr<ast::ReplaceStmt *>
405   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
406                     MutableArrayRef<ast::Expr *> replValues);
407   FailureOr<ast::RewriteStmt *>
408   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
409                     ast::CompoundStmt *rewriteBody);
410 
411   //===--------------------------------------------------------------------===//
412   // Code Completion
413   //===--------------------------------------------------------------------===//
414 
415   /// The set of various code completion methods. Every completion method
416   /// returns `failure` to stop the parsing process after providing completion
417   /// results.
418 
419   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
420   LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
421   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
422                                            bool allowNonCoreConstraints,
423                                            bool allowInlineTypeConstraints);
424   LogicalResult codeCompleteDialectName();
425   LogicalResult codeCompleteOperationName(StringRef dialectName);
426   LogicalResult codeCompletePatternMetadata();
427   LogicalResult codeCompleteIncludeFilename(StringRef curPath);
428 
429   void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
430   void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
431                                               unsigned currentNumOperands);
432   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
433                                              unsigned currentNumResults);
434 
435   //===--------------------------------------------------------------------===//
436   // Lexer Utilities
437   //===--------------------------------------------------------------------===//
438 
439   /// If the current token has the specified kind, consume it and return true.
440   /// If not, return false.
441   bool consumeIf(Token::Kind kind) {
442     if (curToken.isNot(kind))
443       return false;
444     consumeToken(kind);
445     return true;
446   }
447 
448   /// Advance the current lexer onto the next token.
449   void consumeToken() {
450     assert(curToken.isNot(Token::eof, Token::error) &&
451            "shouldn't advance past EOF or errors");
452     curToken = lexer.lexToken();
453   }
454 
455   /// Advance the current lexer onto the next token, asserting what the expected
456   /// current token is. This is preferred to the above method because it leads
457   /// to more self-documenting code with better checking.
458   void consumeToken(Token::Kind kind) {
459     assert(curToken.is(kind) && "consumed an unexpected token");
460     consumeToken();
461   }
462 
463   /// Reset the lexer to the location at the given position.
464   void resetToken(SMRange tokLoc) {
465     lexer.resetPointer(tokLoc.Start.getPointer());
466     curToken = lexer.lexToken();
467   }
468 
469   /// Consume the specified token if present and return success. On failure,
470   /// output a diagnostic and return failure.
471   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
472     if (curToken.getKind() != kind)
473       return emitError(curToken.getLoc(), msg);
474     consumeToken();
475     return success();
476   }
477   LogicalResult emitError(SMRange loc, const Twine &msg) {
478     lexer.emitError(loc, msg);
479     return failure();
480   }
481   LogicalResult emitError(const Twine &msg) {
482     return emitError(curToken.getLoc(), msg);
483   }
484   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
485                                  const Twine &note) {
486     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
487     return failure();
488   }
489 
490   //===--------------------------------------------------------------------===//
491   // Fields
492   //===--------------------------------------------------------------------===//
493 
494   /// The owning AST context.
495   ast::Context &ctx;
496 
497   /// The lexer of this parser.
498   Lexer lexer;
499 
500   /// The current token within the lexer.
501   Token curToken;
502 
503   /// The most recently defined decl scope.
504   ast::DeclScope *curDeclScope = nullptr;
505   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
506 
507   /// The current context of the parser.
508   ParserContext parserContext = ParserContext::Global;
509 
510   /// Cached types to simplify verification and expression creation.
511   ast::Type valueTy, valueRangeTy;
512   ast::Type typeTy, typeRangeTy;
513   ast::Type attrTy;
514 
515   /// A counter used when naming anonymous constraints and rewrites.
516   unsigned anonymousDeclNameCounter = 0;
517 
518   /// The optional code completion context.
519   CodeCompleteContext *codeCompleteContext;
520 };
521 } // namespace
522 
523 FailureOr<ast::Module *> Parser::parseModule() {
524   SMLoc moduleLoc = curToken.getStartLoc();
525   pushDeclScope();
526 
527   // Parse the top-level decls of the module.
528   SmallVector<ast::Decl *> decls;
529   if (failed(parseModuleBody(decls)))
530     return popDeclScope(), failure();
531 
532   popDeclScope();
533   return ast::Module::create(ctx, moduleLoc, decls);
534 }
535 
536 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
537   while (curToken.isNot(Token::eof)) {
538     if (curToken.is(Token::directive)) {
539       if (failed(parseDirective(decls)))
540         return failure();
541       continue;
542     }
543 
544     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
545     if (failed(decl))
546       return failure();
547     decls.push_back(*decl);
548   }
549   return success();
550 }
551 
552 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
553   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
554                                                  valueRangeTy);
555 }
556 
557 LogicalResult Parser::convertExpressionTo(
558     ast::Expr *&expr, ast::Type type,
559     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
560   ast::Type exprType = expr->getType();
561   if (exprType == type)
562     return success();
563 
564   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
565     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
566         expr->getLoc(), llvm::formatv("unable to convert expression of type "
567                                       "`{0}` to the expected type of "
568                                       "`{1}`",
569                                       exprType, type));
570     if (noteAttachFn)
571       noteAttachFn(*diag);
572     return diag;
573   };
574 
575   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
576     // Two operation types are compatible if they have the same name, or if the
577     // expected type is more general.
578     if (auto opType = type.dyn_cast<ast::OperationType>()) {
579       if (opType.getName())
580         return emitConvertError();
581       return success();
582     }
583 
584     // An operation can always convert to a ValueRange.
585     if (type == valueRangeTy) {
586       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
587                                                      valueRangeTy);
588       return success();
589     }
590 
591     // Allow conversion to a single value by constraining the result range.
592     if (type == valueTy) {
593       // If the operation is registered, we can verify if it can ever have a
594       // single result.
595       Optional<StringRef> opName = exprOpType.getName();
596       if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
597         if (odsOp->getResults().empty()) {
598           return emitConvertError()->attachNote(
599               llvm::formatv("see the definition of `{0}`, which was defined "
600                             "with zero results",
601                             odsOp->getName()),
602               odsOp->getLoc());
603         }
604 
605         unsigned numSingleResults = llvm::count_if(
606             odsOp->getResults(), [](const ods::OperandOrResult &result) {
607               return result.getVariableLengthKind() ==
608                      ods::VariableLengthKind::Single;
609             });
610         if (numSingleResults > 1) {
611           return emitConvertError()->attachNote(
612               llvm::formatv("see the definition of `{0}`, which was defined "
613                             "with at least {1} results",
614                             odsOp->getName(), numSingleResults),
615               odsOp->getLoc());
616         }
617       }
618 
619       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
620                                                      valueTy);
621       return success();
622     }
623     return emitConvertError();
624   }
625 
626   // FIXME: Decide how to allow/support converting a single result to multiple,
627   // and multiple to a single result. For now, we just allow Single->Range,
628   // but this isn't something really supported in the PDL dialect. We should
629   // figure out some way to support both.
630   if ((exprType == valueTy || exprType == valueRangeTy) &&
631       (type == valueTy || type == valueRangeTy))
632     return success();
633   if ((exprType == typeTy || exprType == typeRangeTy) &&
634       (type == typeTy || type == typeRangeTy))
635     return success();
636 
637   // Handle tuple types.
638   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
639     auto tupleType = type.dyn_cast<ast::TupleType>();
640     if (!tupleType || tupleType.size() != exprTupleType.size())
641       return emitConvertError();
642 
643     // Build a new tuple expression using each of the elements of the current
644     // tuple.
645     SmallVector<ast::Expr *> newExprs;
646     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
647       newExprs.push_back(ast::MemberAccessExpr::create(
648           ctx, expr->getLoc(), expr, llvm::to_string(i),
649           exprTupleType.getElementTypes()[i]));
650 
651       auto diagFn = [&](ast::Diagnostic &diag) {
652         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
653                                       i, exprTupleType));
654         if (noteAttachFn)
655           noteAttachFn(diag);
656       };
657       if (failed(convertExpressionTo(newExprs.back(),
658                                      tupleType.getElementTypes()[i], diagFn)))
659         return failure();
660     }
661     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
662                                   tupleType.getElementNames());
663     return success();
664   }
665 
666   return emitConvertError();
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // Directives
671 
672 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
673   StringRef directive = curToken.getSpelling();
674   if (directive == "#include")
675     return parseInclude(decls);
676 
677   return emitError("unknown directive `" + directive + "`");
678 }
679 
680 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
681   SMRange loc = curToken.getLoc();
682   consumeToken(Token::directive);
683 
684   // Handle code completion of the include file path.
685   if (curToken.is(Token::code_complete_string))
686     return codeCompleteIncludeFilename(curToken.getStringValue());
687 
688   // Parse the file being included.
689   if (!curToken.isString())
690     return emitError(loc,
691                      "expected string file name after `include` directive");
692   SMRange fileLoc = curToken.getLoc();
693   std::string filenameStr = curToken.getStringValue();
694   StringRef filename = filenameStr;
695   consumeToken();
696 
697   // Check the type of include. If ending with `.pdll`, this is another pdl file
698   // to be parsed along with the current module.
699   if (filename.endswith(".pdll")) {
700     if (failed(lexer.pushInclude(filename, fileLoc)))
701       return emitError(fileLoc,
702                        "unable to open include file `" + filename + "`");
703 
704     // If we added the include successfully, parse it into the current module.
705     // Make sure to update to the next token after we finish parsing the nested
706     // file.
707     curToken = lexer.lexToken();
708     LogicalResult result = parseModuleBody(decls);
709     curToken = lexer.lexToken();
710     return result;
711   }
712 
713   // Otherwise, this must be a `.td` include.
714   if (filename.endswith(".td"))
715     return parseTdInclude(filename, fileLoc, decls);
716 
717   return emitError(fileLoc,
718                    "expected include filename to end with `.pdll` or `.td`");
719 }
720 
721 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
722                                      SmallVectorImpl<ast::Decl *> &decls) {
723   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
724 
725   // Use the source manager to open the file, but don't yet add it.
726   std::string includedFile;
727   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
728       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
729   if (!includeBuffer)
730     return emitError(fileLoc, "unable to open include file `" + filename + "`");
731 
732   // Setup the source manager for parsing the tablegen file.
733   llvm::SourceMgr tdSrcMgr;
734   tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
735   tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
736 
737   // This class provides a context argument for the llvm::SourceMgr diagnostic
738   // handler.
739   struct DiagHandlerContext {
740     Parser &parser;
741     StringRef filename;
742     llvm::SMRange loc;
743   } handlerContext{*this, filename, fileLoc};
744 
745   // Set the diagnostic handler for the tablegen source manager.
746   tdSrcMgr.setDiagHandler(
747       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
748         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
749         (void)ctx->parser.emitError(
750             ctx->loc,
751             llvm::formatv("error while processing include file `{0}`: {1}",
752                           ctx->filename, diag.getMessage()));
753       },
754       &handlerContext);
755 
756   // Parse the tablegen file.
757   llvm::RecordKeeper tdRecords;
758   if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
759     return failure();
760 
761   // Process the parsed records.
762   processTdIncludeRecords(tdRecords, decls);
763 
764   // After we are done processing, move all of the tablegen source buffers to
765   // the main parser source mgr. This allows for directly using source locations
766   // from the .td files without needing to remap them.
767   parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
768   return success();
769 }
770 
771 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
772                                      SmallVectorImpl<ast::Decl *> &decls) {
773   // Return the length kind of the given value.
774   auto getLengthKind = [](const auto &value) {
775     if (value.isOptional())
776       return ods::VariableLengthKind::Optional;
777     return value.isVariadic() ? ods::VariableLengthKind::Variadic
778                               : ods::VariableLengthKind::Single;
779   };
780 
781   // Insert a type constraint into the ODS context.
782   ods::Context &odsContext = ctx.getODSContext();
783   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
784       -> const ods::TypeConstraint & {
785     return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(),
786                                            cst.constraint.getSummary(),
787                                            cst.constraint.getCPPClassName());
788   };
789   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
790     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
791   };
792 
793   // Process the parsed tablegen records to build ODS information.
794   /// Operations.
795   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
796     tblgen::Operator op(def);
797 
798     bool inserted = false;
799     ods::Operation *odsOp = nullptr;
800     std::tie(odsOp, inserted) =
801         odsContext.insertOperation(op.getOperationName(), op.getSummary(),
802                                    op.getDescription(), op.getLoc().front());
803 
804     // Ignore operations that have already been added.
805     if (!inserted)
806       continue;
807 
808     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
809       odsOp->appendAttribute(
810           attr.name, attr.attr.isOptional(),
811           odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(),
812                                                attr.attr.getSummary(),
813                                                attr.attr.getStorageType()));
814     }
815     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
816       odsOp->appendOperand(operand.name, getLengthKind(operand),
817                            addTypeConstraint(operand));
818     }
819     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
820       odsOp->appendResult(result.name, getLengthKind(result),
821                           addTypeConstraint(result));
822     }
823   }
824   /// Attr constraints.
825   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
826     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
827       decls.push_back(
828           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
829               tblgen::AttrConstraint(def),
830               convertLocToRange(def->getLoc().front()), attrTy));
831     }
832   }
833   /// Type constraints.
834   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
835     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
836       decls.push_back(
837           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
838               tblgen::TypeConstraint(def),
839               convertLocToRange(def->getLoc().front()), typeTy));
840     }
841   }
842   /// Interfaces.
843   ast::Type opTy = ast::OperationType::get(ctx);
844   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
845     StringRef name = def->getName();
846     if (def->isAnonymous() || curDeclScope->lookup(name) ||
847         def->isSubClassOf("DeclareInterfaceMethods"))
848       continue;
849     SMRange loc = convertLocToRange(def->getLoc().front());
850 
851     StringRef className = def->getValueAsString("cppClassName");
852     StringRef cppNamespace = def->getValueAsString("cppNamespace");
853     std::string codeBlock =
854         llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));",
855                       cppNamespace, className)
856             .str();
857 
858     if (def->isSubClassOf("OpInterface")) {
859       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
860           name, codeBlock, loc, opTy));
861     } else if (def->isSubClassOf("AttrInterface")) {
862       decls.push_back(
863           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
864               name, codeBlock, loc, attrTy));
865     } else if (def->isSubClassOf("TypeInterface")) {
866       decls.push_back(
867           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
868               name, codeBlock, loc, typeTy));
869     }
870   }
871 }
872 
873 template <typename ConstraintT>
874 ast::Decl *
875 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
876                                           SMRange loc, ast::Type type) {
877   // Build the single input parameter.
878   ast::DeclScope *argScope = pushDeclScope();
879   auto *paramVar = ast::VariableDecl::create(
880       ctx, ast::Name::create(ctx, "self", loc), type,
881       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
882   argScope->add(paramVar);
883   popDeclScope();
884 
885   // Build the native constraint.
886   auto *constraintDecl = ast::UserConstraintDecl::createNative(
887       ctx, ast::Name::create(ctx, name, loc), paramVar,
888       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
889   curDeclScope->add(constraintDecl);
890   return constraintDecl;
891 }
892 
893 template <typename ConstraintT>
894 ast::Decl *
895 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
896                                           SMRange loc, ast::Type type) {
897   // Format the condition template.
898   tblgen::FmtContext fmtContext;
899   fmtContext.withSelf("self");
900   std::string codeBlock = tblgen::tgfmt(
901       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
902       &fmtContext);
903 
904   return createODSNativePDLLConstraintDecl<ConstraintT>(
905       constraint.getUniqueDefName(), codeBlock, loc, type);
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // Decls
910 
911 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
912   FailureOr<ast::Decl *> decl;
913   switch (curToken.getKind()) {
914   case Token::kw_Constraint:
915     decl = parseUserConstraintDecl();
916     break;
917   case Token::kw_Pattern:
918     decl = parsePatternDecl();
919     break;
920   case Token::kw_Rewrite:
921     decl = parseUserRewriteDecl();
922     break;
923   default:
924     return emitError("expected top-level declaration, such as a `Pattern`");
925   }
926   if (failed(decl))
927     return failure();
928 
929   // If the decl has a name, add it to the current scope.
930   if (const ast::Name *name = (*decl)->getName()) {
931     if (failed(checkDefineNamedDecl(*name)))
932       return failure();
933     curDeclScope->add(*decl);
934   }
935   return decl;
936 }
937 
938 FailureOr<ast::NamedAttributeDecl *>
939 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
940   // Check for name code completion.
941   if (curToken.is(Token::code_complete))
942     return codeCompleteAttributeName(parentOpName);
943 
944   std::string attrNameStr;
945   if (curToken.isString())
946     attrNameStr = curToken.getStringValue();
947   else if (curToken.is(Token::identifier) || curToken.isKeyword())
948     attrNameStr = curToken.getSpelling().str();
949   else
950     return emitError("expected identifier or string attribute name");
951   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
952   consumeToken();
953 
954   // Check for a value of the attribute.
955   ast::Expr *attrValue = nullptr;
956   if (consumeIf(Token::equal)) {
957     FailureOr<ast::Expr *> attrExpr = parseExpr();
958     if (failed(attrExpr))
959       return failure();
960     attrValue = *attrExpr;
961   } else {
962     // If there isn't a concrete value, create an expression representing a
963     // UnitAttr.
964     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
965   }
966 
967   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
968 }
969 
970 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
971     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
972     bool expectTerminalSemicolon) {
973   consumeToken(Token::equal_arrow);
974 
975   // Parse the single statement of the lambda body.
976   SMLoc bodyStartLoc = curToken.getStartLoc();
977   pushDeclScope();
978   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
979   bool failedToParse =
980       failed(singleStatement) || failed(processStatementFn(*singleStatement));
981   popDeclScope();
982   if (failedToParse)
983     return failure();
984 
985   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
986   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
987 }
988 
989 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
990   // Ensure that the argument is named.
991   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
992     return emitError("expected identifier argument name");
993 
994   // Parse the argument similarly to a normal variable.
995   StringRef name = curToken.getSpelling();
996   SMRange nameLoc = curToken.getLoc();
997   consumeToken();
998 
999   if (failed(
1000           parseToken(Token::colon, "expected `:` before argument constraint")))
1001     return failure();
1002 
1003   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1004   if (failed(cst))
1005     return failure();
1006 
1007   return createArgOrResultVariableDecl(name, nameLoc, *cst);
1008 }
1009 
1010 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1011   // Check to see if this result is named.
1012   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1013     // Check to see if this name actually refers to a Constraint.
1014     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1015     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1016       // If yes, and this is a Rewrite, give a nice error message as non-Core
1017       // constraints are not supported on Rewrite results.
1018       if (parserContext == ParserContext::Rewrite) {
1019         return emitError(
1020             "`Rewrite` results are only permitted to use core constraints, "
1021             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1022       }
1023 
1024       // Otherwise, parse this as an unnamed result variable.
1025     } else {
1026       // If it wasn't a constraint, parse the result similarly to a variable. If
1027       // there is already an existing decl, we will emit an error when defining
1028       // this variable later.
1029       StringRef name = curToken.getSpelling();
1030       SMRange nameLoc = curToken.getLoc();
1031       consumeToken();
1032 
1033       if (failed(parseToken(Token::colon,
1034                             "expected `:` before result constraint")))
1035         return failure();
1036 
1037       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1038       if (failed(cst))
1039         return failure();
1040 
1041       return createArgOrResultVariableDecl(name, nameLoc, *cst);
1042     }
1043   }
1044 
1045   // If it isn't named, we parse the constraint directly and create an unnamed
1046   // result variable.
1047   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1048   if (failed(cst))
1049     return failure();
1050 
1051   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1052 }
1053 
1054 FailureOr<ast::UserConstraintDecl *>
1055 Parser::parseUserConstraintDecl(bool isInline) {
1056   // Constraints and rewrites have very similar formats, dispatch to a shared
1057   // interface for parsing.
1058   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1059       [&](auto &&...args) {
1060         return this->parseUserPDLLConstraintDecl(args...);
1061       },
1062       ParserContext::Constraint, "constraint", isInline);
1063 }
1064 
1065 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1066   FailureOr<ast::UserConstraintDecl *> decl =
1067       parseUserConstraintDecl(/*isInline=*/true);
1068   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1069     return failure();
1070 
1071   curDeclScope->add(*decl);
1072   return decl;
1073 }
1074 
1075 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1076     const ast::Name &name, bool isInline,
1077     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1078     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1079   // Push the argument scope back onto the list, so that the body can
1080   // reference arguments.
1081   pushDeclScope(argumentScope);
1082 
1083   // Parse the body of the constraint. The body is either defined as a compound
1084   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1085   ast::CompoundStmt *body;
1086   if (curToken.is(Token::equal_arrow)) {
1087     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1088         [&](ast::Stmt *&stmt) -> LogicalResult {
1089           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1090           if (!stmtExpr) {
1091             return emitError(stmt->getLoc(),
1092                              "expected `Constraint` lambda body to contain a "
1093                              "single expression");
1094           }
1095           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1096           return success();
1097         },
1098         /*expectTerminalSemicolon=*/!isInline);
1099     if (failed(bodyResult))
1100       return failure();
1101     body = *bodyResult;
1102   } else {
1103     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1104     if (failed(bodyResult))
1105       return failure();
1106     body = *bodyResult;
1107 
1108     // Verify the structure of the body.
1109     auto bodyIt = body->begin(), bodyE = body->end();
1110     for (; bodyIt != bodyE; ++bodyIt)
1111       if (isa<ast::ReturnStmt>(*bodyIt))
1112         break;
1113     if (failed(validateUserConstraintOrRewriteReturn(
1114             "Constraint", body, bodyIt, bodyE, results, resultType)))
1115       return failure();
1116   }
1117   popDeclScope();
1118 
1119   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1120       name, arguments, results, resultType, body);
1121 }
1122 
1123 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1124   // Constraints and rewrites have very similar formats, dispatch to a shared
1125   // interface for parsing.
1126   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1127       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1128       ParserContext::Rewrite, "rewrite", isInline);
1129 }
1130 
1131 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1132   FailureOr<ast::UserRewriteDecl *> decl =
1133       parseUserRewriteDecl(/*isInline=*/true);
1134   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1135     return failure();
1136 
1137   curDeclScope->add(*decl);
1138   return decl;
1139 }
1140 
1141 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1142     const ast::Name &name, bool isInline,
1143     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1144     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1145   // Push the argument scope back onto the list, so that the body can
1146   // reference arguments.
1147   curDeclScope = argumentScope;
1148   ast::CompoundStmt *body;
1149   if (curToken.is(Token::equal_arrow)) {
1150     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1151         [&](ast::Stmt *&statement) -> LogicalResult {
1152           if (isa<ast::OpRewriteStmt>(statement))
1153             return success();
1154 
1155           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1156           if (!statementExpr) {
1157             return emitError(
1158                 statement->getLoc(),
1159                 "expected `Rewrite` lambda body to contain a single expression "
1160                 "or an operation rewrite statement; such as `erase`, "
1161                 "`replace`, or `rewrite`");
1162           }
1163           statement =
1164               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1165           return success();
1166         },
1167         /*expectTerminalSemicolon=*/!isInline);
1168     if (failed(bodyResult))
1169       return failure();
1170     body = *bodyResult;
1171   } else {
1172     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1173     if (failed(bodyResult))
1174       return failure();
1175     body = *bodyResult;
1176   }
1177   popDeclScope();
1178 
1179   // Verify the structure of the body.
1180   auto bodyIt = body->begin(), bodyE = body->end();
1181   for (; bodyIt != bodyE; ++bodyIt)
1182     if (isa<ast::ReturnStmt>(*bodyIt))
1183       break;
1184   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1185                                                    bodyE, results, resultType)))
1186     return failure();
1187   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1188       name, arguments, results, resultType, body);
1189 }
1190 
1191 template <typename T, typename ParseUserPDLLDeclFnT>
1192 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1193     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1194     StringRef anonymousNamePrefix, bool isInline) {
1195   SMRange loc = curToken.getLoc();
1196   consumeToken();
1197   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1198 
1199   // Parse the name of the decl.
1200   const ast::Name *name = nullptr;
1201   if (curToken.isNot(Token::identifier)) {
1202     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1203     // in C++, so being unnamed is fine.
1204     if (!isInline)
1205       return emitError("expected identifier name");
1206 
1207     // Create a unique anonymous name to use, as the name for this decl is not
1208     // important.
1209     std::string anonName =
1210         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1211                       anonymousDeclNameCounter++)
1212             .str();
1213     name = &ast::Name::create(ctx, anonName, loc);
1214   } else {
1215     // If a name was provided, we can use it directly.
1216     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1217     consumeToken(Token::identifier);
1218   }
1219 
1220   // Parse the functional signature of the decl.
1221   SmallVector<ast::VariableDecl *> arguments, results;
1222   ast::DeclScope *argumentScope;
1223   ast::Type resultType;
1224   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1225                                                    argumentScope, resultType)))
1226     return failure();
1227 
1228   // Check to see which type of constraint this is. If the constraint contains a
1229   // compound body, this is a PDLL decl.
1230   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1231     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1232                            resultType);
1233 
1234   // Otherwise, this is a native decl.
1235   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1236                                                    results, resultType);
1237 }
1238 
1239 template <typename T>
1240 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1241     const ast::Name &name, bool isInline,
1242     ArrayRef<ast::VariableDecl *> arguments,
1243     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1244   // If followed by a string, the native code body has also been specified.
1245   std::string codeStrStorage;
1246   Optional<StringRef> optCodeStr;
1247   if (curToken.isString()) {
1248     codeStrStorage = curToken.getStringValue();
1249     optCodeStr = codeStrStorage;
1250     consumeToken();
1251   } else if (isInline) {
1252     return emitError(name.getLoc(),
1253                      "external declarations must be declared in global scope");
1254   } else if (curToken.is(Token::error)) {
1255     return failure();
1256   }
1257   if (failed(parseToken(Token::semicolon,
1258                         "expected `;` after native declaration")))
1259     return failure();
1260   // TODO: PDL should be able to support constraint results in certain
1261   // situations, we should revise this.
1262   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1263     return emitError(
1264         "native Constraints currently do not support returning results");
1265   }
1266   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1267 }
1268 
1269 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1270     SmallVectorImpl<ast::VariableDecl *> &arguments,
1271     SmallVectorImpl<ast::VariableDecl *> &results,
1272     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1273   // Parse the argument list of the decl.
1274   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1275     return failure();
1276 
1277   argumentScope = pushDeclScope();
1278   if (curToken.isNot(Token::r_paren)) {
1279     do {
1280       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1281       if (failed(argument))
1282         return failure();
1283       arguments.emplace_back(*argument);
1284     } while (consumeIf(Token::comma));
1285   }
1286   popDeclScope();
1287   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1288     return failure();
1289 
1290   // Parse the results of the decl.
1291   pushDeclScope();
1292   if (consumeIf(Token::arrow)) {
1293     auto parseResultFn = [&]() -> LogicalResult {
1294       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1295       if (failed(result))
1296         return failure();
1297       results.emplace_back(*result);
1298       return success();
1299     };
1300 
1301     // Check for a list of results.
1302     if (consumeIf(Token::l_paren)) {
1303       do {
1304         if (failed(parseResultFn()))
1305           return failure();
1306       } while (consumeIf(Token::comma));
1307       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1308         return failure();
1309 
1310       // Otherwise, there is only one result.
1311     } else if (failed(parseResultFn())) {
1312       return failure();
1313     }
1314   }
1315   popDeclScope();
1316 
1317   // Compute the result type of the decl.
1318   resultType = createUserConstraintRewriteResultType(results);
1319 
1320   // Verify that results are only named if there are more than one.
1321   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1322     return emitError(
1323         results.front()->getLoc(),
1324         "cannot create a single-element tuple with an element label");
1325   }
1326   return success();
1327 }
1328 
1329 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1330     StringRef declType, ast::CompoundStmt *body,
1331     ArrayRef<ast::Stmt *>::iterator bodyIt,
1332     ArrayRef<ast::Stmt *>::iterator bodyE,
1333     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1334   // Handle if a `return` was provided.
1335   if (bodyIt != bodyE) {
1336     // Emit an error if we have trailing statements after the return.
1337     if (std::next(bodyIt) != bodyE) {
1338       return emitError(
1339           (*std::next(bodyIt))->getLoc(),
1340           llvm::formatv("`return` terminated the `{0}` body, but found "
1341                         "trailing statements afterwards",
1342                         declType));
1343     }
1344 
1345     // Otherwise if a return wasn't provided, check that no results are
1346     // expected.
1347   } else if (!results.empty()) {
1348     return emitError(
1349         {body->getLoc().End, body->getLoc().End},
1350         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1351                       declType, resultType));
1352   }
1353   return success();
1354 }
1355 
1356 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1357   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1358     if (isa<ast::OpRewriteStmt>(statement))
1359       return success();
1360     return emitError(
1361         statement->getLoc(),
1362         "expected Pattern lambda body to contain a single operation "
1363         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1364   });
1365 }
1366 
1367 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1368   SMRange loc = curToken.getLoc();
1369   consumeToken(Token::kw_Pattern);
1370   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1371                                               ParserContext::PatternMatch);
1372 
1373   // Check for an optional identifier for the pattern name.
1374   const ast::Name *name = nullptr;
1375   if (curToken.is(Token::identifier)) {
1376     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1377     consumeToken(Token::identifier);
1378   }
1379 
1380   // Parse any pattern metadata.
1381   ParsedPatternMetadata metadata;
1382   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1383     return failure();
1384 
1385   // Parse the pattern body.
1386   ast::CompoundStmt *body;
1387 
1388   // Handle a lambda body.
1389   if (curToken.is(Token::equal_arrow)) {
1390     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1391     if (failed(bodyResult))
1392       return failure();
1393     body = *bodyResult;
1394   } else {
1395     if (curToken.isNot(Token::l_brace))
1396       return emitError("expected `{` or `=>` to start pattern body");
1397     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1398     if (failed(bodyResult))
1399       return failure();
1400     body = *bodyResult;
1401 
1402     // Verify the body of the pattern.
1403     auto bodyIt = body->begin(), bodyE = body->end();
1404     for (; bodyIt != bodyE; ++bodyIt) {
1405       if (isa<ast::ReturnStmt>(*bodyIt)) {
1406         return emitError((*bodyIt)->getLoc(),
1407                          "`return` statements are only permitted within a "
1408                          "`Constraint` or `Rewrite` body");
1409       }
1410       // Break when we've found the rewrite statement.
1411       if (isa<ast::OpRewriteStmt>(*bodyIt))
1412         break;
1413     }
1414     if (bodyIt == bodyE) {
1415       return emitError(loc,
1416                        "expected Pattern body to terminate with an operation "
1417                        "rewrite statement, such as `erase`");
1418     }
1419     if (std::next(bodyIt) != bodyE) {
1420       return emitError((*std::next(bodyIt))->getLoc(),
1421                        "Pattern body was terminated by an operation "
1422                        "rewrite statement, but found trailing statements");
1423     }
1424   }
1425 
1426   return createPatternDecl(loc, name, metadata, body);
1427 }
1428 
1429 LogicalResult
1430 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1431   Optional<SMRange> benefitLoc;
1432   Optional<SMRange> hasBoundedRecursionLoc;
1433 
1434   do {
1435     // Handle metadata code completion.
1436     if (curToken.is(Token::code_complete))
1437       return codeCompletePatternMetadata();
1438 
1439     if (curToken.isNot(Token::identifier))
1440       return emitError("expected pattern metadata identifier");
1441     StringRef metadataStr = curToken.getSpelling();
1442     SMRange metadataLoc = curToken.getLoc();
1443     consumeToken(Token::identifier);
1444 
1445     // Parse the benefit metadata: benefit(<integer-value>)
1446     if (metadataStr == "benefit") {
1447       if (benefitLoc) {
1448         return emitErrorAndNote(metadataLoc,
1449                                 "pattern benefit has already been specified",
1450                                 *benefitLoc, "see previous definition here");
1451       }
1452       if (failed(parseToken(Token::l_paren,
1453                             "expected `(` before pattern benefit")))
1454         return failure();
1455 
1456       uint16_t benefitValue = 0;
1457       if (curToken.isNot(Token::integer))
1458         return emitError("expected integral pattern benefit");
1459       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1460         return emitError(
1461             "expected pattern benefit to fit within a 16-bit integer");
1462       consumeToken(Token::integer);
1463 
1464       metadata.benefit = benefitValue;
1465       benefitLoc = metadataLoc;
1466 
1467       if (failed(
1468               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1469         return failure();
1470       continue;
1471     }
1472 
1473     // Parse the bounded recursion metadata: recursion
1474     if (metadataStr == "recursion") {
1475       if (hasBoundedRecursionLoc) {
1476         return emitErrorAndNote(
1477             metadataLoc,
1478             "pattern recursion metadata has already been specified",
1479             *hasBoundedRecursionLoc, "see previous definition here");
1480       }
1481       metadata.hasBoundedRecursion = true;
1482       hasBoundedRecursionLoc = metadataLoc;
1483       continue;
1484     }
1485 
1486     return emitError(metadataLoc, "unknown pattern metadata");
1487   } while (consumeIf(Token::comma));
1488 
1489   return success();
1490 }
1491 
1492 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1493   consumeToken(Token::less);
1494 
1495   FailureOr<ast::Expr *> typeExpr = parseExpr();
1496   if (failed(typeExpr) ||
1497       failed(parseToken(Token::greater,
1498                         "expected `>` after variable type constraint")))
1499     return failure();
1500   return typeExpr;
1501 }
1502 
1503 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1504   assert(curDeclScope && "defining decl outside of a decl scope");
1505   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1506     return emitErrorAndNote(
1507         name.getLoc(), "`" + name.getName() + "` has already been defined",
1508         lastDecl->getName()->getLoc(), "see previous definition here");
1509   }
1510   return success();
1511 }
1512 
1513 FailureOr<ast::VariableDecl *>
1514 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1515                            ast::Expr *initExpr,
1516                            ArrayRef<ast::ConstraintRef> constraints) {
1517   assert(curDeclScope && "defining variable outside of decl scope");
1518   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1519 
1520   // If the name of the variable indicates a special variable, we don't add it
1521   // to the scope. This variable is local to the definition point.
1522   if (name.empty() || name == "_") {
1523     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1524                                      constraints);
1525   }
1526   if (failed(checkDefineNamedDecl(nameDecl)))
1527     return failure();
1528 
1529   auto *varDecl =
1530       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1531   curDeclScope->add(varDecl);
1532   return varDecl;
1533 }
1534 
1535 FailureOr<ast::VariableDecl *>
1536 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1537                            ArrayRef<ast::ConstraintRef> constraints) {
1538   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1539                             constraints);
1540 }
1541 
1542 LogicalResult Parser::parseVariableDeclConstraintList(
1543     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1544   Optional<SMRange> typeConstraint;
1545   auto parseSingleConstraint = [&] {
1546     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1547         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1548         /*allowNonCoreConstraints=*/true);
1549     if (failed(constraint))
1550       return failure();
1551     constraints.push_back(*constraint);
1552     return success();
1553   };
1554 
1555   // Check to see if this is a single constraint, or a list.
1556   if (!consumeIf(Token::l_square))
1557     return parseSingleConstraint();
1558 
1559   do {
1560     if (failed(parseSingleConstraint()))
1561       return failure();
1562   } while (consumeIf(Token::comma));
1563   return parseToken(Token::r_square, "expected `]` after constraint list");
1564 }
1565 
1566 FailureOr<ast::ConstraintRef>
1567 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1568                         ArrayRef<ast::ConstraintRef> existingConstraints,
1569                         bool allowInlineTypeConstraints,
1570                         bool allowNonCoreConstraints) {
1571   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1572     if (!allowInlineTypeConstraints) {
1573       return emitError(
1574           curToken.getLoc(),
1575           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1576           "permitted on arguments or results");
1577     }
1578     if (typeConstraint)
1579       return emitErrorAndNote(
1580           curToken.getLoc(),
1581           "the type of this variable has already been constrained",
1582           *typeConstraint, "see previous constraint location here");
1583     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1584     if (failed(constraintExpr))
1585       return failure();
1586     typeExpr = *constraintExpr;
1587     typeConstraint = typeExpr->getLoc();
1588     return success();
1589   };
1590 
1591   SMRange loc = curToken.getLoc();
1592   switch (curToken.getKind()) {
1593   case Token::kw_Attr: {
1594     consumeToken(Token::kw_Attr);
1595 
1596     // Check for a type constraint.
1597     ast::Expr *typeExpr = nullptr;
1598     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1599       return failure();
1600     return ast::ConstraintRef(
1601         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1602   }
1603   case Token::kw_Op: {
1604     consumeToken(Token::kw_Op);
1605 
1606     // Parse an optional operation name. If the name isn't provided, this refers
1607     // to "any" operation.
1608     FailureOr<ast::OpNameDecl *> opName =
1609         parseWrappedOperationName(/*allowEmptyName=*/true);
1610     if (failed(opName))
1611       return failure();
1612 
1613     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1614                               loc);
1615   }
1616   case Token::kw_Type:
1617     consumeToken(Token::kw_Type);
1618     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1619   case Token::kw_TypeRange:
1620     consumeToken(Token::kw_TypeRange);
1621     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1622                               loc);
1623   case Token::kw_Value: {
1624     consumeToken(Token::kw_Value);
1625 
1626     // Check for a type constraint.
1627     ast::Expr *typeExpr = nullptr;
1628     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1629       return failure();
1630 
1631     return ast::ConstraintRef(
1632         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1633   }
1634   case Token::kw_ValueRange: {
1635     consumeToken(Token::kw_ValueRange);
1636 
1637     // Check for a type constraint.
1638     ast::Expr *typeExpr = nullptr;
1639     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1640       return failure();
1641 
1642     return ast::ConstraintRef(
1643         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1644   }
1645 
1646   case Token::kw_Constraint: {
1647     // Handle an inline constraint.
1648     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1649     if (failed(decl))
1650       return failure();
1651     return ast::ConstraintRef(*decl, loc);
1652   }
1653   case Token::identifier: {
1654     StringRef constraintName = curToken.getSpelling();
1655     consumeToken(Token::identifier);
1656 
1657     // Lookup the referenced constraint.
1658     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1659     if (!cstDecl) {
1660       return emitError(loc, "unknown reference to constraint `" +
1661                                 constraintName + "`");
1662     }
1663 
1664     // Handle a reference to a proper constraint.
1665     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1666       return ast::ConstraintRef(cst, loc);
1667 
1668     return emitErrorAndNote(
1669         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1670         "see the definition of `" + constraintName + "` here");
1671   }
1672     // Handle single entity constraint code completion.
1673   case Token::code_complete: {
1674     // Try to infer the current type for use by code completion.
1675     ast::Type inferredType;
1676     if (failed(validateVariableConstraints(existingConstraints, inferredType,
1677                                            allowNonCoreConstraints)))
1678       return failure();
1679 
1680     return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1681                                       allowInlineTypeConstraints);
1682   }
1683   default:
1684     break;
1685   }
1686   return emitError(loc, "expected identifier constraint");
1687 }
1688 
1689 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1690   // Constraint arguments may apply more complex constraints via the arguments.
1691   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1692 
1693   Optional<SMRange> typeConstraint;
1694   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1695                          /*allowInlineTypeConstraints=*/false,
1696                          allowNonCoreConstraints);
1697 }
1698 
1699 //===----------------------------------------------------------------------===//
1700 // Exprs
1701 
1702 FailureOr<ast::Expr *> Parser::parseExpr() {
1703   if (curToken.is(Token::underscore))
1704     return parseUnderscoreExpr();
1705 
1706   // Parse the LHS expression.
1707   FailureOr<ast::Expr *> lhsExpr;
1708   switch (curToken.getKind()) {
1709   case Token::kw_attr:
1710     lhsExpr = parseAttributeExpr();
1711     break;
1712   case Token::kw_Constraint:
1713     lhsExpr = parseInlineConstraintLambdaExpr();
1714     break;
1715   case Token::identifier:
1716     lhsExpr = parseIdentifierExpr();
1717     break;
1718   case Token::kw_op:
1719     lhsExpr = parseOperationExpr();
1720     break;
1721   case Token::kw_Rewrite:
1722     lhsExpr = parseInlineRewriteLambdaExpr();
1723     break;
1724   case Token::kw_type:
1725     lhsExpr = parseTypeExpr();
1726     break;
1727   case Token::l_paren:
1728     lhsExpr = parseTupleExpr();
1729     break;
1730   default:
1731     return emitError("expected expression");
1732   }
1733   if (failed(lhsExpr))
1734     return failure();
1735 
1736   // Check for an operator expression.
1737   while (true) {
1738     switch (curToken.getKind()) {
1739     case Token::dot:
1740       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1741       break;
1742     case Token::l_paren:
1743       lhsExpr = parseCallExpr(*lhsExpr);
1744       break;
1745     default:
1746       return lhsExpr;
1747     }
1748     if (failed(lhsExpr))
1749       return failure();
1750   }
1751 }
1752 
1753 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1754   SMRange loc = curToken.getLoc();
1755   consumeToken(Token::kw_attr);
1756 
1757   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1758   // identifier.
1759   if (!consumeIf(Token::less)) {
1760     resetToken(loc);
1761     return parseIdentifierExpr();
1762   }
1763 
1764   if (!curToken.isString())
1765     return emitError("expected string literal containing MLIR attribute");
1766   std::string attrExpr = curToken.getStringValue();
1767   consumeToken();
1768 
1769   loc.End = curToken.getEndLoc();
1770   if (failed(
1771           parseToken(Token::greater, "expected `>` after attribute literal")))
1772     return failure();
1773   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1774 }
1775 
1776 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1777   consumeToken(Token::l_paren);
1778 
1779   // Parse the arguments of the call.
1780   SmallVector<ast::Expr *> arguments;
1781   if (curToken.isNot(Token::r_paren)) {
1782     do {
1783       // Handle code completion for the call arguments.
1784       if (curToken.is(Token::code_complete)) {
1785         codeCompleteCallSignature(parentExpr, arguments.size());
1786         return failure();
1787       }
1788 
1789       FailureOr<ast::Expr *> argument = parseExpr();
1790       if (failed(argument))
1791         return failure();
1792       arguments.push_back(*argument);
1793     } while (consumeIf(Token::comma));
1794   }
1795 
1796   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1797   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1798     return failure();
1799 
1800   return createCallExpr(loc, parentExpr, arguments);
1801 }
1802 
1803 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1804   ast::Decl *decl = curDeclScope->lookup(name);
1805   if (!decl)
1806     return emitError(loc, "undefined reference to `" + name + "`");
1807 
1808   return createDeclRefExpr(loc, decl);
1809 }
1810 
1811 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1812   StringRef name = curToken.getSpelling();
1813   SMRange nameLoc = curToken.getLoc();
1814   consumeToken();
1815 
1816   // Check to see if this is a decl ref expression that defines a variable
1817   // inline.
1818   if (consumeIf(Token::colon)) {
1819     SmallVector<ast::ConstraintRef> constraints;
1820     if (failed(parseVariableDeclConstraintList(constraints)))
1821       return failure();
1822     ast::Type type;
1823     if (failed(validateVariableConstraints(constraints, type)))
1824       return failure();
1825     return createInlineVariableExpr(type, name, nameLoc, constraints);
1826   }
1827 
1828   return parseDeclRefExpr(name, nameLoc);
1829 }
1830 
1831 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1832   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1833   if (failed(decl))
1834     return failure();
1835 
1836   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1837                                   ast::ConstraintType::get(ctx));
1838 }
1839 
1840 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1841   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1842   if (failed(decl))
1843     return failure();
1844 
1845   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1846                                   ast::RewriteType::get(ctx));
1847 }
1848 
1849 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1850   SMRange dotLoc = curToken.getLoc();
1851   consumeToken(Token::dot);
1852 
1853   // Check for code completion of the member name.
1854   if (curToken.is(Token::code_complete))
1855     return codeCompleteMemberAccess(parentExpr);
1856 
1857   // Parse the member name.
1858   Token memberNameTok = curToken;
1859   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1860       !memberNameTok.isKeyword())
1861     return emitError(dotLoc, "expected identifier or numeric member name");
1862   StringRef memberName = memberNameTok.getSpelling();
1863   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1864   consumeToken();
1865 
1866   return createMemberAccessExpr(parentExpr, memberName, loc);
1867 }
1868 
1869 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1870   SMRange loc = curToken.getLoc();
1871 
1872   // Check for code completion for the dialect name.
1873   if (curToken.is(Token::code_complete))
1874     return codeCompleteDialectName();
1875 
1876   // Handle the case of an no operation name.
1877   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1878     if (allowEmptyName)
1879       return ast::OpNameDecl::create(ctx, SMRange());
1880     return emitError("expected dialect namespace");
1881   }
1882   StringRef name = curToken.getSpelling();
1883   consumeToken();
1884 
1885   // Otherwise, this is a literal operation name.
1886   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1887     return failure();
1888 
1889   // Check for code completion for the operation name.
1890   if (curToken.is(Token::code_complete))
1891     return codeCompleteOperationName(name);
1892 
1893   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1894     return emitError("expected operation name after dialect namespace");
1895 
1896   name = StringRef(name.data(), name.size() + 1);
1897   do {
1898     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1899     loc.End = curToken.getEndLoc();
1900     consumeToken();
1901   } while (curToken.isAny(Token::identifier, Token::dot) ||
1902            curToken.isKeyword());
1903   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1904 }
1905 
1906 FailureOr<ast::OpNameDecl *>
1907 Parser::parseWrappedOperationName(bool allowEmptyName) {
1908   if (!consumeIf(Token::less))
1909     return ast::OpNameDecl::create(ctx, SMRange());
1910 
1911   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1912   if (failed(opNameDecl))
1913     return failure();
1914 
1915   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1916     return failure();
1917   return opNameDecl;
1918 }
1919 
1920 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1921   SMRange loc = curToken.getLoc();
1922   consumeToken(Token::kw_op);
1923 
1924   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1925   // identifier.
1926   if (curToken.isNot(Token::less)) {
1927     resetToken(loc);
1928     return parseIdentifierExpr();
1929   }
1930 
1931   // Parse the operation name. The name may be elided, in which case the
1932   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1933   // `Operation*`). Operation names within a rewrite context must be named.
1934   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1935   FailureOr<ast::OpNameDecl *> opNameDecl =
1936       parseWrappedOperationName(allowEmptyName);
1937   if (failed(opNameDecl))
1938     return failure();
1939   Optional<StringRef> opName = (*opNameDecl)->getName();
1940 
1941   // Functor used to create an implicit range variable, used for implicit "all"
1942   // operand or results variables.
1943   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1944     FailureOr<ast::VariableDecl *> rangeVar =
1945         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1946     assert(succeeded(rangeVar) && "expected range variable to be valid");
1947     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1948   };
1949 
1950   // Check for the optional list of operands.
1951   SmallVector<ast::Expr *> operands;
1952   if (!consumeIf(Token::l_paren)) {
1953     // If the operand list isn't specified and we are in a match context, define
1954     // an inplace unconstrained operand range corresponding to all of the
1955     // operands of the operation. This avoids treating zero operands the same
1956     // way as "unconstrained operands".
1957     if (parserContext != ParserContext::Rewrite) {
1958       operands.push_back(createImplicitRangeVar(
1959           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1960     }
1961   } else if (!consumeIf(Token::r_paren)) {
1962     // Check for operand signature code completion.
1963     if (curToken.is(Token::code_complete)) {
1964       codeCompleteOperationOperandsSignature(opName, operands.size());
1965       return failure();
1966     }
1967 
1968     // If the operand list was specified and non-empty, parse the operands.
1969     do {
1970       FailureOr<ast::Expr *> operand = parseExpr();
1971       if (failed(operand))
1972         return failure();
1973       operands.push_back(*operand);
1974     } while (consumeIf(Token::comma));
1975 
1976     if (failed(parseToken(Token::r_paren,
1977                           "expected `)` after operation operand list")))
1978       return failure();
1979   }
1980 
1981   // Check for the optional list of attributes.
1982   SmallVector<ast::NamedAttributeDecl *> attributes;
1983   if (consumeIf(Token::l_brace)) {
1984     do {
1985       FailureOr<ast::NamedAttributeDecl *> decl =
1986           parseNamedAttributeDecl(opName);
1987       if (failed(decl))
1988         return failure();
1989       attributes.emplace_back(*decl);
1990     } while (consumeIf(Token::comma));
1991 
1992     if (failed(parseToken(Token::r_brace,
1993                           "expected `}` after operation attribute list")))
1994       return failure();
1995   }
1996 
1997   // Check for the optional list of result types.
1998   SmallVector<ast::Expr *> resultTypes;
1999   if (consumeIf(Token::arrow)) {
2000     if (failed(parseToken(Token::l_paren,
2001                           "expected `(` before operation result type list")))
2002       return failure();
2003 
2004     // Handle the case of an empty result list.
2005     if (!consumeIf(Token::r_paren)) {
2006       do {
2007         // Check for result signature code completion.
2008         if (curToken.is(Token::code_complete)) {
2009           codeCompleteOperationResultsSignature(opName, resultTypes.size());
2010           return failure();
2011         }
2012 
2013         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2014         if (failed(resultTypeExpr))
2015           return failure();
2016         resultTypes.push_back(*resultTypeExpr);
2017       } while (consumeIf(Token::comma));
2018 
2019       if (failed(parseToken(Token::r_paren,
2020                             "expected `)` after operation result type list")))
2021         return failure();
2022     }
2023   } else if (parserContext != ParserContext::Rewrite) {
2024     // If the result list isn't specified and we are in a match context, define
2025     // an inplace unconstrained result range corresponding to all of the results
2026     // of the operation. This avoids treating zero results the same way as
2027     // "unconstrained results".
2028     resultTypes.push_back(createImplicitRangeVar(
2029         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2030   }
2031 
2032   return createOperationExpr(loc, *opNameDecl, operands, attributes,
2033                              resultTypes);
2034 }
2035 
2036 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2037   SMRange loc = curToken.getLoc();
2038   consumeToken(Token::l_paren);
2039 
2040   DenseMap<StringRef, SMRange> usedNames;
2041   SmallVector<StringRef> elementNames;
2042   SmallVector<ast::Expr *> elements;
2043   if (curToken.isNot(Token::r_paren)) {
2044     do {
2045       // Check for the optional element name assignment before the value.
2046       StringRef elementName;
2047       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2048         Token elementNameTok = curToken;
2049         consumeToken();
2050 
2051         // The element name is only present if followed by an `=`.
2052         if (consumeIf(Token::equal)) {
2053           elementName = elementNameTok.getSpelling();
2054 
2055           // Check to see if this name is already used.
2056           auto elementNameIt =
2057               usedNames.try_emplace(elementName, elementNameTok.getLoc());
2058           if (!elementNameIt.second) {
2059             return emitErrorAndNote(
2060                 elementNameTok.getLoc(),
2061                 llvm::formatv("duplicate tuple element label `{0}`",
2062                               elementName),
2063                 elementNameIt.first->getSecond(),
2064                 "see previous label use here");
2065           }
2066         } else {
2067           // Otherwise, we treat this as part of an expression so reset the
2068           // lexer.
2069           resetToken(elementNameTok.getLoc());
2070         }
2071       }
2072       elementNames.push_back(elementName);
2073 
2074       // Parse the tuple element value.
2075       FailureOr<ast::Expr *> element = parseExpr();
2076       if (failed(element))
2077         return failure();
2078       elements.push_back(*element);
2079     } while (consumeIf(Token::comma));
2080   }
2081   loc.End = curToken.getEndLoc();
2082   if (failed(
2083           parseToken(Token::r_paren, "expected `)` after tuple element list")))
2084     return failure();
2085   return createTupleExpr(loc, elements, elementNames);
2086 }
2087 
2088 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2089   SMRange loc = curToken.getLoc();
2090   consumeToken(Token::kw_type);
2091 
2092   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2093   // identifier.
2094   if (!consumeIf(Token::less)) {
2095     resetToken(loc);
2096     return parseIdentifierExpr();
2097   }
2098 
2099   if (!curToken.isString())
2100     return emitError("expected string literal containing MLIR type");
2101   std::string attrExpr = curToken.getStringValue();
2102   consumeToken();
2103 
2104   loc.End = curToken.getEndLoc();
2105   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2106     return failure();
2107   return ast::TypeExpr::create(ctx, loc, attrExpr);
2108 }
2109 
2110 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2111   StringRef name = curToken.getSpelling();
2112   SMRange nameLoc = curToken.getLoc();
2113   consumeToken(Token::underscore);
2114 
2115   // Underscore expressions require a constraint list.
2116   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2117     return failure();
2118 
2119   // Parse the constraints for the expression.
2120   SmallVector<ast::ConstraintRef> constraints;
2121   if (failed(parseVariableDeclConstraintList(constraints)))
2122     return failure();
2123 
2124   ast::Type type;
2125   if (failed(validateVariableConstraints(constraints, type)))
2126     return failure();
2127   return createInlineVariableExpr(type, name, nameLoc, constraints);
2128 }
2129 
2130 //===----------------------------------------------------------------------===//
2131 // Stmts
2132 
2133 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2134   FailureOr<ast::Stmt *> stmt;
2135   switch (curToken.getKind()) {
2136   case Token::kw_erase:
2137     stmt = parseEraseStmt();
2138     break;
2139   case Token::kw_let:
2140     stmt = parseLetStmt();
2141     break;
2142   case Token::kw_replace:
2143     stmt = parseReplaceStmt();
2144     break;
2145   case Token::kw_return:
2146     stmt = parseReturnStmt();
2147     break;
2148   case Token::kw_rewrite:
2149     stmt = parseRewriteStmt();
2150     break;
2151   default:
2152     stmt = parseExpr();
2153     break;
2154   }
2155   if (failed(stmt) ||
2156       (expectTerminalSemicolon &&
2157        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2158     return failure();
2159   return stmt;
2160 }
2161 
2162 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2163   SMLoc startLoc = curToken.getStartLoc();
2164   consumeToken(Token::l_brace);
2165 
2166   // Push a new block scope and parse any nested statements.
2167   pushDeclScope();
2168   SmallVector<ast::Stmt *> statements;
2169   while (curToken.isNot(Token::r_brace)) {
2170     FailureOr<ast::Stmt *> statement = parseStmt();
2171     if (failed(statement))
2172       return popDeclScope(), failure();
2173     statements.push_back(*statement);
2174   }
2175   popDeclScope();
2176 
2177   // Consume the end brace.
2178   SMRange location(startLoc, curToken.getEndLoc());
2179   consumeToken(Token::r_brace);
2180 
2181   return ast::CompoundStmt::create(ctx, location, statements);
2182 }
2183 
2184 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2185   if (parserContext == ParserContext::Constraint)
2186     return emitError("`erase` cannot be used within a Constraint");
2187   SMRange loc = curToken.getLoc();
2188   consumeToken(Token::kw_erase);
2189 
2190   // Parse the root operation expression.
2191   FailureOr<ast::Expr *> rootOp = parseExpr();
2192   if (failed(rootOp))
2193     return failure();
2194 
2195   return createEraseStmt(loc, *rootOp);
2196 }
2197 
2198 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2199   SMRange loc = curToken.getLoc();
2200   consumeToken(Token::kw_let);
2201 
2202   // Parse the name of the new variable.
2203   SMRange varLoc = curToken.getLoc();
2204   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2205     // `_` is a reserved variable name.
2206     if (curToken.is(Token::underscore)) {
2207       return emitError(varLoc,
2208                        "`_` may only be used to define \"inline\" variables");
2209     }
2210     return emitError(varLoc,
2211                      "expected identifier after `let` to name a new variable");
2212   }
2213   StringRef varName = curToken.getSpelling();
2214   consumeToken();
2215 
2216   // Parse the optional set of constraints.
2217   SmallVector<ast::ConstraintRef> constraints;
2218   if (consumeIf(Token::colon) &&
2219       failed(parseVariableDeclConstraintList(constraints)))
2220     return failure();
2221 
2222   // Parse the optional initializer expression.
2223   ast::Expr *initializer = nullptr;
2224   if (consumeIf(Token::equal)) {
2225     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2226     if (failed(initOrFailure))
2227       return failure();
2228     initializer = *initOrFailure;
2229 
2230     // Check that the constraints are compatible with having an initializer,
2231     // e.g. type constraints cannot be used with initializers.
2232     for (ast::ConstraintRef constraint : constraints) {
2233       LogicalResult result =
2234           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2235               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2236                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2237                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2238                   return this->emitError(
2239                       constraint.referenceLoc,
2240                       "type constraints are not permitted on variables with "
2241                       "initializers");
2242                 }
2243                 return success();
2244               })
2245               .Default(success());
2246       if (failed(result))
2247         return failure();
2248     }
2249   }
2250 
2251   FailureOr<ast::VariableDecl *> varDecl =
2252       createVariableDecl(varName, varLoc, initializer, constraints);
2253   if (failed(varDecl))
2254     return failure();
2255   return ast::LetStmt::create(ctx, loc, *varDecl);
2256 }
2257 
2258 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2259   if (parserContext == ParserContext::Constraint)
2260     return emitError("`replace` cannot be used within a Constraint");
2261   SMRange loc = curToken.getLoc();
2262   consumeToken(Token::kw_replace);
2263 
2264   // Parse the root operation expression.
2265   FailureOr<ast::Expr *> rootOp = parseExpr();
2266   if (failed(rootOp))
2267     return failure();
2268 
2269   if (failed(
2270           parseToken(Token::kw_with, "expected `with` after root operation")))
2271     return failure();
2272 
2273   // The replacement portion of this statement is within a rewrite context.
2274   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2275                                               ParserContext::Rewrite);
2276 
2277   // Parse the replacement values.
2278   SmallVector<ast::Expr *> replValues;
2279   if (consumeIf(Token::l_paren)) {
2280     if (consumeIf(Token::r_paren)) {
2281       return emitError(
2282           loc, "expected at least one replacement value, consider using "
2283                "`erase` if no replacement values are desired");
2284     }
2285 
2286     do {
2287       FailureOr<ast::Expr *> replExpr = parseExpr();
2288       if (failed(replExpr))
2289         return failure();
2290       replValues.emplace_back(*replExpr);
2291     } while (consumeIf(Token::comma));
2292 
2293     if (failed(parseToken(Token::r_paren,
2294                           "expected `)` after replacement values")))
2295       return failure();
2296   } else {
2297     FailureOr<ast::Expr *> replExpr = parseExpr();
2298     if (failed(replExpr))
2299       return failure();
2300     replValues.emplace_back(*replExpr);
2301   }
2302 
2303   return createReplaceStmt(loc, *rootOp, replValues);
2304 }
2305 
2306 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2307   SMRange loc = curToken.getLoc();
2308   consumeToken(Token::kw_return);
2309 
2310   // Parse the result value.
2311   FailureOr<ast::Expr *> resultExpr = parseExpr();
2312   if (failed(resultExpr))
2313     return failure();
2314 
2315   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2316 }
2317 
2318 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2319   if (parserContext == ParserContext::Constraint)
2320     return emitError("`rewrite` cannot be used within a Constraint");
2321   SMRange loc = curToken.getLoc();
2322   consumeToken(Token::kw_rewrite);
2323 
2324   // Parse the root operation.
2325   FailureOr<ast::Expr *> rootOp = parseExpr();
2326   if (failed(rootOp))
2327     return failure();
2328 
2329   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2330     return failure();
2331 
2332   if (curToken.isNot(Token::l_brace))
2333     return emitError("expected `{` to start rewrite body");
2334 
2335   // The rewrite body of this statement is within a rewrite context.
2336   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2337                                               ParserContext::Rewrite);
2338 
2339   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2340   if (failed(rewriteBody))
2341     return failure();
2342 
2343   // Verify the rewrite body.
2344   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2345     if (isa<ast::ReturnStmt>(stmt)) {
2346       return emitError(stmt->getLoc(),
2347                        "`return` statements are only permitted within a "
2348                        "`Constraint` or `Rewrite` body");
2349     }
2350   }
2351 
2352   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2353 }
2354 
2355 //===----------------------------------------------------------------------===//
2356 // Creation+Analysis
2357 //===----------------------------------------------------------------------===//
2358 
2359 //===----------------------------------------------------------------------===//
2360 // Decls
2361 
2362 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2363   // Unwrap reference expressions.
2364   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2365     node = init->getDecl();
2366   return dyn_cast<ast::CallableDecl>(node);
2367 }
2368 
2369 FailureOr<ast::PatternDecl *>
2370 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2371                           const ParsedPatternMetadata &metadata,
2372                           ast::CompoundStmt *body) {
2373   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2374                                   metadata.hasBoundedRecursion, body);
2375 }
2376 
2377 ast::Type Parser::createUserConstraintRewriteResultType(
2378     ArrayRef<ast::VariableDecl *> results) {
2379   // Single result decls use the type of the single result.
2380   if (results.size() == 1)
2381     return results[0]->getType();
2382 
2383   // Multiple results use a tuple type, with the types and names grabbed from
2384   // the result variable decls.
2385   auto resultTypes = llvm::map_range(
2386       results, [&](const auto *result) { return result->getType(); });
2387   auto resultNames = llvm::map_range(
2388       results, [&](const auto *result) { return result->getName().getName(); });
2389   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2390                              llvm::to_vector(resultNames));
2391 }
2392 
2393 template <typename T>
2394 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2395     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2396     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2397     ast::CompoundStmt *body) {
2398   if (!body->getChildren().empty()) {
2399     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2400       ast::Expr *resultExpr = retStmt->getResultExpr();
2401 
2402       // Process the result of the decl. If no explicit signature results
2403       // were provided, check for return type inference. Otherwise, check that
2404       // the return expression can be converted to the expected type.
2405       if (results.empty())
2406         resultType = resultExpr->getType();
2407       else if (failed(convertExpressionTo(resultExpr, resultType)))
2408         return failure();
2409       else
2410         retStmt->setResultExpr(resultExpr);
2411     }
2412   }
2413   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2414 }
2415 
2416 FailureOr<ast::VariableDecl *>
2417 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2418                            ArrayRef<ast::ConstraintRef> constraints) {
2419   // The type of the variable, which is expected to be inferred by either a
2420   // constraint or an initializer expression.
2421   ast::Type type;
2422   if (failed(validateVariableConstraints(constraints, type)))
2423     return failure();
2424 
2425   if (initializer) {
2426     // Update the variable type based on the initializer, or try to convert the
2427     // initializer to the existing type.
2428     if (!type)
2429       type = initializer->getType();
2430     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2431       type = mergedType;
2432     else if (failed(convertExpressionTo(initializer, type)))
2433       return failure();
2434 
2435     // Otherwise, if there is no initializer check that the type has already
2436     // been resolved from the constraint list.
2437   } else if (!type) {
2438     return emitErrorAndNote(
2439         loc, "unable to infer type for variable `" + name + "`", loc,
2440         "the type of a variable must be inferable from the constraint "
2441         "list or the initializer");
2442   }
2443 
2444   // Constraint types cannot be used when defining variables.
2445   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2446     return emitError(
2447         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2448   }
2449 
2450   // Try to define a variable with the given name.
2451   FailureOr<ast::VariableDecl *> varDecl =
2452       defineVariableDecl(name, loc, type, initializer, constraints);
2453   if (failed(varDecl))
2454     return failure();
2455 
2456   return *varDecl;
2457 }
2458 
2459 FailureOr<ast::VariableDecl *>
2460 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2461                                       const ast::ConstraintRef &constraint) {
2462   // Constraint arguments may apply more complex constraints via the arguments.
2463   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2464   ast::Type argType;
2465   if (failed(validateVariableConstraint(constraint, argType,
2466                                         allowNonCoreConstraints)))
2467     return failure();
2468   return defineVariableDecl(name, loc, argType, constraint);
2469 }
2470 
2471 LogicalResult
2472 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2473                                     ast::Type &inferredType,
2474                                     bool allowNonCoreConstraints) {
2475   for (const ast::ConstraintRef &ref : constraints)
2476     if (failed(validateVariableConstraint(ref, inferredType,
2477                                           allowNonCoreConstraints)))
2478       return failure();
2479   return success();
2480 }
2481 
2482 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2483                                                  ast::Type &inferredType,
2484                                                  bool allowNonCoreConstraints) {
2485   ast::Type constraintType;
2486   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2487     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2488       if (failed(validateTypeConstraintExpr(typeExpr)))
2489         return failure();
2490     }
2491     constraintType = ast::AttributeType::get(ctx);
2492   } else if (const auto *cst =
2493                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2494     constraintType = ast::OperationType::get(ctx, cst->getName());
2495   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2496     constraintType = typeTy;
2497   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2498     constraintType = typeRangeTy;
2499   } else if (const auto *cst =
2500                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2501     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2502       if (failed(validateTypeConstraintExpr(typeExpr)))
2503         return failure();
2504     }
2505     constraintType = valueTy;
2506   } else if (const auto *cst =
2507                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2508     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2509       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2510         return failure();
2511     }
2512     constraintType = valueRangeTy;
2513   } else if (const auto *cst =
2514                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2515     if (!allowNonCoreConstraints) {
2516       return emitError(ref.referenceLoc,
2517                        "`Rewrite` arguments and results are only permitted to "
2518                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2519                        "`TypeRange`, `Value`, `ValueRange`");
2520     }
2521 
2522     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2523     if (inputs.size() != 1) {
2524       return emitErrorAndNote(ref.referenceLoc,
2525                               "`Constraint`s applied via a variable constraint "
2526                               "list must take a single input, but got " +
2527                                   Twine(inputs.size()),
2528                               cst->getLoc(),
2529                               "see definition of constraint here");
2530     }
2531     constraintType = inputs.front()->getType();
2532   } else {
2533     llvm_unreachable("unknown constraint type");
2534   }
2535 
2536   // Check that the constraint type is compatible with the current inferred
2537   // type.
2538   if (!inferredType) {
2539     inferredType = constraintType;
2540   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2541     inferredType = mergedTy;
2542   } else {
2543     return emitError(ref.referenceLoc,
2544                      llvm::formatv("constraint type `{0}` is incompatible "
2545                                    "with the previously inferred type `{1}`",
2546                                    constraintType, inferredType));
2547   }
2548   return success();
2549 }
2550 
2551 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2552   ast::Type typeExprType = typeExpr->getType();
2553   if (typeExprType != typeTy) {
2554     return emitError(typeExpr->getLoc(),
2555                      "expected expression of `Type` in type constraint");
2556   }
2557   return success();
2558 }
2559 
2560 LogicalResult
2561 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2562   ast::Type typeExprType = typeExpr->getType();
2563   if (typeExprType != typeRangeTy) {
2564     return emitError(typeExpr->getLoc(),
2565                      "expected expression of `TypeRange` in type constraint");
2566   }
2567   return success();
2568 }
2569 
2570 //===----------------------------------------------------------------------===//
2571 // Exprs
2572 
2573 FailureOr<ast::CallExpr *>
2574 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2575                        MutableArrayRef<ast::Expr *> arguments) {
2576   ast::Type parentType = parentExpr->getType();
2577 
2578   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2579   if (!callableDecl) {
2580     return emitError(loc,
2581                      llvm::formatv("expected a reference to a callable "
2582                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2583                                    parentType));
2584   }
2585   if (parserContext == ParserContext::Rewrite) {
2586     if (isa<ast::UserConstraintDecl>(callableDecl))
2587       return emitError(
2588           loc, "unable to invoke `Constraint` within a rewrite section");
2589   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2590     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2591   }
2592 
2593   // Verify the arguments of the call.
2594   /// Handle size mismatch.
2595   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2596   if (callArgs.size() != arguments.size()) {
2597     return emitErrorAndNote(
2598         loc,
2599         llvm::formatv("invalid number of arguments for {0} call; expected "
2600                       "{1}, but got {2}",
2601                       callableDecl->getCallableType(), callArgs.size(),
2602                       arguments.size()),
2603         callableDecl->getLoc(),
2604         llvm::formatv("see the definition of {0} here",
2605                       callableDecl->getName()->getName()));
2606   }
2607 
2608   /// Handle argument type mismatch.
2609   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2610     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2611                                   callableDecl->getName()->getName()),
2612                     callableDecl->getLoc());
2613   };
2614   for (auto it : llvm::zip(callArgs, arguments)) {
2615     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2616                                    attachDiagFn)))
2617       return failure();
2618   }
2619 
2620   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2621                                callableDecl->getResultType());
2622 }
2623 
2624 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2625                                                         ast::Decl *decl) {
2626   // Check the type of decl being referenced.
2627   ast::Type declType;
2628   if (isa<ast::ConstraintDecl>(decl))
2629     declType = ast::ConstraintType::get(ctx);
2630   else if (isa<ast::UserRewriteDecl>(decl))
2631     declType = ast::RewriteType::get(ctx);
2632   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2633     declType = varDecl->getType();
2634   else
2635     return emitError(loc, "invalid reference to `" +
2636                               decl->getName()->getName() + "`");
2637 
2638   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2639 }
2640 
2641 FailureOr<ast::DeclRefExpr *>
2642 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2643                                  ArrayRef<ast::ConstraintRef> constraints) {
2644   FailureOr<ast::VariableDecl *> decl =
2645       defineVariableDecl(name, loc, type, constraints);
2646   if (failed(decl))
2647     return failure();
2648   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2649 }
2650 
2651 FailureOr<ast::MemberAccessExpr *>
2652 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2653                                SMRange loc) {
2654   // Validate the member name for the given parent expression.
2655   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2656   if (failed(memberType))
2657     return failure();
2658 
2659   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2660 }
2661 
2662 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2663                                                   StringRef name, SMRange loc) {
2664   ast::Type parentType = parentExpr->getType();
2665   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2666     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2667       return valueRangeTy;
2668 
2669     // Verify member access based on the operation type.
2670     if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
2671       auto results = odsOp->getResults();
2672 
2673       // Handle indexed results.
2674       unsigned index = 0;
2675       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2676           index < results.size()) {
2677         return results[index].isVariadic() ? valueRangeTy : valueTy;
2678       }
2679 
2680       // Handle named results.
2681       const auto *it = llvm::find_if(results, [&](const auto &result) {
2682         return result.getName() == name;
2683       });
2684       if (it != results.end())
2685         return it->isVariadic() ? valueRangeTy : valueTy;
2686     }
2687 
2688   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2689     // Handle indexed results.
2690     unsigned index = 0;
2691     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2692         index < tupleType.size()) {
2693       return tupleType.getElementTypes()[index];
2694     }
2695 
2696     // Handle named results.
2697     auto elementNames = tupleType.getElementNames();
2698     const auto *it = llvm::find(elementNames, name);
2699     if (it != elementNames.end())
2700       return tupleType.getElementTypes()[it - elementNames.begin()];
2701   }
2702   return emitError(
2703       loc,
2704       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2705                     name, parentType));
2706 }
2707 
2708 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2709     SMRange loc, const ast::OpNameDecl *name,
2710     MutableArrayRef<ast::Expr *> operands,
2711     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2712     MutableArrayRef<ast::Expr *> results) {
2713   Optional<StringRef> opNameRef = name->getName();
2714   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2715 
2716   // Verify the inputs operands.
2717   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2718     return failure();
2719 
2720   // Verify the attribute list.
2721   for (ast::NamedAttributeDecl *attr : attributes) {
2722     // Check for an attribute type, or a type awaiting resolution.
2723     ast::Type attrType = attr->getValue()->getType();
2724     if (!attrType.isa<ast::AttributeType>()) {
2725       return emitError(
2726           attr->getValue()->getLoc(),
2727           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2728     }
2729   }
2730 
2731   // Verify the result types.
2732   if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2733     return failure();
2734 
2735   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2736                                     attributes);
2737 }
2738 
2739 LogicalResult
2740 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2741                                   const ods::Operation *odsOp,
2742                                   MutableArrayRef<ast::Expr *> operands) {
2743   return validateOperationOperandsOrResults(
2744       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2745       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2746       valueRangeTy);
2747 }
2748 
2749 LogicalResult
2750 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2751                                  const ods::Operation *odsOp,
2752                                  MutableArrayRef<ast::Expr *> results) {
2753   return validateOperationOperandsOrResults(
2754       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2755       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2756 }
2757 
2758 LogicalResult Parser::validateOperationOperandsOrResults(
2759     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2760     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2761     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2762     ast::Type rangeTy) {
2763   // All operation types accept a single range parameter.
2764   if (values.size() == 1) {
2765     if (failed(convertExpressionTo(values[0], rangeTy)))
2766       return failure();
2767     return success();
2768   }
2769 
2770   /// If the operation has ODS information, we can more accurately verify the
2771   /// values.
2772   if (odsOpLoc) {
2773     if (odsValues.size() != values.size()) {
2774       return emitErrorAndNote(
2775           loc,
2776           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2777                         "{2}, but got {3}",
2778                         groupName, *name, odsValues.size(), values.size()),
2779           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2780     }
2781     auto diagFn = [&](ast::Diagnostic &diag) {
2782       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2783                       *odsOpLoc);
2784     };
2785     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2786       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2787       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2788         return failure();
2789     }
2790     return success();
2791   }
2792 
2793   // Otherwise, accept the value groups as they have been defined and just
2794   // ensure they are one of the expected types.
2795   for (ast::Expr *&valueExpr : values) {
2796     ast::Type valueExprType = valueExpr->getType();
2797 
2798     // Check if this is one of the expected types.
2799     if (valueExprType == rangeTy || valueExprType == singleTy)
2800       continue;
2801 
2802     // If the operand is an Operation, allow converting to a Value or
2803     // ValueRange. This situations arises quite often with nested operation
2804     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2805     if (singleTy == valueTy) {
2806       if (valueExprType.isa<ast::OperationType>()) {
2807         valueExpr = convertOpToValue(valueExpr);
2808         continue;
2809       }
2810     }
2811 
2812     return emitError(
2813         valueExpr->getLoc(),
2814         llvm::formatv(
2815             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2816             singleTy, rangeTy, valueExprType));
2817   }
2818   return success();
2819 }
2820 
2821 FailureOr<ast::TupleExpr *>
2822 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2823                         ArrayRef<StringRef> elementNames) {
2824   for (const ast::Expr *element : elements) {
2825     ast::Type eleTy = element->getType();
2826     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2827       return emitError(
2828           element->getLoc(),
2829           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2830     }
2831   }
2832   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2833 }
2834 
2835 //===----------------------------------------------------------------------===//
2836 // Stmts
2837 
2838 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2839                                                     ast::Expr *rootOp) {
2840   // Check that root is an Operation.
2841   ast::Type rootType = rootOp->getType();
2842   if (!rootType.isa<ast::OperationType>())
2843     return emitError(rootOp->getLoc(), "expected `Op` expression");
2844 
2845   return ast::EraseStmt::create(ctx, loc, rootOp);
2846 }
2847 
2848 FailureOr<ast::ReplaceStmt *>
2849 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2850                           MutableArrayRef<ast::Expr *> replValues) {
2851   // Check that root is an Operation.
2852   ast::Type rootType = rootOp->getType();
2853   if (!rootType.isa<ast::OperationType>()) {
2854     return emitError(
2855         rootOp->getLoc(),
2856         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2857   }
2858 
2859   // If there are multiple replacement values, we implicitly convert any Op
2860   // expressions to the value form.
2861   bool shouldConvertOpToValues = replValues.size() > 1;
2862   for (ast::Expr *&replExpr : replValues) {
2863     ast::Type replType = replExpr->getType();
2864 
2865     // Check that replExpr is an Operation, Value, or ValueRange.
2866     if (replType.isa<ast::OperationType>()) {
2867       if (shouldConvertOpToValues)
2868         replExpr = convertOpToValue(replExpr);
2869       continue;
2870     }
2871 
2872     if (replType != valueTy && replType != valueRangeTy) {
2873       return emitError(replExpr->getLoc(),
2874                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2875                                      "expression, but got `{0}`",
2876                                      replType));
2877     }
2878   }
2879 
2880   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2881 }
2882 
2883 FailureOr<ast::RewriteStmt *>
2884 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2885                           ast::CompoundStmt *rewriteBody) {
2886   // Check that root is an Operation.
2887   ast::Type rootType = rootOp->getType();
2888   if (!rootType.isa<ast::OperationType>()) {
2889     return emitError(
2890         rootOp->getLoc(),
2891         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2892   }
2893 
2894   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2895 }
2896 
2897 //===----------------------------------------------------------------------===//
2898 // Code Completion
2899 //===----------------------------------------------------------------------===//
2900 
2901 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
2902   ast::Type parentType = parentExpr->getType();
2903   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
2904     codeCompleteContext->codeCompleteOperationMemberAccess(opType);
2905   else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
2906     codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
2907   return failure();
2908 }
2909 
2910 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
2911   if (opName)
2912     codeCompleteContext->codeCompleteOperationAttributeName(*opName);
2913   return failure();
2914 }
2915 
2916 LogicalResult
2917 Parser::codeCompleteConstraintName(ast::Type inferredType,
2918                                    bool allowNonCoreConstraints,
2919                                    bool allowInlineTypeConstraints) {
2920   codeCompleteContext->codeCompleteConstraintName(
2921       inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
2922       curDeclScope);
2923   return failure();
2924 }
2925 
2926 LogicalResult Parser::codeCompleteDialectName() {
2927   codeCompleteContext->codeCompleteDialectName();
2928   return failure();
2929 }
2930 
2931 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
2932   codeCompleteContext->codeCompleteOperationName(dialectName);
2933   return failure();
2934 }
2935 
2936 LogicalResult Parser::codeCompletePatternMetadata() {
2937   codeCompleteContext->codeCompletePatternMetadata();
2938   return failure();
2939 }
2940 
2941 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
2942   codeCompleteContext->codeCompleteIncludeFilename(curPath);
2943   return failure();
2944 }
2945 
2946 void Parser::codeCompleteCallSignature(ast::Node *parent,
2947                                        unsigned currentNumArgs) {
2948   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
2949   if (!callableDecl)
2950     return;
2951 
2952   codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
2953 }
2954 
2955 void Parser::codeCompleteOperationOperandsSignature(
2956     Optional<StringRef> opName, unsigned currentNumOperands) {
2957   codeCompleteContext->codeCompleteOperationOperandsSignature(
2958       opName, currentNumOperands);
2959 }
2960 
2961 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
2962                                                    unsigned currentNumResults) {
2963   codeCompleteContext->codeCompleteOperationResultsSignature(opName,
2964                                                              currentNumResults);
2965 }
2966 
2967 //===----------------------------------------------------------------------===//
2968 // Parser
2969 //===----------------------------------------------------------------------===//
2970 
2971 FailureOr<ast::Module *>
2972 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
2973                         CodeCompleteContext *codeCompleteContext) {
2974   Parser parser(ctx, sourceMgr, codeCompleteContext);
2975   return parser.parseModule();
2976 }
2977