1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/TableGen/Argument.h"
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Constraint.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/AST/Types.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Operation.h"
24 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <string>
34 
35 using namespace mlir;
36 using namespace mlir::pdll;
37 
38 //===----------------------------------------------------------------------===//
39 // Parser
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 class Parser {
44 public:
45   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
46          CodeCompleteContext *codeCompleteContext)
47       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
48         curToken(lexer.lexToken()), valueTy(ast::ValueType::get(ctx)),
49         valueRangeTy(ast::ValueRangeType::get(ctx)),
50         typeTy(ast::TypeType::get(ctx)),
51         typeRangeTy(ast::TypeRangeType::get(ctx)),
52         attrTy(ast::AttributeType::get(ctx)),
53         codeCompleteContext(codeCompleteContext) {}
54 
55   /// Try to parse a new module. Returns nullptr in the case of failure.
56   FailureOr<ast::Module *> parseModule();
57 
58 private:
59   /// The current context of the parser. It allows for the parser to know a bit
60   /// about the construct it is nested within during parsing. This is used
61   /// specifically to provide additional verification during parsing, e.g. to
62   /// prevent using rewrites within a match context, matcher constraints within
63   /// a rewrite section, etc.
64   enum class ParserContext {
65     /// The parser is in the global context.
66     Global,
67     /// The parser is currently within a Constraint, which disallows all types
68     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
69     Constraint,
70     /// The parser is currently within the matcher portion of a Pattern, which
71     /// is allows a terminal operation rewrite statement but no other rewrite
72     /// transformations.
73     PatternMatch,
74     /// The parser is currently within a Rewrite, which disallows calls to
75     /// constraints, requires operation expressions to have names, etc.
76     Rewrite,
77   };
78 
79   /// The current specification context of an operations result type. This
80   /// indicates how the result types of an operation may be inferred.
81   enum class OpResultTypeContext {
82     /// The result types of the operation are not known to be inferred.
83     Explicit,
84     /// The result types of the operation are inferred from the root input of a
85     /// `replace` statement.
86     Replacement,
87     /// The result types of the operation are inferred by using the
88     /// `InferTypeOpInterface` interface provided by the operation.
89     Interface,
90   };
91 
92   //===--------------------------------------------------------------------===//
93   // Parsing
94   //===--------------------------------------------------------------------===//
95 
96   /// Push a new decl scope onto the lexer.
97   ast::DeclScope *pushDeclScope() {
98     ast::DeclScope *newScope =
99         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
100     return (curDeclScope = newScope);
101   }
102   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
103 
104   /// Pop the last decl scope from the lexer.
105   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
106 
107   /// Parse the body of an AST module.
108   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
109 
110   /// Try to convert the given expression to `type`. Returns failure and emits
111   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
112   /// invoked to attach notes to the emitted error diagnostic. On success,
113   /// `expr` is updated to the expression used to convert to `type`.
114   LogicalResult convertExpressionTo(
115       ast::Expr *&expr, ast::Type type,
116       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
117 
118   /// Given an operation expression, convert it to a Value or ValueRange
119   /// typed expression.
120   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
121 
122   /// Lookup ODS information for the given operation, returns nullptr if no
123   /// information is found.
124   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
125     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
126   }
127 
128   //===--------------------------------------------------------------------===//
129   // Directives
130 
131   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
132   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
133   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
134                                SmallVectorImpl<ast::Decl *> &decls);
135 
136   /// Process the records of a parsed tablegen include file.
137   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
138                                SmallVectorImpl<ast::Decl *> &decls);
139 
140   /// Create a user defined native constraint for a constraint imported from
141   /// ODS.
142   template <typename ConstraintT>
143   ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
144                                                StringRef codeBlock, SMRange loc,
145                                                ast::Type type);
146   template <typename ConstraintT>
147   ast::Decl *
148   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
149                                     SMRange loc, ast::Type type);
150 
151   //===--------------------------------------------------------------------===//
152   // Decls
153 
154   /// This structure contains the set of pattern metadata that may be parsed.
155   struct ParsedPatternMetadata {
156     Optional<uint16_t> benefit;
157     bool hasBoundedRecursion = false;
158   };
159 
160   FailureOr<ast::Decl *> parseTopLevelDecl();
161   FailureOr<ast::NamedAttributeDecl *>
162   parseNamedAttributeDecl(Optional<StringRef> parentOpName);
163 
164   /// Parse an argument variable as part of the signature of a
165   /// UserConstraintDecl or UserRewriteDecl.
166   FailureOr<ast::VariableDecl *> parseArgumentDecl();
167 
168   /// Parse a result variable as part of the signature of a UserConstraintDecl
169   /// or UserRewriteDecl.
170   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
171 
172   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
173   /// defined in a non-global context.
174   FailureOr<ast::UserConstraintDecl *>
175   parseUserConstraintDecl(bool isInline = false);
176 
177   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
178   /// non-global context, such as within a Pattern/Constraint/etc.
179   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
180 
181   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
182   /// PDLL constructs.
183   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
184       const ast::Name &name, bool isInline,
185       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
186       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
187 
188   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
189   /// defined in a non-global context.
190   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
191 
192   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
193   /// non-global context, such as within a Pattern/Rewrite/etc.
194   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
195 
196   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
197   /// PDLL constructs.
198   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
199       const ast::Name &name, bool isInline,
200       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
201       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
202 
203   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
204   /// effectively the same syntax, and only differ on slight semantics (given
205   /// the different parsing contexts).
206   template <typename T, typename ParseUserPDLLDeclFnT>
207   FailureOr<T *> parseUserConstraintOrRewriteDecl(
208       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
209       StringRef anonymousNamePrefix, bool isInline);
210 
211   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
212   /// These decls have effectively the same syntax.
213   template <typename T>
214   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
215       const ast::Name &name, bool isInline,
216       ArrayRef<ast::VariableDecl *> arguments,
217       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218 
219   /// Parse the functional signature (i.e. the arguments and results) of a
220   /// UserConstraintDecl or UserRewriteDecl.
221   LogicalResult parseUserConstraintOrRewriteSignature(
222       SmallVectorImpl<ast::VariableDecl *> &arguments,
223       SmallVectorImpl<ast::VariableDecl *> &results,
224       ast::DeclScope *&argumentScope, ast::Type &resultType);
225 
226   /// Validate the return (which if present is specified by bodyIt) of a
227   /// UserConstraintDecl or UserRewriteDecl.
228   LogicalResult validateUserConstraintOrRewriteReturn(
229       StringRef declType, ast::CompoundStmt *body,
230       ArrayRef<ast::Stmt *>::iterator bodyIt,
231       ArrayRef<ast::Stmt *>::iterator bodyE,
232       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
233 
234   FailureOr<ast::CompoundStmt *>
235   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
236                   bool expectTerminalSemicolon = true);
237   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
238   FailureOr<ast::Decl *> parsePatternDecl();
239   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
240 
241   /// Check to see if a decl has already been defined with the given name, if
242   /// one has emit and error and return failure. Returns success otherwise.
243   LogicalResult checkDefineNamedDecl(const ast::Name &name);
244 
245   /// Try to define a variable decl with the given components, returns the
246   /// variable on success.
247   FailureOr<ast::VariableDecl *>
248   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
249                      ast::Expr *initExpr,
250                      ArrayRef<ast::ConstraintRef> constraints);
251   FailureOr<ast::VariableDecl *>
252   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
253                      ArrayRef<ast::ConstraintRef> constraints);
254 
255   /// Parse the constraint reference list for a variable decl.
256   LogicalResult parseVariableDeclConstraintList(
257       SmallVectorImpl<ast::ConstraintRef> &constraints);
258 
259   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
260   FailureOr<ast::Expr *> parseTypeConstraintExpr();
261 
262   /// Try to parse a single reference to a constraint. `typeConstraint` is the
263   /// location of a previously parsed type constraint for the entity that will
264   /// be constrained by the parsed constraint. `existingConstraints` are any
265   /// existing constraints that have already been parsed for the same entity
266   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
267   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
268   /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
269   /// constraints) may be used with the variable.
270   FailureOr<ast::ConstraintRef>
271   parseConstraint(Optional<SMRange> &typeConstraint,
272                   ArrayRef<ast::ConstraintRef> existingConstraints,
273                   bool allowInlineTypeConstraints,
274                   bool allowNonCoreConstraints);
275 
276   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
277   /// argument or result variable. The constraints for these variables do not
278   /// allow inline type constraints, and only permit a single constraint.
279   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
280 
281   //===--------------------------------------------------------------------===//
282   // Exprs
283 
284   FailureOr<ast::Expr *> parseExpr();
285 
286   /// Identifier expressions.
287   FailureOr<ast::Expr *> parseAttributeExpr();
288   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
289   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
290   FailureOr<ast::Expr *> parseIdentifierExpr();
291   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
292   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
293   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
294   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
295   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
296   FailureOr<ast::Expr *>
297   parseOperationExpr(OpResultTypeContext inputResultTypeContext =
298                          OpResultTypeContext::Explicit);
299   FailureOr<ast::Expr *> parseTupleExpr();
300   FailureOr<ast::Expr *> parseTypeExpr();
301   FailureOr<ast::Expr *> parseUnderscoreExpr();
302 
303   //===--------------------------------------------------------------------===//
304   // Stmts
305 
306   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
307   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
308   FailureOr<ast::EraseStmt *> parseEraseStmt();
309   FailureOr<ast::LetStmt *> parseLetStmt();
310   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
311   FailureOr<ast::ReturnStmt *> parseReturnStmt();
312   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
313 
314   //===--------------------------------------------------------------------===//
315   // Creation+Analysis
316   //===--------------------------------------------------------------------===//
317 
318   //===--------------------------------------------------------------------===//
319   // Decls
320 
321   /// Try to extract a callable from the given AST node. Returns nullptr on
322   /// failure.
323   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
324 
325   /// Try to create a pattern decl with the given components, returning the
326   /// Pattern on success.
327   FailureOr<ast::PatternDecl *>
328   createPatternDecl(SMRange loc, const ast::Name *name,
329                     const ParsedPatternMetadata &metadata,
330                     ast::CompoundStmt *body);
331 
332   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
333   /// of results, defined as part of the signature.
334   ast::Type
335   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
336 
337   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
338   template <typename T>
339   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
340       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
341       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
342       ast::CompoundStmt *body);
343 
344   /// Try to create a variable decl with the given components, returning the
345   /// Variable on success.
346   FailureOr<ast::VariableDecl *>
347   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
348                      ArrayRef<ast::ConstraintRef> constraints);
349 
350   /// Create a variable for an argument or result defined as part of the
351   /// signature of a UserConstraintDecl/UserRewriteDecl.
352   FailureOr<ast::VariableDecl *>
353   createArgOrResultVariableDecl(StringRef name, SMRange loc,
354                                 const ast::ConstraintRef &constraint);
355 
356   /// Validate the constraints used to constraint a variable decl.
357   /// `inferredType` is the type of the variable inferred by the constraints
358   /// within the list, and is updated to the most refined type as determined by
359   /// the constraints. Returns success if the constraint list is valid, failure
360   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
361   /// defined constraints) may be used with the variable.
362   LogicalResult
363   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
364                               ast::Type &inferredType,
365                               bool allowNonCoreConstraints = true);
366   /// Validate a single reference to a constraint. `inferredType` contains the
367   /// currently inferred variabled type and is refined within the type defined
368   /// by the constraint. Returns success if the constraint is valid, failure
369   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
370   /// defined constraints) may be used with the variable.
371   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
372                                            ast::Type &inferredType,
373                                            bool allowNonCoreConstraints = true);
374   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
375   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
376 
377   //===--------------------------------------------------------------------===//
378   // Exprs
379 
380   FailureOr<ast::CallExpr *>
381   createCallExpr(SMRange loc, ast::Expr *parentExpr,
382                  MutableArrayRef<ast::Expr *> arguments);
383   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
384   FailureOr<ast::DeclRefExpr *>
385   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
386                            ArrayRef<ast::ConstraintRef> constraints);
387   FailureOr<ast::MemberAccessExpr *>
388   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
389 
390   /// Validate the member access `name` into the given parent expression. On
391   /// success, this also returns the type of the member accessed.
392   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
393                                             StringRef name, SMRange loc);
394   FailureOr<ast::OperationExpr *>
395   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
396                       OpResultTypeContext resultTypeContext,
397                       MutableArrayRef<ast::Expr *> operands,
398                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
399                       MutableArrayRef<ast::Expr *> results);
400   LogicalResult
401   validateOperationOperands(SMRange loc, Optional<StringRef> name,
402                             const ods::Operation *odsOp,
403                             MutableArrayRef<ast::Expr *> operands);
404   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
405                                          const ods::Operation *odsOp,
406                                          MutableArrayRef<ast::Expr *> results);
407   void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
408                                           const ods::Operation *odsOp);
409   LogicalResult validateOperationOperandsOrResults(
410       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
411       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
412       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
413       ast::Type rangeTy);
414   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
415                                               ArrayRef<ast::Expr *> elements,
416                                               ArrayRef<StringRef> elementNames);
417 
418   //===--------------------------------------------------------------------===//
419   // Stmts
420 
421   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
422   FailureOr<ast::ReplaceStmt *>
423   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
424                     MutableArrayRef<ast::Expr *> replValues);
425   FailureOr<ast::RewriteStmt *>
426   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
427                     ast::CompoundStmt *rewriteBody);
428 
429   //===--------------------------------------------------------------------===//
430   // Code Completion
431   //===--------------------------------------------------------------------===//
432 
433   /// The set of various code completion methods. Every completion method
434   /// returns `failure` to stop the parsing process after providing completion
435   /// results.
436 
437   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
438   LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
439   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
440                                            bool allowNonCoreConstraints,
441                                            bool allowInlineTypeConstraints);
442   LogicalResult codeCompleteDialectName();
443   LogicalResult codeCompleteOperationName(StringRef dialectName);
444   LogicalResult codeCompletePatternMetadata();
445   LogicalResult codeCompleteIncludeFilename(StringRef curPath);
446 
447   void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
448   void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
449                                               unsigned currentNumOperands);
450   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
451                                              unsigned currentNumResults);
452 
453   //===--------------------------------------------------------------------===//
454   // Lexer Utilities
455   //===--------------------------------------------------------------------===//
456 
457   /// If the current token has the specified kind, consume it and return true.
458   /// If not, return false.
459   bool consumeIf(Token::Kind kind) {
460     if (curToken.isNot(kind))
461       return false;
462     consumeToken(kind);
463     return true;
464   }
465 
466   /// Advance the current lexer onto the next token.
467   void consumeToken() {
468     assert(curToken.isNot(Token::eof, Token::error) &&
469            "shouldn't advance past EOF or errors");
470     curToken = lexer.lexToken();
471   }
472 
473   /// Advance the current lexer onto the next token, asserting what the expected
474   /// current token is. This is preferred to the above method because it leads
475   /// to more self-documenting code with better checking.
476   void consumeToken(Token::Kind kind) {
477     assert(curToken.is(kind) && "consumed an unexpected token");
478     consumeToken();
479   }
480 
481   /// Reset the lexer to the location at the given position.
482   void resetToken(SMRange tokLoc) {
483     lexer.resetPointer(tokLoc.Start.getPointer());
484     curToken = lexer.lexToken();
485   }
486 
487   /// Consume the specified token if present and return success. On failure,
488   /// output a diagnostic and return failure.
489   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
490     if (curToken.getKind() != kind)
491       return emitError(curToken.getLoc(), msg);
492     consumeToken();
493     return success();
494   }
495   LogicalResult emitError(SMRange loc, const Twine &msg) {
496     lexer.emitError(loc, msg);
497     return failure();
498   }
499   LogicalResult emitError(const Twine &msg) {
500     return emitError(curToken.getLoc(), msg);
501   }
502   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
503                                  const Twine &note) {
504     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
505     return failure();
506   }
507 
508   //===--------------------------------------------------------------------===//
509   // Fields
510   //===--------------------------------------------------------------------===//
511 
512   /// The owning AST context.
513   ast::Context &ctx;
514 
515   /// The lexer of this parser.
516   Lexer lexer;
517 
518   /// The current token within the lexer.
519   Token curToken;
520 
521   /// The most recently defined decl scope.
522   ast::DeclScope *curDeclScope = nullptr;
523   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
524 
525   /// The current context of the parser.
526   ParserContext parserContext = ParserContext::Global;
527 
528   /// Cached types to simplify verification and expression creation.
529   ast::Type valueTy, valueRangeTy;
530   ast::Type typeTy, typeRangeTy;
531   ast::Type attrTy;
532 
533   /// A counter used when naming anonymous constraints and rewrites.
534   unsigned anonymousDeclNameCounter = 0;
535 
536   /// The optional code completion context.
537   CodeCompleteContext *codeCompleteContext;
538 };
539 } // namespace
540 
541 FailureOr<ast::Module *> Parser::parseModule() {
542   SMLoc moduleLoc = curToken.getStartLoc();
543   pushDeclScope();
544 
545   // Parse the top-level decls of the module.
546   SmallVector<ast::Decl *> decls;
547   if (failed(parseModuleBody(decls)))
548     return popDeclScope(), failure();
549 
550   popDeclScope();
551   return ast::Module::create(ctx, moduleLoc, decls);
552 }
553 
554 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
555   while (curToken.isNot(Token::eof)) {
556     if (curToken.is(Token::directive)) {
557       if (failed(parseDirective(decls)))
558         return failure();
559       continue;
560     }
561 
562     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
563     if (failed(decl))
564       return failure();
565     decls.push_back(*decl);
566   }
567   return success();
568 }
569 
570 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
571   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
572                                                  valueRangeTy);
573 }
574 
575 LogicalResult Parser::convertExpressionTo(
576     ast::Expr *&expr, ast::Type type,
577     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
578   ast::Type exprType = expr->getType();
579   if (exprType == type)
580     return success();
581 
582   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
583     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
584         expr->getLoc(), llvm::formatv("unable to convert expression of type "
585                                       "`{0}` to the expected type of "
586                                       "`{1}`",
587                                       exprType, type));
588     if (noteAttachFn)
589       noteAttachFn(*diag);
590     return diag;
591   };
592 
593   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
594     // Two operation types are compatible if they have the same name, or if the
595     // expected type is more general.
596     if (auto opType = type.dyn_cast<ast::OperationType>()) {
597       if (opType.getName())
598         return emitConvertError();
599       return success();
600     }
601 
602     // An operation can always convert to a ValueRange.
603     if (type == valueRangeTy) {
604       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
605                                                      valueRangeTy);
606       return success();
607     }
608 
609     // Allow conversion to a single value by constraining the result range.
610     if (type == valueTy) {
611       // If the operation is registered, we can verify if it can ever have a
612       // single result.
613       Optional<StringRef> opName = exprOpType.getName();
614       if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
615         if (odsOp->getResults().empty()) {
616           return emitConvertError()->attachNote(
617               llvm::formatv("see the definition of `{0}`, which was defined "
618                             "with zero results",
619                             odsOp->getName()),
620               odsOp->getLoc());
621         }
622 
623         unsigned numSingleResults = llvm::count_if(
624             odsOp->getResults(), [](const ods::OperandOrResult &result) {
625               return result.getVariableLengthKind() ==
626                      ods::VariableLengthKind::Single;
627             });
628         if (numSingleResults > 1) {
629           return emitConvertError()->attachNote(
630               llvm::formatv("see the definition of `{0}`, which was defined "
631                             "with at least {1} results",
632                             odsOp->getName(), numSingleResults),
633               odsOp->getLoc());
634         }
635       }
636 
637       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
638                                                      valueTy);
639       return success();
640     }
641     return emitConvertError();
642   }
643 
644   // FIXME: Decide how to allow/support converting a single result to multiple,
645   // and multiple to a single result. For now, we just allow Single->Range,
646   // but this isn't something really supported in the PDL dialect. We should
647   // figure out some way to support both.
648   if ((exprType == valueTy || exprType == valueRangeTy) &&
649       (type == valueTy || type == valueRangeTy))
650     return success();
651   if ((exprType == typeTy || exprType == typeRangeTy) &&
652       (type == typeTy || type == typeRangeTy))
653     return success();
654 
655   // Handle tuple types.
656   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
657     auto tupleType = type.dyn_cast<ast::TupleType>();
658     if (!tupleType || tupleType.size() != exprTupleType.size())
659       return emitConvertError();
660 
661     // Build a new tuple expression using each of the elements of the current
662     // tuple.
663     SmallVector<ast::Expr *> newExprs;
664     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
665       newExprs.push_back(ast::MemberAccessExpr::create(
666           ctx, expr->getLoc(), expr, llvm::to_string(i),
667           exprTupleType.getElementTypes()[i]));
668 
669       auto diagFn = [&](ast::Diagnostic &diag) {
670         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
671                                       i, exprTupleType));
672         if (noteAttachFn)
673           noteAttachFn(diag);
674       };
675       if (failed(convertExpressionTo(newExprs.back(),
676                                      tupleType.getElementTypes()[i], diagFn)))
677         return failure();
678     }
679     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
680                                   tupleType.getElementNames());
681     return success();
682   }
683 
684   return emitConvertError();
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // Directives
689 
690 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
691   StringRef directive = curToken.getSpelling();
692   if (directive == "#include")
693     return parseInclude(decls);
694 
695   return emitError("unknown directive `" + directive + "`");
696 }
697 
698 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
699   SMRange loc = curToken.getLoc();
700   consumeToken(Token::directive);
701 
702   // Handle code completion of the include file path.
703   if (curToken.is(Token::code_complete_string))
704     return codeCompleteIncludeFilename(curToken.getStringValue());
705 
706   // Parse the file being included.
707   if (!curToken.isString())
708     return emitError(loc,
709                      "expected string file name after `include` directive");
710   SMRange fileLoc = curToken.getLoc();
711   std::string filenameStr = curToken.getStringValue();
712   StringRef filename = filenameStr;
713   consumeToken();
714 
715   // Check the type of include. If ending with `.pdll`, this is another pdl file
716   // to be parsed along with the current module.
717   if (filename.endswith(".pdll")) {
718     if (failed(lexer.pushInclude(filename, fileLoc)))
719       return emitError(fileLoc,
720                        "unable to open include file `" + filename + "`");
721 
722     // If we added the include successfully, parse it into the current module.
723     // Make sure to update to the next token after we finish parsing the nested
724     // file.
725     curToken = lexer.lexToken();
726     LogicalResult result = parseModuleBody(decls);
727     curToken = lexer.lexToken();
728     return result;
729   }
730 
731   // Otherwise, this must be a `.td` include.
732   if (filename.endswith(".td"))
733     return parseTdInclude(filename, fileLoc, decls);
734 
735   return emitError(fileLoc,
736                    "expected include filename to end with `.pdll` or `.td`");
737 }
738 
739 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
740                                      SmallVectorImpl<ast::Decl *> &decls) {
741   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
742 
743   // Use the source manager to open the file, but don't yet add it.
744   std::string includedFile;
745   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
746       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
747   if (!includeBuffer)
748     return emitError(fileLoc, "unable to open include file `" + filename + "`");
749 
750   // Setup the source manager for parsing the tablegen file.
751   llvm::SourceMgr tdSrcMgr;
752   tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
753   tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
754 
755   // This class provides a context argument for the llvm::SourceMgr diagnostic
756   // handler.
757   struct DiagHandlerContext {
758     Parser &parser;
759     StringRef filename;
760     llvm::SMRange loc;
761   } handlerContext{*this, filename, fileLoc};
762 
763   // Set the diagnostic handler for the tablegen source manager.
764   tdSrcMgr.setDiagHandler(
765       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
766         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
767         (void)ctx->parser.emitError(
768             ctx->loc,
769             llvm::formatv("error while processing include file `{0}`: {1}",
770                           ctx->filename, diag.getMessage()));
771       },
772       &handlerContext);
773 
774   // Parse the tablegen file.
775   llvm::RecordKeeper tdRecords;
776   if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
777     return failure();
778 
779   // Process the parsed records.
780   processTdIncludeRecords(tdRecords, decls);
781 
782   // After we are done processing, move all of the tablegen source buffers to
783   // the main parser source mgr. This allows for directly using source locations
784   // from the .td files without needing to remap them.
785   parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
786   return success();
787 }
788 
789 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
790                                      SmallVectorImpl<ast::Decl *> &decls) {
791   // Return the length kind of the given value.
792   auto getLengthKind = [](const auto &value) {
793     if (value.isOptional())
794       return ods::VariableLengthKind::Optional;
795     return value.isVariadic() ? ods::VariableLengthKind::Variadic
796                               : ods::VariableLengthKind::Single;
797   };
798 
799   // Insert a type constraint into the ODS context.
800   ods::Context &odsContext = ctx.getODSContext();
801   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
802       -> const ods::TypeConstraint & {
803     return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(),
804                                            cst.constraint.getSummary(),
805                                            cst.constraint.getCPPClassName());
806   };
807   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
808     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
809   };
810 
811   // Process the parsed tablegen records to build ODS information.
812   /// Operations.
813   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
814     tblgen::Operator op(def);
815 
816     // Check to see if this operation is known to support type inferrence.
817     bool supportsResultTypeInferrence =
818         op.getTrait("::mlir::InferTypeOpInterface::Trait");
819 
820     bool inserted = false;
821     ods::Operation *odsOp = nullptr;
822     std::tie(odsOp, inserted) = odsContext.insertOperation(
823         op.getOperationName(), op.getSummary(), op.getDescription(),
824         supportsResultTypeInferrence, op.getLoc().front());
825 
826     // Ignore operations that have already been added.
827     if (!inserted)
828       continue;
829 
830     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
831       odsOp->appendAttribute(
832           attr.name, attr.attr.isOptional(),
833           odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(),
834                                                attr.attr.getSummary(),
835                                                attr.attr.getStorageType()));
836     }
837     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
838       odsOp->appendOperand(operand.name, getLengthKind(operand),
839                            addTypeConstraint(operand));
840     }
841     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
842       odsOp->appendResult(result.name, getLengthKind(result),
843                           addTypeConstraint(result));
844     }
845   }
846   /// Attr constraints.
847   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
848     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
849       decls.push_back(
850           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
851               tblgen::AttrConstraint(def),
852               convertLocToRange(def->getLoc().front()), attrTy));
853     }
854   }
855   /// Type constraints.
856   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
857     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
858       decls.push_back(
859           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
860               tblgen::TypeConstraint(def),
861               convertLocToRange(def->getLoc().front()), typeTy));
862     }
863   }
864   /// Interfaces.
865   ast::Type opTy = ast::OperationType::get(ctx);
866   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
867     StringRef name = def->getName();
868     if (def->isAnonymous() || curDeclScope->lookup(name) ||
869         def->isSubClassOf("DeclareInterfaceMethods"))
870       continue;
871     SMRange loc = convertLocToRange(def->getLoc().front());
872 
873     StringRef className = def->getValueAsString("cppClassName");
874     StringRef cppNamespace = def->getValueAsString("cppNamespace");
875     std::string codeBlock =
876         llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));",
877                       cppNamespace, className)
878             .str();
879 
880     if (def->isSubClassOf("OpInterface")) {
881       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
882           name, codeBlock, loc, opTy));
883     } else if (def->isSubClassOf("AttrInterface")) {
884       decls.push_back(
885           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
886               name, codeBlock, loc, attrTy));
887     } else if (def->isSubClassOf("TypeInterface")) {
888       decls.push_back(
889           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
890               name, codeBlock, loc, typeTy));
891     }
892   }
893 }
894 
895 template <typename ConstraintT>
896 ast::Decl *
897 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
898                                           SMRange loc, ast::Type type) {
899   // Build the single input parameter.
900   ast::DeclScope *argScope = pushDeclScope();
901   auto *paramVar = ast::VariableDecl::create(
902       ctx, ast::Name::create(ctx, "self", loc), type,
903       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
904   argScope->add(paramVar);
905   popDeclScope();
906 
907   // Build the native constraint.
908   auto *constraintDecl = ast::UserConstraintDecl::createNative(
909       ctx, ast::Name::create(ctx, name, loc), paramVar,
910       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
911   curDeclScope->add(constraintDecl);
912   return constraintDecl;
913 }
914 
915 template <typename ConstraintT>
916 ast::Decl *
917 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
918                                           SMRange loc, ast::Type type) {
919   // Format the condition template.
920   tblgen::FmtContext fmtContext;
921   fmtContext.withSelf("self");
922   std::string codeBlock = tblgen::tgfmt(
923       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
924       &fmtContext);
925 
926   return createODSNativePDLLConstraintDecl<ConstraintT>(
927       constraint.getUniqueDefName(), codeBlock, loc, type);
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // Decls
932 
933 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
934   FailureOr<ast::Decl *> decl;
935   switch (curToken.getKind()) {
936   case Token::kw_Constraint:
937     decl = parseUserConstraintDecl();
938     break;
939   case Token::kw_Pattern:
940     decl = parsePatternDecl();
941     break;
942   case Token::kw_Rewrite:
943     decl = parseUserRewriteDecl();
944     break;
945   default:
946     return emitError("expected top-level declaration, such as a `Pattern`");
947   }
948   if (failed(decl))
949     return failure();
950 
951   // If the decl has a name, add it to the current scope.
952   if (const ast::Name *name = (*decl)->getName()) {
953     if (failed(checkDefineNamedDecl(*name)))
954       return failure();
955     curDeclScope->add(*decl);
956   }
957   return decl;
958 }
959 
960 FailureOr<ast::NamedAttributeDecl *>
961 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
962   // Check for name code completion.
963   if (curToken.is(Token::code_complete))
964     return codeCompleteAttributeName(parentOpName);
965 
966   std::string attrNameStr;
967   if (curToken.isString())
968     attrNameStr = curToken.getStringValue();
969   else if (curToken.is(Token::identifier) || curToken.isKeyword())
970     attrNameStr = curToken.getSpelling().str();
971   else
972     return emitError("expected identifier or string attribute name");
973   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
974   consumeToken();
975 
976   // Check for a value of the attribute.
977   ast::Expr *attrValue = nullptr;
978   if (consumeIf(Token::equal)) {
979     FailureOr<ast::Expr *> attrExpr = parseExpr();
980     if (failed(attrExpr))
981       return failure();
982     attrValue = *attrExpr;
983   } else {
984     // If there isn't a concrete value, create an expression representing a
985     // UnitAttr.
986     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
987   }
988 
989   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
990 }
991 
992 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
993     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
994     bool expectTerminalSemicolon) {
995   consumeToken(Token::equal_arrow);
996 
997   // Parse the single statement of the lambda body.
998   SMLoc bodyStartLoc = curToken.getStartLoc();
999   pushDeclScope();
1000   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1001   bool failedToParse =
1002       failed(singleStatement) || failed(processStatementFn(*singleStatement));
1003   popDeclScope();
1004   if (failedToParse)
1005     return failure();
1006 
1007   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1008   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1009 }
1010 
1011 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1012   // Ensure that the argument is named.
1013   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1014     return emitError("expected identifier argument name");
1015 
1016   // Parse the argument similarly to a normal variable.
1017   StringRef name = curToken.getSpelling();
1018   SMRange nameLoc = curToken.getLoc();
1019   consumeToken();
1020 
1021   if (failed(
1022           parseToken(Token::colon, "expected `:` before argument constraint")))
1023     return failure();
1024 
1025   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1026   if (failed(cst))
1027     return failure();
1028 
1029   return createArgOrResultVariableDecl(name, nameLoc, *cst);
1030 }
1031 
1032 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1033   // Check to see if this result is named.
1034   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1035     // Check to see if this name actually refers to a Constraint.
1036     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1037     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1038       // If yes, and this is a Rewrite, give a nice error message as non-Core
1039       // constraints are not supported on Rewrite results.
1040       if (parserContext == ParserContext::Rewrite) {
1041         return emitError(
1042             "`Rewrite` results are only permitted to use core constraints, "
1043             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1044       }
1045 
1046       // Otherwise, parse this as an unnamed result variable.
1047     } else {
1048       // If it wasn't a constraint, parse the result similarly to a variable. If
1049       // there is already an existing decl, we will emit an error when defining
1050       // this variable later.
1051       StringRef name = curToken.getSpelling();
1052       SMRange nameLoc = curToken.getLoc();
1053       consumeToken();
1054 
1055       if (failed(parseToken(Token::colon,
1056                             "expected `:` before result constraint")))
1057         return failure();
1058 
1059       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1060       if (failed(cst))
1061         return failure();
1062 
1063       return createArgOrResultVariableDecl(name, nameLoc, *cst);
1064     }
1065   }
1066 
1067   // If it isn't named, we parse the constraint directly and create an unnamed
1068   // result variable.
1069   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1070   if (failed(cst))
1071     return failure();
1072 
1073   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1074 }
1075 
1076 FailureOr<ast::UserConstraintDecl *>
1077 Parser::parseUserConstraintDecl(bool isInline) {
1078   // Constraints and rewrites have very similar formats, dispatch to a shared
1079   // interface for parsing.
1080   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1081       [&](auto &&...args) {
1082         return this->parseUserPDLLConstraintDecl(args...);
1083       },
1084       ParserContext::Constraint, "constraint", isInline);
1085 }
1086 
1087 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1088   FailureOr<ast::UserConstraintDecl *> decl =
1089       parseUserConstraintDecl(/*isInline=*/true);
1090   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1091     return failure();
1092 
1093   curDeclScope->add(*decl);
1094   return decl;
1095 }
1096 
1097 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1098     const ast::Name &name, bool isInline,
1099     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1100     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1101   // Push the argument scope back onto the list, so that the body can
1102   // reference arguments.
1103   pushDeclScope(argumentScope);
1104 
1105   // Parse the body of the constraint. The body is either defined as a compound
1106   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1107   ast::CompoundStmt *body;
1108   if (curToken.is(Token::equal_arrow)) {
1109     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1110         [&](ast::Stmt *&stmt) -> LogicalResult {
1111           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1112           if (!stmtExpr) {
1113             return emitError(stmt->getLoc(),
1114                              "expected `Constraint` lambda body to contain a "
1115                              "single expression");
1116           }
1117           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1118           return success();
1119         },
1120         /*expectTerminalSemicolon=*/!isInline);
1121     if (failed(bodyResult))
1122       return failure();
1123     body = *bodyResult;
1124   } else {
1125     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1126     if (failed(bodyResult))
1127       return failure();
1128     body = *bodyResult;
1129 
1130     // Verify the structure of the body.
1131     auto bodyIt = body->begin(), bodyE = body->end();
1132     for (; bodyIt != bodyE; ++bodyIt)
1133       if (isa<ast::ReturnStmt>(*bodyIt))
1134         break;
1135     if (failed(validateUserConstraintOrRewriteReturn(
1136             "Constraint", body, bodyIt, bodyE, results, resultType)))
1137       return failure();
1138   }
1139   popDeclScope();
1140 
1141   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1142       name, arguments, results, resultType, body);
1143 }
1144 
1145 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1146   // Constraints and rewrites have very similar formats, dispatch to a shared
1147   // interface for parsing.
1148   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1149       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1150       ParserContext::Rewrite, "rewrite", isInline);
1151 }
1152 
1153 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1154   FailureOr<ast::UserRewriteDecl *> decl =
1155       parseUserRewriteDecl(/*isInline=*/true);
1156   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1157     return failure();
1158 
1159   curDeclScope->add(*decl);
1160   return decl;
1161 }
1162 
1163 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1164     const ast::Name &name, bool isInline,
1165     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1166     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1167   // Push the argument scope back onto the list, so that the body can
1168   // reference arguments.
1169   curDeclScope = argumentScope;
1170   ast::CompoundStmt *body;
1171   if (curToken.is(Token::equal_arrow)) {
1172     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1173         [&](ast::Stmt *&statement) -> LogicalResult {
1174           if (isa<ast::OpRewriteStmt>(statement))
1175             return success();
1176 
1177           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1178           if (!statementExpr) {
1179             return emitError(
1180                 statement->getLoc(),
1181                 "expected `Rewrite` lambda body to contain a single expression "
1182                 "or an operation rewrite statement; such as `erase`, "
1183                 "`replace`, or `rewrite`");
1184           }
1185           statement =
1186               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1187           return success();
1188         },
1189         /*expectTerminalSemicolon=*/!isInline);
1190     if (failed(bodyResult))
1191       return failure();
1192     body = *bodyResult;
1193   } else {
1194     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1195     if (failed(bodyResult))
1196       return failure();
1197     body = *bodyResult;
1198   }
1199   popDeclScope();
1200 
1201   // Verify the structure of the body.
1202   auto bodyIt = body->begin(), bodyE = body->end();
1203   for (; bodyIt != bodyE; ++bodyIt)
1204     if (isa<ast::ReturnStmt>(*bodyIt))
1205       break;
1206   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1207                                                    bodyE, results, resultType)))
1208     return failure();
1209   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1210       name, arguments, results, resultType, body);
1211 }
1212 
1213 template <typename T, typename ParseUserPDLLDeclFnT>
1214 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1215     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1216     StringRef anonymousNamePrefix, bool isInline) {
1217   SMRange loc = curToken.getLoc();
1218   consumeToken();
1219   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1220 
1221   // Parse the name of the decl.
1222   const ast::Name *name = nullptr;
1223   if (curToken.isNot(Token::identifier)) {
1224     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1225     // in C++, so being unnamed is fine.
1226     if (!isInline)
1227       return emitError("expected identifier name");
1228 
1229     // Create a unique anonymous name to use, as the name for this decl is not
1230     // important.
1231     std::string anonName =
1232         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1233                       anonymousDeclNameCounter++)
1234             .str();
1235     name = &ast::Name::create(ctx, anonName, loc);
1236   } else {
1237     // If a name was provided, we can use it directly.
1238     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1239     consumeToken(Token::identifier);
1240   }
1241 
1242   // Parse the functional signature of the decl.
1243   SmallVector<ast::VariableDecl *> arguments, results;
1244   ast::DeclScope *argumentScope;
1245   ast::Type resultType;
1246   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1247                                                    argumentScope, resultType)))
1248     return failure();
1249 
1250   // Check to see which type of constraint this is. If the constraint contains a
1251   // compound body, this is a PDLL decl.
1252   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1253     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1254                            resultType);
1255 
1256   // Otherwise, this is a native decl.
1257   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1258                                                    results, resultType);
1259 }
1260 
1261 template <typename T>
1262 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1263     const ast::Name &name, bool isInline,
1264     ArrayRef<ast::VariableDecl *> arguments,
1265     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1266   // If followed by a string, the native code body has also been specified.
1267   std::string codeStrStorage;
1268   Optional<StringRef> optCodeStr;
1269   if (curToken.isString()) {
1270     codeStrStorage = curToken.getStringValue();
1271     optCodeStr = codeStrStorage;
1272     consumeToken();
1273   } else if (isInline) {
1274     return emitError(name.getLoc(),
1275                      "external declarations must be declared in global scope");
1276   } else if (curToken.is(Token::error)) {
1277     return failure();
1278   }
1279   if (failed(parseToken(Token::semicolon,
1280                         "expected `;` after native declaration")))
1281     return failure();
1282   // TODO: PDL should be able to support constraint results in certain
1283   // situations, we should revise this.
1284   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1285     return emitError(
1286         "native Constraints currently do not support returning results");
1287   }
1288   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1289 }
1290 
1291 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1292     SmallVectorImpl<ast::VariableDecl *> &arguments,
1293     SmallVectorImpl<ast::VariableDecl *> &results,
1294     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1295   // Parse the argument list of the decl.
1296   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1297     return failure();
1298 
1299   argumentScope = pushDeclScope();
1300   if (curToken.isNot(Token::r_paren)) {
1301     do {
1302       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1303       if (failed(argument))
1304         return failure();
1305       arguments.emplace_back(*argument);
1306     } while (consumeIf(Token::comma));
1307   }
1308   popDeclScope();
1309   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1310     return failure();
1311 
1312   // Parse the results of the decl.
1313   pushDeclScope();
1314   if (consumeIf(Token::arrow)) {
1315     auto parseResultFn = [&]() -> LogicalResult {
1316       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1317       if (failed(result))
1318         return failure();
1319       results.emplace_back(*result);
1320       return success();
1321     };
1322 
1323     // Check for a list of results.
1324     if (consumeIf(Token::l_paren)) {
1325       do {
1326         if (failed(parseResultFn()))
1327           return failure();
1328       } while (consumeIf(Token::comma));
1329       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1330         return failure();
1331 
1332       // Otherwise, there is only one result.
1333     } else if (failed(parseResultFn())) {
1334       return failure();
1335     }
1336   }
1337   popDeclScope();
1338 
1339   // Compute the result type of the decl.
1340   resultType = createUserConstraintRewriteResultType(results);
1341 
1342   // Verify that results are only named if there are more than one.
1343   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1344     return emitError(
1345         results.front()->getLoc(),
1346         "cannot create a single-element tuple with an element label");
1347   }
1348   return success();
1349 }
1350 
1351 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1352     StringRef declType, ast::CompoundStmt *body,
1353     ArrayRef<ast::Stmt *>::iterator bodyIt,
1354     ArrayRef<ast::Stmt *>::iterator bodyE,
1355     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1356   // Handle if a `return` was provided.
1357   if (bodyIt != bodyE) {
1358     // Emit an error if we have trailing statements after the return.
1359     if (std::next(bodyIt) != bodyE) {
1360       return emitError(
1361           (*std::next(bodyIt))->getLoc(),
1362           llvm::formatv("`return` terminated the `{0}` body, but found "
1363                         "trailing statements afterwards",
1364                         declType));
1365     }
1366 
1367     // Otherwise if a return wasn't provided, check that no results are
1368     // expected.
1369   } else if (!results.empty()) {
1370     return emitError(
1371         {body->getLoc().End, body->getLoc().End},
1372         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1373                       declType, resultType));
1374   }
1375   return success();
1376 }
1377 
1378 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1379   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1380     if (isa<ast::OpRewriteStmt>(statement))
1381       return success();
1382     return emitError(
1383         statement->getLoc(),
1384         "expected Pattern lambda body to contain a single operation "
1385         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1386   });
1387 }
1388 
1389 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1390   SMRange loc = curToken.getLoc();
1391   consumeToken(Token::kw_Pattern);
1392   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1393                                               ParserContext::PatternMatch);
1394 
1395   // Check for an optional identifier for the pattern name.
1396   const ast::Name *name = nullptr;
1397   if (curToken.is(Token::identifier)) {
1398     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1399     consumeToken(Token::identifier);
1400   }
1401 
1402   // Parse any pattern metadata.
1403   ParsedPatternMetadata metadata;
1404   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1405     return failure();
1406 
1407   // Parse the pattern body.
1408   ast::CompoundStmt *body;
1409 
1410   // Handle a lambda body.
1411   if (curToken.is(Token::equal_arrow)) {
1412     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1413     if (failed(bodyResult))
1414       return failure();
1415     body = *bodyResult;
1416   } else {
1417     if (curToken.isNot(Token::l_brace))
1418       return emitError("expected `{` or `=>` to start pattern body");
1419     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1420     if (failed(bodyResult))
1421       return failure();
1422     body = *bodyResult;
1423 
1424     // Verify the body of the pattern.
1425     auto bodyIt = body->begin(), bodyE = body->end();
1426     for (; bodyIt != bodyE; ++bodyIt) {
1427       if (isa<ast::ReturnStmt>(*bodyIt)) {
1428         return emitError((*bodyIt)->getLoc(),
1429                          "`return` statements are only permitted within a "
1430                          "`Constraint` or `Rewrite` body");
1431       }
1432       // Break when we've found the rewrite statement.
1433       if (isa<ast::OpRewriteStmt>(*bodyIt))
1434         break;
1435     }
1436     if (bodyIt == bodyE) {
1437       return emitError(loc,
1438                        "expected Pattern body to terminate with an operation "
1439                        "rewrite statement, such as `erase`");
1440     }
1441     if (std::next(bodyIt) != bodyE) {
1442       return emitError((*std::next(bodyIt))->getLoc(),
1443                        "Pattern body was terminated by an operation "
1444                        "rewrite statement, but found trailing statements");
1445     }
1446   }
1447 
1448   return createPatternDecl(loc, name, metadata, body);
1449 }
1450 
1451 LogicalResult
1452 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1453   Optional<SMRange> benefitLoc;
1454   Optional<SMRange> hasBoundedRecursionLoc;
1455 
1456   do {
1457     // Handle metadata code completion.
1458     if (curToken.is(Token::code_complete))
1459       return codeCompletePatternMetadata();
1460 
1461     if (curToken.isNot(Token::identifier))
1462       return emitError("expected pattern metadata identifier");
1463     StringRef metadataStr = curToken.getSpelling();
1464     SMRange metadataLoc = curToken.getLoc();
1465     consumeToken(Token::identifier);
1466 
1467     // Parse the benefit metadata: benefit(<integer-value>)
1468     if (metadataStr == "benefit") {
1469       if (benefitLoc) {
1470         return emitErrorAndNote(metadataLoc,
1471                                 "pattern benefit has already been specified",
1472                                 *benefitLoc, "see previous definition here");
1473       }
1474       if (failed(parseToken(Token::l_paren,
1475                             "expected `(` before pattern benefit")))
1476         return failure();
1477 
1478       uint16_t benefitValue = 0;
1479       if (curToken.isNot(Token::integer))
1480         return emitError("expected integral pattern benefit");
1481       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1482         return emitError(
1483             "expected pattern benefit to fit within a 16-bit integer");
1484       consumeToken(Token::integer);
1485 
1486       metadata.benefit = benefitValue;
1487       benefitLoc = metadataLoc;
1488 
1489       if (failed(
1490               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1491         return failure();
1492       continue;
1493     }
1494 
1495     // Parse the bounded recursion metadata: recursion
1496     if (metadataStr == "recursion") {
1497       if (hasBoundedRecursionLoc) {
1498         return emitErrorAndNote(
1499             metadataLoc,
1500             "pattern recursion metadata has already been specified",
1501             *hasBoundedRecursionLoc, "see previous definition here");
1502       }
1503       metadata.hasBoundedRecursion = true;
1504       hasBoundedRecursionLoc = metadataLoc;
1505       continue;
1506     }
1507 
1508     return emitError(metadataLoc, "unknown pattern metadata");
1509   } while (consumeIf(Token::comma));
1510 
1511   return success();
1512 }
1513 
1514 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1515   consumeToken(Token::less);
1516 
1517   FailureOr<ast::Expr *> typeExpr = parseExpr();
1518   if (failed(typeExpr) ||
1519       failed(parseToken(Token::greater,
1520                         "expected `>` after variable type constraint")))
1521     return failure();
1522   return typeExpr;
1523 }
1524 
1525 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1526   assert(curDeclScope && "defining decl outside of a decl scope");
1527   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1528     return emitErrorAndNote(
1529         name.getLoc(), "`" + name.getName() + "` has already been defined",
1530         lastDecl->getName()->getLoc(), "see previous definition here");
1531   }
1532   return success();
1533 }
1534 
1535 FailureOr<ast::VariableDecl *>
1536 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1537                            ast::Expr *initExpr,
1538                            ArrayRef<ast::ConstraintRef> constraints) {
1539   assert(curDeclScope && "defining variable outside of decl scope");
1540   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1541 
1542   // If the name of the variable indicates a special variable, we don't add it
1543   // to the scope. This variable is local to the definition point.
1544   if (name.empty() || name == "_") {
1545     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1546                                      constraints);
1547   }
1548   if (failed(checkDefineNamedDecl(nameDecl)))
1549     return failure();
1550 
1551   auto *varDecl =
1552       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1553   curDeclScope->add(varDecl);
1554   return varDecl;
1555 }
1556 
1557 FailureOr<ast::VariableDecl *>
1558 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1559                            ArrayRef<ast::ConstraintRef> constraints) {
1560   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1561                             constraints);
1562 }
1563 
1564 LogicalResult Parser::parseVariableDeclConstraintList(
1565     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1566   Optional<SMRange> typeConstraint;
1567   auto parseSingleConstraint = [&] {
1568     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1569         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1570         /*allowNonCoreConstraints=*/true);
1571     if (failed(constraint))
1572       return failure();
1573     constraints.push_back(*constraint);
1574     return success();
1575   };
1576 
1577   // Check to see if this is a single constraint, or a list.
1578   if (!consumeIf(Token::l_square))
1579     return parseSingleConstraint();
1580 
1581   do {
1582     if (failed(parseSingleConstraint()))
1583       return failure();
1584   } while (consumeIf(Token::comma));
1585   return parseToken(Token::r_square, "expected `]` after constraint list");
1586 }
1587 
1588 FailureOr<ast::ConstraintRef>
1589 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1590                         ArrayRef<ast::ConstraintRef> existingConstraints,
1591                         bool allowInlineTypeConstraints,
1592                         bool allowNonCoreConstraints) {
1593   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1594     if (!allowInlineTypeConstraints) {
1595       return emitError(
1596           curToken.getLoc(),
1597           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1598           "permitted on arguments or results");
1599     }
1600     if (typeConstraint)
1601       return emitErrorAndNote(
1602           curToken.getLoc(),
1603           "the type of this variable has already been constrained",
1604           *typeConstraint, "see previous constraint location here");
1605     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1606     if (failed(constraintExpr))
1607       return failure();
1608     typeExpr = *constraintExpr;
1609     typeConstraint = typeExpr->getLoc();
1610     return success();
1611   };
1612 
1613   SMRange loc = curToken.getLoc();
1614   switch (curToken.getKind()) {
1615   case Token::kw_Attr: {
1616     consumeToken(Token::kw_Attr);
1617 
1618     // Check for a type constraint.
1619     ast::Expr *typeExpr = nullptr;
1620     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1621       return failure();
1622     return ast::ConstraintRef(
1623         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1624   }
1625   case Token::kw_Op: {
1626     consumeToken(Token::kw_Op);
1627 
1628     // Parse an optional operation name. If the name isn't provided, this refers
1629     // to "any" operation.
1630     FailureOr<ast::OpNameDecl *> opName =
1631         parseWrappedOperationName(/*allowEmptyName=*/true);
1632     if (failed(opName))
1633       return failure();
1634 
1635     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1636                               loc);
1637   }
1638   case Token::kw_Type:
1639     consumeToken(Token::kw_Type);
1640     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1641   case Token::kw_TypeRange:
1642     consumeToken(Token::kw_TypeRange);
1643     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1644                               loc);
1645   case Token::kw_Value: {
1646     consumeToken(Token::kw_Value);
1647 
1648     // Check for a type constraint.
1649     ast::Expr *typeExpr = nullptr;
1650     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1651       return failure();
1652 
1653     return ast::ConstraintRef(
1654         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1655   }
1656   case Token::kw_ValueRange: {
1657     consumeToken(Token::kw_ValueRange);
1658 
1659     // Check for a type constraint.
1660     ast::Expr *typeExpr = nullptr;
1661     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1662       return failure();
1663 
1664     return ast::ConstraintRef(
1665         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1666   }
1667 
1668   case Token::kw_Constraint: {
1669     // Handle an inline constraint.
1670     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1671     if (failed(decl))
1672       return failure();
1673     return ast::ConstraintRef(*decl, loc);
1674   }
1675   case Token::identifier: {
1676     StringRef constraintName = curToken.getSpelling();
1677     consumeToken(Token::identifier);
1678 
1679     // Lookup the referenced constraint.
1680     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1681     if (!cstDecl) {
1682       return emitError(loc, "unknown reference to constraint `" +
1683                                 constraintName + "`");
1684     }
1685 
1686     // Handle a reference to a proper constraint.
1687     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1688       return ast::ConstraintRef(cst, loc);
1689 
1690     return emitErrorAndNote(
1691         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1692         "see the definition of `" + constraintName + "` here");
1693   }
1694     // Handle single entity constraint code completion.
1695   case Token::code_complete: {
1696     // Try to infer the current type for use by code completion.
1697     ast::Type inferredType;
1698     if (failed(validateVariableConstraints(existingConstraints, inferredType,
1699                                            allowNonCoreConstraints)))
1700       return failure();
1701 
1702     return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1703                                       allowInlineTypeConstraints);
1704   }
1705   default:
1706     break;
1707   }
1708   return emitError(loc, "expected identifier constraint");
1709 }
1710 
1711 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1712   // Constraint arguments may apply more complex constraints via the arguments.
1713   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1714 
1715   Optional<SMRange> typeConstraint;
1716   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1717                          /*allowInlineTypeConstraints=*/false,
1718                          allowNonCoreConstraints);
1719 }
1720 
1721 //===----------------------------------------------------------------------===//
1722 // Exprs
1723 
1724 FailureOr<ast::Expr *> Parser::parseExpr() {
1725   if (curToken.is(Token::underscore))
1726     return parseUnderscoreExpr();
1727 
1728   // Parse the LHS expression.
1729   FailureOr<ast::Expr *> lhsExpr;
1730   switch (curToken.getKind()) {
1731   case Token::kw_attr:
1732     lhsExpr = parseAttributeExpr();
1733     break;
1734   case Token::kw_Constraint:
1735     lhsExpr = parseInlineConstraintLambdaExpr();
1736     break;
1737   case Token::identifier:
1738     lhsExpr = parseIdentifierExpr();
1739     break;
1740   case Token::kw_op:
1741     lhsExpr = parseOperationExpr();
1742     break;
1743   case Token::kw_Rewrite:
1744     lhsExpr = parseInlineRewriteLambdaExpr();
1745     break;
1746   case Token::kw_type:
1747     lhsExpr = parseTypeExpr();
1748     break;
1749   case Token::l_paren:
1750     lhsExpr = parseTupleExpr();
1751     break;
1752   default:
1753     return emitError("expected expression");
1754   }
1755   if (failed(lhsExpr))
1756     return failure();
1757 
1758   // Check for an operator expression.
1759   while (true) {
1760     switch (curToken.getKind()) {
1761     case Token::dot:
1762       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1763       break;
1764     case Token::l_paren:
1765       lhsExpr = parseCallExpr(*lhsExpr);
1766       break;
1767     default:
1768       return lhsExpr;
1769     }
1770     if (failed(lhsExpr))
1771       return failure();
1772   }
1773 }
1774 
1775 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1776   SMRange loc = curToken.getLoc();
1777   consumeToken(Token::kw_attr);
1778 
1779   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1780   // identifier.
1781   if (!consumeIf(Token::less)) {
1782     resetToken(loc);
1783     return parseIdentifierExpr();
1784   }
1785 
1786   if (!curToken.isString())
1787     return emitError("expected string literal containing MLIR attribute");
1788   std::string attrExpr = curToken.getStringValue();
1789   consumeToken();
1790 
1791   loc.End = curToken.getEndLoc();
1792   if (failed(
1793           parseToken(Token::greater, "expected `>` after attribute literal")))
1794     return failure();
1795   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1796 }
1797 
1798 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1799   consumeToken(Token::l_paren);
1800 
1801   // Parse the arguments of the call.
1802   SmallVector<ast::Expr *> arguments;
1803   if (curToken.isNot(Token::r_paren)) {
1804     do {
1805       // Handle code completion for the call arguments.
1806       if (curToken.is(Token::code_complete)) {
1807         codeCompleteCallSignature(parentExpr, arguments.size());
1808         return failure();
1809       }
1810 
1811       FailureOr<ast::Expr *> argument = parseExpr();
1812       if (failed(argument))
1813         return failure();
1814       arguments.push_back(*argument);
1815     } while (consumeIf(Token::comma));
1816   }
1817 
1818   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1819   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1820     return failure();
1821 
1822   return createCallExpr(loc, parentExpr, arguments);
1823 }
1824 
1825 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1826   ast::Decl *decl = curDeclScope->lookup(name);
1827   if (!decl)
1828     return emitError(loc, "undefined reference to `" + name + "`");
1829 
1830   return createDeclRefExpr(loc, decl);
1831 }
1832 
1833 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1834   StringRef name = curToken.getSpelling();
1835   SMRange nameLoc = curToken.getLoc();
1836   consumeToken();
1837 
1838   // Check to see if this is a decl ref expression that defines a variable
1839   // inline.
1840   if (consumeIf(Token::colon)) {
1841     SmallVector<ast::ConstraintRef> constraints;
1842     if (failed(parseVariableDeclConstraintList(constraints)))
1843       return failure();
1844     ast::Type type;
1845     if (failed(validateVariableConstraints(constraints, type)))
1846       return failure();
1847     return createInlineVariableExpr(type, name, nameLoc, constraints);
1848   }
1849 
1850   return parseDeclRefExpr(name, nameLoc);
1851 }
1852 
1853 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1854   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1855   if (failed(decl))
1856     return failure();
1857 
1858   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1859                                   ast::ConstraintType::get(ctx));
1860 }
1861 
1862 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1863   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1864   if (failed(decl))
1865     return failure();
1866 
1867   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1868                                   ast::RewriteType::get(ctx));
1869 }
1870 
1871 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1872   SMRange dotLoc = curToken.getLoc();
1873   consumeToken(Token::dot);
1874 
1875   // Check for code completion of the member name.
1876   if (curToken.is(Token::code_complete))
1877     return codeCompleteMemberAccess(parentExpr);
1878 
1879   // Parse the member name.
1880   Token memberNameTok = curToken;
1881   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1882       !memberNameTok.isKeyword())
1883     return emitError(dotLoc, "expected identifier or numeric member name");
1884   StringRef memberName = memberNameTok.getSpelling();
1885   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1886   consumeToken();
1887 
1888   return createMemberAccessExpr(parentExpr, memberName, loc);
1889 }
1890 
1891 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1892   SMRange loc = curToken.getLoc();
1893 
1894   // Check for code completion for the dialect name.
1895   if (curToken.is(Token::code_complete))
1896     return codeCompleteDialectName();
1897 
1898   // Handle the case of an no operation name.
1899   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1900     if (allowEmptyName)
1901       return ast::OpNameDecl::create(ctx, SMRange());
1902     return emitError("expected dialect namespace");
1903   }
1904   StringRef name = curToken.getSpelling();
1905   consumeToken();
1906 
1907   // Otherwise, this is a literal operation name.
1908   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1909     return failure();
1910 
1911   // Check for code completion for the operation name.
1912   if (curToken.is(Token::code_complete))
1913     return codeCompleteOperationName(name);
1914 
1915   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1916     return emitError("expected operation name after dialect namespace");
1917 
1918   name = StringRef(name.data(), name.size() + 1);
1919   do {
1920     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1921     loc.End = curToken.getEndLoc();
1922     consumeToken();
1923   } while (curToken.isAny(Token::identifier, Token::dot) ||
1924            curToken.isKeyword());
1925   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1926 }
1927 
1928 FailureOr<ast::OpNameDecl *>
1929 Parser::parseWrappedOperationName(bool allowEmptyName) {
1930   if (!consumeIf(Token::less))
1931     return ast::OpNameDecl::create(ctx, SMRange());
1932 
1933   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1934   if (failed(opNameDecl))
1935     return failure();
1936 
1937   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1938     return failure();
1939   return opNameDecl;
1940 }
1941 
1942 FailureOr<ast::Expr *>
1943 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
1944   SMRange loc = curToken.getLoc();
1945   consumeToken(Token::kw_op);
1946 
1947   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1948   // identifier.
1949   if (curToken.isNot(Token::less)) {
1950     resetToken(loc);
1951     return parseIdentifierExpr();
1952   }
1953 
1954   // Parse the operation name. The name may be elided, in which case the
1955   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1956   // `Operation*`). Operation names within a rewrite context must be named.
1957   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1958   FailureOr<ast::OpNameDecl *> opNameDecl =
1959       parseWrappedOperationName(allowEmptyName);
1960   if (failed(opNameDecl))
1961     return failure();
1962   Optional<StringRef> opName = (*opNameDecl)->getName();
1963 
1964   // Functor used to create an implicit range variable, used for implicit "all"
1965   // operand or results variables.
1966   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1967     FailureOr<ast::VariableDecl *> rangeVar =
1968         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1969     assert(succeeded(rangeVar) && "expected range variable to be valid");
1970     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1971   };
1972 
1973   // Check for the optional list of operands.
1974   SmallVector<ast::Expr *> operands;
1975   if (!consumeIf(Token::l_paren)) {
1976     // If the operand list isn't specified and we are in a match context, define
1977     // an inplace unconstrained operand range corresponding to all of the
1978     // operands of the operation. This avoids treating zero operands the same
1979     // way as "unconstrained operands".
1980     if (parserContext != ParserContext::Rewrite) {
1981       operands.push_back(createImplicitRangeVar(
1982           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1983     }
1984   } else if (!consumeIf(Token::r_paren)) {
1985     // Check for operand signature code completion.
1986     if (curToken.is(Token::code_complete)) {
1987       codeCompleteOperationOperandsSignature(opName, operands.size());
1988       return failure();
1989     }
1990 
1991     // If the operand list was specified and non-empty, parse the operands.
1992     do {
1993       FailureOr<ast::Expr *> operand = parseExpr();
1994       if (failed(operand))
1995         return failure();
1996       operands.push_back(*operand);
1997     } while (consumeIf(Token::comma));
1998 
1999     if (failed(parseToken(Token::r_paren,
2000                           "expected `)` after operation operand list")))
2001       return failure();
2002   }
2003 
2004   // Check for the optional list of attributes.
2005   SmallVector<ast::NamedAttributeDecl *> attributes;
2006   if (consumeIf(Token::l_brace)) {
2007     do {
2008       FailureOr<ast::NamedAttributeDecl *> decl =
2009           parseNamedAttributeDecl(opName);
2010       if (failed(decl))
2011         return failure();
2012       attributes.emplace_back(*decl);
2013     } while (consumeIf(Token::comma));
2014 
2015     if (failed(parseToken(Token::r_brace,
2016                           "expected `}` after operation attribute list")))
2017       return failure();
2018   }
2019 
2020   // Handle the result types of the operation.
2021   SmallVector<ast::Expr *> resultTypes;
2022   OpResultTypeContext resultTypeContext = inputResultTypeContext;
2023 
2024   // Check for an explicit list of result types.
2025   if (consumeIf(Token::arrow)) {
2026     if (failed(parseToken(Token::l_paren,
2027                           "expected `(` before operation result type list")))
2028       return failure();
2029 
2030     // If result types are provided, initially assume that the operation does
2031     // not rely on type inferrence. We don't assert that it isn't, because we
2032     // may be inferring the value of some type/type range variables, but given
2033     // that these variables may be defined in calls we can't always discern when
2034     // this is the case.
2035     resultTypeContext = OpResultTypeContext::Explicit;
2036 
2037     // Handle the case of an empty result list.
2038     if (!consumeIf(Token::r_paren)) {
2039       do {
2040         // Check for result signature code completion.
2041         if (curToken.is(Token::code_complete)) {
2042           codeCompleteOperationResultsSignature(opName, resultTypes.size());
2043           return failure();
2044         }
2045 
2046         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2047         if (failed(resultTypeExpr))
2048           return failure();
2049         resultTypes.push_back(*resultTypeExpr);
2050       } while (consumeIf(Token::comma));
2051 
2052       if (failed(parseToken(Token::r_paren,
2053                             "expected `)` after operation result type list")))
2054         return failure();
2055     }
2056   } else if (parserContext != ParserContext::Rewrite) {
2057     // If the result list isn't specified and we are in a match context, define
2058     // an inplace unconstrained result range corresponding to all of the results
2059     // of the operation. This avoids treating zero results the same way as
2060     // "unconstrained results".
2061     resultTypes.push_back(createImplicitRangeVar(
2062         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2063   } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2064     // If the result list isn't specified and we are in a rewrite, try to infer
2065     // them at runtime instead.
2066     resultTypeContext = OpResultTypeContext::Interface;
2067   }
2068 
2069   return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2070                              attributes, resultTypes);
2071 }
2072 
2073 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2074   SMRange loc = curToken.getLoc();
2075   consumeToken(Token::l_paren);
2076 
2077   DenseMap<StringRef, SMRange> usedNames;
2078   SmallVector<StringRef> elementNames;
2079   SmallVector<ast::Expr *> elements;
2080   if (curToken.isNot(Token::r_paren)) {
2081     do {
2082       // Check for the optional element name assignment before the value.
2083       StringRef elementName;
2084       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2085         Token elementNameTok = curToken;
2086         consumeToken();
2087 
2088         // The element name is only present if followed by an `=`.
2089         if (consumeIf(Token::equal)) {
2090           elementName = elementNameTok.getSpelling();
2091 
2092           // Check to see if this name is already used.
2093           auto elementNameIt =
2094               usedNames.try_emplace(elementName, elementNameTok.getLoc());
2095           if (!elementNameIt.second) {
2096             return emitErrorAndNote(
2097                 elementNameTok.getLoc(),
2098                 llvm::formatv("duplicate tuple element label `{0}`",
2099                               elementName),
2100                 elementNameIt.first->getSecond(),
2101                 "see previous label use here");
2102           }
2103         } else {
2104           // Otherwise, we treat this as part of an expression so reset the
2105           // lexer.
2106           resetToken(elementNameTok.getLoc());
2107         }
2108       }
2109       elementNames.push_back(elementName);
2110 
2111       // Parse the tuple element value.
2112       FailureOr<ast::Expr *> element = parseExpr();
2113       if (failed(element))
2114         return failure();
2115       elements.push_back(*element);
2116     } while (consumeIf(Token::comma));
2117   }
2118   loc.End = curToken.getEndLoc();
2119   if (failed(
2120           parseToken(Token::r_paren, "expected `)` after tuple element list")))
2121     return failure();
2122   return createTupleExpr(loc, elements, elementNames);
2123 }
2124 
2125 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2126   SMRange loc = curToken.getLoc();
2127   consumeToken(Token::kw_type);
2128 
2129   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2130   // identifier.
2131   if (!consumeIf(Token::less)) {
2132     resetToken(loc);
2133     return parseIdentifierExpr();
2134   }
2135 
2136   if (!curToken.isString())
2137     return emitError("expected string literal containing MLIR type");
2138   std::string attrExpr = curToken.getStringValue();
2139   consumeToken();
2140 
2141   loc.End = curToken.getEndLoc();
2142   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2143     return failure();
2144   return ast::TypeExpr::create(ctx, loc, attrExpr);
2145 }
2146 
2147 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2148   StringRef name = curToken.getSpelling();
2149   SMRange nameLoc = curToken.getLoc();
2150   consumeToken(Token::underscore);
2151 
2152   // Underscore expressions require a constraint list.
2153   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2154     return failure();
2155 
2156   // Parse the constraints for the expression.
2157   SmallVector<ast::ConstraintRef> constraints;
2158   if (failed(parseVariableDeclConstraintList(constraints)))
2159     return failure();
2160 
2161   ast::Type type;
2162   if (failed(validateVariableConstraints(constraints, type)))
2163     return failure();
2164   return createInlineVariableExpr(type, name, nameLoc, constraints);
2165 }
2166 
2167 //===----------------------------------------------------------------------===//
2168 // Stmts
2169 
2170 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2171   FailureOr<ast::Stmt *> stmt;
2172   switch (curToken.getKind()) {
2173   case Token::kw_erase:
2174     stmt = parseEraseStmt();
2175     break;
2176   case Token::kw_let:
2177     stmt = parseLetStmt();
2178     break;
2179   case Token::kw_replace:
2180     stmt = parseReplaceStmt();
2181     break;
2182   case Token::kw_return:
2183     stmt = parseReturnStmt();
2184     break;
2185   case Token::kw_rewrite:
2186     stmt = parseRewriteStmt();
2187     break;
2188   default:
2189     stmt = parseExpr();
2190     break;
2191   }
2192   if (failed(stmt) ||
2193       (expectTerminalSemicolon &&
2194        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2195     return failure();
2196   return stmt;
2197 }
2198 
2199 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2200   SMLoc startLoc = curToken.getStartLoc();
2201   consumeToken(Token::l_brace);
2202 
2203   // Push a new block scope and parse any nested statements.
2204   pushDeclScope();
2205   SmallVector<ast::Stmt *> statements;
2206   while (curToken.isNot(Token::r_brace)) {
2207     FailureOr<ast::Stmt *> statement = parseStmt();
2208     if (failed(statement))
2209       return popDeclScope(), failure();
2210     statements.push_back(*statement);
2211   }
2212   popDeclScope();
2213 
2214   // Consume the end brace.
2215   SMRange location(startLoc, curToken.getEndLoc());
2216   consumeToken(Token::r_brace);
2217 
2218   return ast::CompoundStmt::create(ctx, location, statements);
2219 }
2220 
2221 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2222   if (parserContext == ParserContext::Constraint)
2223     return emitError("`erase` cannot be used within a Constraint");
2224   SMRange loc = curToken.getLoc();
2225   consumeToken(Token::kw_erase);
2226 
2227   // Parse the root operation expression.
2228   FailureOr<ast::Expr *> rootOp = parseExpr();
2229   if (failed(rootOp))
2230     return failure();
2231 
2232   return createEraseStmt(loc, *rootOp);
2233 }
2234 
2235 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2236   SMRange loc = curToken.getLoc();
2237   consumeToken(Token::kw_let);
2238 
2239   // Parse the name of the new variable.
2240   SMRange varLoc = curToken.getLoc();
2241   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2242     // `_` is a reserved variable name.
2243     if (curToken.is(Token::underscore)) {
2244       return emitError(varLoc,
2245                        "`_` may only be used to define \"inline\" variables");
2246     }
2247     return emitError(varLoc,
2248                      "expected identifier after `let` to name a new variable");
2249   }
2250   StringRef varName = curToken.getSpelling();
2251   consumeToken();
2252 
2253   // Parse the optional set of constraints.
2254   SmallVector<ast::ConstraintRef> constraints;
2255   if (consumeIf(Token::colon) &&
2256       failed(parseVariableDeclConstraintList(constraints)))
2257     return failure();
2258 
2259   // Parse the optional initializer expression.
2260   ast::Expr *initializer = nullptr;
2261   if (consumeIf(Token::equal)) {
2262     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2263     if (failed(initOrFailure))
2264       return failure();
2265     initializer = *initOrFailure;
2266 
2267     // Check that the constraints are compatible with having an initializer,
2268     // e.g. type constraints cannot be used with initializers.
2269     for (ast::ConstraintRef constraint : constraints) {
2270       LogicalResult result =
2271           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2272               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2273                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2274                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2275                   return this->emitError(
2276                       constraint.referenceLoc,
2277                       "type constraints are not permitted on variables with "
2278                       "initializers");
2279                 }
2280                 return success();
2281               })
2282               .Default(success());
2283       if (failed(result))
2284         return failure();
2285     }
2286   }
2287 
2288   FailureOr<ast::VariableDecl *> varDecl =
2289       createVariableDecl(varName, varLoc, initializer, constraints);
2290   if (failed(varDecl))
2291     return failure();
2292   return ast::LetStmt::create(ctx, loc, *varDecl);
2293 }
2294 
2295 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2296   if (parserContext == ParserContext::Constraint)
2297     return emitError("`replace` cannot be used within a Constraint");
2298   SMRange loc = curToken.getLoc();
2299   consumeToken(Token::kw_replace);
2300 
2301   // Parse the root operation expression.
2302   FailureOr<ast::Expr *> rootOp = parseExpr();
2303   if (failed(rootOp))
2304     return failure();
2305 
2306   if (failed(
2307           parseToken(Token::kw_with, "expected `with` after root operation")))
2308     return failure();
2309 
2310   // The replacement portion of this statement is within a rewrite context.
2311   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2312                                               ParserContext::Rewrite);
2313 
2314   // Parse the replacement values.
2315   SmallVector<ast::Expr *> replValues;
2316   if (consumeIf(Token::l_paren)) {
2317     if (consumeIf(Token::r_paren)) {
2318       return emitError(
2319           loc, "expected at least one replacement value, consider using "
2320                "`erase` if no replacement values are desired");
2321     }
2322 
2323     do {
2324       FailureOr<ast::Expr *> replExpr = parseExpr();
2325       if (failed(replExpr))
2326         return failure();
2327       replValues.emplace_back(*replExpr);
2328     } while (consumeIf(Token::comma));
2329 
2330     if (failed(parseToken(Token::r_paren,
2331                           "expected `)` after replacement values")))
2332       return failure();
2333   } else {
2334     // Handle replacement with an operation uniquely, as the replacement
2335     // operation supports type inferrence from the root operation.
2336     FailureOr<ast::Expr *> replExpr;
2337     if (curToken.is(Token::kw_op))
2338       replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2339     else
2340       replExpr = parseExpr();
2341     if (failed(replExpr))
2342       return failure();
2343     replValues.emplace_back(*replExpr);
2344   }
2345 
2346   return createReplaceStmt(loc, *rootOp, replValues);
2347 }
2348 
2349 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2350   SMRange loc = curToken.getLoc();
2351   consumeToken(Token::kw_return);
2352 
2353   // Parse the result value.
2354   FailureOr<ast::Expr *> resultExpr = parseExpr();
2355   if (failed(resultExpr))
2356     return failure();
2357 
2358   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2359 }
2360 
2361 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2362   if (parserContext == ParserContext::Constraint)
2363     return emitError("`rewrite` cannot be used within a Constraint");
2364   SMRange loc = curToken.getLoc();
2365   consumeToken(Token::kw_rewrite);
2366 
2367   // Parse the root operation.
2368   FailureOr<ast::Expr *> rootOp = parseExpr();
2369   if (failed(rootOp))
2370     return failure();
2371 
2372   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2373     return failure();
2374 
2375   if (curToken.isNot(Token::l_brace))
2376     return emitError("expected `{` to start rewrite body");
2377 
2378   // The rewrite body of this statement is within a rewrite context.
2379   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2380                                               ParserContext::Rewrite);
2381 
2382   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2383   if (failed(rewriteBody))
2384     return failure();
2385 
2386   // Verify the rewrite body.
2387   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2388     if (isa<ast::ReturnStmt>(stmt)) {
2389       return emitError(stmt->getLoc(),
2390                        "`return` statements are only permitted within a "
2391                        "`Constraint` or `Rewrite` body");
2392     }
2393   }
2394 
2395   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2396 }
2397 
2398 //===----------------------------------------------------------------------===//
2399 // Creation+Analysis
2400 //===----------------------------------------------------------------------===//
2401 
2402 //===----------------------------------------------------------------------===//
2403 // Decls
2404 
2405 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2406   // Unwrap reference expressions.
2407   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2408     node = init->getDecl();
2409   return dyn_cast<ast::CallableDecl>(node);
2410 }
2411 
2412 FailureOr<ast::PatternDecl *>
2413 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2414                           const ParsedPatternMetadata &metadata,
2415                           ast::CompoundStmt *body) {
2416   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2417                                   metadata.hasBoundedRecursion, body);
2418 }
2419 
2420 ast::Type Parser::createUserConstraintRewriteResultType(
2421     ArrayRef<ast::VariableDecl *> results) {
2422   // Single result decls use the type of the single result.
2423   if (results.size() == 1)
2424     return results[0]->getType();
2425 
2426   // Multiple results use a tuple type, with the types and names grabbed from
2427   // the result variable decls.
2428   auto resultTypes = llvm::map_range(
2429       results, [&](const auto *result) { return result->getType(); });
2430   auto resultNames = llvm::map_range(
2431       results, [&](const auto *result) { return result->getName().getName(); });
2432   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2433                              llvm::to_vector(resultNames));
2434 }
2435 
2436 template <typename T>
2437 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2438     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2439     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2440     ast::CompoundStmt *body) {
2441   if (!body->getChildren().empty()) {
2442     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2443       ast::Expr *resultExpr = retStmt->getResultExpr();
2444 
2445       // Process the result of the decl. If no explicit signature results
2446       // were provided, check for return type inference. Otherwise, check that
2447       // the return expression can be converted to the expected type.
2448       if (results.empty())
2449         resultType = resultExpr->getType();
2450       else if (failed(convertExpressionTo(resultExpr, resultType)))
2451         return failure();
2452       else
2453         retStmt->setResultExpr(resultExpr);
2454     }
2455   }
2456   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2457 }
2458 
2459 FailureOr<ast::VariableDecl *>
2460 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2461                            ArrayRef<ast::ConstraintRef> constraints) {
2462   // The type of the variable, which is expected to be inferred by either a
2463   // constraint or an initializer expression.
2464   ast::Type type;
2465   if (failed(validateVariableConstraints(constraints, type)))
2466     return failure();
2467 
2468   if (initializer) {
2469     // Update the variable type based on the initializer, or try to convert the
2470     // initializer to the existing type.
2471     if (!type)
2472       type = initializer->getType();
2473     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2474       type = mergedType;
2475     else if (failed(convertExpressionTo(initializer, type)))
2476       return failure();
2477 
2478     // Otherwise, if there is no initializer check that the type has already
2479     // been resolved from the constraint list.
2480   } else if (!type) {
2481     return emitErrorAndNote(
2482         loc, "unable to infer type for variable `" + name + "`", loc,
2483         "the type of a variable must be inferable from the constraint "
2484         "list or the initializer");
2485   }
2486 
2487   // Constraint types cannot be used when defining variables.
2488   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2489     return emitError(
2490         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2491   }
2492 
2493   // Try to define a variable with the given name.
2494   FailureOr<ast::VariableDecl *> varDecl =
2495       defineVariableDecl(name, loc, type, initializer, constraints);
2496   if (failed(varDecl))
2497     return failure();
2498 
2499   return *varDecl;
2500 }
2501 
2502 FailureOr<ast::VariableDecl *>
2503 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2504                                       const ast::ConstraintRef &constraint) {
2505   // Constraint arguments may apply more complex constraints via the arguments.
2506   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2507   ast::Type argType;
2508   if (failed(validateVariableConstraint(constraint, argType,
2509                                         allowNonCoreConstraints)))
2510     return failure();
2511   return defineVariableDecl(name, loc, argType, constraint);
2512 }
2513 
2514 LogicalResult
2515 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2516                                     ast::Type &inferredType,
2517                                     bool allowNonCoreConstraints) {
2518   for (const ast::ConstraintRef &ref : constraints)
2519     if (failed(validateVariableConstraint(ref, inferredType,
2520                                           allowNonCoreConstraints)))
2521       return failure();
2522   return success();
2523 }
2524 
2525 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2526                                                  ast::Type &inferredType,
2527                                                  bool allowNonCoreConstraints) {
2528   ast::Type constraintType;
2529   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2530     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2531       if (failed(validateTypeConstraintExpr(typeExpr)))
2532         return failure();
2533     }
2534     constraintType = ast::AttributeType::get(ctx);
2535   } else if (const auto *cst =
2536                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2537     constraintType = ast::OperationType::get(ctx, cst->getName());
2538   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2539     constraintType = typeTy;
2540   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2541     constraintType = typeRangeTy;
2542   } else if (const auto *cst =
2543                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2544     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2545       if (failed(validateTypeConstraintExpr(typeExpr)))
2546         return failure();
2547     }
2548     constraintType = valueTy;
2549   } else if (const auto *cst =
2550                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2551     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2552       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2553         return failure();
2554     }
2555     constraintType = valueRangeTy;
2556   } else if (const auto *cst =
2557                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2558     if (!allowNonCoreConstraints) {
2559       return emitError(ref.referenceLoc,
2560                        "`Rewrite` arguments and results are only permitted to "
2561                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2562                        "`TypeRange`, `Value`, `ValueRange`");
2563     }
2564 
2565     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2566     if (inputs.size() != 1) {
2567       return emitErrorAndNote(ref.referenceLoc,
2568                               "`Constraint`s applied via a variable constraint "
2569                               "list must take a single input, but got " +
2570                                   Twine(inputs.size()),
2571                               cst->getLoc(),
2572                               "see definition of constraint here");
2573     }
2574     constraintType = inputs.front()->getType();
2575   } else {
2576     llvm_unreachable("unknown constraint type");
2577   }
2578 
2579   // Check that the constraint type is compatible with the current inferred
2580   // type.
2581   if (!inferredType) {
2582     inferredType = constraintType;
2583   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2584     inferredType = mergedTy;
2585   } else {
2586     return emitError(ref.referenceLoc,
2587                      llvm::formatv("constraint type `{0}` is incompatible "
2588                                    "with the previously inferred type `{1}`",
2589                                    constraintType, inferredType));
2590   }
2591   return success();
2592 }
2593 
2594 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2595   ast::Type typeExprType = typeExpr->getType();
2596   if (typeExprType != typeTy) {
2597     return emitError(typeExpr->getLoc(),
2598                      "expected expression of `Type` in type constraint");
2599   }
2600   return success();
2601 }
2602 
2603 LogicalResult
2604 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2605   ast::Type typeExprType = typeExpr->getType();
2606   if (typeExprType != typeRangeTy) {
2607     return emitError(typeExpr->getLoc(),
2608                      "expected expression of `TypeRange` in type constraint");
2609   }
2610   return success();
2611 }
2612 
2613 //===----------------------------------------------------------------------===//
2614 // Exprs
2615 
2616 FailureOr<ast::CallExpr *>
2617 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2618                        MutableArrayRef<ast::Expr *> arguments) {
2619   ast::Type parentType = parentExpr->getType();
2620 
2621   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2622   if (!callableDecl) {
2623     return emitError(loc,
2624                      llvm::formatv("expected a reference to a callable "
2625                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2626                                    parentType));
2627   }
2628   if (parserContext == ParserContext::Rewrite) {
2629     if (isa<ast::UserConstraintDecl>(callableDecl))
2630       return emitError(
2631           loc, "unable to invoke `Constraint` within a rewrite section");
2632   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2633     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2634   }
2635 
2636   // Verify the arguments of the call.
2637   /// Handle size mismatch.
2638   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2639   if (callArgs.size() != arguments.size()) {
2640     return emitErrorAndNote(
2641         loc,
2642         llvm::formatv("invalid number of arguments for {0} call; expected "
2643                       "{1}, but got {2}",
2644                       callableDecl->getCallableType(), callArgs.size(),
2645                       arguments.size()),
2646         callableDecl->getLoc(),
2647         llvm::formatv("see the definition of {0} here",
2648                       callableDecl->getName()->getName()));
2649   }
2650 
2651   /// Handle argument type mismatch.
2652   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2653     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2654                                   callableDecl->getName()->getName()),
2655                     callableDecl->getLoc());
2656   };
2657   for (auto it : llvm::zip(callArgs, arguments)) {
2658     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2659                                    attachDiagFn)))
2660       return failure();
2661   }
2662 
2663   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2664                                callableDecl->getResultType());
2665 }
2666 
2667 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2668                                                         ast::Decl *decl) {
2669   // Check the type of decl being referenced.
2670   ast::Type declType;
2671   if (isa<ast::ConstraintDecl>(decl))
2672     declType = ast::ConstraintType::get(ctx);
2673   else if (isa<ast::UserRewriteDecl>(decl))
2674     declType = ast::RewriteType::get(ctx);
2675   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2676     declType = varDecl->getType();
2677   else
2678     return emitError(loc, "invalid reference to `" +
2679                               decl->getName()->getName() + "`");
2680 
2681   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2682 }
2683 
2684 FailureOr<ast::DeclRefExpr *>
2685 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2686                                  ArrayRef<ast::ConstraintRef> constraints) {
2687   FailureOr<ast::VariableDecl *> decl =
2688       defineVariableDecl(name, loc, type, constraints);
2689   if (failed(decl))
2690     return failure();
2691   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2692 }
2693 
2694 FailureOr<ast::MemberAccessExpr *>
2695 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2696                                SMRange loc) {
2697   // Validate the member name for the given parent expression.
2698   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2699   if (failed(memberType))
2700     return failure();
2701 
2702   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2703 }
2704 
2705 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2706                                                   StringRef name, SMRange loc) {
2707   ast::Type parentType = parentExpr->getType();
2708   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2709     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2710       return valueRangeTy;
2711 
2712     // Verify member access based on the operation type.
2713     if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
2714       auto results = odsOp->getResults();
2715 
2716       // Handle indexed results.
2717       unsigned index = 0;
2718       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2719           index < results.size()) {
2720         return results[index].isVariadic() ? valueRangeTy : valueTy;
2721       }
2722 
2723       // Handle named results.
2724       const auto *it = llvm::find_if(results, [&](const auto &result) {
2725         return result.getName() == name;
2726       });
2727       if (it != results.end())
2728         return it->isVariadic() ? valueRangeTy : valueTy;
2729     } else if (llvm::isDigit(name[0])) {
2730       // Allow unchecked numeric indexing of the results of unregistered
2731       // operations. It returns a single value.
2732       return valueTy;
2733     }
2734   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2735     // Handle indexed results.
2736     unsigned index = 0;
2737     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2738         index < tupleType.size()) {
2739       return tupleType.getElementTypes()[index];
2740     }
2741 
2742     // Handle named results.
2743     auto elementNames = tupleType.getElementNames();
2744     const auto *it = llvm::find(elementNames, name);
2745     if (it != elementNames.end())
2746       return tupleType.getElementTypes()[it - elementNames.begin()];
2747   }
2748   return emitError(
2749       loc,
2750       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2751                     name, parentType));
2752 }
2753 
2754 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2755     SMRange loc, const ast::OpNameDecl *name,
2756     OpResultTypeContext resultTypeContext,
2757     MutableArrayRef<ast::Expr *> operands,
2758     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2759     MutableArrayRef<ast::Expr *> results) {
2760   Optional<StringRef> opNameRef = name->getName();
2761   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2762 
2763   // Verify the inputs operands.
2764   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2765     return failure();
2766 
2767   // Verify the attribute list.
2768   for (ast::NamedAttributeDecl *attr : attributes) {
2769     // Check for an attribute type, or a type awaiting resolution.
2770     ast::Type attrType = attr->getValue()->getType();
2771     if (!attrType.isa<ast::AttributeType>()) {
2772       return emitError(
2773           attr->getValue()->getLoc(),
2774           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2775     }
2776   }
2777 
2778   assert(
2779       (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2780       "unexpected inferrence when results were explicitly specified");
2781 
2782   // If we aren't relying on type inferrence, or explicit results were provided,
2783   // validate them.
2784   if (resultTypeContext == OpResultTypeContext::Explicit) {
2785     if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2786       return failure();
2787 
2788     // Validate the use of interface based type inferrence for this operation.
2789   } else if (resultTypeContext == OpResultTypeContext::Interface) {
2790     assert(opNameRef &&
2791            "expected valid operation name when inferring operation results");
2792     checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2793   }
2794 
2795   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2796                                     attributes);
2797 }
2798 
2799 LogicalResult
2800 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2801                                   const ods::Operation *odsOp,
2802                                   MutableArrayRef<ast::Expr *> operands) {
2803   return validateOperationOperandsOrResults(
2804       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2805       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2806       valueRangeTy);
2807 }
2808 
2809 LogicalResult
2810 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2811                                  const ods::Operation *odsOp,
2812                                  MutableArrayRef<ast::Expr *> results) {
2813   return validateOperationOperandsOrResults(
2814       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2815       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2816 }
2817 
2818 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2819                                                 const ods::Operation *odsOp) {
2820   // If the operation might not have inferrence support, emit a warning to the
2821   // user. We don't emit an error because the interface might be added to the
2822   // operation at runtime. It's rare, but it could still happen. We emit a
2823   // warning here instead.
2824 
2825   // Handle inferrence warnings for unknown operations.
2826   if (!odsOp) {
2827     ctx.getDiagEngine().emitWarning(
2828         loc, llvm::formatv(
2829                  "operation result types are marked to be inferred, but "
2830                  "`{0}` is unknown. Ensure that `{0}` supports zero "
2831                  "results or implements `InferTypeOpInterface`. Include "
2832                  "the ODS definition of this operation to remove this warning.",
2833                  opName));
2834     return;
2835   }
2836 
2837   // Handle inferrence warnings for known operations that expected at least one
2838   // result, but don't have inference support. An elided results list can mean
2839   // "zero-results", and we don't want to warn when that is the expected
2840   // behavior.
2841   bool requiresInferrence =
2842       llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2843         return !result.isVariableLength();
2844       });
2845   if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2846     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2847         loc,
2848         llvm::formatv("operation result types are marked to be inferred, but "
2849                       "`{0}` does not provide an implementation of "
2850                       "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2851                       "`InferTypeOpInterface` at runtime, or add support to "
2852                       "the ODS definition to remove this warning.",
2853                       opName));
2854     diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2855                      odsOp->getLoc());
2856     return;
2857   }
2858 }
2859 
2860 LogicalResult Parser::validateOperationOperandsOrResults(
2861     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2862     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2863     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2864     ast::Type rangeTy) {
2865   // All operation types accept a single range parameter.
2866   if (values.size() == 1) {
2867     if (failed(convertExpressionTo(values[0], rangeTy)))
2868       return failure();
2869     return success();
2870   }
2871 
2872   /// If the operation has ODS information, we can more accurately verify the
2873   /// values.
2874   if (odsOpLoc) {
2875     if (odsValues.size() != values.size()) {
2876       return emitErrorAndNote(
2877           loc,
2878           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2879                         "{2}, but got {3}",
2880                         groupName, *name, odsValues.size(), values.size()),
2881           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2882     }
2883     auto diagFn = [&](ast::Diagnostic &diag) {
2884       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2885                       *odsOpLoc);
2886     };
2887     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2888       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2889       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2890         return failure();
2891     }
2892     return success();
2893   }
2894 
2895   // Otherwise, accept the value groups as they have been defined and just
2896   // ensure they are one of the expected types.
2897   for (ast::Expr *&valueExpr : values) {
2898     ast::Type valueExprType = valueExpr->getType();
2899 
2900     // Check if this is one of the expected types.
2901     if (valueExprType == rangeTy || valueExprType == singleTy)
2902       continue;
2903 
2904     // If the operand is an Operation, allow converting to a Value or
2905     // ValueRange. This situations arises quite often with nested operation
2906     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2907     if (singleTy == valueTy) {
2908       if (valueExprType.isa<ast::OperationType>()) {
2909         valueExpr = convertOpToValue(valueExpr);
2910         continue;
2911       }
2912     }
2913 
2914     return emitError(
2915         valueExpr->getLoc(),
2916         llvm::formatv(
2917             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2918             singleTy, rangeTy, valueExprType));
2919   }
2920   return success();
2921 }
2922 
2923 FailureOr<ast::TupleExpr *>
2924 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2925                         ArrayRef<StringRef> elementNames) {
2926   for (const ast::Expr *element : elements) {
2927     ast::Type eleTy = element->getType();
2928     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2929       return emitError(
2930           element->getLoc(),
2931           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2932     }
2933   }
2934   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2935 }
2936 
2937 //===----------------------------------------------------------------------===//
2938 // Stmts
2939 
2940 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2941                                                     ast::Expr *rootOp) {
2942   // Check that root is an Operation.
2943   ast::Type rootType = rootOp->getType();
2944   if (!rootType.isa<ast::OperationType>())
2945     return emitError(rootOp->getLoc(), "expected `Op` expression");
2946 
2947   return ast::EraseStmt::create(ctx, loc, rootOp);
2948 }
2949 
2950 FailureOr<ast::ReplaceStmt *>
2951 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2952                           MutableArrayRef<ast::Expr *> replValues) {
2953   // Check that root is an Operation.
2954   ast::Type rootType = rootOp->getType();
2955   if (!rootType.isa<ast::OperationType>()) {
2956     return emitError(
2957         rootOp->getLoc(),
2958         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2959   }
2960 
2961   // If there are multiple replacement values, we implicitly convert any Op
2962   // expressions to the value form.
2963   bool shouldConvertOpToValues = replValues.size() > 1;
2964   for (ast::Expr *&replExpr : replValues) {
2965     ast::Type replType = replExpr->getType();
2966 
2967     // Check that replExpr is an Operation, Value, or ValueRange.
2968     if (replType.isa<ast::OperationType>()) {
2969       if (shouldConvertOpToValues)
2970         replExpr = convertOpToValue(replExpr);
2971       continue;
2972     }
2973 
2974     if (replType != valueTy && replType != valueRangeTy) {
2975       return emitError(replExpr->getLoc(),
2976                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2977                                      "expression, but got `{0}`",
2978                                      replType));
2979     }
2980   }
2981 
2982   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2983 }
2984 
2985 FailureOr<ast::RewriteStmt *>
2986 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2987                           ast::CompoundStmt *rewriteBody) {
2988   // Check that root is an Operation.
2989   ast::Type rootType = rootOp->getType();
2990   if (!rootType.isa<ast::OperationType>()) {
2991     return emitError(
2992         rootOp->getLoc(),
2993         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2994   }
2995 
2996   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2997 }
2998 
2999 //===----------------------------------------------------------------------===//
3000 // Code Completion
3001 //===----------------------------------------------------------------------===//
3002 
3003 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3004   ast::Type parentType = parentExpr->getType();
3005   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3006     codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3007   else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3008     codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3009   return failure();
3010 }
3011 
3012 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
3013   if (opName)
3014     codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3015   return failure();
3016 }
3017 
3018 LogicalResult
3019 Parser::codeCompleteConstraintName(ast::Type inferredType,
3020                                    bool allowNonCoreConstraints,
3021                                    bool allowInlineTypeConstraints) {
3022   codeCompleteContext->codeCompleteConstraintName(
3023       inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
3024       curDeclScope);
3025   return failure();
3026 }
3027 
3028 LogicalResult Parser::codeCompleteDialectName() {
3029   codeCompleteContext->codeCompleteDialectName();
3030   return failure();
3031 }
3032 
3033 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3034   codeCompleteContext->codeCompleteOperationName(dialectName);
3035   return failure();
3036 }
3037 
3038 LogicalResult Parser::codeCompletePatternMetadata() {
3039   codeCompleteContext->codeCompletePatternMetadata();
3040   return failure();
3041 }
3042 
3043 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3044   codeCompleteContext->codeCompleteIncludeFilename(curPath);
3045   return failure();
3046 }
3047 
3048 void Parser::codeCompleteCallSignature(ast::Node *parent,
3049                                        unsigned currentNumArgs) {
3050   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3051   if (!callableDecl)
3052     return;
3053 
3054   codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3055 }
3056 
3057 void Parser::codeCompleteOperationOperandsSignature(
3058     Optional<StringRef> opName, unsigned currentNumOperands) {
3059   codeCompleteContext->codeCompleteOperationOperandsSignature(
3060       opName, currentNumOperands);
3061 }
3062 
3063 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
3064                                                    unsigned currentNumResults) {
3065   codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3066                                                              currentNumResults);
3067 }
3068 
3069 //===----------------------------------------------------------------------===//
3070 // Parser
3071 //===----------------------------------------------------------------------===//
3072 
3073 FailureOr<ast::Module *>
3074 mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3075                         CodeCompleteContext *codeCompleteContext) {
3076   Parser parser(ctx, sourceMgr, codeCompleteContext);
3077   return parser.parseModule();
3078 }
3079