1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/IndentedOstream.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/Argument.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Constraint.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/Tools/PDLL/AST/Context.h"
19 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
20 #include "mlir/Tools/PDLL/AST/Nodes.h"
21 #include "mlir/Tools/PDLL/AST/Types.h"
22 #include "mlir/Tools/PDLL/ODS/Constraint.h"
23 #include "mlir/Tools/PDLL/ODS/Context.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
46   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47          bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49         curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50         valueTy(ast::ValueType::get(ctx)),
51         valueRangeTy(ast::ValueRangeType::get(ctx)),
52         typeTy(ast::TypeType::get(ctx)),
53         typeRangeTy(ast::TypeRangeType::get(ctx)),
54         attrTy(ast::AttributeType::get(ctx)),
55         codeCompleteContext(codeCompleteContext) {}
56 
57   /// Try to parse a new module. Returns nullptr in the case of failure.
58   FailureOr<ast::Module *> parseModule();
59 
60 private:
61   /// The current context of the parser. It allows for the parser to know a bit
62   /// about the construct it is nested within during parsing. This is used
63   /// specifically to provide additional verification during parsing, e.g. to
64   /// prevent using rewrites within a match context, matcher constraints within
65   /// a rewrite section, etc.
66   enum class ParserContext {
67     /// The parser is in the global context.
68     Global,
69     /// The parser is currently within a Constraint, which disallows all types
70     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71     Constraint,
72     /// The parser is currently within the matcher portion of a Pattern, which
73     /// is allows a terminal operation rewrite statement but no other rewrite
74     /// transformations.
75     PatternMatch,
76     /// The parser is currently within a Rewrite, which disallows calls to
77     /// constraints, requires operation expressions to have names, etc.
78     Rewrite,
79   };
80 
81   /// The current specification context of an operations result type. This
82   /// indicates how the result types of an operation may be inferred.
83   enum class OpResultTypeContext {
84     /// The result types of the operation are not known to be inferred.
85     Explicit,
86     /// The result types of the operation are inferred from the root input of a
87     /// `replace` statement.
88     Replacement,
89     /// The result types of the operation are inferred by using the
90     /// `InferTypeOpInterface` interface provided by the operation.
91     Interface,
92   };
93 
94   //===--------------------------------------------------------------------===//
95   // Parsing
96   //===--------------------------------------------------------------------===//
97 
98   /// Push a new decl scope onto the lexer.
99   ast::DeclScope *pushDeclScope() {
100     ast::DeclScope *newScope =
101         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102     return (curDeclScope = newScope);
103   }
104   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105 
106   /// Pop the last decl scope from the lexer.
107   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108 
109   /// Parse the body of an AST module.
110   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111 
112   /// Try to convert the given expression to `type`. Returns failure and emits
113   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114   /// invoked to attach notes to the emitted error diagnostic. On success,
115   /// `expr` is updated to the expression used to convert to `type`.
116   LogicalResult convertExpressionTo(
117       ast::Expr *&expr, ast::Type type,
118       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
119 
120   /// Given an operation expression, convert it to a Value or ValueRange
121   /// typed expression.
122   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
123 
124   /// Lookup ODS information for the given operation, returns nullptr if no
125   /// information is found.
126   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
127     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
128   }
129 
130   /// Process the given documentation string, or return an empty string if
131   /// documentation isn't enabled.
132   StringRef processDoc(StringRef doc) {
133     return enableDocumentation ? doc : StringRef();
134   }
135 
136   /// Process the given documentation string and format it, or return an empty
137   /// string if documentation isn't enabled.
138   std::string processAndFormatDoc(const Twine &doc) {
139     if (!enableDocumentation)
140       return "";
141     std::string docStr;
142     {
143       llvm::raw_string_ostream docOS(docStr);
144       std::string tmpDocStr = doc.str();
145       raw_indented_ostream(docOS).printReindented(
146           StringRef(tmpDocStr).rtrim(" \t"));
147     }
148     return docStr;
149   }
150 
151   //===--------------------------------------------------------------------===//
152   // Directives
153 
154   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
155   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
156   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
157                                SmallVectorImpl<ast::Decl *> &decls);
158 
159   /// Process the records of a parsed tablegen include file.
160   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
161                                SmallVectorImpl<ast::Decl *> &decls);
162 
163   /// Create a user defined native constraint for a constraint imported from
164   /// ODS.
165   template <typename ConstraintT>
166   ast::Decl *
167   createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
168                                     SMRange loc, ast::Type type,
169                                     StringRef nativeType, StringRef docString);
170   template <typename ConstraintT>
171   ast::Decl *
172   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
173                                     SMRange loc, ast::Type type,
174                                     StringRef nativeType);
175 
176   //===--------------------------------------------------------------------===//
177   // Decls
178 
179   /// This structure contains the set of pattern metadata that may be parsed.
180   struct ParsedPatternMetadata {
181     Optional<uint16_t> benefit;
182     bool hasBoundedRecursion = false;
183   };
184 
185   FailureOr<ast::Decl *> parseTopLevelDecl();
186   FailureOr<ast::NamedAttributeDecl *>
187   parseNamedAttributeDecl(Optional<StringRef> parentOpName);
188 
189   /// Parse an argument variable as part of the signature of a
190   /// UserConstraintDecl or UserRewriteDecl.
191   FailureOr<ast::VariableDecl *> parseArgumentDecl();
192 
193   /// Parse a result variable as part of the signature of a UserConstraintDecl
194   /// or UserRewriteDecl.
195   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
196 
197   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
198   /// defined in a non-global context.
199   FailureOr<ast::UserConstraintDecl *>
200   parseUserConstraintDecl(bool isInline = false);
201 
202   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
203   /// non-global context, such as within a Pattern/Constraint/etc.
204   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
205 
206   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
207   /// PDLL constructs.
208   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
209       const ast::Name &name, bool isInline,
210       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
211       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
212 
213   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
214   /// defined in a non-global context.
215   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
216 
217   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
218   /// non-global context, such as within a Pattern/Rewrite/etc.
219   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
220 
221   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
222   /// PDLL constructs.
223   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
224       const ast::Name &name, bool isInline,
225       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
226       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
227 
228   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
229   /// effectively the same syntax, and only differ on slight semantics (given
230   /// the different parsing contexts).
231   template <typename T, typename ParseUserPDLLDeclFnT>
232   FailureOr<T *> parseUserConstraintOrRewriteDecl(
233       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
234       StringRef anonymousNamePrefix, bool isInline);
235 
236   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
237   /// These decls have effectively the same syntax.
238   template <typename T>
239   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
240       const ast::Name &name, bool isInline,
241       ArrayRef<ast::VariableDecl *> arguments,
242       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
243 
244   /// Parse the functional signature (i.e. the arguments and results) of a
245   /// UserConstraintDecl or UserRewriteDecl.
246   LogicalResult parseUserConstraintOrRewriteSignature(
247       SmallVectorImpl<ast::VariableDecl *> &arguments,
248       SmallVectorImpl<ast::VariableDecl *> &results,
249       ast::DeclScope *&argumentScope, ast::Type &resultType);
250 
251   /// Validate the return (which if present is specified by bodyIt) of a
252   /// UserConstraintDecl or UserRewriteDecl.
253   LogicalResult validateUserConstraintOrRewriteReturn(
254       StringRef declType, ast::CompoundStmt *body,
255       ArrayRef<ast::Stmt *>::iterator bodyIt,
256       ArrayRef<ast::Stmt *>::iterator bodyE,
257       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
258 
259   FailureOr<ast::CompoundStmt *>
260   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
261                   bool expectTerminalSemicolon = true);
262   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
263   FailureOr<ast::Decl *> parsePatternDecl();
264   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
265 
266   /// Check to see if a decl has already been defined with the given name, if
267   /// one has emit and error and return failure. Returns success otherwise.
268   LogicalResult checkDefineNamedDecl(const ast::Name &name);
269 
270   /// Try to define a variable decl with the given components, returns the
271   /// variable on success.
272   FailureOr<ast::VariableDecl *>
273   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
274                      ast::Expr *initExpr,
275                      ArrayRef<ast::ConstraintRef> constraints);
276   FailureOr<ast::VariableDecl *>
277   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
278                      ArrayRef<ast::ConstraintRef> constraints);
279 
280   /// Parse the constraint reference list for a variable decl.
281   LogicalResult parseVariableDeclConstraintList(
282       SmallVectorImpl<ast::ConstraintRef> &constraints);
283 
284   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
285   FailureOr<ast::Expr *> parseTypeConstraintExpr();
286 
287   /// Try to parse a single reference to a constraint. `typeConstraint` is the
288   /// location of a previously parsed type constraint for the entity that will
289   /// be constrained by the parsed constraint. `existingConstraints` are any
290   /// existing constraints that have already been parsed for the same entity
291   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
292   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
293   /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
294   /// constraints) may be used with the variable.
295   FailureOr<ast::ConstraintRef>
296   parseConstraint(Optional<SMRange> &typeConstraint,
297                   ArrayRef<ast::ConstraintRef> existingConstraints,
298                   bool allowInlineTypeConstraints,
299                   bool allowNonCoreConstraints);
300 
301   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
302   /// argument or result variable. The constraints for these variables do not
303   /// allow inline type constraints, and only permit a single constraint.
304   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
305 
306   //===--------------------------------------------------------------------===//
307   // Exprs
308 
309   FailureOr<ast::Expr *> parseExpr();
310 
311   /// Identifier expressions.
312   FailureOr<ast::Expr *> parseAttributeExpr();
313   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
314   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
315   FailureOr<ast::Expr *> parseIdentifierExpr();
316   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
317   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
318   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
319   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
320   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
321   FailureOr<ast::Expr *>
322   parseOperationExpr(OpResultTypeContext inputResultTypeContext =
323                          OpResultTypeContext::Explicit);
324   FailureOr<ast::Expr *> parseTupleExpr();
325   FailureOr<ast::Expr *> parseTypeExpr();
326   FailureOr<ast::Expr *> parseUnderscoreExpr();
327 
328   //===--------------------------------------------------------------------===//
329   // Stmts
330 
331   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
332   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
333   FailureOr<ast::EraseStmt *> parseEraseStmt();
334   FailureOr<ast::LetStmt *> parseLetStmt();
335   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
336   FailureOr<ast::ReturnStmt *> parseReturnStmt();
337   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
338 
339   //===--------------------------------------------------------------------===//
340   // Creation+Analysis
341   //===--------------------------------------------------------------------===//
342 
343   //===--------------------------------------------------------------------===//
344   // Decls
345 
346   /// Try to extract a callable from the given AST node. Returns nullptr on
347   /// failure.
348   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
349 
350   /// Try to create a pattern decl with the given components, returning the
351   /// Pattern on success.
352   FailureOr<ast::PatternDecl *>
353   createPatternDecl(SMRange loc, const ast::Name *name,
354                     const ParsedPatternMetadata &metadata,
355                     ast::CompoundStmt *body);
356 
357   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
358   /// of results, defined as part of the signature.
359   ast::Type
360   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
361 
362   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
363   template <typename T>
364   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
365       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
366       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
367       ast::CompoundStmt *body);
368 
369   /// Try to create a variable decl with the given components, returning the
370   /// Variable on success.
371   FailureOr<ast::VariableDecl *>
372   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
373                      ArrayRef<ast::ConstraintRef> constraints);
374 
375   /// Create a variable for an argument or result defined as part of the
376   /// signature of a UserConstraintDecl/UserRewriteDecl.
377   FailureOr<ast::VariableDecl *>
378   createArgOrResultVariableDecl(StringRef name, SMRange loc,
379                                 const ast::ConstraintRef &constraint);
380 
381   /// Validate the constraints used to constraint a variable decl.
382   /// `inferredType` is the type of the variable inferred by the constraints
383   /// within the list, and is updated to the most refined type as determined by
384   /// the constraints. Returns success if the constraint list is valid, failure
385   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
386   /// defined constraints) may be used with the variable.
387   LogicalResult
388   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
389                               ast::Type &inferredType,
390                               bool allowNonCoreConstraints = true);
391   /// Validate a single reference to a constraint. `inferredType` contains the
392   /// currently inferred variabled type and is refined within the type defined
393   /// by the constraint. Returns success if the constraint is valid, failure
394   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
395   /// defined constraints) may be used with the variable.
396   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
397                                            ast::Type &inferredType,
398                                            bool allowNonCoreConstraints = true);
399   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
400   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
401 
402   //===--------------------------------------------------------------------===//
403   // Exprs
404 
405   FailureOr<ast::CallExpr *>
406   createCallExpr(SMRange loc, ast::Expr *parentExpr,
407                  MutableArrayRef<ast::Expr *> arguments);
408   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
409   FailureOr<ast::DeclRefExpr *>
410   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
411                            ArrayRef<ast::ConstraintRef> constraints);
412   FailureOr<ast::MemberAccessExpr *>
413   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
414 
415   /// Validate the member access `name` into the given parent expression. On
416   /// success, this also returns the type of the member accessed.
417   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
418                                             StringRef name, SMRange loc);
419   FailureOr<ast::OperationExpr *>
420   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
421                       OpResultTypeContext resultTypeContext,
422                       MutableArrayRef<ast::Expr *> operands,
423                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
424                       MutableArrayRef<ast::Expr *> results);
425   LogicalResult
426   validateOperationOperands(SMRange loc, Optional<StringRef> name,
427                             const ods::Operation *odsOp,
428                             MutableArrayRef<ast::Expr *> operands);
429   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
430                                          const ods::Operation *odsOp,
431                                          MutableArrayRef<ast::Expr *> results);
432   void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
433                                           const ods::Operation *odsOp);
434   LogicalResult validateOperationOperandsOrResults(
435       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
436       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
437       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
438       ast::Type rangeTy);
439   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
440                                               ArrayRef<ast::Expr *> elements,
441                                               ArrayRef<StringRef> elementNames);
442 
443   //===--------------------------------------------------------------------===//
444   // Stmts
445 
446   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
447   FailureOr<ast::ReplaceStmt *>
448   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
449                     MutableArrayRef<ast::Expr *> replValues);
450   FailureOr<ast::RewriteStmt *>
451   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
452                     ast::CompoundStmt *rewriteBody);
453 
454   //===--------------------------------------------------------------------===//
455   // Code Completion
456   //===--------------------------------------------------------------------===//
457 
458   /// The set of various code completion methods. Every completion method
459   /// returns `failure` to stop the parsing process after providing completion
460   /// results.
461 
462   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
463   LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
464   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
465                                            bool allowNonCoreConstraints,
466                                            bool allowInlineTypeConstraints);
467   LogicalResult codeCompleteDialectName();
468   LogicalResult codeCompleteOperationName(StringRef dialectName);
469   LogicalResult codeCompletePatternMetadata();
470   LogicalResult codeCompleteIncludeFilename(StringRef curPath);
471 
472   void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
473   void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
474                                               unsigned currentNumOperands);
475   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
476                                              unsigned currentNumResults);
477 
478   //===--------------------------------------------------------------------===//
479   // Lexer Utilities
480   //===--------------------------------------------------------------------===//
481 
482   /// If the current token has the specified kind, consume it and return true.
483   /// If not, return false.
484   bool consumeIf(Token::Kind kind) {
485     if (curToken.isNot(kind))
486       return false;
487     consumeToken(kind);
488     return true;
489   }
490 
491   /// Advance the current lexer onto the next token.
492   void consumeToken() {
493     assert(curToken.isNot(Token::eof, Token::error) &&
494            "shouldn't advance past EOF or errors");
495     curToken = lexer.lexToken();
496   }
497 
498   /// Advance the current lexer onto the next token, asserting what the expected
499   /// current token is. This is preferred to the above method because it leads
500   /// to more self-documenting code with better checking.
501   void consumeToken(Token::Kind kind) {
502     assert(curToken.is(kind) && "consumed an unexpected token");
503     consumeToken();
504   }
505 
506   /// Reset the lexer to the location at the given position.
507   void resetToken(SMRange tokLoc) {
508     lexer.resetPointer(tokLoc.Start.getPointer());
509     curToken = lexer.lexToken();
510   }
511 
512   /// Consume the specified token if present and return success. On failure,
513   /// output a diagnostic and return failure.
514   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
515     if (curToken.getKind() != kind)
516       return emitError(curToken.getLoc(), msg);
517     consumeToken();
518     return success();
519   }
520   LogicalResult emitError(SMRange loc, const Twine &msg) {
521     lexer.emitError(loc, msg);
522     return failure();
523   }
524   LogicalResult emitError(const Twine &msg) {
525     return emitError(curToken.getLoc(), msg);
526   }
527   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
528                                  const Twine &note) {
529     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
530     return failure();
531   }
532 
533   //===--------------------------------------------------------------------===//
534   // Fields
535   //===--------------------------------------------------------------------===//
536 
537   /// The owning AST context.
538   ast::Context &ctx;
539 
540   /// The lexer of this parser.
541   Lexer lexer;
542 
543   /// The current token within the lexer.
544   Token curToken;
545 
546   /// A flag indicating if the parser should add documentation to AST nodes when
547   /// viable.
548   bool enableDocumentation;
549 
550   /// The most recently defined decl scope.
551   ast::DeclScope *curDeclScope = nullptr;
552   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
553 
554   /// The current context of the parser.
555   ParserContext parserContext = ParserContext::Global;
556 
557   /// Cached types to simplify verification and expression creation.
558   ast::Type valueTy, valueRangeTy;
559   ast::Type typeTy, typeRangeTy;
560   ast::Type attrTy;
561 
562   /// A counter used when naming anonymous constraints and rewrites.
563   unsigned anonymousDeclNameCounter = 0;
564 
565   /// The optional code completion context.
566   CodeCompleteContext *codeCompleteContext;
567 };
568 } // namespace
569 
570 FailureOr<ast::Module *> Parser::parseModule() {
571   SMLoc moduleLoc = curToken.getStartLoc();
572   pushDeclScope();
573 
574   // Parse the top-level decls of the module.
575   SmallVector<ast::Decl *> decls;
576   if (failed(parseModuleBody(decls)))
577     return popDeclScope(), failure();
578 
579   popDeclScope();
580   return ast::Module::create(ctx, moduleLoc, decls);
581 }
582 
583 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
584   while (curToken.isNot(Token::eof)) {
585     if (curToken.is(Token::directive)) {
586       if (failed(parseDirective(decls)))
587         return failure();
588       continue;
589     }
590 
591     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
592     if (failed(decl))
593       return failure();
594     decls.push_back(*decl);
595   }
596   return success();
597 }
598 
599 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
600   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
601                                                  valueRangeTy);
602 }
603 
604 LogicalResult Parser::convertExpressionTo(
605     ast::Expr *&expr, ast::Type type,
606     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
607   ast::Type exprType = expr->getType();
608   if (exprType == type)
609     return success();
610 
611   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
612     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
613         expr->getLoc(), llvm::formatv("unable to convert expression of type "
614                                       "`{0}` to the expected type of "
615                                       "`{1}`",
616                                       exprType, type));
617     if (noteAttachFn)
618       noteAttachFn(*diag);
619     return diag;
620   };
621 
622   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
623     // Two operation types are compatible if they have the same name, or if the
624     // expected type is more general.
625     if (auto opType = type.dyn_cast<ast::OperationType>()) {
626       if (opType.getName())
627         return emitConvertError();
628       return success();
629     }
630 
631     // An operation can always convert to a ValueRange.
632     if (type == valueRangeTy) {
633       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
634                                                      valueRangeTy);
635       return success();
636     }
637 
638     // Allow conversion to a single value by constraining the result range.
639     if (type == valueTy) {
640       // If the operation is registered, we can verify if it can ever have a
641       // single result.
642       if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
643         if (odsOp->getResults().empty()) {
644           return emitConvertError()->attachNote(
645               llvm::formatv("see the definition of `{0}`, which was defined "
646                             "with zero results",
647                             odsOp->getName()),
648               odsOp->getLoc());
649         }
650 
651         unsigned numSingleResults = llvm::count_if(
652             odsOp->getResults(), [](const ods::OperandOrResult &result) {
653               return result.getVariableLengthKind() ==
654                      ods::VariableLengthKind::Single;
655             });
656         if (numSingleResults > 1) {
657           return emitConvertError()->attachNote(
658               llvm::formatv("see the definition of `{0}`, which was defined "
659                             "with at least {1} results",
660                             odsOp->getName(), numSingleResults),
661               odsOp->getLoc());
662         }
663       }
664 
665       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
666                                                      valueTy);
667       return success();
668     }
669     return emitConvertError();
670   }
671 
672   // FIXME: Decide how to allow/support converting a single result to multiple,
673   // and multiple to a single result. For now, we just allow Single->Range,
674   // but this isn't something really supported in the PDL dialect. We should
675   // figure out some way to support both.
676   if ((exprType == valueTy || exprType == valueRangeTy) &&
677       (type == valueTy || type == valueRangeTy))
678     return success();
679   if ((exprType == typeTy || exprType == typeRangeTy) &&
680       (type == typeTy || type == typeRangeTy))
681     return success();
682 
683   // Handle tuple types.
684   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
685     auto tupleType = type.dyn_cast<ast::TupleType>();
686     if (!tupleType || tupleType.size() != exprTupleType.size())
687       return emitConvertError();
688 
689     // Build a new tuple expression using each of the elements of the current
690     // tuple.
691     SmallVector<ast::Expr *> newExprs;
692     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
693       newExprs.push_back(ast::MemberAccessExpr::create(
694           ctx, expr->getLoc(), expr, llvm::to_string(i),
695           exprTupleType.getElementTypes()[i]));
696 
697       auto diagFn = [&](ast::Diagnostic &diag) {
698         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
699                                       i, exprTupleType));
700         if (noteAttachFn)
701           noteAttachFn(diag);
702       };
703       if (failed(convertExpressionTo(newExprs.back(),
704                                      tupleType.getElementTypes()[i], diagFn)))
705         return failure();
706     }
707     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
708                                   tupleType.getElementNames());
709     return success();
710   }
711 
712   return emitConvertError();
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Directives
717 
718 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
719   StringRef directive = curToken.getSpelling();
720   if (directive == "#include")
721     return parseInclude(decls);
722 
723   return emitError("unknown directive `" + directive + "`");
724 }
725 
726 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
727   SMRange loc = curToken.getLoc();
728   consumeToken(Token::directive);
729 
730   // Handle code completion of the include file path.
731   if (curToken.is(Token::code_complete_string))
732     return codeCompleteIncludeFilename(curToken.getStringValue());
733 
734   // Parse the file being included.
735   if (!curToken.isString())
736     return emitError(loc,
737                      "expected string file name after `include` directive");
738   SMRange fileLoc = curToken.getLoc();
739   std::string filenameStr = curToken.getStringValue();
740   StringRef filename = filenameStr;
741   consumeToken();
742 
743   // Check the type of include. If ending with `.pdll`, this is another pdl file
744   // to be parsed along with the current module.
745   if (filename.endswith(".pdll")) {
746     if (failed(lexer.pushInclude(filename, fileLoc)))
747       return emitError(fileLoc,
748                        "unable to open include file `" + filename + "`");
749 
750     // If we added the include successfully, parse it into the current module.
751     // Make sure to update to the next token after we finish parsing the nested
752     // file.
753     curToken = lexer.lexToken();
754     LogicalResult result = parseModuleBody(decls);
755     curToken = lexer.lexToken();
756     return result;
757   }
758 
759   // Otherwise, this must be a `.td` include.
760   if (filename.endswith(".td"))
761     return parseTdInclude(filename, fileLoc, decls);
762 
763   return emitError(fileLoc,
764                    "expected include filename to end with `.pdll` or `.td`");
765 }
766 
767 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
768                                      SmallVectorImpl<ast::Decl *> &decls) {
769   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
770 
771   // Use the source manager to open the file, but don't yet add it.
772   std::string includedFile;
773   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
774       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
775   if (!includeBuffer)
776     return emitError(fileLoc, "unable to open include file `" + filename + "`");
777 
778   // Setup the source manager for parsing the tablegen file.
779   llvm::SourceMgr tdSrcMgr;
780   tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
781   tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
782 
783   // This class provides a context argument for the llvm::SourceMgr diagnostic
784   // handler.
785   struct DiagHandlerContext {
786     Parser &parser;
787     StringRef filename;
788     llvm::SMRange loc;
789   } handlerContext{*this, filename, fileLoc};
790 
791   // Set the diagnostic handler for the tablegen source manager.
792   tdSrcMgr.setDiagHandler(
793       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
794         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
795         (void)ctx->parser.emitError(
796             ctx->loc,
797             llvm::formatv("error while processing include file `{0}`: {1}",
798                           ctx->filename, diag.getMessage()));
799       },
800       &handlerContext);
801 
802   // Parse the tablegen file.
803   llvm::RecordKeeper tdRecords;
804   if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
805     return failure();
806 
807   // Process the parsed records.
808   processTdIncludeRecords(tdRecords, decls);
809 
810   // After we are done processing, move all of the tablegen source buffers to
811   // the main parser source mgr. This allows for directly using source locations
812   // from the .td files without needing to remap them.
813   parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
814   return success();
815 }
816 
817 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
818                                      SmallVectorImpl<ast::Decl *> &decls) {
819   // Return the length kind of the given value.
820   auto getLengthKind = [](const auto &value) {
821     if (value.isOptional())
822       return ods::VariableLengthKind::Optional;
823     return value.isVariadic() ? ods::VariableLengthKind::Variadic
824                               : ods::VariableLengthKind::Single;
825   };
826 
827   // Insert a type constraint into the ODS context.
828   ods::Context &odsContext = ctx.getODSContext();
829   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
830       -> const ods::TypeConstraint & {
831     return odsContext.insertTypeConstraint(
832         cst.constraint.getUniqueDefName(),
833         processDoc(cst.constraint.getSummary()),
834         cst.constraint.getCPPClassName());
835   };
836   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
837     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
838   };
839 
840   // Process the parsed tablegen records to build ODS information.
841   /// Operations.
842   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
843     tblgen::Operator op(def);
844 
845     // Check to see if this operation is known to support type inferrence.
846     bool supportsResultTypeInferrence =
847         op.getTrait("::mlir::InferTypeOpInterface::Trait");
848 
849     bool inserted = false;
850     ods::Operation *odsOp = nullptr;
851     std::tie(odsOp, inserted) = odsContext.insertOperation(
852         op.getOperationName(), processDoc(op.getSummary()),
853         processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
854         supportsResultTypeInferrence, op.getLoc().front());
855 
856     // Ignore operations that have already been added.
857     if (!inserted)
858       continue;
859 
860     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
861       odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
862                              odsContext.insertAttributeConstraint(
863                                  attr.attr.getUniqueDefName(),
864                                  processDoc(attr.attr.getSummary()),
865                                  attr.attr.getStorageType()));
866     }
867     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
868       odsOp->appendOperand(operand.name, getLengthKind(operand),
869                            addTypeConstraint(operand));
870     }
871     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
872       odsOp->appendResult(result.name, getLengthKind(result),
873                           addTypeConstraint(result));
874     }
875   }
876   /// Attr constraints.
877   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
878     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
879       tblgen::Attribute constraint(def);
880       decls.push_back(
881           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
882               constraint, convertLocToRange(def->getLoc().front()), attrTy,
883               constraint.getStorageType()));
884     }
885   }
886   /// Type constraints.
887   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
888     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
889       tblgen::TypeConstraint constraint(def);
890       decls.push_back(
891           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
892               constraint, convertLocToRange(def->getLoc().front()), typeTy,
893               constraint.getCPPClassName()));
894     }
895   }
896   /// Interfaces.
897   ast::Type opTy = ast::OperationType::get(ctx);
898   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
899     StringRef name = def->getName();
900     if (def->isAnonymous() || curDeclScope->lookup(name) ||
901         def->isSubClassOf("DeclareInterfaceMethods"))
902       continue;
903     SMRange loc = convertLocToRange(def->getLoc().front());
904 
905     std::string cppClassName =
906         llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
907                       def->getValueAsString("cppClassName"))
908             .str();
909     std::string codeBlock =
910         llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
911                       cppClassName)
912             .str();
913 
914     std::string desc =
915         processAndFormatDoc(def->getValueAsString("description"));
916     if (def->isSubClassOf("OpInterface")) {
917       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
918           name, codeBlock, loc, opTy, cppClassName, desc));
919     } else if (def->isSubClassOf("AttrInterface")) {
920       decls.push_back(
921           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
922               name, codeBlock, loc, attrTy, cppClassName, desc));
923     } else if (def->isSubClassOf("TypeInterface")) {
924       decls.push_back(
925           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
926               name, codeBlock, loc, typeTy, cppClassName, desc));
927     }
928   }
929 }
930 
931 template <typename ConstraintT>
932 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
933     StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
934     StringRef nativeType, StringRef docString) {
935   // Build the single input parameter.
936   ast::DeclScope *argScope = pushDeclScope();
937   auto *paramVar = ast::VariableDecl::create(
938       ctx, ast::Name::create(ctx, "self", loc), type,
939       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
940   argScope->add(paramVar);
941   popDeclScope();
942 
943   // Build the native constraint.
944   auto *constraintDecl = ast::UserConstraintDecl::createNative(
945       ctx, ast::Name::create(ctx, name, loc), paramVar,
946       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
947   constraintDecl->setDocComment(ctx, docString);
948   curDeclScope->add(constraintDecl);
949   return constraintDecl;
950 }
951 
952 template <typename ConstraintT>
953 ast::Decl *
954 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
955                                           SMRange loc, ast::Type type,
956                                           StringRef nativeType) {
957   // Format the condition template.
958   tblgen::FmtContext fmtContext;
959   fmtContext.withSelf("self");
960   std::string codeBlock = tblgen::tgfmt(
961       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
962       &fmtContext);
963 
964   // If documentation was enabled, build the doc string for the generated
965   // constraint. It would be nice to do this lazily, but TableGen information is
966   // destroyed after we finish parsing the file.
967   std::string docString;
968   if (enableDocumentation) {
969     StringRef desc = constraint.getDescription();
970     docString = processAndFormatDoc(
971         constraint.getSummary() +
972         (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
973   }
974 
975   return createODSNativePDLLConstraintDecl<ConstraintT>(
976       constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
977       docString);
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // Decls
982 
983 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
984   FailureOr<ast::Decl *> decl;
985   switch (curToken.getKind()) {
986   case Token::kw_Constraint:
987     decl = parseUserConstraintDecl();
988     break;
989   case Token::kw_Pattern:
990     decl = parsePatternDecl();
991     break;
992   case Token::kw_Rewrite:
993     decl = parseUserRewriteDecl();
994     break;
995   default:
996     return emitError("expected top-level declaration, such as a `Pattern`");
997   }
998   if (failed(decl))
999     return failure();
1000 
1001   // If the decl has a name, add it to the current scope.
1002   if (const ast::Name *name = (*decl)->getName()) {
1003     if (failed(checkDefineNamedDecl(*name)))
1004       return failure();
1005     curDeclScope->add(*decl);
1006   }
1007   return decl;
1008 }
1009 
1010 FailureOr<ast::NamedAttributeDecl *>
1011 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
1012   // Check for name code completion.
1013   if (curToken.is(Token::code_complete))
1014     return codeCompleteAttributeName(parentOpName);
1015 
1016   std::string attrNameStr;
1017   if (curToken.isString())
1018     attrNameStr = curToken.getStringValue();
1019   else if (curToken.is(Token::identifier) || curToken.isKeyword())
1020     attrNameStr = curToken.getSpelling().str();
1021   else
1022     return emitError("expected identifier or string attribute name");
1023   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1024   consumeToken();
1025 
1026   // Check for a value of the attribute.
1027   ast::Expr *attrValue = nullptr;
1028   if (consumeIf(Token::equal)) {
1029     FailureOr<ast::Expr *> attrExpr = parseExpr();
1030     if (failed(attrExpr))
1031       return failure();
1032     attrValue = *attrExpr;
1033   } else {
1034     // If there isn't a concrete value, create an expression representing a
1035     // UnitAttr.
1036     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1037   }
1038 
1039   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1040 }
1041 
1042 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1043     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1044     bool expectTerminalSemicolon) {
1045   consumeToken(Token::equal_arrow);
1046 
1047   // Parse the single statement of the lambda body.
1048   SMLoc bodyStartLoc = curToken.getStartLoc();
1049   pushDeclScope();
1050   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1051   bool failedToParse =
1052       failed(singleStatement) || failed(processStatementFn(*singleStatement));
1053   popDeclScope();
1054   if (failedToParse)
1055     return failure();
1056 
1057   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1058   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1059 }
1060 
1061 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1062   // Ensure that the argument is named.
1063   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1064     return emitError("expected identifier argument name");
1065 
1066   // Parse the argument similarly to a normal variable.
1067   StringRef name = curToken.getSpelling();
1068   SMRange nameLoc = curToken.getLoc();
1069   consumeToken();
1070 
1071   if (failed(
1072           parseToken(Token::colon, "expected `:` before argument constraint")))
1073     return failure();
1074 
1075   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1076   if (failed(cst))
1077     return failure();
1078 
1079   return createArgOrResultVariableDecl(name, nameLoc, *cst);
1080 }
1081 
1082 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1083   // Check to see if this result is named.
1084   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1085     // Check to see if this name actually refers to a Constraint.
1086     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1087     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1088       // If yes, and this is a Rewrite, give a nice error message as non-Core
1089       // constraints are not supported on Rewrite results.
1090       if (parserContext == ParserContext::Rewrite) {
1091         return emitError(
1092             "`Rewrite` results are only permitted to use core constraints, "
1093             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1094       }
1095 
1096       // Otherwise, parse this as an unnamed result variable.
1097     } else {
1098       // If it wasn't a constraint, parse the result similarly to a variable. If
1099       // there is already an existing decl, we will emit an error when defining
1100       // this variable later.
1101       StringRef name = curToken.getSpelling();
1102       SMRange nameLoc = curToken.getLoc();
1103       consumeToken();
1104 
1105       if (failed(parseToken(Token::colon,
1106                             "expected `:` before result constraint")))
1107         return failure();
1108 
1109       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1110       if (failed(cst))
1111         return failure();
1112 
1113       return createArgOrResultVariableDecl(name, nameLoc, *cst);
1114     }
1115   }
1116 
1117   // If it isn't named, we parse the constraint directly and create an unnamed
1118   // result variable.
1119   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120   if (failed(cst))
1121     return failure();
1122 
1123   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1124 }
1125 
1126 FailureOr<ast::UserConstraintDecl *>
1127 Parser::parseUserConstraintDecl(bool isInline) {
1128   // Constraints and rewrites have very similar formats, dispatch to a shared
1129   // interface for parsing.
1130   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1131       [&](auto &&...args) {
1132         return this->parseUserPDLLConstraintDecl(args...);
1133       },
1134       ParserContext::Constraint, "constraint", isInline);
1135 }
1136 
1137 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1138   FailureOr<ast::UserConstraintDecl *> decl =
1139       parseUserConstraintDecl(/*isInline=*/true);
1140   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1141     return failure();
1142 
1143   curDeclScope->add(*decl);
1144   return decl;
1145 }
1146 
1147 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1148     const ast::Name &name, bool isInline,
1149     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1150     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1151   // Push the argument scope back onto the list, so that the body can
1152   // reference arguments.
1153   pushDeclScope(argumentScope);
1154 
1155   // Parse the body of the constraint. The body is either defined as a compound
1156   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1157   ast::CompoundStmt *body;
1158   if (curToken.is(Token::equal_arrow)) {
1159     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1160         [&](ast::Stmt *&stmt) -> LogicalResult {
1161           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1162           if (!stmtExpr) {
1163             return emitError(stmt->getLoc(),
1164                              "expected `Constraint` lambda body to contain a "
1165                              "single expression");
1166           }
1167           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1168           return success();
1169         },
1170         /*expectTerminalSemicolon=*/!isInline);
1171     if (failed(bodyResult))
1172       return failure();
1173     body = *bodyResult;
1174   } else {
1175     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1176     if (failed(bodyResult))
1177       return failure();
1178     body = *bodyResult;
1179 
1180     // Verify the structure of the body.
1181     auto bodyIt = body->begin(), bodyE = body->end();
1182     for (; bodyIt != bodyE; ++bodyIt)
1183       if (isa<ast::ReturnStmt>(*bodyIt))
1184         break;
1185     if (failed(validateUserConstraintOrRewriteReturn(
1186             "Constraint", body, bodyIt, bodyE, results, resultType)))
1187       return failure();
1188   }
1189   popDeclScope();
1190 
1191   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1192       name, arguments, results, resultType, body);
1193 }
1194 
1195 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1196   // Constraints and rewrites have very similar formats, dispatch to a shared
1197   // interface for parsing.
1198   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1199       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1200       ParserContext::Rewrite, "rewrite", isInline);
1201 }
1202 
1203 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1204   FailureOr<ast::UserRewriteDecl *> decl =
1205       parseUserRewriteDecl(/*isInline=*/true);
1206   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1207     return failure();
1208 
1209   curDeclScope->add(*decl);
1210   return decl;
1211 }
1212 
1213 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1214     const ast::Name &name, bool isInline,
1215     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1216     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1217   // Push the argument scope back onto the list, so that the body can
1218   // reference arguments.
1219   curDeclScope = argumentScope;
1220   ast::CompoundStmt *body;
1221   if (curToken.is(Token::equal_arrow)) {
1222     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1223         [&](ast::Stmt *&statement) -> LogicalResult {
1224           if (isa<ast::OpRewriteStmt>(statement))
1225             return success();
1226 
1227           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1228           if (!statementExpr) {
1229             return emitError(
1230                 statement->getLoc(),
1231                 "expected `Rewrite` lambda body to contain a single expression "
1232                 "or an operation rewrite statement; such as `erase`, "
1233                 "`replace`, or `rewrite`");
1234           }
1235           statement =
1236               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1237           return success();
1238         },
1239         /*expectTerminalSemicolon=*/!isInline);
1240     if (failed(bodyResult))
1241       return failure();
1242     body = *bodyResult;
1243   } else {
1244     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1245     if (failed(bodyResult))
1246       return failure();
1247     body = *bodyResult;
1248   }
1249   popDeclScope();
1250 
1251   // Verify the structure of the body.
1252   auto bodyIt = body->begin(), bodyE = body->end();
1253   for (; bodyIt != bodyE; ++bodyIt)
1254     if (isa<ast::ReturnStmt>(*bodyIt))
1255       break;
1256   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1257                                                    bodyE, results, resultType)))
1258     return failure();
1259   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1260       name, arguments, results, resultType, body);
1261 }
1262 
1263 template <typename T, typename ParseUserPDLLDeclFnT>
1264 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1265     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1266     StringRef anonymousNamePrefix, bool isInline) {
1267   SMRange loc = curToken.getLoc();
1268   consumeToken();
1269   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1270 
1271   // Parse the name of the decl.
1272   const ast::Name *name = nullptr;
1273   if (curToken.isNot(Token::identifier)) {
1274     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1275     // in C++, so being unnamed is fine.
1276     if (!isInline)
1277       return emitError("expected identifier name");
1278 
1279     // Create a unique anonymous name to use, as the name for this decl is not
1280     // important.
1281     std::string anonName =
1282         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1283                       anonymousDeclNameCounter++)
1284             .str();
1285     name = &ast::Name::create(ctx, anonName, loc);
1286   } else {
1287     // If a name was provided, we can use it directly.
1288     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1289     consumeToken(Token::identifier);
1290   }
1291 
1292   // Parse the functional signature of the decl.
1293   SmallVector<ast::VariableDecl *> arguments, results;
1294   ast::DeclScope *argumentScope;
1295   ast::Type resultType;
1296   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1297                                                    argumentScope, resultType)))
1298     return failure();
1299 
1300   // Check to see which type of constraint this is. If the constraint contains a
1301   // compound body, this is a PDLL decl.
1302   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1303     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1304                            resultType);
1305 
1306   // Otherwise, this is a native decl.
1307   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1308                                                    results, resultType);
1309 }
1310 
1311 template <typename T>
1312 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1313     const ast::Name &name, bool isInline,
1314     ArrayRef<ast::VariableDecl *> arguments,
1315     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1316   // If followed by a string, the native code body has also been specified.
1317   std::string codeStrStorage;
1318   Optional<StringRef> optCodeStr;
1319   if (curToken.isString()) {
1320     codeStrStorage = curToken.getStringValue();
1321     optCodeStr = codeStrStorage;
1322     consumeToken();
1323   } else if (isInline) {
1324     return emitError(name.getLoc(),
1325                      "external declarations must be declared in global scope");
1326   } else if (curToken.is(Token::error)) {
1327     return failure();
1328   }
1329   if (failed(parseToken(Token::semicolon,
1330                         "expected `;` after native declaration")))
1331     return failure();
1332   // TODO: PDL should be able to support constraint results in certain
1333   // situations, we should revise this.
1334   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1335     return emitError(
1336         "native Constraints currently do not support returning results");
1337   }
1338   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1339 }
1340 
1341 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1342     SmallVectorImpl<ast::VariableDecl *> &arguments,
1343     SmallVectorImpl<ast::VariableDecl *> &results,
1344     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1345   // Parse the argument list of the decl.
1346   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1347     return failure();
1348 
1349   argumentScope = pushDeclScope();
1350   if (curToken.isNot(Token::r_paren)) {
1351     do {
1352       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1353       if (failed(argument))
1354         return failure();
1355       arguments.emplace_back(*argument);
1356     } while (consumeIf(Token::comma));
1357   }
1358   popDeclScope();
1359   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1360     return failure();
1361 
1362   // Parse the results of the decl.
1363   pushDeclScope();
1364   if (consumeIf(Token::arrow)) {
1365     auto parseResultFn = [&]() -> LogicalResult {
1366       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1367       if (failed(result))
1368         return failure();
1369       results.emplace_back(*result);
1370       return success();
1371     };
1372 
1373     // Check for a list of results.
1374     if (consumeIf(Token::l_paren)) {
1375       do {
1376         if (failed(parseResultFn()))
1377           return failure();
1378       } while (consumeIf(Token::comma));
1379       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1380         return failure();
1381 
1382       // Otherwise, there is only one result.
1383     } else if (failed(parseResultFn())) {
1384       return failure();
1385     }
1386   }
1387   popDeclScope();
1388 
1389   // Compute the result type of the decl.
1390   resultType = createUserConstraintRewriteResultType(results);
1391 
1392   // Verify that results are only named if there are more than one.
1393   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1394     return emitError(
1395         results.front()->getLoc(),
1396         "cannot create a single-element tuple with an element label");
1397   }
1398   return success();
1399 }
1400 
1401 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1402     StringRef declType, ast::CompoundStmt *body,
1403     ArrayRef<ast::Stmt *>::iterator bodyIt,
1404     ArrayRef<ast::Stmt *>::iterator bodyE,
1405     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1406   // Handle if a `return` was provided.
1407   if (bodyIt != bodyE) {
1408     // Emit an error if we have trailing statements after the return.
1409     if (std::next(bodyIt) != bodyE) {
1410       return emitError(
1411           (*std::next(bodyIt))->getLoc(),
1412           llvm::formatv("`return` terminated the `{0}` body, but found "
1413                         "trailing statements afterwards",
1414                         declType));
1415     }
1416 
1417     // Otherwise if a return wasn't provided, check that no results are
1418     // expected.
1419   } else if (!results.empty()) {
1420     return emitError(
1421         {body->getLoc().End, body->getLoc().End},
1422         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1423                       declType, resultType));
1424   }
1425   return success();
1426 }
1427 
1428 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1429   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1430     if (isa<ast::OpRewriteStmt>(statement))
1431       return success();
1432     return emitError(
1433         statement->getLoc(),
1434         "expected Pattern lambda body to contain a single operation "
1435         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1436   });
1437 }
1438 
1439 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1440   SMRange loc = curToken.getLoc();
1441   consumeToken(Token::kw_Pattern);
1442   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1443                                               ParserContext::PatternMatch);
1444 
1445   // Check for an optional identifier for the pattern name.
1446   const ast::Name *name = nullptr;
1447   if (curToken.is(Token::identifier)) {
1448     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1449     consumeToken(Token::identifier);
1450   }
1451 
1452   // Parse any pattern metadata.
1453   ParsedPatternMetadata metadata;
1454   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1455     return failure();
1456 
1457   // Parse the pattern body.
1458   ast::CompoundStmt *body;
1459 
1460   // Handle a lambda body.
1461   if (curToken.is(Token::equal_arrow)) {
1462     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1463     if (failed(bodyResult))
1464       return failure();
1465     body = *bodyResult;
1466   } else {
1467     if (curToken.isNot(Token::l_brace))
1468       return emitError("expected `{` or `=>` to start pattern body");
1469     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1470     if (failed(bodyResult))
1471       return failure();
1472     body = *bodyResult;
1473 
1474     // Verify the body of the pattern.
1475     auto bodyIt = body->begin(), bodyE = body->end();
1476     for (; bodyIt != bodyE; ++bodyIt) {
1477       if (isa<ast::ReturnStmt>(*bodyIt)) {
1478         return emitError((*bodyIt)->getLoc(),
1479                          "`return` statements are only permitted within a "
1480                          "`Constraint` or `Rewrite` body");
1481       }
1482       // Break when we've found the rewrite statement.
1483       if (isa<ast::OpRewriteStmt>(*bodyIt))
1484         break;
1485     }
1486     if (bodyIt == bodyE) {
1487       return emitError(loc,
1488                        "expected Pattern body to terminate with an operation "
1489                        "rewrite statement, such as `erase`");
1490     }
1491     if (std::next(bodyIt) != bodyE) {
1492       return emitError((*std::next(bodyIt))->getLoc(),
1493                        "Pattern body was terminated by an operation "
1494                        "rewrite statement, but found trailing statements");
1495     }
1496   }
1497 
1498   return createPatternDecl(loc, name, metadata, body);
1499 }
1500 
1501 LogicalResult
1502 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1503   Optional<SMRange> benefitLoc;
1504   Optional<SMRange> hasBoundedRecursionLoc;
1505 
1506   do {
1507     // Handle metadata code completion.
1508     if (curToken.is(Token::code_complete))
1509       return codeCompletePatternMetadata();
1510 
1511     if (curToken.isNot(Token::identifier))
1512       return emitError("expected pattern metadata identifier");
1513     StringRef metadataStr = curToken.getSpelling();
1514     SMRange metadataLoc = curToken.getLoc();
1515     consumeToken(Token::identifier);
1516 
1517     // Parse the benefit metadata: benefit(<integer-value>)
1518     if (metadataStr == "benefit") {
1519       if (benefitLoc) {
1520         return emitErrorAndNote(metadataLoc,
1521                                 "pattern benefit has already been specified",
1522                                 *benefitLoc, "see previous definition here");
1523       }
1524       if (failed(parseToken(Token::l_paren,
1525                             "expected `(` before pattern benefit")))
1526         return failure();
1527 
1528       uint16_t benefitValue = 0;
1529       if (curToken.isNot(Token::integer))
1530         return emitError("expected integral pattern benefit");
1531       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1532         return emitError(
1533             "expected pattern benefit to fit within a 16-bit integer");
1534       consumeToken(Token::integer);
1535 
1536       metadata.benefit = benefitValue;
1537       benefitLoc = metadataLoc;
1538 
1539       if (failed(
1540               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1541         return failure();
1542       continue;
1543     }
1544 
1545     // Parse the bounded recursion metadata: recursion
1546     if (metadataStr == "recursion") {
1547       if (hasBoundedRecursionLoc) {
1548         return emitErrorAndNote(
1549             metadataLoc,
1550             "pattern recursion metadata has already been specified",
1551             *hasBoundedRecursionLoc, "see previous definition here");
1552       }
1553       metadata.hasBoundedRecursion = true;
1554       hasBoundedRecursionLoc = metadataLoc;
1555       continue;
1556     }
1557 
1558     return emitError(metadataLoc, "unknown pattern metadata");
1559   } while (consumeIf(Token::comma));
1560 
1561   return success();
1562 }
1563 
1564 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1565   consumeToken(Token::less);
1566 
1567   FailureOr<ast::Expr *> typeExpr = parseExpr();
1568   if (failed(typeExpr) ||
1569       failed(parseToken(Token::greater,
1570                         "expected `>` after variable type constraint")))
1571     return failure();
1572   return typeExpr;
1573 }
1574 
1575 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1576   assert(curDeclScope && "defining decl outside of a decl scope");
1577   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1578     return emitErrorAndNote(
1579         name.getLoc(), "`" + name.getName() + "` has already been defined",
1580         lastDecl->getName()->getLoc(), "see previous definition here");
1581   }
1582   return success();
1583 }
1584 
1585 FailureOr<ast::VariableDecl *>
1586 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1587                            ast::Expr *initExpr,
1588                            ArrayRef<ast::ConstraintRef> constraints) {
1589   assert(curDeclScope && "defining variable outside of decl scope");
1590   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1591 
1592   // If the name of the variable indicates a special variable, we don't add it
1593   // to the scope. This variable is local to the definition point.
1594   if (name.empty() || name == "_") {
1595     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1596                                      constraints);
1597   }
1598   if (failed(checkDefineNamedDecl(nameDecl)))
1599     return failure();
1600 
1601   auto *varDecl =
1602       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1603   curDeclScope->add(varDecl);
1604   return varDecl;
1605 }
1606 
1607 FailureOr<ast::VariableDecl *>
1608 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1609                            ArrayRef<ast::ConstraintRef> constraints) {
1610   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1611                             constraints);
1612 }
1613 
1614 LogicalResult Parser::parseVariableDeclConstraintList(
1615     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1616   Optional<SMRange> typeConstraint;
1617   auto parseSingleConstraint = [&] {
1618     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1619         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1620         /*allowNonCoreConstraints=*/true);
1621     if (failed(constraint))
1622       return failure();
1623     constraints.push_back(*constraint);
1624     return success();
1625   };
1626 
1627   // Check to see if this is a single constraint, or a list.
1628   if (!consumeIf(Token::l_square))
1629     return parseSingleConstraint();
1630 
1631   do {
1632     if (failed(parseSingleConstraint()))
1633       return failure();
1634   } while (consumeIf(Token::comma));
1635   return parseToken(Token::r_square, "expected `]` after constraint list");
1636 }
1637 
1638 FailureOr<ast::ConstraintRef>
1639 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1640                         ArrayRef<ast::ConstraintRef> existingConstraints,
1641                         bool allowInlineTypeConstraints,
1642                         bool allowNonCoreConstraints) {
1643   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1644     if (!allowInlineTypeConstraints) {
1645       return emitError(
1646           curToken.getLoc(),
1647           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1648           "permitted on arguments or results");
1649     }
1650     if (typeConstraint)
1651       return emitErrorAndNote(
1652           curToken.getLoc(),
1653           "the type of this variable has already been constrained",
1654           *typeConstraint, "see previous constraint location here");
1655     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1656     if (failed(constraintExpr))
1657       return failure();
1658     typeExpr = *constraintExpr;
1659     typeConstraint = typeExpr->getLoc();
1660     return success();
1661   };
1662 
1663   SMRange loc = curToken.getLoc();
1664   switch (curToken.getKind()) {
1665   case Token::kw_Attr: {
1666     consumeToken(Token::kw_Attr);
1667 
1668     // Check for a type constraint.
1669     ast::Expr *typeExpr = nullptr;
1670     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1671       return failure();
1672     return ast::ConstraintRef(
1673         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1674   }
1675   case Token::kw_Op: {
1676     consumeToken(Token::kw_Op);
1677 
1678     // Parse an optional operation name. If the name isn't provided, this refers
1679     // to "any" operation.
1680     FailureOr<ast::OpNameDecl *> opName =
1681         parseWrappedOperationName(/*allowEmptyName=*/true);
1682     if (failed(opName))
1683       return failure();
1684 
1685     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1686                               loc);
1687   }
1688   case Token::kw_Type:
1689     consumeToken(Token::kw_Type);
1690     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1691   case Token::kw_TypeRange:
1692     consumeToken(Token::kw_TypeRange);
1693     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1694                               loc);
1695   case Token::kw_Value: {
1696     consumeToken(Token::kw_Value);
1697 
1698     // Check for a type constraint.
1699     ast::Expr *typeExpr = nullptr;
1700     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1701       return failure();
1702 
1703     return ast::ConstraintRef(
1704         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1705   }
1706   case Token::kw_ValueRange: {
1707     consumeToken(Token::kw_ValueRange);
1708 
1709     // Check for a type constraint.
1710     ast::Expr *typeExpr = nullptr;
1711     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1712       return failure();
1713 
1714     return ast::ConstraintRef(
1715         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1716   }
1717 
1718   case Token::kw_Constraint: {
1719     // Handle an inline constraint.
1720     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1721     if (failed(decl))
1722       return failure();
1723     return ast::ConstraintRef(*decl, loc);
1724   }
1725   case Token::identifier: {
1726     StringRef constraintName = curToken.getSpelling();
1727     consumeToken(Token::identifier);
1728 
1729     // Lookup the referenced constraint.
1730     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1731     if (!cstDecl) {
1732       return emitError(loc, "unknown reference to constraint `" +
1733                                 constraintName + "`");
1734     }
1735 
1736     // Handle a reference to a proper constraint.
1737     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1738       return ast::ConstraintRef(cst, loc);
1739 
1740     return emitErrorAndNote(
1741         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1742         "see the definition of `" + constraintName + "` here");
1743   }
1744     // Handle single entity constraint code completion.
1745   case Token::code_complete: {
1746     // Try to infer the current type for use by code completion.
1747     ast::Type inferredType;
1748     if (failed(validateVariableConstraints(existingConstraints, inferredType,
1749                                            allowNonCoreConstraints)))
1750       return failure();
1751 
1752     return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1753                                       allowInlineTypeConstraints);
1754   }
1755   default:
1756     break;
1757   }
1758   return emitError(loc, "expected identifier constraint");
1759 }
1760 
1761 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1762   // Constraint arguments may apply more complex constraints via the arguments.
1763   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1764 
1765   Optional<SMRange> typeConstraint;
1766   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1767                          /*allowInlineTypeConstraints=*/false,
1768                          allowNonCoreConstraints);
1769 }
1770 
1771 //===----------------------------------------------------------------------===//
1772 // Exprs
1773 
1774 FailureOr<ast::Expr *> Parser::parseExpr() {
1775   if (curToken.is(Token::underscore))
1776     return parseUnderscoreExpr();
1777 
1778   // Parse the LHS expression.
1779   FailureOr<ast::Expr *> lhsExpr;
1780   switch (curToken.getKind()) {
1781   case Token::kw_attr:
1782     lhsExpr = parseAttributeExpr();
1783     break;
1784   case Token::kw_Constraint:
1785     lhsExpr = parseInlineConstraintLambdaExpr();
1786     break;
1787   case Token::identifier:
1788     lhsExpr = parseIdentifierExpr();
1789     break;
1790   case Token::kw_op:
1791     lhsExpr = parseOperationExpr();
1792     break;
1793   case Token::kw_Rewrite:
1794     lhsExpr = parseInlineRewriteLambdaExpr();
1795     break;
1796   case Token::kw_type:
1797     lhsExpr = parseTypeExpr();
1798     break;
1799   case Token::l_paren:
1800     lhsExpr = parseTupleExpr();
1801     break;
1802   default:
1803     return emitError("expected expression");
1804   }
1805   if (failed(lhsExpr))
1806     return failure();
1807 
1808   // Check for an operator expression.
1809   while (true) {
1810     switch (curToken.getKind()) {
1811     case Token::dot:
1812       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1813       break;
1814     case Token::l_paren:
1815       lhsExpr = parseCallExpr(*lhsExpr);
1816       break;
1817     default:
1818       return lhsExpr;
1819     }
1820     if (failed(lhsExpr))
1821       return failure();
1822   }
1823 }
1824 
1825 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1826   SMRange loc = curToken.getLoc();
1827   consumeToken(Token::kw_attr);
1828 
1829   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1830   // identifier.
1831   if (!consumeIf(Token::less)) {
1832     resetToken(loc);
1833     return parseIdentifierExpr();
1834   }
1835 
1836   if (!curToken.isString())
1837     return emitError("expected string literal containing MLIR attribute");
1838   std::string attrExpr = curToken.getStringValue();
1839   consumeToken();
1840 
1841   loc.End = curToken.getEndLoc();
1842   if (failed(
1843           parseToken(Token::greater, "expected `>` after attribute literal")))
1844     return failure();
1845   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1846 }
1847 
1848 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1849   consumeToken(Token::l_paren);
1850 
1851   // Parse the arguments of the call.
1852   SmallVector<ast::Expr *> arguments;
1853   if (curToken.isNot(Token::r_paren)) {
1854     do {
1855       // Handle code completion for the call arguments.
1856       if (curToken.is(Token::code_complete)) {
1857         codeCompleteCallSignature(parentExpr, arguments.size());
1858         return failure();
1859       }
1860 
1861       FailureOr<ast::Expr *> argument = parseExpr();
1862       if (failed(argument))
1863         return failure();
1864       arguments.push_back(*argument);
1865     } while (consumeIf(Token::comma));
1866   }
1867 
1868   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1869   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1870     return failure();
1871 
1872   return createCallExpr(loc, parentExpr, arguments);
1873 }
1874 
1875 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1876   ast::Decl *decl = curDeclScope->lookup(name);
1877   if (!decl)
1878     return emitError(loc, "undefined reference to `" + name + "`");
1879 
1880   return createDeclRefExpr(loc, decl);
1881 }
1882 
1883 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1884   StringRef name = curToken.getSpelling();
1885   SMRange nameLoc = curToken.getLoc();
1886   consumeToken();
1887 
1888   // Check to see if this is a decl ref expression that defines a variable
1889   // inline.
1890   if (consumeIf(Token::colon)) {
1891     SmallVector<ast::ConstraintRef> constraints;
1892     if (failed(parseVariableDeclConstraintList(constraints)))
1893       return failure();
1894     ast::Type type;
1895     if (failed(validateVariableConstraints(constraints, type)))
1896       return failure();
1897     return createInlineVariableExpr(type, name, nameLoc, constraints);
1898   }
1899 
1900   return parseDeclRefExpr(name, nameLoc);
1901 }
1902 
1903 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1904   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1905   if (failed(decl))
1906     return failure();
1907 
1908   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1909                                   ast::ConstraintType::get(ctx));
1910 }
1911 
1912 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1913   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1914   if (failed(decl))
1915     return failure();
1916 
1917   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1918                                   ast::RewriteType::get(ctx));
1919 }
1920 
1921 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1922   SMRange dotLoc = curToken.getLoc();
1923   consumeToken(Token::dot);
1924 
1925   // Check for code completion of the member name.
1926   if (curToken.is(Token::code_complete))
1927     return codeCompleteMemberAccess(parentExpr);
1928 
1929   // Parse the member name.
1930   Token memberNameTok = curToken;
1931   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1932       !memberNameTok.isKeyword())
1933     return emitError(dotLoc, "expected identifier or numeric member name");
1934   StringRef memberName = memberNameTok.getSpelling();
1935   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1936   consumeToken();
1937 
1938   return createMemberAccessExpr(parentExpr, memberName, loc);
1939 }
1940 
1941 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1942   SMRange loc = curToken.getLoc();
1943 
1944   // Check for code completion for the dialect name.
1945   if (curToken.is(Token::code_complete))
1946     return codeCompleteDialectName();
1947 
1948   // Handle the case of an no operation name.
1949   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1950     if (allowEmptyName)
1951       return ast::OpNameDecl::create(ctx, SMRange());
1952     return emitError("expected dialect namespace");
1953   }
1954   StringRef name = curToken.getSpelling();
1955   consumeToken();
1956 
1957   // Otherwise, this is a literal operation name.
1958   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1959     return failure();
1960 
1961   // Check for code completion for the operation name.
1962   if (curToken.is(Token::code_complete))
1963     return codeCompleteOperationName(name);
1964 
1965   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1966     return emitError("expected operation name after dialect namespace");
1967 
1968   name = StringRef(name.data(), name.size() + 1);
1969   do {
1970     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1971     loc.End = curToken.getEndLoc();
1972     consumeToken();
1973   } while (curToken.isAny(Token::identifier, Token::dot) ||
1974            curToken.isKeyword());
1975   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1976 }
1977 
1978 FailureOr<ast::OpNameDecl *>
1979 Parser::parseWrappedOperationName(bool allowEmptyName) {
1980   if (!consumeIf(Token::less))
1981     return ast::OpNameDecl::create(ctx, SMRange());
1982 
1983   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1984   if (failed(opNameDecl))
1985     return failure();
1986 
1987   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1988     return failure();
1989   return opNameDecl;
1990 }
1991 
1992 FailureOr<ast::Expr *>
1993 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
1994   SMRange loc = curToken.getLoc();
1995   consumeToken(Token::kw_op);
1996 
1997   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1998   // identifier.
1999   if (curToken.isNot(Token::less)) {
2000     resetToken(loc);
2001     return parseIdentifierExpr();
2002   }
2003 
2004   // Parse the operation name. The name may be elided, in which case the
2005   // operation refers to "any" operation(i.e. a difference between `MyOp` and
2006   // `Operation*`). Operation names within a rewrite context must be named.
2007   bool allowEmptyName = parserContext != ParserContext::Rewrite;
2008   FailureOr<ast::OpNameDecl *> opNameDecl =
2009       parseWrappedOperationName(allowEmptyName);
2010   if (failed(opNameDecl))
2011     return failure();
2012   Optional<StringRef> opName = (*opNameDecl)->getName();
2013 
2014   // Functor used to create an implicit range variable, used for implicit "all"
2015   // operand or results variables.
2016   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2017     FailureOr<ast::VariableDecl *> rangeVar =
2018         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2019     assert(succeeded(rangeVar) && "expected range variable to be valid");
2020     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2021   };
2022 
2023   // Check for the optional list of operands.
2024   SmallVector<ast::Expr *> operands;
2025   if (!consumeIf(Token::l_paren)) {
2026     // If the operand list isn't specified and we are in a match context, define
2027     // an inplace unconstrained operand range corresponding to all of the
2028     // operands of the operation. This avoids treating zero operands the same
2029     // way as "unconstrained operands".
2030     if (parserContext != ParserContext::Rewrite) {
2031       operands.push_back(createImplicitRangeVar(
2032           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2033     }
2034   } else if (!consumeIf(Token::r_paren)) {
2035     // If the operand list was specified and non-empty, parse the operands.
2036     do {
2037       // Check for operand signature code completion.
2038       if (curToken.is(Token::code_complete)) {
2039         codeCompleteOperationOperandsSignature(opName, operands.size());
2040         return failure();
2041       }
2042 
2043       FailureOr<ast::Expr *> operand = parseExpr();
2044       if (failed(operand))
2045         return failure();
2046       operands.push_back(*operand);
2047     } while (consumeIf(Token::comma));
2048 
2049     if (failed(parseToken(Token::r_paren,
2050                           "expected `)` after operation operand list")))
2051       return failure();
2052   }
2053 
2054   // Check for the optional list of attributes.
2055   SmallVector<ast::NamedAttributeDecl *> attributes;
2056   if (consumeIf(Token::l_brace)) {
2057     do {
2058       FailureOr<ast::NamedAttributeDecl *> decl =
2059           parseNamedAttributeDecl(opName);
2060       if (failed(decl))
2061         return failure();
2062       attributes.emplace_back(*decl);
2063     } while (consumeIf(Token::comma));
2064 
2065     if (failed(parseToken(Token::r_brace,
2066                           "expected `}` after operation attribute list")))
2067       return failure();
2068   }
2069 
2070   // Handle the result types of the operation.
2071   SmallVector<ast::Expr *> resultTypes;
2072   OpResultTypeContext resultTypeContext = inputResultTypeContext;
2073 
2074   // Check for an explicit list of result types.
2075   if (consumeIf(Token::arrow)) {
2076     if (failed(parseToken(Token::l_paren,
2077                           "expected `(` before operation result type list")))
2078       return failure();
2079 
2080     // If result types are provided, initially assume that the operation does
2081     // not rely on type inferrence. We don't assert that it isn't, because we
2082     // may be inferring the value of some type/type range variables, but given
2083     // that these variables may be defined in calls we can't always discern when
2084     // this is the case.
2085     resultTypeContext = OpResultTypeContext::Explicit;
2086 
2087     // Handle the case of an empty result list.
2088     if (!consumeIf(Token::r_paren)) {
2089       do {
2090         // Check for result signature code completion.
2091         if (curToken.is(Token::code_complete)) {
2092           codeCompleteOperationResultsSignature(opName, resultTypes.size());
2093           return failure();
2094         }
2095 
2096         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2097         if (failed(resultTypeExpr))
2098           return failure();
2099         resultTypes.push_back(*resultTypeExpr);
2100       } while (consumeIf(Token::comma));
2101 
2102       if (failed(parseToken(Token::r_paren,
2103                             "expected `)` after operation result type list")))
2104         return failure();
2105     }
2106   } else if (parserContext != ParserContext::Rewrite) {
2107     // If the result list isn't specified and we are in a match context, define
2108     // an inplace unconstrained result range corresponding to all of the results
2109     // of the operation. This avoids treating zero results the same way as
2110     // "unconstrained results".
2111     resultTypes.push_back(createImplicitRangeVar(
2112         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2113   } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2114     // If the result list isn't specified and we are in a rewrite, try to infer
2115     // them at runtime instead.
2116     resultTypeContext = OpResultTypeContext::Interface;
2117   }
2118 
2119   return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2120                              attributes, resultTypes);
2121 }
2122 
2123 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2124   SMRange loc = curToken.getLoc();
2125   consumeToken(Token::l_paren);
2126 
2127   DenseMap<StringRef, SMRange> usedNames;
2128   SmallVector<StringRef> elementNames;
2129   SmallVector<ast::Expr *> elements;
2130   if (curToken.isNot(Token::r_paren)) {
2131     do {
2132       // Check for the optional element name assignment before the value.
2133       StringRef elementName;
2134       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2135         Token elementNameTok = curToken;
2136         consumeToken();
2137 
2138         // The element name is only present if followed by an `=`.
2139         if (consumeIf(Token::equal)) {
2140           elementName = elementNameTok.getSpelling();
2141 
2142           // Check to see if this name is already used.
2143           auto elementNameIt =
2144               usedNames.try_emplace(elementName, elementNameTok.getLoc());
2145           if (!elementNameIt.second) {
2146             return emitErrorAndNote(
2147                 elementNameTok.getLoc(),
2148                 llvm::formatv("duplicate tuple element label `{0}`",
2149                               elementName),
2150                 elementNameIt.first->getSecond(),
2151                 "see previous label use here");
2152           }
2153         } else {
2154           // Otherwise, we treat this as part of an expression so reset the
2155           // lexer.
2156           resetToken(elementNameTok.getLoc());
2157         }
2158       }
2159       elementNames.push_back(elementName);
2160 
2161       // Parse the tuple element value.
2162       FailureOr<ast::Expr *> element = parseExpr();
2163       if (failed(element))
2164         return failure();
2165       elements.push_back(*element);
2166     } while (consumeIf(Token::comma));
2167   }
2168   loc.End = curToken.getEndLoc();
2169   if (failed(
2170           parseToken(Token::r_paren, "expected `)` after tuple element list")))
2171     return failure();
2172   return createTupleExpr(loc, elements, elementNames);
2173 }
2174 
2175 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2176   SMRange loc = curToken.getLoc();
2177   consumeToken(Token::kw_type);
2178 
2179   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2180   // identifier.
2181   if (!consumeIf(Token::less)) {
2182     resetToken(loc);
2183     return parseIdentifierExpr();
2184   }
2185 
2186   if (!curToken.isString())
2187     return emitError("expected string literal containing MLIR type");
2188   std::string attrExpr = curToken.getStringValue();
2189   consumeToken();
2190 
2191   loc.End = curToken.getEndLoc();
2192   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2193     return failure();
2194   return ast::TypeExpr::create(ctx, loc, attrExpr);
2195 }
2196 
2197 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2198   StringRef name = curToken.getSpelling();
2199   SMRange nameLoc = curToken.getLoc();
2200   consumeToken(Token::underscore);
2201 
2202   // Underscore expressions require a constraint list.
2203   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2204     return failure();
2205 
2206   // Parse the constraints for the expression.
2207   SmallVector<ast::ConstraintRef> constraints;
2208   if (failed(parseVariableDeclConstraintList(constraints)))
2209     return failure();
2210 
2211   ast::Type type;
2212   if (failed(validateVariableConstraints(constraints, type)))
2213     return failure();
2214   return createInlineVariableExpr(type, name, nameLoc, constraints);
2215 }
2216 
2217 //===----------------------------------------------------------------------===//
2218 // Stmts
2219 
2220 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2221   FailureOr<ast::Stmt *> stmt;
2222   switch (curToken.getKind()) {
2223   case Token::kw_erase:
2224     stmt = parseEraseStmt();
2225     break;
2226   case Token::kw_let:
2227     stmt = parseLetStmt();
2228     break;
2229   case Token::kw_replace:
2230     stmt = parseReplaceStmt();
2231     break;
2232   case Token::kw_return:
2233     stmt = parseReturnStmt();
2234     break;
2235   case Token::kw_rewrite:
2236     stmt = parseRewriteStmt();
2237     break;
2238   default:
2239     stmt = parseExpr();
2240     break;
2241   }
2242   if (failed(stmt) ||
2243       (expectTerminalSemicolon &&
2244        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2245     return failure();
2246   return stmt;
2247 }
2248 
2249 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2250   SMLoc startLoc = curToken.getStartLoc();
2251   consumeToken(Token::l_brace);
2252 
2253   // Push a new block scope and parse any nested statements.
2254   pushDeclScope();
2255   SmallVector<ast::Stmt *> statements;
2256   while (curToken.isNot(Token::r_brace)) {
2257     FailureOr<ast::Stmt *> statement = parseStmt();
2258     if (failed(statement))
2259       return popDeclScope(), failure();
2260     statements.push_back(*statement);
2261   }
2262   popDeclScope();
2263 
2264   // Consume the end brace.
2265   SMRange location(startLoc, curToken.getEndLoc());
2266   consumeToken(Token::r_brace);
2267 
2268   return ast::CompoundStmt::create(ctx, location, statements);
2269 }
2270 
2271 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2272   if (parserContext == ParserContext::Constraint)
2273     return emitError("`erase` cannot be used within a Constraint");
2274   SMRange loc = curToken.getLoc();
2275   consumeToken(Token::kw_erase);
2276 
2277   // Parse the root operation expression.
2278   FailureOr<ast::Expr *> rootOp = parseExpr();
2279   if (failed(rootOp))
2280     return failure();
2281 
2282   return createEraseStmt(loc, *rootOp);
2283 }
2284 
2285 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2286   SMRange loc = curToken.getLoc();
2287   consumeToken(Token::kw_let);
2288 
2289   // Parse the name of the new variable.
2290   SMRange varLoc = curToken.getLoc();
2291   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2292     // `_` is a reserved variable name.
2293     if (curToken.is(Token::underscore)) {
2294       return emitError(varLoc,
2295                        "`_` may only be used to define \"inline\" variables");
2296     }
2297     return emitError(varLoc,
2298                      "expected identifier after `let` to name a new variable");
2299   }
2300   StringRef varName = curToken.getSpelling();
2301   consumeToken();
2302 
2303   // Parse the optional set of constraints.
2304   SmallVector<ast::ConstraintRef> constraints;
2305   if (consumeIf(Token::colon) &&
2306       failed(parseVariableDeclConstraintList(constraints)))
2307     return failure();
2308 
2309   // Parse the optional initializer expression.
2310   ast::Expr *initializer = nullptr;
2311   if (consumeIf(Token::equal)) {
2312     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2313     if (failed(initOrFailure))
2314       return failure();
2315     initializer = *initOrFailure;
2316 
2317     // Check that the constraints are compatible with having an initializer,
2318     // e.g. type constraints cannot be used with initializers.
2319     for (ast::ConstraintRef constraint : constraints) {
2320       LogicalResult result =
2321           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2322               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2323                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2324                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2325                   return this->emitError(
2326                       constraint.referenceLoc,
2327                       "type constraints are not permitted on variables with "
2328                       "initializers");
2329                 }
2330                 return success();
2331               })
2332               .Default(success());
2333       if (failed(result))
2334         return failure();
2335     }
2336   }
2337 
2338   FailureOr<ast::VariableDecl *> varDecl =
2339       createVariableDecl(varName, varLoc, initializer, constraints);
2340   if (failed(varDecl))
2341     return failure();
2342   return ast::LetStmt::create(ctx, loc, *varDecl);
2343 }
2344 
2345 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2346   if (parserContext == ParserContext::Constraint)
2347     return emitError("`replace` cannot be used within a Constraint");
2348   SMRange loc = curToken.getLoc();
2349   consumeToken(Token::kw_replace);
2350 
2351   // Parse the root operation expression.
2352   FailureOr<ast::Expr *> rootOp = parseExpr();
2353   if (failed(rootOp))
2354     return failure();
2355 
2356   if (failed(
2357           parseToken(Token::kw_with, "expected `with` after root operation")))
2358     return failure();
2359 
2360   // The replacement portion of this statement is within a rewrite context.
2361   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2362                                               ParserContext::Rewrite);
2363 
2364   // Parse the replacement values.
2365   SmallVector<ast::Expr *> replValues;
2366   if (consumeIf(Token::l_paren)) {
2367     if (consumeIf(Token::r_paren)) {
2368       return emitError(
2369           loc, "expected at least one replacement value, consider using "
2370                "`erase` if no replacement values are desired");
2371     }
2372 
2373     do {
2374       FailureOr<ast::Expr *> replExpr = parseExpr();
2375       if (failed(replExpr))
2376         return failure();
2377       replValues.emplace_back(*replExpr);
2378     } while (consumeIf(Token::comma));
2379 
2380     if (failed(parseToken(Token::r_paren,
2381                           "expected `)` after replacement values")))
2382       return failure();
2383   } else {
2384     // Handle replacement with an operation uniquely, as the replacement
2385     // operation supports type inferrence from the root operation.
2386     FailureOr<ast::Expr *> replExpr;
2387     if (curToken.is(Token::kw_op))
2388       replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2389     else
2390       replExpr = parseExpr();
2391     if (failed(replExpr))
2392       return failure();
2393     replValues.emplace_back(*replExpr);
2394   }
2395 
2396   return createReplaceStmt(loc, *rootOp, replValues);
2397 }
2398 
2399 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2400   SMRange loc = curToken.getLoc();
2401   consumeToken(Token::kw_return);
2402 
2403   // Parse the result value.
2404   FailureOr<ast::Expr *> resultExpr = parseExpr();
2405   if (failed(resultExpr))
2406     return failure();
2407 
2408   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2409 }
2410 
2411 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2412   if (parserContext == ParserContext::Constraint)
2413     return emitError("`rewrite` cannot be used within a Constraint");
2414   SMRange loc = curToken.getLoc();
2415   consumeToken(Token::kw_rewrite);
2416 
2417   // Parse the root operation.
2418   FailureOr<ast::Expr *> rootOp = parseExpr();
2419   if (failed(rootOp))
2420     return failure();
2421 
2422   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2423     return failure();
2424 
2425   if (curToken.isNot(Token::l_brace))
2426     return emitError("expected `{` to start rewrite body");
2427 
2428   // The rewrite body of this statement is within a rewrite context.
2429   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2430                                               ParserContext::Rewrite);
2431 
2432   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2433   if (failed(rewriteBody))
2434     return failure();
2435 
2436   // Verify the rewrite body.
2437   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2438     if (isa<ast::ReturnStmt>(stmt)) {
2439       return emitError(stmt->getLoc(),
2440                        "`return` statements are only permitted within a "
2441                        "`Constraint` or `Rewrite` body");
2442     }
2443   }
2444 
2445   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2446 }
2447 
2448 //===----------------------------------------------------------------------===//
2449 // Creation+Analysis
2450 //===----------------------------------------------------------------------===//
2451 
2452 //===----------------------------------------------------------------------===//
2453 // Decls
2454 
2455 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2456   // Unwrap reference expressions.
2457   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2458     node = init->getDecl();
2459   return dyn_cast<ast::CallableDecl>(node);
2460 }
2461 
2462 FailureOr<ast::PatternDecl *>
2463 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2464                           const ParsedPatternMetadata &metadata,
2465                           ast::CompoundStmt *body) {
2466   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2467                                   metadata.hasBoundedRecursion, body);
2468 }
2469 
2470 ast::Type Parser::createUserConstraintRewriteResultType(
2471     ArrayRef<ast::VariableDecl *> results) {
2472   // Single result decls use the type of the single result.
2473   if (results.size() == 1)
2474     return results[0]->getType();
2475 
2476   // Multiple results use a tuple type, with the types and names grabbed from
2477   // the result variable decls.
2478   auto resultTypes = llvm::map_range(
2479       results, [&](const auto *result) { return result->getType(); });
2480   auto resultNames = llvm::map_range(
2481       results, [&](const auto *result) { return result->getName().getName(); });
2482   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2483                              llvm::to_vector(resultNames));
2484 }
2485 
2486 template <typename T>
2487 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2488     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2489     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2490     ast::CompoundStmt *body) {
2491   if (!body->getChildren().empty()) {
2492     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2493       ast::Expr *resultExpr = retStmt->getResultExpr();
2494 
2495       // Process the result of the decl. If no explicit signature results
2496       // were provided, check for return type inference. Otherwise, check that
2497       // the return expression can be converted to the expected type.
2498       if (results.empty())
2499         resultType = resultExpr->getType();
2500       else if (failed(convertExpressionTo(resultExpr, resultType)))
2501         return failure();
2502       else
2503         retStmt->setResultExpr(resultExpr);
2504     }
2505   }
2506   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2507 }
2508 
2509 FailureOr<ast::VariableDecl *>
2510 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2511                            ArrayRef<ast::ConstraintRef> constraints) {
2512   // The type of the variable, which is expected to be inferred by either a
2513   // constraint or an initializer expression.
2514   ast::Type type;
2515   if (failed(validateVariableConstraints(constraints, type)))
2516     return failure();
2517 
2518   if (initializer) {
2519     // Update the variable type based on the initializer, or try to convert the
2520     // initializer to the existing type.
2521     if (!type)
2522       type = initializer->getType();
2523     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2524       type = mergedType;
2525     else if (failed(convertExpressionTo(initializer, type)))
2526       return failure();
2527 
2528     // Otherwise, if there is no initializer check that the type has already
2529     // been resolved from the constraint list.
2530   } else if (!type) {
2531     return emitErrorAndNote(
2532         loc, "unable to infer type for variable `" + name + "`", loc,
2533         "the type of a variable must be inferable from the constraint "
2534         "list or the initializer");
2535   }
2536 
2537   // Constraint types cannot be used when defining variables.
2538   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2539     return emitError(
2540         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2541   }
2542 
2543   // Try to define a variable with the given name.
2544   FailureOr<ast::VariableDecl *> varDecl =
2545       defineVariableDecl(name, loc, type, initializer, constraints);
2546   if (failed(varDecl))
2547     return failure();
2548 
2549   return *varDecl;
2550 }
2551 
2552 FailureOr<ast::VariableDecl *>
2553 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2554                                       const ast::ConstraintRef &constraint) {
2555   // Constraint arguments may apply more complex constraints via the arguments.
2556   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2557   ast::Type argType;
2558   if (failed(validateVariableConstraint(constraint, argType,
2559                                         allowNonCoreConstraints)))
2560     return failure();
2561   return defineVariableDecl(name, loc, argType, constraint);
2562 }
2563 
2564 LogicalResult
2565 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2566                                     ast::Type &inferredType,
2567                                     bool allowNonCoreConstraints) {
2568   for (const ast::ConstraintRef &ref : constraints)
2569     if (failed(validateVariableConstraint(ref, inferredType,
2570                                           allowNonCoreConstraints)))
2571       return failure();
2572   return success();
2573 }
2574 
2575 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2576                                                  ast::Type &inferredType,
2577                                                  bool allowNonCoreConstraints) {
2578   ast::Type constraintType;
2579   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2580     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2581       if (failed(validateTypeConstraintExpr(typeExpr)))
2582         return failure();
2583     }
2584     constraintType = ast::AttributeType::get(ctx);
2585   } else if (const auto *cst =
2586                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2587     constraintType = ast::OperationType::get(
2588         ctx, cst->getName(), lookupODSOperation(cst->getName()));
2589   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2590     constraintType = typeTy;
2591   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2592     constraintType = typeRangeTy;
2593   } else if (const auto *cst =
2594                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2595     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2596       if (failed(validateTypeConstraintExpr(typeExpr)))
2597         return failure();
2598     }
2599     constraintType = valueTy;
2600   } else if (const auto *cst =
2601                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2602     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2603       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2604         return failure();
2605     }
2606     constraintType = valueRangeTy;
2607   } else if (const auto *cst =
2608                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2609     if (!allowNonCoreConstraints) {
2610       return emitError(ref.referenceLoc,
2611                        "`Rewrite` arguments and results are only permitted to "
2612                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2613                        "`TypeRange`, `Value`, `ValueRange`");
2614     }
2615 
2616     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2617     if (inputs.size() != 1) {
2618       return emitErrorAndNote(ref.referenceLoc,
2619                               "`Constraint`s applied via a variable constraint "
2620                               "list must take a single input, but got " +
2621                                   Twine(inputs.size()),
2622                               cst->getLoc(),
2623                               "see definition of constraint here");
2624     }
2625     constraintType = inputs.front()->getType();
2626   } else {
2627     llvm_unreachable("unknown constraint type");
2628   }
2629 
2630   // Check that the constraint type is compatible with the current inferred
2631   // type.
2632   if (!inferredType) {
2633     inferredType = constraintType;
2634   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2635     inferredType = mergedTy;
2636   } else {
2637     return emitError(ref.referenceLoc,
2638                      llvm::formatv("constraint type `{0}` is incompatible "
2639                                    "with the previously inferred type `{1}`",
2640                                    constraintType, inferredType));
2641   }
2642   return success();
2643 }
2644 
2645 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2646   ast::Type typeExprType = typeExpr->getType();
2647   if (typeExprType != typeTy) {
2648     return emitError(typeExpr->getLoc(),
2649                      "expected expression of `Type` in type constraint");
2650   }
2651   return success();
2652 }
2653 
2654 LogicalResult
2655 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2656   ast::Type typeExprType = typeExpr->getType();
2657   if (typeExprType != typeRangeTy) {
2658     return emitError(typeExpr->getLoc(),
2659                      "expected expression of `TypeRange` in type constraint");
2660   }
2661   return success();
2662 }
2663 
2664 //===----------------------------------------------------------------------===//
2665 // Exprs
2666 
2667 FailureOr<ast::CallExpr *>
2668 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2669                        MutableArrayRef<ast::Expr *> arguments) {
2670   ast::Type parentType = parentExpr->getType();
2671 
2672   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2673   if (!callableDecl) {
2674     return emitError(loc,
2675                      llvm::formatv("expected a reference to a callable "
2676                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2677                                    parentType));
2678   }
2679   if (parserContext == ParserContext::Rewrite) {
2680     if (isa<ast::UserConstraintDecl>(callableDecl))
2681       return emitError(
2682           loc, "unable to invoke `Constraint` within a rewrite section");
2683   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2684     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2685   }
2686 
2687   // Verify the arguments of the call.
2688   /// Handle size mismatch.
2689   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2690   if (callArgs.size() != arguments.size()) {
2691     return emitErrorAndNote(
2692         loc,
2693         llvm::formatv("invalid number of arguments for {0} call; expected "
2694                       "{1}, but got {2}",
2695                       callableDecl->getCallableType(), callArgs.size(),
2696                       arguments.size()),
2697         callableDecl->getLoc(),
2698         llvm::formatv("see the definition of {0} here",
2699                       callableDecl->getName()->getName()));
2700   }
2701 
2702   /// Handle argument type mismatch.
2703   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2704     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2705                                   callableDecl->getName()->getName()),
2706                     callableDecl->getLoc());
2707   };
2708   for (auto it : llvm::zip(callArgs, arguments)) {
2709     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2710                                    attachDiagFn)))
2711       return failure();
2712   }
2713 
2714   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2715                                callableDecl->getResultType());
2716 }
2717 
2718 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2719                                                         ast::Decl *decl) {
2720   // Check the type of decl being referenced.
2721   ast::Type declType;
2722   if (isa<ast::ConstraintDecl>(decl))
2723     declType = ast::ConstraintType::get(ctx);
2724   else if (isa<ast::UserRewriteDecl>(decl))
2725     declType = ast::RewriteType::get(ctx);
2726   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2727     declType = varDecl->getType();
2728   else
2729     return emitError(loc, "invalid reference to `" +
2730                               decl->getName()->getName() + "`");
2731 
2732   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2733 }
2734 
2735 FailureOr<ast::DeclRefExpr *>
2736 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2737                                  ArrayRef<ast::ConstraintRef> constraints) {
2738   FailureOr<ast::VariableDecl *> decl =
2739       defineVariableDecl(name, loc, type, constraints);
2740   if (failed(decl))
2741     return failure();
2742   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2743 }
2744 
2745 FailureOr<ast::MemberAccessExpr *>
2746 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2747                                SMRange loc) {
2748   // Validate the member name for the given parent expression.
2749   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2750   if (failed(memberType))
2751     return failure();
2752 
2753   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2754 }
2755 
2756 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2757                                                   StringRef name, SMRange loc) {
2758   ast::Type parentType = parentExpr->getType();
2759   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2760     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2761       return valueRangeTy;
2762 
2763     // Verify member access based on the operation type.
2764     if (const ods::Operation *odsOp = opType.getODSOperation()) {
2765       auto results = odsOp->getResults();
2766 
2767       // Handle indexed results.
2768       unsigned index = 0;
2769       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2770           index < results.size()) {
2771         return results[index].isVariadic() ? valueRangeTy : valueTy;
2772       }
2773 
2774       // Handle named results.
2775       const auto *it = llvm::find_if(results, [&](const auto &result) {
2776         return result.getName() == name;
2777       });
2778       if (it != results.end())
2779         return it->isVariadic() ? valueRangeTy : valueTy;
2780     } else if (llvm::isDigit(name[0])) {
2781       // Allow unchecked numeric indexing of the results of unregistered
2782       // operations. It returns a single value.
2783       return valueTy;
2784     }
2785   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2786     // Handle indexed results.
2787     unsigned index = 0;
2788     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2789         index < tupleType.size()) {
2790       return tupleType.getElementTypes()[index];
2791     }
2792 
2793     // Handle named results.
2794     auto elementNames = tupleType.getElementNames();
2795     const auto *it = llvm::find(elementNames, name);
2796     if (it != elementNames.end())
2797       return tupleType.getElementTypes()[it - elementNames.begin()];
2798   }
2799   return emitError(
2800       loc,
2801       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2802                     name, parentType));
2803 }
2804 
2805 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2806     SMRange loc, const ast::OpNameDecl *name,
2807     OpResultTypeContext resultTypeContext,
2808     MutableArrayRef<ast::Expr *> operands,
2809     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2810     MutableArrayRef<ast::Expr *> results) {
2811   Optional<StringRef> opNameRef = name->getName();
2812   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2813 
2814   // Verify the inputs operands.
2815   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2816     return failure();
2817 
2818   // Verify the attribute list.
2819   for (ast::NamedAttributeDecl *attr : attributes) {
2820     // Check for an attribute type, or a type awaiting resolution.
2821     ast::Type attrType = attr->getValue()->getType();
2822     if (!attrType.isa<ast::AttributeType>()) {
2823       return emitError(
2824           attr->getValue()->getLoc(),
2825           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2826     }
2827   }
2828 
2829   assert(
2830       (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2831       "unexpected inferrence when results were explicitly specified");
2832 
2833   // If we aren't relying on type inferrence, or explicit results were provided,
2834   // validate them.
2835   if (resultTypeContext == OpResultTypeContext::Explicit) {
2836     if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2837       return failure();
2838 
2839     // Validate the use of interface based type inferrence for this operation.
2840   } else if (resultTypeContext == OpResultTypeContext::Interface) {
2841     assert(opNameRef &&
2842            "expected valid operation name when inferring operation results");
2843     checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2844   }
2845 
2846   return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2847                                     attributes);
2848 }
2849 
2850 LogicalResult
2851 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2852                                   const ods::Operation *odsOp,
2853                                   MutableArrayRef<ast::Expr *> operands) {
2854   return validateOperationOperandsOrResults(
2855       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2856       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2857       valueRangeTy);
2858 }
2859 
2860 LogicalResult
2861 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2862                                  const ods::Operation *odsOp,
2863                                  MutableArrayRef<ast::Expr *> results) {
2864   return validateOperationOperandsOrResults(
2865       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2866       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2867 }
2868 
2869 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2870                                                 const ods::Operation *odsOp) {
2871   // If the operation might not have inferrence support, emit a warning to the
2872   // user. We don't emit an error because the interface might be added to the
2873   // operation at runtime. It's rare, but it could still happen. We emit a
2874   // warning here instead.
2875 
2876   // Handle inferrence warnings for unknown operations.
2877   if (!odsOp) {
2878     ctx.getDiagEngine().emitWarning(
2879         loc, llvm::formatv(
2880                  "operation result types are marked to be inferred, but "
2881                  "`{0}` is unknown. Ensure that `{0}` supports zero "
2882                  "results or implements `InferTypeOpInterface`. Include "
2883                  "the ODS definition of this operation to remove this warning.",
2884                  opName));
2885     return;
2886   }
2887 
2888   // Handle inferrence warnings for known operations that expected at least one
2889   // result, but don't have inference support. An elided results list can mean
2890   // "zero-results", and we don't want to warn when that is the expected
2891   // behavior.
2892   bool requiresInferrence =
2893       llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2894         return !result.isVariableLength();
2895       });
2896   if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2897     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2898         loc,
2899         llvm::formatv("operation result types are marked to be inferred, but "
2900                       "`{0}` does not provide an implementation of "
2901                       "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2902                       "`InferTypeOpInterface` at runtime, or add support to "
2903                       "the ODS definition to remove this warning.",
2904                       opName));
2905     diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2906                      odsOp->getLoc());
2907     return;
2908   }
2909 }
2910 
2911 LogicalResult Parser::validateOperationOperandsOrResults(
2912     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2913     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2914     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2915     ast::Type rangeTy) {
2916   // All operation types accept a single range parameter.
2917   if (values.size() == 1) {
2918     if (failed(convertExpressionTo(values[0], rangeTy)))
2919       return failure();
2920     return success();
2921   }
2922 
2923   /// If the operation has ODS information, we can more accurately verify the
2924   /// values.
2925   if (odsOpLoc) {
2926     if (odsValues.size() != values.size()) {
2927       return emitErrorAndNote(
2928           loc,
2929           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2930                         "{2}, but got {3}",
2931                         groupName, *name, odsValues.size(), values.size()),
2932           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2933     }
2934     auto diagFn = [&](ast::Diagnostic &diag) {
2935       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2936                       *odsOpLoc);
2937     };
2938     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2939       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2940       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2941         return failure();
2942     }
2943     return success();
2944   }
2945 
2946   // Otherwise, accept the value groups as they have been defined and just
2947   // ensure they are one of the expected types.
2948   for (ast::Expr *&valueExpr : values) {
2949     ast::Type valueExprType = valueExpr->getType();
2950 
2951     // Check if this is one of the expected types.
2952     if (valueExprType == rangeTy || valueExprType == singleTy)
2953       continue;
2954 
2955     // If the operand is an Operation, allow converting to a Value or
2956     // ValueRange. This situations arises quite often with nested operation
2957     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2958     if (singleTy == valueTy) {
2959       if (valueExprType.isa<ast::OperationType>()) {
2960         valueExpr = convertOpToValue(valueExpr);
2961         continue;
2962       }
2963     }
2964 
2965     return emitError(
2966         valueExpr->getLoc(),
2967         llvm::formatv(
2968             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2969             singleTy, rangeTy, valueExprType));
2970   }
2971   return success();
2972 }
2973 
2974 FailureOr<ast::TupleExpr *>
2975 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2976                         ArrayRef<StringRef> elementNames) {
2977   for (const ast::Expr *element : elements) {
2978     ast::Type eleTy = element->getType();
2979     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2980       return emitError(
2981           element->getLoc(),
2982           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2983     }
2984   }
2985   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2986 }
2987 
2988 //===----------------------------------------------------------------------===//
2989 // Stmts
2990 
2991 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2992                                                     ast::Expr *rootOp) {
2993   // Check that root is an Operation.
2994   ast::Type rootType = rootOp->getType();
2995   if (!rootType.isa<ast::OperationType>())
2996     return emitError(rootOp->getLoc(), "expected `Op` expression");
2997 
2998   return ast::EraseStmt::create(ctx, loc, rootOp);
2999 }
3000 
3001 FailureOr<ast::ReplaceStmt *>
3002 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3003                           MutableArrayRef<ast::Expr *> replValues) {
3004   // Check that root is an Operation.
3005   ast::Type rootType = rootOp->getType();
3006   if (!rootType.isa<ast::OperationType>()) {
3007     return emitError(
3008         rootOp->getLoc(),
3009         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3010   }
3011 
3012   // If there are multiple replacement values, we implicitly convert any Op
3013   // expressions to the value form.
3014   bool shouldConvertOpToValues = replValues.size() > 1;
3015   for (ast::Expr *&replExpr : replValues) {
3016     ast::Type replType = replExpr->getType();
3017 
3018     // Check that replExpr is an Operation, Value, or ValueRange.
3019     if (replType.isa<ast::OperationType>()) {
3020       if (shouldConvertOpToValues)
3021         replExpr = convertOpToValue(replExpr);
3022       continue;
3023     }
3024 
3025     if (replType != valueTy && replType != valueRangeTy) {
3026       return emitError(replExpr->getLoc(),
3027                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3028                                      "expression, but got `{0}`",
3029                                      replType));
3030     }
3031   }
3032 
3033   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3034 }
3035 
3036 FailureOr<ast::RewriteStmt *>
3037 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3038                           ast::CompoundStmt *rewriteBody) {
3039   // Check that root is an Operation.
3040   ast::Type rootType = rootOp->getType();
3041   if (!rootType.isa<ast::OperationType>()) {
3042     return emitError(
3043         rootOp->getLoc(),
3044         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3045   }
3046 
3047   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3048 }
3049 
3050 //===----------------------------------------------------------------------===//
3051 // Code Completion
3052 //===----------------------------------------------------------------------===//
3053 
3054 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3055   ast::Type parentType = parentExpr->getType();
3056   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3057     codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3058   else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3059     codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3060   return failure();
3061 }
3062 
3063 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
3064   if (opName)
3065     codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3066   return failure();
3067 }
3068 
3069 LogicalResult
3070 Parser::codeCompleteConstraintName(ast::Type inferredType,
3071                                    bool allowNonCoreConstraints,
3072                                    bool allowInlineTypeConstraints) {
3073   codeCompleteContext->codeCompleteConstraintName(
3074       inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
3075       curDeclScope);
3076   return failure();
3077 }
3078 
3079 LogicalResult Parser::codeCompleteDialectName() {
3080   codeCompleteContext->codeCompleteDialectName();
3081   return failure();
3082 }
3083 
3084 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3085   codeCompleteContext->codeCompleteOperationName(dialectName);
3086   return failure();
3087 }
3088 
3089 LogicalResult Parser::codeCompletePatternMetadata() {
3090   codeCompleteContext->codeCompletePatternMetadata();
3091   return failure();
3092 }
3093 
3094 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3095   codeCompleteContext->codeCompleteIncludeFilename(curPath);
3096   return failure();
3097 }
3098 
3099 void Parser::codeCompleteCallSignature(ast::Node *parent,
3100                                        unsigned currentNumArgs) {
3101   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3102   if (!callableDecl)
3103     return;
3104 
3105   codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3106 }
3107 
3108 void Parser::codeCompleteOperationOperandsSignature(
3109     Optional<StringRef> opName, unsigned currentNumOperands) {
3110   codeCompleteContext->codeCompleteOperationOperandsSignature(
3111       opName, currentNumOperands);
3112 }
3113 
3114 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
3115                                                    unsigned currentNumResults) {
3116   codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3117                                                              currentNumResults);
3118 }
3119 
3120 //===----------------------------------------------------------------------===//
3121 // Parser
3122 //===----------------------------------------------------------------------===//
3123 
3124 FailureOr<ast::Module *>
3125 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3126                          bool enableDocumentation,
3127                          CodeCompleteContext *codeCompleteContext) {
3128   Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3129   return parser.parseModule();
3130 }
3131