1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/TableGen/Argument.h"
13 #include "mlir/TableGen/Attribute.h"
14 #include "mlir/TableGen/Constraint.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/AST/Types.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Operation.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ManagedStatic.h"
28 #include "llvm/Support/SaveAndRestore.h"
29 #include "llvm/Support/ScopedPrinter.h"
30 #include "llvm/TableGen/Error.h"
31 #include "llvm/TableGen/Parser.h"
32 #include <string>
33 
34 using namespace mlir;
35 using namespace mlir::pdll;
36 
37 //===----------------------------------------------------------------------===//
38 // Parser
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 class Parser {
43 public:
44   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
45       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
46         curToken(lexer.lexToken()), curDeclScope(nullptr),
47         valueTy(ast::ValueType::get(ctx)),
48         valueRangeTy(ast::ValueRangeType::get(ctx)),
49         typeTy(ast::TypeType::get(ctx)),
50         typeRangeTy(ast::TypeRangeType::get(ctx)),
51         attrTy(ast::AttributeType::get(ctx)) {}
52 
53   /// Try to parse a new module. Returns nullptr in the case of failure.
54   FailureOr<ast::Module *> parseModule();
55 
56 private:
57   /// The current context of the parser. It allows for the parser to know a bit
58   /// about the construct it is nested within during parsing. This is used
59   /// specifically to provide additional verification during parsing, e.g. to
60   /// prevent using rewrites within a match context, matcher constraints within
61   /// a rewrite section, etc.
62   enum class ParserContext {
63     /// The parser is in the global context.
64     Global,
65     /// The parser is currently within a Constraint, which disallows all types
66     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
67     Constraint,
68     /// The parser is currently within the matcher portion of a Pattern, which
69     /// is allows a terminal operation rewrite statement but no other rewrite
70     /// transformations.
71     PatternMatch,
72     /// The parser is currently within a Rewrite, which disallows calls to
73     /// constraints, requires operation expressions to have names, etc.
74     Rewrite,
75   };
76 
77   //===--------------------------------------------------------------------===//
78   // Parsing
79   //===--------------------------------------------------------------------===//
80 
81   /// Push a new decl scope onto the lexer.
82   ast::DeclScope *pushDeclScope() {
83     ast::DeclScope *newScope =
84         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
85     return (curDeclScope = newScope);
86   }
87   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
88 
89   /// Pop the last decl scope from the lexer.
90   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
91 
92   /// Parse the body of an AST module.
93   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
94 
95   /// Try to convert the given expression to `type`. Returns failure and emits
96   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
97   /// invoked to attach notes to the emitted error diagnostic. On success,
98   /// `expr` is updated to the expression used to convert to `type`.
99   LogicalResult convertExpressionTo(
100       ast::Expr *&expr, ast::Type type,
101       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
102 
103   /// Given an operation expression, convert it to a Value or ValueRange
104   /// typed expression.
105   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
106 
107   /// Lookup ODS information for the given operation, returns nullptr if no
108   /// information is found.
109   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
110     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
111   }
112 
113   //===--------------------------------------------------------------------===//
114   // Directives
115 
116   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
117   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
118   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
119                                SmallVectorImpl<ast::Decl *> &decls);
120 
121   /// Process the records of a parsed tablegen include file.
122   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
123                                SmallVectorImpl<ast::Decl *> &decls);
124 
125   /// Create a user defined native constraint for a constraint imported from
126   /// ODS.
127   template <typename ConstraintT>
128   ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
129                                                StringRef codeBlock, SMRange loc,
130                                                ast::Type type);
131   template <typename ConstraintT>
132   ast::Decl *
133   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
134                                     SMRange loc, ast::Type type);
135 
136   //===--------------------------------------------------------------------===//
137   // Decls
138 
139   /// This structure contains the set of pattern metadata that may be parsed.
140   struct ParsedPatternMetadata {
141     Optional<uint16_t> benefit;
142     bool hasBoundedRecursion = false;
143   };
144 
145   FailureOr<ast::Decl *> parseTopLevelDecl();
146   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
147 
148   /// Parse an argument variable as part of the signature of a
149   /// UserConstraintDecl or UserRewriteDecl.
150   FailureOr<ast::VariableDecl *> parseArgumentDecl();
151 
152   /// Parse a result variable as part of the signature of a UserConstraintDecl
153   /// or UserRewriteDecl.
154   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
155 
156   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
157   /// defined in a non-global context.
158   FailureOr<ast::UserConstraintDecl *>
159   parseUserConstraintDecl(bool isInline = false);
160 
161   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
162   /// non-global context, such as within a Pattern/Constraint/etc.
163   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
164 
165   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
166   /// PDLL constructs.
167   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
168       const ast::Name &name, bool isInline,
169       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
170       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
171 
172   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
173   /// defined in a non-global context.
174   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
175 
176   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
177   /// non-global context, such as within a Pattern/Rewrite/etc.
178   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
179 
180   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
181   /// PDLL constructs.
182   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
183       const ast::Name &name, bool isInline,
184       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
185       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
186 
187   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
188   /// effectively the same syntax, and only differ on slight semantics (given
189   /// the different parsing contexts).
190   template <typename T, typename ParseUserPDLLDeclFnT>
191   FailureOr<T *> parseUserConstraintOrRewriteDecl(
192       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
193       StringRef anonymousNamePrefix, bool isInline);
194 
195   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
196   /// These decls have effectively the same syntax.
197   template <typename T>
198   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
199       const ast::Name &name, bool isInline,
200       ArrayRef<ast::VariableDecl *> arguments,
201       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
202 
203   /// Parse the functional signature (i.e. the arguments and results) of a
204   /// UserConstraintDecl or UserRewriteDecl.
205   LogicalResult parseUserConstraintOrRewriteSignature(
206       SmallVectorImpl<ast::VariableDecl *> &arguments,
207       SmallVectorImpl<ast::VariableDecl *> &results,
208       ast::DeclScope *&argumentScope, ast::Type &resultType);
209 
210   /// Validate the return (which if present is specified by bodyIt) of a
211   /// UserConstraintDecl or UserRewriteDecl.
212   LogicalResult validateUserConstraintOrRewriteReturn(
213       StringRef declType, ast::CompoundStmt *body,
214       ArrayRef<ast::Stmt *>::iterator bodyIt,
215       ArrayRef<ast::Stmt *>::iterator bodyE,
216       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
217 
218   FailureOr<ast::CompoundStmt *>
219   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
220                   bool expectTerminalSemicolon = true);
221   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
222   FailureOr<ast::Decl *> parsePatternDecl();
223   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
224 
225   /// Check to see if a decl has already been defined with the given name, if
226   /// one has emit and error and return failure. Returns success otherwise.
227   LogicalResult checkDefineNamedDecl(const ast::Name &name);
228 
229   /// Try to define a variable decl with the given components, returns the
230   /// variable on success.
231   FailureOr<ast::VariableDecl *>
232   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
233                      ast::Expr *initExpr,
234                      ArrayRef<ast::ConstraintRef> constraints);
235   FailureOr<ast::VariableDecl *>
236   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
237                      ArrayRef<ast::ConstraintRef> constraints);
238 
239   /// Parse the constraint reference list for a variable decl.
240   LogicalResult parseVariableDeclConstraintList(
241       SmallVectorImpl<ast::ConstraintRef> &constraints);
242 
243   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
244   FailureOr<ast::Expr *> parseTypeConstraintExpr();
245 
246   /// Try to parse a single reference to a constraint. `typeConstraint` is the
247   /// location of a previously parsed type constraint for the entity that will
248   /// be constrained by the parsed constraint. `existingConstraints` are any
249   /// existing constraints that have already been parsed for the same entity
250   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
251   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
252   FailureOr<ast::ConstraintRef>
253   parseConstraint(Optional<SMRange> &typeConstraint,
254                   ArrayRef<ast::ConstraintRef> existingConstraints,
255                   bool allowInlineTypeConstraints);
256 
257   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
258   /// argument or result variable. The constraints for these variables do not
259   /// allow inline type constraints, and only permit a single constraint.
260   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
261 
262   //===--------------------------------------------------------------------===//
263   // Exprs
264 
265   FailureOr<ast::Expr *> parseExpr();
266 
267   /// Identifier expressions.
268   FailureOr<ast::Expr *> parseAttributeExpr();
269   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
270   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
271   FailureOr<ast::Expr *> parseIdentifierExpr();
272   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
273   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
274   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
275   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
276   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
277   FailureOr<ast::Expr *> parseOperationExpr();
278   FailureOr<ast::Expr *> parseTupleExpr();
279   FailureOr<ast::Expr *> parseTypeExpr();
280   FailureOr<ast::Expr *> parseUnderscoreExpr();
281 
282   //===--------------------------------------------------------------------===//
283   // Stmts
284 
285   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
286   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
287   FailureOr<ast::EraseStmt *> parseEraseStmt();
288   FailureOr<ast::LetStmt *> parseLetStmt();
289   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
290   FailureOr<ast::ReturnStmt *> parseReturnStmt();
291   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
292 
293   //===--------------------------------------------------------------------===//
294   // Creation+Analysis
295   //===--------------------------------------------------------------------===//
296 
297   //===--------------------------------------------------------------------===//
298   // Decls
299 
300   /// Try to extract a callable from the given AST node. Returns nullptr on
301   /// failure.
302   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
303 
304   /// Try to create a pattern decl with the given components, returning the
305   /// Pattern on success.
306   FailureOr<ast::PatternDecl *>
307   createPatternDecl(SMRange loc, const ast::Name *name,
308                     const ParsedPatternMetadata &metadata,
309                     ast::CompoundStmt *body);
310 
311   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
312   /// of results, defined as part of the signature.
313   ast::Type
314   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
315 
316   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
317   template <typename T>
318   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
319       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
320       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
321       ast::CompoundStmt *body);
322 
323   /// Try to create a variable decl with the given components, returning the
324   /// Variable on success.
325   FailureOr<ast::VariableDecl *>
326   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
327                      ArrayRef<ast::ConstraintRef> constraints);
328 
329   /// Create a variable for an argument or result defined as part of the
330   /// signature of a UserConstraintDecl/UserRewriteDecl.
331   FailureOr<ast::VariableDecl *>
332   createArgOrResultVariableDecl(StringRef name, SMRange loc,
333                                 const ast::ConstraintRef &constraint);
334 
335   /// Validate the constraints used to constraint a variable decl.
336   /// `inferredType` is the type of the variable inferred by the constraints
337   /// within the list, and is updated to the most refined type as determined by
338   /// the constraints. Returns success if the constraint list is valid, failure
339   /// otherwise.
340   LogicalResult
341   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
342                               ast::Type &inferredType);
343   /// Validate a single reference to a constraint. `inferredType` contains the
344   /// currently inferred variabled type and is refined within the type defined
345   /// by the constraint. Returns success if the constraint is valid, failure
346   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
347   /// defined constraints) may be used with the variable.
348   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
349                                            ast::Type &inferredType,
350                                            bool allowNonCoreConstraints = true);
351   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
352   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
353 
354   //===--------------------------------------------------------------------===//
355   // Exprs
356 
357   FailureOr<ast::CallExpr *>
358   createCallExpr(SMRange loc, ast::Expr *parentExpr,
359                  MutableArrayRef<ast::Expr *> arguments);
360   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
361   FailureOr<ast::DeclRefExpr *>
362   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
363                            ArrayRef<ast::ConstraintRef> constraints);
364   FailureOr<ast::MemberAccessExpr *>
365   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
366 
367   /// Validate the member access `name` into the given parent expression. On
368   /// success, this also returns the type of the member accessed.
369   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
370                                             StringRef name, SMRange loc);
371   FailureOr<ast::OperationExpr *>
372   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
373                       MutableArrayRef<ast::Expr *> operands,
374                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
375                       MutableArrayRef<ast::Expr *> results);
376   LogicalResult
377   validateOperationOperands(SMRange loc, Optional<StringRef> name,
378                             const ods::Operation *odsOp,
379                             MutableArrayRef<ast::Expr *> operands);
380   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
381                                          const ods::Operation *odsOp,
382                                          MutableArrayRef<ast::Expr *> results);
383   LogicalResult validateOperationOperandsOrResults(
384       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
385       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
386       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
387       ast::Type rangeTy);
388   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
389                                               ArrayRef<ast::Expr *> elements,
390                                               ArrayRef<StringRef> elementNames);
391 
392   //===--------------------------------------------------------------------===//
393   // Stmts
394 
395   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
396   FailureOr<ast::ReplaceStmt *>
397   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
398                     MutableArrayRef<ast::Expr *> replValues);
399   FailureOr<ast::RewriteStmt *>
400   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
401                     ast::CompoundStmt *rewriteBody);
402 
403   //===--------------------------------------------------------------------===//
404   // Lexer Utilities
405   //===--------------------------------------------------------------------===//
406 
407   /// If the current token has the specified kind, consume it and return true.
408   /// If not, return false.
409   bool consumeIf(Token::Kind kind) {
410     if (curToken.isNot(kind))
411       return false;
412     consumeToken(kind);
413     return true;
414   }
415 
416   /// Advance the current lexer onto the next token.
417   void consumeToken() {
418     assert(curToken.isNot(Token::eof, Token::error) &&
419            "shouldn't advance past EOF or errors");
420     curToken = lexer.lexToken();
421   }
422 
423   /// Advance the current lexer onto the next token, asserting what the expected
424   /// current token is. This is preferred to the above method because it leads
425   /// to more self-documenting code with better checking.
426   void consumeToken(Token::Kind kind) {
427     assert(curToken.is(kind) && "consumed an unexpected token");
428     consumeToken();
429   }
430 
431   /// Reset the lexer to the location at the given position.
432   void resetToken(SMRange tokLoc) {
433     lexer.resetPointer(tokLoc.Start.getPointer());
434     curToken = lexer.lexToken();
435   }
436 
437   /// Consume the specified token if present and return success. On failure,
438   /// output a diagnostic and return failure.
439   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
440     if (curToken.getKind() != kind)
441       return emitError(curToken.getLoc(), msg);
442     consumeToken();
443     return success();
444   }
445   LogicalResult emitError(SMRange loc, const Twine &msg) {
446     lexer.emitError(loc, msg);
447     return failure();
448   }
449   LogicalResult emitError(const Twine &msg) {
450     return emitError(curToken.getLoc(), msg);
451   }
452   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
453                                  const Twine &note) {
454     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
455     return failure();
456   }
457 
458   //===--------------------------------------------------------------------===//
459   // Fields
460   //===--------------------------------------------------------------------===//
461 
462   /// The owning AST context.
463   ast::Context &ctx;
464 
465   /// The lexer of this parser.
466   Lexer lexer;
467 
468   /// The current token within the lexer.
469   Token curToken;
470 
471   /// The most recently defined decl scope.
472   ast::DeclScope *curDeclScope;
473   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
474 
475   /// The current context of the parser.
476   ParserContext parserContext = ParserContext::Global;
477 
478   /// Cached types to simplify verification and expression creation.
479   ast::Type valueTy, valueRangeTy;
480   ast::Type typeTy, typeRangeTy;
481   ast::Type attrTy;
482 
483   /// A counter used when naming anonymous constraints and rewrites.
484   unsigned anonymousDeclNameCounter = 0;
485 };
486 } // namespace
487 
488 FailureOr<ast::Module *> Parser::parseModule() {
489   SMLoc moduleLoc = curToken.getStartLoc();
490   pushDeclScope();
491 
492   // Parse the top-level decls of the module.
493   SmallVector<ast::Decl *> decls;
494   if (failed(parseModuleBody(decls)))
495     return popDeclScope(), failure();
496 
497   popDeclScope();
498   return ast::Module::create(ctx, moduleLoc, decls);
499 }
500 
501 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
502   while (curToken.isNot(Token::eof)) {
503     if (curToken.is(Token::directive)) {
504       if (failed(parseDirective(decls)))
505         return failure();
506       continue;
507     }
508 
509     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
510     if (failed(decl))
511       return failure();
512     decls.push_back(*decl);
513   }
514   return success();
515 }
516 
517 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
518   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
519                                                  valueRangeTy);
520 }
521 
522 LogicalResult Parser::convertExpressionTo(
523     ast::Expr *&expr, ast::Type type,
524     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
525   ast::Type exprType = expr->getType();
526   if (exprType == type)
527     return success();
528 
529   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
530     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
531         expr->getLoc(), llvm::formatv("unable to convert expression of type "
532                                       "`{0}` to the expected type of "
533                                       "`{1}`",
534                                       exprType, type));
535     if (noteAttachFn)
536       noteAttachFn(*diag);
537     return diag;
538   };
539 
540   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
541     // Two operation types are compatible if they have the same name, or if the
542     // expected type is more general.
543     if (auto opType = type.dyn_cast<ast::OperationType>()) {
544       if (opType.getName())
545         return emitConvertError();
546       return success();
547     }
548 
549     // An operation can always convert to a ValueRange.
550     if (type == valueRangeTy) {
551       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
552                                                      valueRangeTy);
553       return success();
554     }
555 
556     // Allow conversion to a single value by constraining the result range.
557     if (type == valueTy) {
558       // If the operation is registered, we can verify if it can ever have a
559       // single result.
560       Optional<StringRef> opName = exprOpType.getName();
561       if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
562         if (odsOp->getResults().empty()) {
563           return emitConvertError()->attachNote(
564               llvm::formatv("see the definition of `{0}`, which was defined "
565                             "with zero results",
566                             odsOp->getName()),
567               odsOp->getLoc());
568         }
569 
570         unsigned numSingleResults = llvm::count_if(
571             odsOp->getResults(), [](const ods::OperandOrResult &result) {
572               return result.getVariableLengthKind() ==
573                      ods::VariableLengthKind::Single;
574             });
575         if (numSingleResults > 1) {
576           return emitConvertError()->attachNote(
577               llvm::formatv("see the definition of `{0}`, which was defined "
578                             "with at least {1} results",
579                             odsOp->getName(), numSingleResults),
580               odsOp->getLoc());
581         }
582       }
583 
584       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
585                                                      valueTy);
586       return success();
587     }
588     return emitConvertError();
589   }
590 
591   // FIXME: Decide how to allow/support converting a single result to multiple,
592   // and multiple to a single result. For now, we just allow Single->Range,
593   // but this isn't something really supported in the PDL dialect. We should
594   // figure out some way to support both.
595   if ((exprType == valueTy || exprType == valueRangeTy) &&
596       (type == valueTy || type == valueRangeTy))
597     return success();
598   if ((exprType == typeTy || exprType == typeRangeTy) &&
599       (type == typeTy || type == typeRangeTy))
600     return success();
601 
602   // Handle tuple types.
603   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
604     auto tupleType = type.dyn_cast<ast::TupleType>();
605     if (!tupleType || tupleType.size() != exprTupleType.size())
606       return emitConvertError();
607 
608     // Build a new tuple expression using each of the elements of the current
609     // tuple.
610     SmallVector<ast::Expr *> newExprs;
611     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
612       newExprs.push_back(ast::MemberAccessExpr::create(
613           ctx, expr->getLoc(), expr, llvm::to_string(i),
614           exprTupleType.getElementTypes()[i]));
615 
616       auto diagFn = [&](ast::Diagnostic &diag) {
617         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
618                                       i, exprTupleType));
619         if (noteAttachFn)
620           noteAttachFn(diag);
621       };
622       if (failed(convertExpressionTo(newExprs.back(),
623                                      tupleType.getElementTypes()[i], diagFn)))
624         return failure();
625     }
626     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
627                                   tupleType.getElementNames());
628     return success();
629   }
630 
631   return emitConvertError();
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // Directives
636 
637 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
638   StringRef directive = curToken.getSpelling();
639   if (directive == "#include")
640     return parseInclude(decls);
641 
642   return emitError("unknown directive `" + directive + "`");
643 }
644 
645 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
646   SMRange loc = curToken.getLoc();
647   consumeToken(Token::directive);
648 
649   // Parse the file being included.
650   if (!curToken.isString())
651     return emitError(loc,
652                      "expected string file name after `include` directive");
653   SMRange fileLoc = curToken.getLoc();
654   std::string filenameStr = curToken.getStringValue();
655   StringRef filename = filenameStr;
656   consumeToken();
657 
658   // Check the type of include. If ending with `.pdll`, this is another pdl file
659   // to be parsed along with the current module.
660   if (filename.endswith(".pdll")) {
661     if (failed(lexer.pushInclude(filename)))
662       return emitError(fileLoc,
663                        "unable to open include file `" + filename + "`");
664 
665     // If we added the include successfully, parse it into the current module.
666     // Make sure to save the current token so that we can restore it when we
667     // finish parsing the nested file.
668     Token oldToken = curToken;
669     curToken = lexer.lexToken();
670     LogicalResult result = parseModuleBody(decls);
671     curToken = oldToken;
672     return result;
673   }
674 
675   // Otherwise, this must be a `.td` include.
676   if (filename.endswith(".td"))
677     return parseTdInclude(filename, fileLoc, decls);
678 
679   return emitError(fileLoc,
680                    "expected include filename to end with `.pdll` or `.td`");
681 }
682 
683 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
684                                      SmallVectorImpl<ast::Decl *> &decls) {
685   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
686 
687   // This class provides a context argument for the llvm::SourceMgr diagnostic
688   // handler.
689   struct DiagHandlerContext {
690     Parser &parser;
691     StringRef filename;
692     llvm::SMRange loc;
693   } handlerContext{*this, filename, fileLoc};
694 
695   // Set the diagnostic handler for the tablegen source manager.
696   llvm::SrcMgr.setDiagHandler(
697       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
698         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
699         (void)ctx->parser.emitError(
700             ctx->loc,
701             llvm::formatv("error while processing include file `{0}`: {1}",
702                           ctx->filename, diag.getMessage()));
703       },
704       &handlerContext);
705 
706   // Use the source manager to open the file, but don't yet add it.
707   std::string includedFile;
708   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
709       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
710   if (!includeBuffer)
711     return emitError(fileLoc, "unable to open include file `" + filename + "`");
712 
713   auto processFn = [&](llvm::RecordKeeper &records) {
714     processTdIncludeRecords(records, decls);
715 
716     // After we are done processing, move all of the tablegen source buffers to
717     // the main parser source mgr. This allows for directly using source
718     // locations from the .td files without needing to remap them.
719     parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr);
720     return false;
721   };
722   if (llvm::TableGenParseFile(std::move(*includeBuffer),
723                               parserSrcMgr.getIncludeDirs(), processFn))
724     return failure();
725 
726   return success();
727 }
728 
729 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
730                                      SmallVectorImpl<ast::Decl *> &decls) {
731   // Return the length kind of the given value.
732   auto getLengthKind = [](const auto &value) {
733     if (value.isOptional())
734       return ods::VariableLengthKind::Optional;
735     return value.isVariadic() ? ods::VariableLengthKind::Variadic
736                               : ods::VariableLengthKind::Single;
737   };
738 
739   // Insert a type constraint into the ODS context.
740   ods::Context &odsContext = ctx.getODSContext();
741   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
742       -> const ods::TypeConstraint & {
743     return odsContext.insertTypeConstraint(cst.constraint.getDefName(),
744                                            cst.constraint.getSummary(),
745                                            cst.constraint.getCPPClassName());
746   };
747   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
748     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
749   };
750 
751   // Process the parsed tablegen records to build ODS information.
752   /// Operations.
753   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
754     tblgen::Operator op(def);
755 
756     bool inserted = false;
757     ods::Operation *odsOp = nullptr;
758     std::tie(odsOp, inserted) =
759         odsContext.insertOperation(op.getOperationName(), op.getSummary(),
760                                    op.getDescription(), op.getLoc().front());
761 
762     // Ignore operations that have already been added.
763     if (!inserted)
764       continue;
765 
766     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
767       odsOp->appendAttribute(
768           attr.name, attr.attr.isOptional(),
769           odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(),
770                                                attr.attr.getSummary(),
771                                                attr.attr.getStorageType()));
772     }
773     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
774       odsOp->appendOperand(operand.name, getLengthKind(operand),
775                            addTypeConstraint(operand));
776     }
777     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
778       odsOp->appendResult(result.name, getLengthKind(result),
779                           addTypeConstraint(result));
780     }
781   }
782   /// Attr constraints.
783   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
784     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
785       decls.push_back(
786           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
787               tblgen::AttrConstraint(def),
788               convertLocToRange(def->getLoc().front()), attrTy));
789     }
790   }
791   /// Type constraints.
792   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
793     if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
794       decls.push_back(
795           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
796               tblgen::TypeConstraint(def),
797               convertLocToRange(def->getLoc().front()), typeTy));
798     }
799   }
800   /// Interfaces.
801   ast::Type opTy = ast::OperationType::get(ctx);
802   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
803     StringRef name = def->getName();
804     if (def->isAnonymous() || curDeclScope->lookup(name) ||
805         def->isSubClassOf("DeclareInterfaceMethods"))
806       continue;
807     SMRange loc = convertLocToRange(def->getLoc().front());
808 
809     StringRef className = def->getValueAsString("cppClassName");
810     StringRef cppNamespace = def->getValueAsString("cppNamespace");
811     std::string codeBlock =
812         llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className)
813             .str();
814 
815     if (def->isSubClassOf("OpInterface")) {
816       decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
817           name, codeBlock, loc, opTy));
818     } else if (def->isSubClassOf("AttrInterface")) {
819       decls.push_back(
820           createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
821               name, codeBlock, loc, attrTy));
822     } else if (def->isSubClassOf("TypeInterface")) {
823       decls.push_back(
824           createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
825               name, codeBlock, loc, typeTy));
826     }
827   }
828 }
829 
830 template <typename ConstraintT>
831 ast::Decl *
832 Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
833                                           SMRange loc, ast::Type type) {
834   // Build the single input parameter.
835   ast::DeclScope *argScope = pushDeclScope();
836   auto *paramVar = ast::VariableDecl::create(
837       ctx, ast::Name::create(ctx, "self", loc), type,
838       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
839   argScope->add(paramVar);
840   popDeclScope();
841 
842   // Build the native constraint.
843   auto *constraintDecl = ast::UserConstraintDecl::createNative(
844       ctx, ast::Name::create(ctx, name, loc), paramVar,
845       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
846   curDeclScope->add(constraintDecl);
847   return constraintDecl;
848 }
849 
850 template <typename ConstraintT>
851 ast::Decl *
852 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
853                                           SMRange loc, ast::Type type) {
854   // Format the condition template.
855   tblgen::FmtContext fmtContext;
856   fmtContext.withSelf("self");
857   std::string codeBlock =
858       tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext);
859 
860   return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(),
861                                                         codeBlock, loc, type);
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // Decls
866 
867 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
868   FailureOr<ast::Decl *> decl;
869   switch (curToken.getKind()) {
870   case Token::kw_Constraint:
871     decl = parseUserConstraintDecl();
872     break;
873   case Token::kw_Pattern:
874     decl = parsePatternDecl();
875     break;
876   case Token::kw_Rewrite:
877     decl = parseUserRewriteDecl();
878     break;
879   default:
880     return emitError("expected top-level declaration, such as a `Pattern`");
881   }
882   if (failed(decl))
883     return failure();
884 
885   // If the decl has a name, add it to the current scope.
886   if (const ast::Name *name = (*decl)->getName()) {
887     if (failed(checkDefineNamedDecl(*name)))
888       return failure();
889     curDeclScope->add(*decl);
890   }
891   return decl;
892 }
893 
894 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
895   std::string attrNameStr;
896   if (curToken.isString())
897     attrNameStr = curToken.getStringValue();
898   else if (curToken.is(Token::identifier) || curToken.isKeyword())
899     attrNameStr = curToken.getSpelling().str();
900   else
901     return emitError("expected identifier or string attribute name");
902   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
903   consumeToken();
904 
905   // Check for a value of the attribute.
906   ast::Expr *attrValue = nullptr;
907   if (consumeIf(Token::equal)) {
908     FailureOr<ast::Expr *> attrExpr = parseExpr();
909     if (failed(attrExpr))
910       return failure();
911     attrValue = *attrExpr;
912   } else {
913     // If there isn't a concrete value, create an expression representing a
914     // UnitAttr.
915     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
916   }
917 
918   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
919 }
920 
921 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
922     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
923     bool expectTerminalSemicolon) {
924   consumeToken(Token::equal_arrow);
925 
926   // Parse the single statement of the lambda body.
927   SMLoc bodyStartLoc = curToken.getStartLoc();
928   pushDeclScope();
929   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
930   bool failedToParse =
931       failed(singleStatement) || failed(processStatementFn(*singleStatement));
932   popDeclScope();
933   if (failedToParse)
934     return failure();
935 
936   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
937   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
938 }
939 
940 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
941   // Ensure that the argument is named.
942   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
943     return emitError("expected identifier argument name");
944 
945   // Parse the argument similarly to a normal variable.
946   StringRef name = curToken.getSpelling();
947   SMRange nameLoc = curToken.getLoc();
948   consumeToken();
949 
950   if (failed(
951           parseToken(Token::colon, "expected `:` before argument constraint")))
952     return failure();
953 
954   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
955   if (failed(cst))
956     return failure();
957 
958   return createArgOrResultVariableDecl(name, nameLoc, *cst);
959 }
960 
961 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
962   // Check to see if this result is named.
963   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
964     // Check to see if this name actually refers to a Constraint.
965     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
966     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
967       // If yes, and this is a Rewrite, give a nice error message as non-Core
968       // constraints are not supported on Rewrite results.
969       if (parserContext == ParserContext::Rewrite) {
970         return emitError(
971             "`Rewrite` results are only permitted to use core constraints, "
972             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
973       }
974 
975       // Otherwise, parse this as an unnamed result variable.
976     } else {
977       // If it wasn't a constraint, parse the result similarly to a variable. If
978       // there is already an existing decl, we will emit an error when defining
979       // this variable later.
980       StringRef name = curToken.getSpelling();
981       SMRange nameLoc = curToken.getLoc();
982       consumeToken();
983 
984       if (failed(parseToken(Token::colon,
985                             "expected `:` before result constraint")))
986         return failure();
987 
988       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
989       if (failed(cst))
990         return failure();
991 
992       return createArgOrResultVariableDecl(name, nameLoc, *cst);
993     }
994   }
995 
996   // If it isn't named, we parse the constraint directly and create an unnamed
997   // result variable.
998   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
999   if (failed(cst))
1000     return failure();
1001 
1002   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1003 }
1004 
1005 FailureOr<ast::UserConstraintDecl *>
1006 Parser::parseUserConstraintDecl(bool isInline) {
1007   // Constraints and rewrites have very similar formats, dispatch to a shared
1008   // interface for parsing.
1009   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1010       [&](auto &&...args) {
1011         return this->parseUserPDLLConstraintDecl(args...);
1012       },
1013       ParserContext::Constraint, "constraint", isInline);
1014 }
1015 
1016 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1017   FailureOr<ast::UserConstraintDecl *> decl =
1018       parseUserConstraintDecl(/*isInline=*/true);
1019   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1020     return failure();
1021 
1022   curDeclScope->add(*decl);
1023   return decl;
1024 }
1025 
1026 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1027     const ast::Name &name, bool isInline,
1028     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1029     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1030   // Push the argument scope back onto the list, so that the body can
1031   // reference arguments.
1032   pushDeclScope(argumentScope);
1033 
1034   // Parse the body of the constraint. The body is either defined as a compound
1035   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1036   ast::CompoundStmt *body;
1037   if (curToken.is(Token::equal_arrow)) {
1038     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1039         [&](ast::Stmt *&stmt) -> LogicalResult {
1040           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1041           if (!stmtExpr) {
1042             return emitError(stmt->getLoc(),
1043                              "expected `Constraint` lambda body to contain a "
1044                              "single expression");
1045           }
1046           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1047           return success();
1048         },
1049         /*expectTerminalSemicolon=*/!isInline);
1050     if (failed(bodyResult))
1051       return failure();
1052     body = *bodyResult;
1053   } else {
1054     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1055     if (failed(bodyResult))
1056       return failure();
1057     body = *bodyResult;
1058 
1059     // Verify the structure of the body.
1060     auto bodyIt = body->begin(), bodyE = body->end();
1061     for (; bodyIt != bodyE; ++bodyIt)
1062       if (isa<ast::ReturnStmt>(*bodyIt))
1063         break;
1064     if (failed(validateUserConstraintOrRewriteReturn(
1065             "Constraint", body, bodyIt, bodyE, results, resultType)))
1066       return failure();
1067   }
1068   popDeclScope();
1069 
1070   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1071       name, arguments, results, resultType, body);
1072 }
1073 
1074 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1075   // Constraints and rewrites have very similar formats, dispatch to a shared
1076   // interface for parsing.
1077   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1078       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1079       ParserContext::Rewrite, "rewrite", isInline);
1080 }
1081 
1082 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1083   FailureOr<ast::UserRewriteDecl *> decl =
1084       parseUserRewriteDecl(/*isInline=*/true);
1085   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1086     return failure();
1087 
1088   curDeclScope->add(*decl);
1089   return decl;
1090 }
1091 
1092 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1093     const ast::Name &name, bool isInline,
1094     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1095     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1096   // Push the argument scope back onto the list, so that the body can
1097   // reference arguments.
1098   curDeclScope = argumentScope;
1099   ast::CompoundStmt *body;
1100   if (curToken.is(Token::equal_arrow)) {
1101     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1102         [&](ast::Stmt *&statement) -> LogicalResult {
1103           if (isa<ast::OpRewriteStmt>(statement))
1104             return success();
1105 
1106           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1107           if (!statementExpr) {
1108             return emitError(
1109                 statement->getLoc(),
1110                 "expected `Rewrite` lambda body to contain a single expression "
1111                 "or an operation rewrite statement; such as `erase`, "
1112                 "`replace`, or `rewrite`");
1113           }
1114           statement =
1115               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1116           return success();
1117         },
1118         /*expectTerminalSemicolon=*/!isInline);
1119     if (failed(bodyResult))
1120       return failure();
1121     body = *bodyResult;
1122   } else {
1123     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1124     if (failed(bodyResult))
1125       return failure();
1126     body = *bodyResult;
1127   }
1128   popDeclScope();
1129 
1130   // Verify the structure of the body.
1131   auto bodyIt = body->begin(), bodyE = body->end();
1132   for (; bodyIt != bodyE; ++bodyIt)
1133     if (isa<ast::ReturnStmt>(*bodyIt))
1134       break;
1135   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1136                                                    bodyE, results, resultType)))
1137     return failure();
1138   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1139       name, arguments, results, resultType, body);
1140 }
1141 
1142 template <typename T, typename ParseUserPDLLDeclFnT>
1143 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1144     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1145     StringRef anonymousNamePrefix, bool isInline) {
1146   SMRange loc = curToken.getLoc();
1147   consumeToken();
1148   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1149 
1150   // Parse the name of the decl.
1151   const ast::Name *name = nullptr;
1152   if (curToken.isNot(Token::identifier)) {
1153     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1154     // in C++, so being unnamed is fine.
1155     if (!isInline)
1156       return emitError("expected identifier name");
1157 
1158     // Create a unique anonymous name to use, as the name for this decl is not
1159     // important.
1160     std::string anonName =
1161         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1162                       anonymousDeclNameCounter++)
1163             .str();
1164     name = &ast::Name::create(ctx, anonName, loc);
1165   } else {
1166     // If a name was provided, we can use it directly.
1167     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1168     consumeToken(Token::identifier);
1169   }
1170 
1171   // Parse the functional signature of the decl.
1172   SmallVector<ast::VariableDecl *> arguments, results;
1173   ast::DeclScope *argumentScope;
1174   ast::Type resultType;
1175   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1176                                                    argumentScope, resultType)))
1177     return failure();
1178 
1179   // Check to see which type of constraint this is. If the constraint contains a
1180   // compound body, this is a PDLL decl.
1181   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1182     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1183                            resultType);
1184 
1185   // Otherwise, this is a native decl.
1186   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1187                                                    results, resultType);
1188 }
1189 
1190 template <typename T>
1191 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1192     const ast::Name &name, bool isInline,
1193     ArrayRef<ast::VariableDecl *> arguments,
1194     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1195   // If followed by a string, the native code body has also been specified.
1196   std::string codeStrStorage;
1197   Optional<StringRef> optCodeStr;
1198   if (curToken.isString()) {
1199     codeStrStorage = curToken.getStringValue();
1200     optCodeStr = codeStrStorage;
1201     consumeToken();
1202   } else if (isInline) {
1203     return emitError(name.getLoc(),
1204                      "external declarations must be declared in global scope");
1205   }
1206   if (failed(parseToken(Token::semicolon,
1207                         "expected `;` after native declaration")))
1208     return failure();
1209   // TODO: PDL should be able to support constraint results in certain
1210   // situations, we should revise this.
1211   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1212     return emitError(
1213         "native Constraints currently do not support returning results");
1214   }
1215   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1216 }
1217 
1218 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1219     SmallVectorImpl<ast::VariableDecl *> &arguments,
1220     SmallVectorImpl<ast::VariableDecl *> &results,
1221     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1222   // Parse the argument list of the decl.
1223   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1224     return failure();
1225 
1226   argumentScope = pushDeclScope();
1227   if (curToken.isNot(Token::r_paren)) {
1228     do {
1229       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1230       if (failed(argument))
1231         return failure();
1232       arguments.emplace_back(*argument);
1233     } while (consumeIf(Token::comma));
1234   }
1235   popDeclScope();
1236   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1237     return failure();
1238 
1239   // Parse the results of the decl.
1240   pushDeclScope();
1241   if (consumeIf(Token::arrow)) {
1242     auto parseResultFn = [&]() -> LogicalResult {
1243       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1244       if (failed(result))
1245         return failure();
1246       results.emplace_back(*result);
1247       return success();
1248     };
1249 
1250     // Check for a list of results.
1251     if (consumeIf(Token::l_paren)) {
1252       do {
1253         if (failed(parseResultFn()))
1254           return failure();
1255       } while (consumeIf(Token::comma));
1256       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1257         return failure();
1258 
1259       // Otherwise, there is only one result.
1260     } else if (failed(parseResultFn())) {
1261       return failure();
1262     }
1263   }
1264   popDeclScope();
1265 
1266   // Compute the result type of the decl.
1267   resultType = createUserConstraintRewriteResultType(results);
1268 
1269   // Verify that results are only named if there are more than one.
1270   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1271     return emitError(
1272         results.front()->getLoc(),
1273         "cannot create a single-element tuple with an element label");
1274   }
1275   return success();
1276 }
1277 
1278 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1279     StringRef declType, ast::CompoundStmt *body,
1280     ArrayRef<ast::Stmt *>::iterator bodyIt,
1281     ArrayRef<ast::Stmt *>::iterator bodyE,
1282     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1283   // Handle if a `return` was provided.
1284   if (bodyIt != bodyE) {
1285     // Emit an error if we have trailing statements after the return.
1286     if (std::next(bodyIt) != bodyE) {
1287       return emitError(
1288           (*std::next(bodyIt))->getLoc(),
1289           llvm::formatv("`return` terminated the `{0}` body, but found "
1290                         "trailing statements afterwards",
1291                         declType));
1292     }
1293 
1294     // Otherwise if a return wasn't provided, check that no results are
1295     // expected.
1296   } else if (!results.empty()) {
1297     return emitError(
1298         {body->getLoc().End, body->getLoc().End},
1299         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1300                       declType, resultType));
1301   }
1302   return success();
1303 }
1304 
1305 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1306   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1307     if (isa<ast::OpRewriteStmt>(statement))
1308       return success();
1309     return emitError(
1310         statement->getLoc(),
1311         "expected Pattern lambda body to contain a single operation "
1312         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1313   });
1314 }
1315 
1316 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1317   SMRange loc = curToken.getLoc();
1318   consumeToken(Token::kw_Pattern);
1319   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1320                                               ParserContext::PatternMatch);
1321 
1322   // Check for an optional identifier for the pattern name.
1323   const ast::Name *name = nullptr;
1324   if (curToken.is(Token::identifier)) {
1325     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1326     consumeToken(Token::identifier);
1327   }
1328 
1329   // Parse any pattern metadata.
1330   ParsedPatternMetadata metadata;
1331   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1332     return failure();
1333 
1334   // Parse the pattern body.
1335   ast::CompoundStmt *body;
1336 
1337   // Handle a lambda body.
1338   if (curToken.is(Token::equal_arrow)) {
1339     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1340     if (failed(bodyResult))
1341       return failure();
1342     body = *bodyResult;
1343   } else {
1344     if (curToken.isNot(Token::l_brace))
1345       return emitError("expected `{` or `=>` to start pattern body");
1346     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1347     if (failed(bodyResult))
1348       return failure();
1349     body = *bodyResult;
1350 
1351     // Verify the body of the pattern.
1352     auto bodyIt = body->begin(), bodyE = body->end();
1353     for (; bodyIt != bodyE; ++bodyIt) {
1354       if (isa<ast::ReturnStmt>(*bodyIt)) {
1355         return emitError((*bodyIt)->getLoc(),
1356                          "`return` statements are only permitted within a "
1357                          "`Constraint` or `Rewrite` body");
1358       }
1359       // Break when we've found the rewrite statement.
1360       if (isa<ast::OpRewriteStmt>(*bodyIt))
1361         break;
1362     }
1363     if (bodyIt == bodyE) {
1364       return emitError(loc,
1365                        "expected Pattern body to terminate with an operation "
1366                        "rewrite statement, such as `erase`");
1367     }
1368     if (std::next(bodyIt) != bodyE) {
1369       return emitError((*std::next(bodyIt))->getLoc(),
1370                        "Pattern body was terminated by an operation "
1371                        "rewrite statement, but found trailing statements");
1372     }
1373   }
1374 
1375   return createPatternDecl(loc, name, metadata, body);
1376 }
1377 
1378 LogicalResult
1379 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1380   Optional<SMRange> benefitLoc;
1381   Optional<SMRange> hasBoundedRecursionLoc;
1382 
1383   do {
1384     if (curToken.isNot(Token::identifier))
1385       return emitError("expected pattern metadata identifier");
1386     StringRef metadataStr = curToken.getSpelling();
1387     SMRange metadataLoc = curToken.getLoc();
1388     consumeToken(Token::identifier);
1389 
1390     // Parse the benefit metadata: benefit(<integer-value>)
1391     if (metadataStr == "benefit") {
1392       if (benefitLoc) {
1393         return emitErrorAndNote(metadataLoc,
1394                                 "pattern benefit has already been specified",
1395                                 *benefitLoc, "see previous definition here");
1396       }
1397       if (failed(parseToken(Token::l_paren,
1398                             "expected `(` before pattern benefit")))
1399         return failure();
1400 
1401       uint16_t benefitValue = 0;
1402       if (curToken.isNot(Token::integer))
1403         return emitError("expected integral pattern benefit");
1404       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1405         return emitError(
1406             "expected pattern benefit to fit within a 16-bit integer");
1407       consumeToken(Token::integer);
1408 
1409       metadata.benefit = benefitValue;
1410       benefitLoc = metadataLoc;
1411 
1412       if (failed(
1413               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1414         return failure();
1415       continue;
1416     }
1417 
1418     // Parse the bounded recursion metadata: recursion
1419     if (metadataStr == "recursion") {
1420       if (hasBoundedRecursionLoc) {
1421         return emitErrorAndNote(
1422             metadataLoc,
1423             "pattern recursion metadata has already been specified",
1424             *hasBoundedRecursionLoc, "see previous definition here");
1425       }
1426       metadata.hasBoundedRecursion = true;
1427       hasBoundedRecursionLoc = metadataLoc;
1428       continue;
1429     }
1430 
1431     return emitError(metadataLoc, "unknown pattern metadata");
1432   } while (consumeIf(Token::comma));
1433 
1434   return success();
1435 }
1436 
1437 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1438   consumeToken(Token::less);
1439 
1440   FailureOr<ast::Expr *> typeExpr = parseExpr();
1441   if (failed(typeExpr) ||
1442       failed(parseToken(Token::greater,
1443                         "expected `>` after variable type constraint")))
1444     return failure();
1445   return typeExpr;
1446 }
1447 
1448 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1449   assert(curDeclScope && "defining decl outside of a decl scope");
1450   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1451     return emitErrorAndNote(
1452         name.getLoc(), "`" + name.getName() + "` has already been defined",
1453         lastDecl->getName()->getLoc(), "see previous definition here");
1454   }
1455   return success();
1456 }
1457 
1458 FailureOr<ast::VariableDecl *>
1459 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1460                            ast::Expr *initExpr,
1461                            ArrayRef<ast::ConstraintRef> constraints) {
1462   assert(curDeclScope && "defining variable outside of decl scope");
1463   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1464 
1465   // If the name of the variable indicates a special variable, we don't add it
1466   // to the scope. This variable is local to the definition point.
1467   if (name.empty() || name == "_") {
1468     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1469                                      constraints);
1470   }
1471   if (failed(checkDefineNamedDecl(nameDecl)))
1472     return failure();
1473 
1474   auto *varDecl =
1475       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1476   curDeclScope->add(varDecl);
1477   return varDecl;
1478 }
1479 
1480 FailureOr<ast::VariableDecl *>
1481 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1482                            ArrayRef<ast::ConstraintRef> constraints) {
1483   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1484                             constraints);
1485 }
1486 
1487 LogicalResult Parser::parseVariableDeclConstraintList(
1488     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1489   Optional<SMRange> typeConstraint;
1490   auto parseSingleConstraint = [&] {
1491     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1492         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1493     if (failed(constraint))
1494       return failure();
1495     constraints.push_back(*constraint);
1496     return success();
1497   };
1498 
1499   // Check to see if this is a single constraint, or a list.
1500   if (!consumeIf(Token::l_square))
1501     return parseSingleConstraint();
1502 
1503   do {
1504     if (failed(parseSingleConstraint()))
1505       return failure();
1506   } while (consumeIf(Token::comma));
1507   return parseToken(Token::r_square, "expected `]` after constraint list");
1508 }
1509 
1510 FailureOr<ast::ConstraintRef>
1511 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1512                         ArrayRef<ast::ConstraintRef> existingConstraints,
1513                         bool allowInlineTypeConstraints) {
1514   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1515     if (!allowInlineTypeConstraints) {
1516       return emitError(
1517           curToken.getLoc(),
1518           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1519           "permitted on arguments or results");
1520     }
1521     if (typeConstraint)
1522       return emitErrorAndNote(
1523           curToken.getLoc(),
1524           "the type of this variable has already been constrained",
1525           *typeConstraint, "see previous constraint location here");
1526     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1527     if (failed(constraintExpr))
1528       return failure();
1529     typeExpr = *constraintExpr;
1530     typeConstraint = typeExpr->getLoc();
1531     return success();
1532   };
1533 
1534   SMRange loc = curToken.getLoc();
1535   switch (curToken.getKind()) {
1536   case Token::kw_Attr: {
1537     consumeToken(Token::kw_Attr);
1538 
1539     // Check for a type constraint.
1540     ast::Expr *typeExpr = nullptr;
1541     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1542       return failure();
1543     return ast::ConstraintRef(
1544         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1545   }
1546   case Token::kw_Op: {
1547     consumeToken(Token::kw_Op);
1548 
1549     // Parse an optional operation name. If the name isn't provided, this refers
1550     // to "any" operation.
1551     FailureOr<ast::OpNameDecl *> opName =
1552         parseWrappedOperationName(/*allowEmptyName=*/true);
1553     if (failed(opName))
1554       return failure();
1555 
1556     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1557                               loc);
1558   }
1559   case Token::kw_Type:
1560     consumeToken(Token::kw_Type);
1561     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1562   case Token::kw_TypeRange:
1563     consumeToken(Token::kw_TypeRange);
1564     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1565                               loc);
1566   case Token::kw_Value: {
1567     consumeToken(Token::kw_Value);
1568 
1569     // Check for a type constraint.
1570     ast::Expr *typeExpr = nullptr;
1571     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1572       return failure();
1573 
1574     return ast::ConstraintRef(
1575         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1576   }
1577   case Token::kw_ValueRange: {
1578     consumeToken(Token::kw_ValueRange);
1579 
1580     // Check for a type constraint.
1581     ast::Expr *typeExpr = nullptr;
1582     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1583       return failure();
1584 
1585     return ast::ConstraintRef(
1586         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1587   }
1588 
1589   case Token::kw_Constraint: {
1590     // Handle an inline constraint.
1591     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1592     if (failed(decl))
1593       return failure();
1594     return ast::ConstraintRef(*decl, loc);
1595   }
1596   case Token::identifier: {
1597     StringRef constraintName = curToken.getSpelling();
1598     consumeToken(Token::identifier);
1599 
1600     // Lookup the referenced constraint.
1601     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1602     if (!cstDecl) {
1603       return emitError(loc, "unknown reference to constraint `" +
1604                                 constraintName + "`");
1605     }
1606 
1607     // Handle a reference to a proper constraint.
1608     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1609       return ast::ConstraintRef(cst, loc);
1610 
1611     return emitErrorAndNote(
1612         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1613         "see the definition of `" + constraintName + "` here");
1614   }
1615   default:
1616     break;
1617   }
1618   return emitError(loc, "expected identifier constraint");
1619 }
1620 
1621 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1622   Optional<SMRange> typeConstraint;
1623   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1624                          /*allowInlineTypeConstraints=*/false);
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // Exprs
1629 
1630 FailureOr<ast::Expr *> Parser::parseExpr() {
1631   if (curToken.is(Token::underscore))
1632     return parseUnderscoreExpr();
1633 
1634   // Parse the LHS expression.
1635   FailureOr<ast::Expr *> lhsExpr;
1636   switch (curToken.getKind()) {
1637   case Token::kw_attr:
1638     lhsExpr = parseAttributeExpr();
1639     break;
1640   case Token::kw_Constraint:
1641     lhsExpr = parseInlineConstraintLambdaExpr();
1642     break;
1643   case Token::identifier:
1644     lhsExpr = parseIdentifierExpr();
1645     break;
1646   case Token::kw_op:
1647     lhsExpr = parseOperationExpr();
1648     break;
1649   case Token::kw_Rewrite:
1650     lhsExpr = parseInlineRewriteLambdaExpr();
1651     break;
1652   case Token::kw_type:
1653     lhsExpr = parseTypeExpr();
1654     break;
1655   case Token::l_paren:
1656     lhsExpr = parseTupleExpr();
1657     break;
1658   default:
1659     return emitError("expected expression");
1660   }
1661   if (failed(lhsExpr))
1662     return failure();
1663 
1664   // Check for an operator expression.
1665   while (true) {
1666     switch (curToken.getKind()) {
1667     case Token::dot:
1668       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1669       break;
1670     case Token::l_paren:
1671       lhsExpr = parseCallExpr(*lhsExpr);
1672       break;
1673     default:
1674       return lhsExpr;
1675     }
1676     if (failed(lhsExpr))
1677       return failure();
1678   }
1679 }
1680 
1681 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1682   SMRange loc = curToken.getLoc();
1683   consumeToken(Token::kw_attr);
1684 
1685   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1686   // identifier.
1687   if (!consumeIf(Token::less)) {
1688     resetToken(loc);
1689     return parseIdentifierExpr();
1690   }
1691 
1692   if (!curToken.isString())
1693     return emitError("expected string literal containing MLIR attribute");
1694   std::string attrExpr = curToken.getStringValue();
1695   consumeToken();
1696 
1697   if (failed(
1698           parseToken(Token::greater, "expected `>` after attribute literal")))
1699     return failure();
1700   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1701 }
1702 
1703 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1704   SMRange loc = curToken.getLoc();
1705   consumeToken(Token::l_paren);
1706 
1707   // Parse the arguments of the call.
1708   SmallVector<ast::Expr *> arguments;
1709   if (curToken.isNot(Token::r_paren)) {
1710     do {
1711       FailureOr<ast::Expr *> argument = parseExpr();
1712       if (failed(argument))
1713         return failure();
1714       arguments.push_back(*argument);
1715     } while (consumeIf(Token::comma));
1716   }
1717   loc.End = curToken.getEndLoc();
1718   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1719     return failure();
1720 
1721   return createCallExpr(loc, parentExpr, arguments);
1722 }
1723 
1724 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1725   ast::Decl *decl = curDeclScope->lookup(name);
1726   if (!decl)
1727     return emitError(loc, "undefined reference to `" + name + "`");
1728 
1729   return createDeclRefExpr(loc, decl);
1730 }
1731 
1732 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1733   StringRef name = curToken.getSpelling();
1734   SMRange nameLoc = curToken.getLoc();
1735   consumeToken();
1736 
1737   // Check to see if this is a decl ref expression that defines a variable
1738   // inline.
1739   if (consumeIf(Token::colon)) {
1740     SmallVector<ast::ConstraintRef> constraints;
1741     if (failed(parseVariableDeclConstraintList(constraints)))
1742       return failure();
1743     ast::Type type;
1744     if (failed(validateVariableConstraints(constraints, type)))
1745       return failure();
1746     return createInlineVariableExpr(type, name, nameLoc, constraints);
1747   }
1748 
1749   return parseDeclRefExpr(name, nameLoc);
1750 }
1751 
1752 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1753   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1754   if (failed(decl))
1755     return failure();
1756 
1757   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1758                                   ast::ConstraintType::get(ctx));
1759 }
1760 
1761 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1762   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1763   if (failed(decl))
1764     return failure();
1765 
1766   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1767                                   ast::RewriteType::get(ctx));
1768 }
1769 
1770 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1771   SMRange loc = curToken.getLoc();
1772   consumeToken(Token::dot);
1773 
1774   // Parse the member name.
1775   Token memberNameTok = curToken;
1776   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1777       !memberNameTok.isKeyword())
1778     return emitError(loc, "expected identifier or numeric member name");
1779   StringRef memberName = memberNameTok.getSpelling();
1780   consumeToken();
1781 
1782   return createMemberAccessExpr(parentExpr, memberName, loc);
1783 }
1784 
1785 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1786   SMRange loc = curToken.getLoc();
1787 
1788   // Handle the case of an no operation name.
1789   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1790     if (allowEmptyName)
1791       return ast::OpNameDecl::create(ctx, SMRange());
1792     return emitError("expected dialect namespace");
1793   }
1794   StringRef name = curToken.getSpelling();
1795   consumeToken();
1796 
1797   // Otherwise, this is a literal operation name.
1798   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1799     return failure();
1800 
1801   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1802     return emitError("expected operation name after dialect namespace");
1803 
1804   name = StringRef(name.data(), name.size() + 1);
1805   do {
1806     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1807     loc.End = curToken.getEndLoc();
1808     consumeToken();
1809   } while (curToken.isAny(Token::identifier, Token::dot) ||
1810            curToken.isKeyword());
1811   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1812 }
1813 
1814 FailureOr<ast::OpNameDecl *>
1815 Parser::parseWrappedOperationName(bool allowEmptyName) {
1816   if (!consumeIf(Token::less))
1817     return ast::OpNameDecl::create(ctx, SMRange());
1818 
1819   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1820   if (failed(opNameDecl))
1821     return failure();
1822 
1823   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1824     return failure();
1825   return opNameDecl;
1826 }
1827 
1828 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1829   SMRange loc = curToken.getLoc();
1830   consumeToken(Token::kw_op);
1831 
1832   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1833   // identifier.
1834   if (curToken.isNot(Token::less)) {
1835     resetToken(loc);
1836     return parseIdentifierExpr();
1837   }
1838 
1839   // Parse the operation name. The name may be elided, in which case the
1840   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1841   // `Operation*`). Operation names within a rewrite context must be named.
1842   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1843   FailureOr<ast::OpNameDecl *> opNameDecl =
1844       parseWrappedOperationName(allowEmptyName);
1845   if (failed(opNameDecl))
1846     return failure();
1847 
1848   // Functor used to create an implicit range variable, used for implicit "all"
1849   // operand or results variables.
1850   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
1851     FailureOr<ast::VariableDecl *> rangeVar =
1852         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
1853     assert(succeeded(rangeVar) && "expected range variable to be valid");
1854     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
1855   };
1856 
1857   // Check for the optional list of operands.
1858   SmallVector<ast::Expr *> operands;
1859   if (!consumeIf(Token::l_paren)) {
1860     // If the operand list isn't specified and we are in a match context, define
1861     // an inplace unconstrained operand range corresponding to all of the
1862     // operands of the operation. This avoids treating zero operands the same
1863     // way as "unconstrained operands".
1864     if (parserContext != ParserContext::Rewrite) {
1865       operands.push_back(createImplicitRangeVar(
1866           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
1867     }
1868   } else if (!consumeIf(Token::r_paren)) {
1869     // If the operand list was specified and non-empty, parse the operands.
1870     do {
1871       FailureOr<ast::Expr *> operand = parseExpr();
1872       if (failed(operand))
1873         return failure();
1874       operands.push_back(*operand);
1875     } while (consumeIf(Token::comma));
1876 
1877     if (failed(parseToken(Token::r_paren,
1878                           "expected `)` after operation operand list")))
1879       return failure();
1880   }
1881 
1882   // Check for the optional list of attributes.
1883   SmallVector<ast::NamedAttributeDecl *> attributes;
1884   if (consumeIf(Token::l_brace)) {
1885     do {
1886       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1887       if (failed(decl))
1888         return failure();
1889       attributes.emplace_back(*decl);
1890     } while (consumeIf(Token::comma));
1891 
1892     if (failed(parseToken(Token::r_brace,
1893                           "expected `}` after operation attribute list")))
1894       return failure();
1895   }
1896 
1897   // Check for the optional list of result types.
1898   SmallVector<ast::Expr *> resultTypes;
1899   if (consumeIf(Token::arrow)) {
1900     if (failed(parseToken(Token::l_paren,
1901                           "expected `(` before operation result type list")))
1902       return failure();
1903 
1904     // Handle the case of an empty result list.
1905     if (!consumeIf(Token::r_paren)) {
1906       do {
1907         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1908         if (failed(resultTypeExpr))
1909           return failure();
1910         resultTypes.push_back(*resultTypeExpr);
1911       } while (consumeIf(Token::comma));
1912 
1913       if (failed(parseToken(Token::r_paren,
1914                             "expected `)` after operation result type list")))
1915         return failure();
1916     }
1917   } else if (parserContext != ParserContext::Rewrite) {
1918     // If the result list isn't specified and we are in a match context, define
1919     // an inplace unconstrained result range corresponding to all of the results
1920     // of the operation. This avoids treating zero results the same way as
1921     // "unconstrained results".
1922     resultTypes.push_back(createImplicitRangeVar(
1923         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
1924   }
1925 
1926   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1927                              resultTypes);
1928 }
1929 
1930 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1931   SMRange loc = curToken.getLoc();
1932   consumeToken(Token::l_paren);
1933 
1934   DenseMap<StringRef, SMRange> usedNames;
1935   SmallVector<StringRef> elementNames;
1936   SmallVector<ast::Expr *> elements;
1937   if (curToken.isNot(Token::r_paren)) {
1938     do {
1939       // Check for the optional element name assignment before the value.
1940       StringRef elementName;
1941       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1942         Token elementNameTok = curToken;
1943         consumeToken();
1944 
1945         // The element name is only present if followed by an `=`.
1946         if (consumeIf(Token::equal)) {
1947           elementName = elementNameTok.getSpelling();
1948 
1949           // Check to see if this name is already used.
1950           auto elementNameIt =
1951               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1952           if (!elementNameIt.second) {
1953             return emitErrorAndNote(
1954                 elementNameTok.getLoc(),
1955                 llvm::formatv("duplicate tuple element label `{0}`",
1956                               elementName),
1957                 elementNameIt.first->getSecond(),
1958                 "see previous label use here");
1959           }
1960         } else {
1961           // Otherwise, we treat this as part of an expression so reset the
1962           // lexer.
1963           resetToken(elementNameTok.getLoc());
1964         }
1965       }
1966       elementNames.push_back(elementName);
1967 
1968       // Parse the tuple element value.
1969       FailureOr<ast::Expr *> element = parseExpr();
1970       if (failed(element))
1971         return failure();
1972       elements.push_back(*element);
1973     } while (consumeIf(Token::comma));
1974   }
1975   loc.End = curToken.getEndLoc();
1976   if (failed(
1977           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1978     return failure();
1979   return createTupleExpr(loc, elements, elementNames);
1980 }
1981 
1982 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1983   SMRange loc = curToken.getLoc();
1984   consumeToken(Token::kw_type);
1985 
1986   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1987   // identifier.
1988   if (!consumeIf(Token::less)) {
1989     resetToken(loc);
1990     return parseIdentifierExpr();
1991   }
1992 
1993   if (!curToken.isString())
1994     return emitError("expected string literal containing MLIR type");
1995   std::string attrExpr = curToken.getStringValue();
1996   consumeToken();
1997 
1998   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1999     return failure();
2000   return ast::TypeExpr::create(ctx, loc, attrExpr);
2001 }
2002 
2003 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2004   StringRef name = curToken.getSpelling();
2005   SMRange nameLoc = curToken.getLoc();
2006   consumeToken(Token::underscore);
2007 
2008   // Underscore expressions require a constraint list.
2009   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2010     return failure();
2011 
2012   // Parse the constraints for the expression.
2013   SmallVector<ast::ConstraintRef> constraints;
2014   if (failed(parseVariableDeclConstraintList(constraints)))
2015     return failure();
2016 
2017   ast::Type type;
2018   if (failed(validateVariableConstraints(constraints, type)))
2019     return failure();
2020   return createInlineVariableExpr(type, name, nameLoc, constraints);
2021 }
2022 
2023 //===----------------------------------------------------------------------===//
2024 // Stmts
2025 
2026 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2027   FailureOr<ast::Stmt *> stmt;
2028   switch (curToken.getKind()) {
2029   case Token::kw_erase:
2030     stmt = parseEraseStmt();
2031     break;
2032   case Token::kw_let:
2033     stmt = parseLetStmt();
2034     break;
2035   case Token::kw_replace:
2036     stmt = parseReplaceStmt();
2037     break;
2038   case Token::kw_return:
2039     stmt = parseReturnStmt();
2040     break;
2041   case Token::kw_rewrite:
2042     stmt = parseRewriteStmt();
2043     break;
2044   default:
2045     stmt = parseExpr();
2046     break;
2047   }
2048   if (failed(stmt) ||
2049       (expectTerminalSemicolon &&
2050        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2051     return failure();
2052   return stmt;
2053 }
2054 
2055 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2056   SMLoc startLoc = curToken.getStartLoc();
2057   consumeToken(Token::l_brace);
2058 
2059   // Push a new block scope and parse any nested statements.
2060   pushDeclScope();
2061   SmallVector<ast::Stmt *> statements;
2062   while (curToken.isNot(Token::r_brace)) {
2063     FailureOr<ast::Stmt *> statement = parseStmt();
2064     if (failed(statement))
2065       return popDeclScope(), failure();
2066     statements.push_back(*statement);
2067   }
2068   popDeclScope();
2069 
2070   // Consume the end brace.
2071   SMRange location(startLoc, curToken.getEndLoc());
2072   consumeToken(Token::r_brace);
2073 
2074   return ast::CompoundStmt::create(ctx, location, statements);
2075 }
2076 
2077 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2078   if (parserContext == ParserContext::Constraint)
2079     return emitError("`erase` cannot be used within a Constraint");
2080   SMRange loc = curToken.getLoc();
2081   consumeToken(Token::kw_erase);
2082 
2083   // Parse the root operation expression.
2084   FailureOr<ast::Expr *> rootOp = parseExpr();
2085   if (failed(rootOp))
2086     return failure();
2087 
2088   return createEraseStmt(loc, *rootOp);
2089 }
2090 
2091 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2092   SMRange loc = curToken.getLoc();
2093   consumeToken(Token::kw_let);
2094 
2095   // Parse the name of the new variable.
2096   SMRange varLoc = curToken.getLoc();
2097   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2098     // `_` is a reserved variable name.
2099     if (curToken.is(Token::underscore)) {
2100       return emitError(varLoc,
2101                        "`_` may only be used to define \"inline\" variables");
2102     }
2103     return emitError(varLoc,
2104                      "expected identifier after `let` to name a new variable");
2105   }
2106   StringRef varName = curToken.getSpelling();
2107   consumeToken();
2108 
2109   // Parse the optional set of constraints.
2110   SmallVector<ast::ConstraintRef> constraints;
2111   if (consumeIf(Token::colon) &&
2112       failed(parseVariableDeclConstraintList(constraints)))
2113     return failure();
2114 
2115   // Parse the optional initializer expression.
2116   ast::Expr *initializer = nullptr;
2117   if (consumeIf(Token::equal)) {
2118     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2119     if (failed(initOrFailure))
2120       return failure();
2121     initializer = *initOrFailure;
2122 
2123     // Check that the constraints are compatible with having an initializer,
2124     // e.g. type constraints cannot be used with initializers.
2125     for (ast::ConstraintRef constraint : constraints) {
2126       LogicalResult result =
2127           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2128               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2129                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2130                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2131                   return this->emitError(
2132                       constraint.referenceLoc,
2133                       "type constraints are not permitted on variables with "
2134                       "initializers");
2135                 }
2136                 return success();
2137               })
2138               .Default(success());
2139       if (failed(result))
2140         return failure();
2141     }
2142   }
2143 
2144   FailureOr<ast::VariableDecl *> varDecl =
2145       createVariableDecl(varName, varLoc, initializer, constraints);
2146   if (failed(varDecl))
2147     return failure();
2148   return ast::LetStmt::create(ctx, loc, *varDecl);
2149 }
2150 
2151 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2152   if (parserContext == ParserContext::Constraint)
2153     return emitError("`replace` cannot be used within a Constraint");
2154   SMRange loc = curToken.getLoc();
2155   consumeToken(Token::kw_replace);
2156 
2157   // Parse the root operation expression.
2158   FailureOr<ast::Expr *> rootOp = parseExpr();
2159   if (failed(rootOp))
2160     return failure();
2161 
2162   if (failed(
2163           parseToken(Token::kw_with, "expected `with` after root operation")))
2164     return failure();
2165 
2166   // The replacement portion of this statement is within a rewrite context.
2167   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2168                                               ParserContext::Rewrite);
2169 
2170   // Parse the replacement values.
2171   SmallVector<ast::Expr *> replValues;
2172   if (consumeIf(Token::l_paren)) {
2173     if (consumeIf(Token::r_paren)) {
2174       return emitError(
2175           loc, "expected at least one replacement value, consider using "
2176                "`erase` if no replacement values are desired");
2177     }
2178 
2179     do {
2180       FailureOr<ast::Expr *> replExpr = parseExpr();
2181       if (failed(replExpr))
2182         return failure();
2183       replValues.emplace_back(*replExpr);
2184     } while (consumeIf(Token::comma));
2185 
2186     if (failed(parseToken(Token::r_paren,
2187                           "expected `)` after replacement values")))
2188       return failure();
2189   } else {
2190     FailureOr<ast::Expr *> replExpr = parseExpr();
2191     if (failed(replExpr))
2192       return failure();
2193     replValues.emplace_back(*replExpr);
2194   }
2195 
2196   return createReplaceStmt(loc, *rootOp, replValues);
2197 }
2198 
2199 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2200   SMRange loc = curToken.getLoc();
2201   consumeToken(Token::kw_return);
2202 
2203   // Parse the result value.
2204   FailureOr<ast::Expr *> resultExpr = parseExpr();
2205   if (failed(resultExpr))
2206     return failure();
2207 
2208   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2209 }
2210 
2211 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2212   if (parserContext == ParserContext::Constraint)
2213     return emitError("`rewrite` cannot be used within a Constraint");
2214   SMRange loc = curToken.getLoc();
2215   consumeToken(Token::kw_rewrite);
2216 
2217   // Parse the root operation.
2218   FailureOr<ast::Expr *> rootOp = parseExpr();
2219   if (failed(rootOp))
2220     return failure();
2221 
2222   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2223     return failure();
2224 
2225   if (curToken.isNot(Token::l_brace))
2226     return emitError("expected `{` to start rewrite body");
2227 
2228   // The rewrite body of this statement is within a rewrite context.
2229   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2230                                               ParserContext::Rewrite);
2231 
2232   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2233   if (failed(rewriteBody))
2234     return failure();
2235 
2236   // Verify the rewrite body.
2237   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2238     if (isa<ast::ReturnStmt>(stmt)) {
2239       return emitError(stmt->getLoc(),
2240                        "`return` statements are only permitted within a "
2241                        "`Constraint` or `Rewrite` body");
2242     }
2243   }
2244 
2245   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2246 }
2247 
2248 //===----------------------------------------------------------------------===//
2249 // Creation+Analysis
2250 //===----------------------------------------------------------------------===//
2251 
2252 //===----------------------------------------------------------------------===//
2253 // Decls
2254 
2255 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2256   // Unwrap reference expressions.
2257   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2258     node = init->getDecl();
2259   return dyn_cast<ast::CallableDecl>(node);
2260 }
2261 
2262 FailureOr<ast::PatternDecl *>
2263 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2264                           const ParsedPatternMetadata &metadata,
2265                           ast::CompoundStmt *body) {
2266   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2267                                   metadata.hasBoundedRecursion, body);
2268 }
2269 
2270 ast::Type Parser::createUserConstraintRewriteResultType(
2271     ArrayRef<ast::VariableDecl *> results) {
2272   // Single result decls use the type of the single result.
2273   if (results.size() == 1)
2274     return results[0]->getType();
2275 
2276   // Multiple results use a tuple type, with the types and names grabbed from
2277   // the result variable decls.
2278   auto resultTypes = llvm::map_range(
2279       results, [&](const auto *result) { return result->getType(); });
2280   auto resultNames = llvm::map_range(
2281       results, [&](const auto *result) { return result->getName().getName(); });
2282   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2283                              llvm::to_vector(resultNames));
2284 }
2285 
2286 template <typename T>
2287 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2288     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2289     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2290     ast::CompoundStmt *body) {
2291   if (!body->getChildren().empty()) {
2292     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2293       ast::Expr *resultExpr = retStmt->getResultExpr();
2294 
2295       // Process the result of the decl. If no explicit signature results
2296       // were provided, check for return type inference. Otherwise, check that
2297       // the return expression can be converted to the expected type.
2298       if (results.empty())
2299         resultType = resultExpr->getType();
2300       else if (failed(convertExpressionTo(resultExpr, resultType)))
2301         return failure();
2302       else
2303         retStmt->setResultExpr(resultExpr);
2304     }
2305   }
2306   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2307 }
2308 
2309 FailureOr<ast::VariableDecl *>
2310 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2311                            ArrayRef<ast::ConstraintRef> constraints) {
2312   // The type of the variable, which is expected to be inferred by either a
2313   // constraint or an initializer expression.
2314   ast::Type type;
2315   if (failed(validateVariableConstraints(constraints, type)))
2316     return failure();
2317 
2318   if (initializer) {
2319     // Update the variable type based on the initializer, or try to convert the
2320     // initializer to the existing type.
2321     if (!type)
2322       type = initializer->getType();
2323     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2324       type = mergedType;
2325     else if (failed(convertExpressionTo(initializer, type)))
2326       return failure();
2327 
2328     // Otherwise, if there is no initializer check that the type has already
2329     // been resolved from the constraint list.
2330   } else if (!type) {
2331     return emitErrorAndNote(
2332         loc, "unable to infer type for variable `" + name + "`", loc,
2333         "the type of a variable must be inferable from the constraint "
2334         "list or the initializer");
2335   }
2336 
2337   // Constraint types cannot be used when defining variables.
2338   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2339     return emitError(
2340         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2341   }
2342 
2343   // Try to define a variable with the given name.
2344   FailureOr<ast::VariableDecl *> varDecl =
2345       defineVariableDecl(name, loc, type, initializer, constraints);
2346   if (failed(varDecl))
2347     return failure();
2348 
2349   return *varDecl;
2350 }
2351 
2352 FailureOr<ast::VariableDecl *>
2353 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2354                                       const ast::ConstraintRef &constraint) {
2355   // Constraint arguments may apply more complex constraints via the arguments.
2356   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2357   ast::Type argType;
2358   if (failed(validateVariableConstraint(constraint, argType,
2359                                         allowNonCoreConstraints)))
2360     return failure();
2361   return defineVariableDecl(name, loc, argType, constraint);
2362 }
2363 
2364 LogicalResult
2365 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2366                                     ast::Type &inferredType) {
2367   for (const ast::ConstraintRef &ref : constraints)
2368     if (failed(validateVariableConstraint(ref, inferredType)))
2369       return failure();
2370   return success();
2371 }
2372 
2373 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2374                                                  ast::Type &inferredType,
2375                                                  bool allowNonCoreConstraints) {
2376   ast::Type constraintType;
2377   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2378     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2379       if (failed(validateTypeConstraintExpr(typeExpr)))
2380         return failure();
2381     }
2382     constraintType = ast::AttributeType::get(ctx);
2383   } else if (const auto *cst =
2384                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2385     constraintType = ast::OperationType::get(ctx, cst->getName());
2386   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2387     constraintType = typeTy;
2388   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2389     constraintType = typeRangeTy;
2390   } else if (const auto *cst =
2391                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2392     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2393       if (failed(validateTypeConstraintExpr(typeExpr)))
2394         return failure();
2395     }
2396     constraintType = valueTy;
2397   } else if (const auto *cst =
2398                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2399     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2400       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2401         return failure();
2402     }
2403     constraintType = valueRangeTy;
2404   } else if (const auto *cst =
2405                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2406     if (!allowNonCoreConstraints) {
2407       return emitError(ref.referenceLoc,
2408                        "`Rewrite` arguments and results are only permitted to "
2409                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2410                        "`TypeRange`, `Value`, `ValueRange`");
2411     }
2412 
2413     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2414     if (inputs.size() != 1) {
2415       return emitErrorAndNote(ref.referenceLoc,
2416                               "`Constraint`s applied via a variable constraint "
2417                               "list must take a single input, but got " +
2418                                   Twine(inputs.size()),
2419                               cst->getLoc(),
2420                               "see definition of constraint here");
2421     }
2422     constraintType = inputs.front()->getType();
2423   } else {
2424     llvm_unreachable("unknown constraint type");
2425   }
2426 
2427   // Check that the constraint type is compatible with the current inferred
2428   // type.
2429   if (!inferredType) {
2430     inferredType = constraintType;
2431   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2432     inferredType = mergedTy;
2433   } else {
2434     return emitError(ref.referenceLoc,
2435                      llvm::formatv("constraint type `{0}` is incompatible "
2436                                    "with the previously inferred type `{1}`",
2437                                    constraintType, inferredType));
2438   }
2439   return success();
2440 }
2441 
2442 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2443   ast::Type typeExprType = typeExpr->getType();
2444   if (typeExprType != typeTy) {
2445     return emitError(typeExpr->getLoc(),
2446                      "expected expression of `Type` in type constraint");
2447   }
2448   return success();
2449 }
2450 
2451 LogicalResult
2452 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2453   ast::Type typeExprType = typeExpr->getType();
2454   if (typeExprType != typeRangeTy) {
2455     return emitError(typeExpr->getLoc(),
2456                      "expected expression of `TypeRange` in type constraint");
2457   }
2458   return success();
2459 }
2460 
2461 //===----------------------------------------------------------------------===//
2462 // Exprs
2463 
2464 FailureOr<ast::CallExpr *>
2465 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2466                        MutableArrayRef<ast::Expr *> arguments) {
2467   ast::Type parentType = parentExpr->getType();
2468 
2469   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2470   if (!callableDecl) {
2471     return emitError(loc,
2472                      llvm::formatv("expected a reference to a callable "
2473                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2474                                    parentType));
2475   }
2476   if (parserContext == ParserContext::Rewrite) {
2477     if (isa<ast::UserConstraintDecl>(callableDecl))
2478       return emitError(
2479           loc, "unable to invoke `Constraint` within a rewrite section");
2480   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2481     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2482   }
2483 
2484   // Verify the arguments of the call.
2485   /// Handle size mismatch.
2486   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2487   if (callArgs.size() != arguments.size()) {
2488     return emitErrorAndNote(
2489         loc,
2490         llvm::formatv("invalid number of arguments for {0} call; expected "
2491                       "{1}, but got {2}",
2492                       callableDecl->getCallableType(), callArgs.size(),
2493                       arguments.size()),
2494         callableDecl->getLoc(),
2495         llvm::formatv("see the definition of {0} here",
2496                       callableDecl->getName()->getName()));
2497   }
2498 
2499   /// Handle argument type mismatch.
2500   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2501     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2502                                   callableDecl->getName()->getName()),
2503                     callableDecl->getLoc());
2504   };
2505   for (auto it : llvm::zip(callArgs, arguments)) {
2506     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2507                                    attachDiagFn)))
2508       return failure();
2509   }
2510 
2511   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2512                                callableDecl->getResultType());
2513 }
2514 
2515 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2516                                                         ast::Decl *decl) {
2517   // Check the type of decl being referenced.
2518   ast::Type declType;
2519   if (isa<ast::ConstraintDecl>(decl))
2520     declType = ast::ConstraintType::get(ctx);
2521   else if (isa<ast::UserRewriteDecl>(decl))
2522     declType = ast::RewriteType::get(ctx);
2523   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2524     declType = varDecl->getType();
2525   else
2526     return emitError(loc, "invalid reference to `" +
2527                               decl->getName()->getName() + "`");
2528 
2529   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2530 }
2531 
2532 FailureOr<ast::DeclRefExpr *>
2533 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2534                                  ArrayRef<ast::ConstraintRef> constraints) {
2535   FailureOr<ast::VariableDecl *> decl =
2536       defineVariableDecl(name, loc, type, constraints);
2537   if (failed(decl))
2538     return failure();
2539   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2540 }
2541 
2542 FailureOr<ast::MemberAccessExpr *>
2543 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2544                                SMRange loc) {
2545   // Validate the member name for the given parent expression.
2546   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2547   if (failed(memberType))
2548     return failure();
2549 
2550   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2551 }
2552 
2553 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2554                                                   StringRef name, SMRange loc) {
2555   ast::Type parentType = parentExpr->getType();
2556   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2557     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2558       return valueRangeTy;
2559 
2560     // Verify member access based on the operation type.
2561     if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
2562       auto results = odsOp->getResults();
2563 
2564       // Handle indexed results.
2565       unsigned index = 0;
2566       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2567           index < results.size()) {
2568         return results[index].isVariadic() ? valueRangeTy : valueTy;
2569       }
2570 
2571       // Handle named results.
2572       const auto *it = llvm::find_if(results, [&](const auto &result) {
2573         return result.getName() == name;
2574       });
2575       if (it != results.end())
2576         return it->isVariadic() ? valueRangeTy : valueTy;
2577     }
2578 
2579   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2580     // Handle indexed results.
2581     unsigned index = 0;
2582     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2583         index < tupleType.size()) {
2584       return tupleType.getElementTypes()[index];
2585     }
2586 
2587     // Handle named results.
2588     auto elementNames = tupleType.getElementNames();
2589     const auto *it = llvm::find(elementNames, name);
2590     if (it != elementNames.end())
2591       return tupleType.getElementTypes()[it - elementNames.begin()];
2592   }
2593   return emitError(
2594       loc,
2595       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2596                     name, parentType));
2597 }
2598 
2599 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2600     SMRange loc, const ast::OpNameDecl *name,
2601     MutableArrayRef<ast::Expr *> operands,
2602     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2603     MutableArrayRef<ast::Expr *> results) {
2604   Optional<StringRef> opNameRef = name->getName();
2605   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2606 
2607   // Verify the inputs operands.
2608   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2609     return failure();
2610 
2611   // Verify the attribute list.
2612   for (ast::NamedAttributeDecl *attr : attributes) {
2613     // Check for an attribute type, or a type awaiting resolution.
2614     ast::Type attrType = attr->getValue()->getType();
2615     if (!attrType.isa<ast::AttributeType>()) {
2616       return emitError(
2617           attr->getValue()->getLoc(),
2618           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2619     }
2620   }
2621 
2622   // Verify the result types.
2623   if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2624     return failure();
2625 
2626   return ast::OperationExpr::create(ctx, loc, name, operands, results,
2627                                     attributes);
2628 }
2629 
2630 LogicalResult
2631 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2632                                   const ods::Operation *odsOp,
2633                                   MutableArrayRef<ast::Expr *> operands) {
2634   return validateOperationOperandsOrResults(
2635       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2636       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2637       valueRangeTy);
2638 }
2639 
2640 LogicalResult
2641 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2642                                  const ods::Operation *odsOp,
2643                                  MutableArrayRef<ast::Expr *> results) {
2644   return validateOperationOperandsOrResults(
2645       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2646       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2647 }
2648 
2649 LogicalResult Parser::validateOperationOperandsOrResults(
2650     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2651     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2652     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2653     ast::Type rangeTy) {
2654   // All operation types accept a single range parameter.
2655   if (values.size() == 1) {
2656     if (failed(convertExpressionTo(values[0], rangeTy)))
2657       return failure();
2658     return success();
2659   }
2660 
2661   /// If the operation has ODS information, we can more accurately verify the
2662   /// values.
2663   if (odsOpLoc) {
2664     if (odsValues.size() != values.size()) {
2665       return emitErrorAndNote(
2666           loc,
2667           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2668                         "{2}, but got {3}",
2669                         groupName, *name, odsValues.size(), values.size()),
2670           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2671     }
2672     auto diagFn = [&](ast::Diagnostic &diag) {
2673       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2674                       *odsOpLoc);
2675     };
2676     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2677       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2678       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2679         return failure();
2680     }
2681     return success();
2682   }
2683 
2684   // Otherwise, accept the value groups as they have been defined and just
2685   // ensure they are one of the expected types.
2686   for (ast::Expr *&valueExpr : values) {
2687     ast::Type valueExprType = valueExpr->getType();
2688 
2689     // Check if this is one of the expected types.
2690     if (valueExprType == rangeTy || valueExprType == singleTy)
2691       continue;
2692 
2693     // If the operand is an Operation, allow converting to a Value or
2694     // ValueRange. This situations arises quite often with nested operation
2695     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2696     if (singleTy == valueTy) {
2697       if (valueExprType.isa<ast::OperationType>()) {
2698         valueExpr = convertOpToValue(valueExpr);
2699         continue;
2700       }
2701     }
2702 
2703     return emitError(
2704         valueExpr->getLoc(),
2705         llvm::formatv(
2706             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2707             singleTy, rangeTy, valueExprType));
2708   }
2709   return success();
2710 }
2711 
2712 FailureOr<ast::TupleExpr *>
2713 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2714                         ArrayRef<StringRef> elementNames) {
2715   for (const ast::Expr *element : elements) {
2716     ast::Type eleTy = element->getType();
2717     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2718       return emitError(
2719           element->getLoc(),
2720           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2721     }
2722   }
2723   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2724 }
2725 
2726 //===----------------------------------------------------------------------===//
2727 // Stmts
2728 
2729 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2730                                                     ast::Expr *rootOp) {
2731   // Check that root is an Operation.
2732   ast::Type rootType = rootOp->getType();
2733   if (!rootType.isa<ast::OperationType>())
2734     return emitError(rootOp->getLoc(), "expected `Op` expression");
2735 
2736   return ast::EraseStmt::create(ctx, loc, rootOp);
2737 }
2738 
2739 FailureOr<ast::ReplaceStmt *>
2740 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2741                           MutableArrayRef<ast::Expr *> replValues) {
2742   // Check that root is an Operation.
2743   ast::Type rootType = rootOp->getType();
2744   if (!rootType.isa<ast::OperationType>()) {
2745     return emitError(
2746         rootOp->getLoc(),
2747         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2748   }
2749 
2750   // If there are multiple replacement values, we implicitly convert any Op
2751   // expressions to the value form.
2752   bool shouldConvertOpToValues = replValues.size() > 1;
2753   for (ast::Expr *&replExpr : replValues) {
2754     ast::Type replType = replExpr->getType();
2755 
2756     // Check that replExpr is an Operation, Value, or ValueRange.
2757     if (replType.isa<ast::OperationType>()) {
2758       if (shouldConvertOpToValues)
2759         replExpr = convertOpToValue(replExpr);
2760       continue;
2761     }
2762 
2763     if (replType != valueTy && replType != valueRangeTy) {
2764       return emitError(replExpr->getLoc(),
2765                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
2766                                      "expression, but got `{0}`",
2767                                      replType));
2768     }
2769   }
2770 
2771   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
2772 }
2773 
2774 FailureOr<ast::RewriteStmt *>
2775 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
2776                           ast::CompoundStmt *rewriteBody) {
2777   // Check that root is an Operation.
2778   ast::Type rootType = rootOp->getType();
2779   if (!rootType.isa<ast::OperationType>()) {
2780     return emitError(
2781         rootOp->getLoc(),
2782         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
2783   }
2784 
2785   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
2786 }
2787 
2788 //===----------------------------------------------------------------------===//
2789 // Parser
2790 //===----------------------------------------------------------------------===//
2791 
2792 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
2793                                                  llvm::SourceMgr &sourceMgr) {
2794   Parser parser(ctx, sourceMgr);
2795   return parser.parseModule();
2796 }
2797