1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/LogicalResult.h"
12 #include "mlir/Tools/PDLL/AST/Context.h"
13 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
14 #include "mlir/Tools/PDLL/AST/Nodes.h"
15 #include "mlir/Tools/PDLL/AST/Types.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/SaveAndRestore.h"
20 #include "llvm/Support/ScopedPrinter.h"
21 #include <string>
22 
23 using namespace mlir;
24 using namespace mlir::pdll;
25 
26 //===----------------------------------------------------------------------===//
27 // Parser
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 class Parser {
32 public:
33   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
34       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
35         curToken(lexer.lexToken()), curDeclScope(nullptr),
36         valueTy(ast::ValueType::get(ctx)),
37         valueRangeTy(ast::ValueRangeType::get(ctx)),
38         typeTy(ast::TypeType::get(ctx)),
39         typeRangeTy(ast::TypeRangeType::get(ctx)) {}
40 
41   /// Try to parse a new module. Returns nullptr in the case of failure.
42   FailureOr<ast::Module *> parseModule();
43 
44 private:
45   /// The current context of the parser. It allows for the parser to know a bit
46   /// about the construct it is nested within during parsing. This is used
47   /// specifically to provide additional verification during parsing, e.g. to
48   /// prevent using rewrites within a match context, matcher constraints within
49   /// a rewrite section, etc.
50   enum class ParserContext {
51     /// The parser is in the global context.
52     Global,
53     /// The parser is currently within the matcher portion of a Pattern, which
54     /// is allows a terminal operation rewrite statement but no other rewrite
55     /// transformations.
56     PatternMatch,
57     /// The parser is currently within a Rewrite, which disallows calls to
58     /// constraints, requires operation expressions to have names, etc.
59     Rewrite,
60   };
61 
62   //===--------------------------------------------------------------------===//
63   // Parsing
64   //===--------------------------------------------------------------------===//
65 
66   /// Push a new decl scope onto the lexer.
67   ast::DeclScope *pushDeclScope() {
68     ast::DeclScope *newScope =
69         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
70     return (curDeclScope = newScope);
71   }
72   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
73 
74   /// Pop the last decl scope from the lexer.
75   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
76 
77   /// Parse the body of an AST module.
78   LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
79 
80   /// Try to convert the given expression to `type`. Returns failure and emits
81   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
82   /// invoked to attach notes to the emitted error diagnostic. On success,
83   /// `expr` is updated to the expression used to convert to `type`.
84   LogicalResult convertExpressionTo(
85       ast::Expr *&expr, ast::Type type,
86       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
87 
88   /// Given an operation expression, convert it to a Value or ValueRange
89   /// typed expression.
90   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
91 
92   //===--------------------------------------------------------------------===//
93   // Directives
94 
95   LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
96   LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
97 
98   //===--------------------------------------------------------------------===//
99   // Decls
100 
101   /// This structure contains the set of pattern metadata that may be parsed.
102   struct ParsedPatternMetadata {
103     Optional<uint16_t> benefit;
104     bool hasBoundedRecursion = false;
105   };
106 
107   FailureOr<ast::Decl *> parseTopLevelDecl();
108   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
109   FailureOr<ast::CompoundStmt *>
110   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
111                   bool expectTerminalSemicolon = true);
112   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
113   FailureOr<ast::Decl *> parsePatternDecl();
114   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
115 
116   /// Check to see if a decl has already been defined with the given name, if
117   /// one has emit and error and return failure. Returns success otherwise.
118   LogicalResult checkDefineNamedDecl(const ast::Name &name);
119 
120   /// Try to define a variable decl with the given components, returns the
121   /// variable on success.
122   FailureOr<ast::VariableDecl *>
123   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
124                      ast::Expr *initExpr,
125                      ArrayRef<ast::ConstraintRef> constraints);
126   FailureOr<ast::VariableDecl *>
127   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
128                      ArrayRef<ast::ConstraintRef> constraints);
129 
130   /// Parse the constraint reference list for a variable decl.
131   LogicalResult parseVariableDeclConstraintList(
132       SmallVectorImpl<ast::ConstraintRef> &constraints);
133 
134   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
135   FailureOr<ast::Expr *> parseTypeConstraintExpr();
136 
137   /// Try to parse a single reference to a constraint. `typeConstraint` is the
138   /// location of a previously parsed type constraint for the entity that will
139   /// be constrained by the parsed constraint. `existingConstraints` are any
140   /// existing constraints that have already been parsed for the same entity
141   /// that will be constrained by this constraint.
142   FailureOr<ast::ConstraintRef>
143   parseConstraint(Optional<SMRange> &typeConstraint,
144                   ArrayRef<ast::ConstraintRef> existingConstraints);
145 
146   //===--------------------------------------------------------------------===//
147   // Exprs
148 
149   FailureOr<ast::Expr *> parseExpr();
150 
151   /// Identifier expressions.
152   FailureOr<ast::Expr *> parseAttributeExpr();
153   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
154   FailureOr<ast::Expr *> parseIdentifierExpr();
155   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
156   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
157   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
158   FailureOr<ast::Expr *> parseOperationExpr();
159   FailureOr<ast::Expr *> parseTupleExpr();
160   FailureOr<ast::Expr *> parseTypeExpr();
161   FailureOr<ast::Expr *> parseUnderscoreExpr();
162 
163   //===--------------------------------------------------------------------===//
164   // Stmts
165 
166   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
167   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
168   FailureOr<ast::EraseStmt *> parseEraseStmt();
169   FailureOr<ast::LetStmt *> parseLetStmt();
170   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
171   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
172 
173   //===--------------------------------------------------------------------===//
174   // Creation+Analysis
175   //===--------------------------------------------------------------------===//
176 
177   //===--------------------------------------------------------------------===//
178   // Decls
179 
180   /// Try to create a pattern decl with the given components, returning the
181   /// Pattern on success.
182   FailureOr<ast::PatternDecl *>
183   createPatternDecl(SMRange loc, const ast::Name *name,
184                     const ParsedPatternMetadata &metadata,
185                     ast::CompoundStmt *body);
186 
187   /// Try to create a variable decl with the given components, returning the
188   /// Variable on success.
189   FailureOr<ast::VariableDecl *>
190   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
191                      ArrayRef<ast::ConstraintRef> constraints);
192 
193   /// Validate the constraints used to constraint a variable decl.
194   /// `inferredType` is the type of the variable inferred by the constraints
195   /// within the list, and is updated to the most refined type as determined by
196   /// the constraints. Returns success if the constraint list is valid, failure
197   /// otherwise.
198   LogicalResult
199   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
200                               ast::Type &inferredType);
201   /// Validate a single reference to a constraint. `inferredType` contains the
202   /// currently inferred variabled type and is refined within the type defined
203   /// by the constraint. Returns success if the constraint is valid, failure
204   /// otherwise.
205   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
206                                            ast::Type &inferredType);
207   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
208   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
209 
210   //===--------------------------------------------------------------------===//
211   // Exprs
212 
213   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
214                                                   ast::Decl *decl);
215   FailureOr<ast::DeclRefExpr *>
216   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
217                            ArrayRef<ast::ConstraintRef> constraints);
218   FailureOr<ast::MemberAccessExpr *>
219   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
220                          SMRange loc);
221 
222   /// Validate the member access `name` into the given parent expression. On
223   /// success, this also returns the type of the member accessed.
224   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
225                                             StringRef name, SMRange loc);
226   FailureOr<ast::OperationExpr *>
227   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
228                       MutableArrayRef<ast::Expr *> operands,
229                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
230                       MutableArrayRef<ast::Expr *> results);
231   LogicalResult
232   validateOperationOperands(SMRange loc, Optional<StringRef> name,
233                             MutableArrayRef<ast::Expr *> operands);
234   LogicalResult validateOperationResults(SMRange loc,
235                                          Optional<StringRef> name,
236                                          MutableArrayRef<ast::Expr *> results);
237   LogicalResult
238   validateOperationOperandsOrResults(SMRange loc,
239                                      Optional<StringRef> name,
240                                      MutableArrayRef<ast::Expr *> values,
241                                      ast::Type singleTy, ast::Type rangeTy);
242   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
243                                               ArrayRef<ast::Expr *> elements,
244                                               ArrayRef<StringRef> elementNames);
245 
246   //===--------------------------------------------------------------------===//
247   // Stmts
248 
249   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
250                                               ast::Expr *rootOp);
251   FailureOr<ast::ReplaceStmt *>
252   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
253                     MutableArrayRef<ast::Expr *> replValues);
254   FailureOr<ast::RewriteStmt *>
255   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
256                     ast::CompoundStmt *rewriteBody);
257 
258   //===--------------------------------------------------------------------===//
259   // Lexer Utilities
260   //===--------------------------------------------------------------------===//
261 
262   /// If the current token has the specified kind, consume it and return true.
263   /// If not, return false.
264   bool consumeIf(Token::Kind kind) {
265     if (curToken.isNot(kind))
266       return false;
267     consumeToken(kind);
268     return true;
269   }
270 
271   /// Advance the current lexer onto the next token.
272   void consumeToken() {
273     assert(curToken.isNot(Token::eof, Token::error) &&
274            "shouldn't advance past EOF or errors");
275     curToken = lexer.lexToken();
276   }
277 
278   /// Advance the current lexer onto the next token, asserting what the expected
279   /// current token is. This is preferred to the above method because it leads
280   /// to more self-documenting code with better checking.
281   void consumeToken(Token::Kind kind) {
282     assert(curToken.is(kind) && "consumed an unexpected token");
283     consumeToken();
284   }
285 
286   /// Reset the lexer to the location at the given position.
287   void resetToken(SMRange tokLoc) {
288     lexer.resetPointer(tokLoc.Start.getPointer());
289     curToken = lexer.lexToken();
290   }
291 
292   /// Consume the specified token if present and return success. On failure,
293   /// output a diagnostic and return failure.
294   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
295     if (curToken.getKind() != kind)
296       return emitError(curToken.getLoc(), msg);
297     consumeToken();
298     return success();
299   }
300   LogicalResult emitError(SMRange loc, const Twine &msg) {
301     lexer.emitError(loc, msg);
302     return failure();
303   }
304   LogicalResult emitError(const Twine &msg) {
305     return emitError(curToken.getLoc(), msg);
306   }
307   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg,
308                                  SMRange noteLoc, const Twine &note) {
309     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
310     return failure();
311   }
312 
313   //===--------------------------------------------------------------------===//
314   // Fields
315   //===--------------------------------------------------------------------===//
316 
317   /// The owning AST context.
318   ast::Context &ctx;
319 
320   /// The lexer of this parser.
321   Lexer lexer;
322 
323   /// The current token within the lexer.
324   Token curToken;
325 
326   /// The most recently defined decl scope.
327   ast::DeclScope *curDeclScope;
328   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
329 
330   /// The current context of the parser.
331   ParserContext parserContext = ParserContext::Global;
332 
333   /// Cached types to simplify verification and expression creation.
334   ast::Type valueTy, valueRangeTy;
335   ast::Type typeTy, typeRangeTy;
336 };
337 } // namespace
338 
339 FailureOr<ast::Module *> Parser::parseModule() {
340   SMLoc moduleLoc = curToken.getStartLoc();
341   pushDeclScope();
342 
343   // Parse the top-level decls of the module.
344   SmallVector<ast::Decl *> decls;
345   if (failed(parseModuleBody(decls)))
346     return popDeclScope(), failure();
347 
348   popDeclScope();
349   return ast::Module::create(ctx, moduleLoc, decls);
350 }
351 
352 LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
353   while (curToken.isNot(Token::eof)) {
354     if (curToken.is(Token::directive)) {
355       if (failed(parseDirective(decls)))
356         return failure();
357       continue;
358     }
359 
360     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
361     if (failed(decl))
362       return failure();
363     decls.push_back(*decl);
364   }
365   return success();
366 }
367 
368 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
369   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
370                                                  valueRangeTy);
371 }
372 
373 LogicalResult Parser::convertExpressionTo(
374     ast::Expr *&expr, ast::Type type,
375     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
376   ast::Type exprType = expr->getType();
377   if (exprType == type)
378     return success();
379 
380   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
381     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
382         expr->getLoc(), llvm::formatv("unable to convert expression of type "
383                                       "`{0}` to the expected type of "
384                                       "`{1}`",
385                                       exprType, type));
386     if (noteAttachFn)
387       noteAttachFn(*diag);
388     return diag;
389   };
390 
391   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
392     // Two operation types are compatible if they have the same name, or if the
393     // expected type is more general.
394     if (auto opType = type.dyn_cast<ast::OperationType>()) {
395       if (opType.getName())
396         return emitConvertError();
397       return success();
398     }
399 
400     // An operation can always convert to a ValueRange.
401     if (type == valueRangeTy) {
402       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
403                                                      valueRangeTy);
404       return success();
405     }
406 
407     // Allow conversion to a single value by constraining the result range.
408     if (type == valueTy) {
409       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
410                                                      valueTy);
411       return success();
412     }
413     return emitConvertError();
414   }
415 
416   // FIXME: Decide how to allow/support converting a single result to multiple,
417   // and multiple to a single result. For now, we just allow Single->Range,
418   // but this isn't something really supported in the PDL dialect. We should
419   // figure out some way to support both.
420   if ((exprType == valueTy || exprType == valueRangeTy) &&
421       (type == valueTy || type == valueRangeTy))
422     return success();
423   if ((exprType == typeTy || exprType == typeRangeTy) &&
424       (type == typeTy || type == typeRangeTy))
425     return success();
426 
427   // Handle tuple types.
428   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
429     auto tupleType = type.dyn_cast<ast::TupleType>();
430     if (!tupleType || tupleType.size() != exprTupleType.size())
431       return emitConvertError();
432 
433     // Build a new tuple expression using each of the elements of the current
434     // tuple.
435     SmallVector<ast::Expr *> newExprs;
436     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
437       newExprs.push_back(ast::MemberAccessExpr::create(
438           ctx, expr->getLoc(), expr, llvm::to_string(i),
439           exprTupleType.getElementTypes()[i]));
440 
441       auto diagFn = [&](ast::Diagnostic &diag) {
442         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
443                                       i, exprTupleType));
444         if (noteAttachFn)
445           noteAttachFn(diag);
446       };
447       if (failed(convertExpressionTo(newExprs.back(),
448                                      tupleType.getElementTypes()[i], diagFn)))
449         return failure();
450     }
451     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
452                                   tupleType.getElementNames());
453     return success();
454   }
455 
456   return emitConvertError();
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // Directives
461 
462 LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
463   StringRef directive = curToken.getSpelling();
464   if (directive == "#include")
465     return parseInclude(decls);
466 
467   return emitError("unknown directive `" + directive + "`");
468 }
469 
470 LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
471   SMRange loc = curToken.getLoc();
472   consumeToken(Token::directive);
473 
474   // Parse the file being included.
475   if (!curToken.isString())
476     return emitError(loc,
477                      "expected string file name after `include` directive");
478   SMRange fileLoc = curToken.getLoc();
479   std::string filenameStr = curToken.getStringValue();
480   StringRef filename = filenameStr;
481   consumeToken();
482 
483   // Check the type of include. If ending with `.pdll`, this is another pdl file
484   // to be parsed along with the current module.
485   if (filename.endswith(".pdll")) {
486     if (failed(lexer.pushInclude(filename)))
487       return emitError(fileLoc,
488                        "unable to open include file `" + filename + "`");
489 
490     // If we added the include successfully, parse it into the current module.
491     // Make sure to save the current token so that we can restore it when we
492     // finish parsing the nested file.
493     Token oldToken = curToken;
494     curToken = lexer.lexToken();
495     LogicalResult result = parseModuleBody(decls);
496     curToken = oldToken;
497     return result;
498   }
499 
500   return emitError(fileLoc, "expected include filename to end with `.pdll`");
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // Decls
505 
506 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
507   FailureOr<ast::Decl *> decl;
508   switch (curToken.getKind()) {
509   case Token::kw_Pattern:
510     decl = parsePatternDecl();
511     break;
512   default:
513     return emitError("expected top-level declaration, such as a `Pattern`");
514   }
515   if (failed(decl))
516     return failure();
517 
518   // If the decl has a name, add it to the current scope.
519   if (const ast::Name *name = (*decl)->getName()) {
520     if (failed(checkDefineNamedDecl(*name)))
521       return failure();
522     curDeclScope->add(*decl);
523   }
524   return decl;
525 }
526 
527 FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
528   std::string attrNameStr;
529   if (curToken.isString())
530     attrNameStr = curToken.getStringValue();
531   else if (curToken.is(Token::identifier) || curToken.isKeyword())
532     attrNameStr = curToken.getSpelling().str();
533   else
534     return emitError("expected identifier or string attribute name");
535   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
536   consumeToken();
537 
538   // Check for a value of the attribute.
539   ast::Expr *attrValue = nullptr;
540   if (consumeIf(Token::equal)) {
541     FailureOr<ast::Expr *> attrExpr = parseExpr();
542     if (failed(attrExpr))
543       return failure();
544     attrValue = *attrExpr;
545   } else {
546     // If there isn't a concrete value, create an expression representing a
547     // UnitAttr.
548     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
549   }
550 
551   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
552 }
553 
554 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
555     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
556     bool expectTerminalSemicolon) {
557   consumeToken(Token::equal_arrow);
558 
559   // Parse the single statement of the lambda body.
560   SMLoc bodyStartLoc = curToken.getStartLoc();
561   pushDeclScope();
562   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
563   bool failedToParse =
564       failed(singleStatement) || failed(processStatementFn(*singleStatement));
565   popDeclScope();
566   if (failedToParse)
567     return failure();
568 
569   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
570   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
571 }
572 
573 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
574   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
575     if (isa<ast::OpRewriteStmt>(statement))
576       return success();
577     return emitError(
578         statement->getLoc(),
579         "expected Pattern lambda body to contain a single operation "
580         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
581   });
582 }
583 
584 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
585   SMRange loc = curToken.getLoc();
586   consumeToken(Token::kw_Pattern);
587   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
588                                               ParserContext::PatternMatch);
589 
590   // Check for an optional identifier for the pattern name.
591   const ast::Name *name = nullptr;
592   if (curToken.is(Token::identifier)) {
593     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
594     consumeToken(Token::identifier);
595   }
596 
597   // Parse any pattern metadata.
598   ParsedPatternMetadata metadata;
599   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
600     return failure();
601 
602   // Parse the pattern body.
603   ast::CompoundStmt *body;
604 
605   // Handle a lambda body.
606   if (curToken.is(Token::equal_arrow)) {
607     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
608     if (failed(bodyResult))
609       return failure();
610     body = *bodyResult;
611   } else {
612     if (curToken.isNot(Token::l_brace))
613       return emitError("expected `{` or `=>` to start pattern body");
614     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
615     if (failed(bodyResult))
616       return failure();
617     body = *bodyResult;
618 
619     // Verify the body of the pattern.
620     auto bodyIt = body->begin(), bodyE = body->end();
621     for (; bodyIt != bodyE; ++bodyIt) {
622       // Break when we've found the rewrite statement.
623       if (isa<ast::OpRewriteStmt>(*bodyIt))
624         break;
625     }
626     if (bodyIt == bodyE) {
627       return emitError(loc,
628                        "expected Pattern body to terminate with an operation "
629                        "rewrite statement, such as `erase`");
630     }
631     if (std::next(bodyIt) != bodyE) {
632       return emitError((*std::next(bodyIt))->getLoc(),
633                        "Pattern body was terminated by an operation "
634                        "rewrite statement, but found trailing statements");
635     }
636   }
637 
638   return createPatternDecl(loc, name, metadata, body);
639 }
640 
641 LogicalResult
642 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
643   Optional<SMRange> benefitLoc;
644   Optional<SMRange> hasBoundedRecursionLoc;
645 
646   do {
647     if (curToken.isNot(Token::identifier))
648       return emitError("expected pattern metadata identifier");
649     StringRef metadataStr = curToken.getSpelling();
650     SMRange metadataLoc = curToken.getLoc();
651     consumeToken(Token::identifier);
652 
653     // Parse the benefit metadata: benefit(<integer-value>)
654     if (metadataStr == "benefit") {
655       if (benefitLoc) {
656         return emitErrorAndNote(metadataLoc,
657                                 "pattern benefit has already been specified",
658                                 *benefitLoc, "see previous definition here");
659       }
660       if (failed(parseToken(Token::l_paren,
661                             "expected `(` before pattern benefit")))
662         return failure();
663 
664       uint16_t benefitValue = 0;
665       if (curToken.isNot(Token::integer))
666         return emitError("expected integral pattern benefit");
667       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
668         return emitError(
669             "expected pattern benefit to fit within a 16-bit integer");
670       consumeToken(Token::integer);
671 
672       metadata.benefit = benefitValue;
673       benefitLoc = metadataLoc;
674 
675       if (failed(
676               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
677         return failure();
678       continue;
679     }
680 
681     // Parse the bounded recursion metadata: recursion
682     if (metadataStr == "recursion") {
683       if (hasBoundedRecursionLoc) {
684         return emitErrorAndNote(
685             metadataLoc,
686             "pattern recursion metadata has already been specified",
687             *hasBoundedRecursionLoc, "see previous definition here");
688       }
689       metadata.hasBoundedRecursion = true;
690       hasBoundedRecursionLoc = metadataLoc;
691       continue;
692     }
693 
694     return emitError(metadataLoc, "unknown pattern metadata");
695   } while (consumeIf(Token::comma));
696 
697   return success();
698 }
699 
700 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
701   consumeToken(Token::less);
702 
703   FailureOr<ast::Expr *> typeExpr = parseExpr();
704   if (failed(typeExpr) ||
705       failed(parseToken(Token::greater,
706                         "expected `>` after variable type constraint")))
707     return failure();
708   return typeExpr;
709 }
710 
711 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
712   assert(curDeclScope && "defining decl outside of a decl scope");
713   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
714     return emitErrorAndNote(
715         name.getLoc(), "`" + name.getName() + "` has already been defined",
716         lastDecl->getName()->getLoc(), "see previous definition here");
717   }
718   return success();
719 }
720 
721 FailureOr<ast::VariableDecl *>
722 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
723                            ast::Type type, ast::Expr *initExpr,
724                            ArrayRef<ast::ConstraintRef> constraints) {
725   assert(curDeclScope && "defining variable outside of decl scope");
726   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
727 
728   // If the name of the variable indicates a special variable, we don't add it
729   // to the scope. This variable is local to the definition point.
730   if (name.empty() || name == "_") {
731     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
732                                      constraints);
733   }
734   if (failed(checkDefineNamedDecl(nameDecl)))
735     return failure();
736 
737   auto *varDecl =
738       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
739   curDeclScope->add(varDecl);
740   return varDecl;
741 }
742 
743 FailureOr<ast::VariableDecl *>
744 Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
745                            ast::Type type,
746                            ArrayRef<ast::ConstraintRef> constraints) {
747   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
748                             constraints);
749 }
750 
751 LogicalResult Parser::parseVariableDeclConstraintList(
752     SmallVectorImpl<ast::ConstraintRef> &constraints) {
753   Optional<SMRange> typeConstraint;
754   auto parseSingleConstraint = [&] {
755     FailureOr<ast::ConstraintRef> constraint =
756         parseConstraint(typeConstraint, constraints);
757     if (failed(constraint))
758       return failure();
759     constraints.push_back(*constraint);
760     return success();
761   };
762 
763   // Check to see if this is a single constraint, or a list.
764   if (!consumeIf(Token::l_square))
765     return parseSingleConstraint();
766 
767   do {
768     if (failed(parseSingleConstraint()))
769       return failure();
770   } while (consumeIf(Token::comma));
771   return parseToken(Token::r_square, "expected `]` after constraint list");
772 }
773 
774 FailureOr<ast::ConstraintRef>
775 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
776                         ArrayRef<ast::ConstraintRef> existingConstraints) {
777   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
778     if (typeConstraint)
779       return emitErrorAndNote(
780           curToken.getLoc(),
781           "the type of this variable has already been constrained",
782           *typeConstraint, "see previous constraint location here");
783     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
784     if (failed(constraintExpr))
785       return failure();
786     typeExpr = *constraintExpr;
787     typeConstraint = typeExpr->getLoc();
788     return success();
789   };
790 
791   SMRange loc = curToken.getLoc();
792   switch (curToken.getKind()) {
793   case Token::kw_Attr: {
794     consumeToken(Token::kw_Attr);
795 
796     // Check for a type constraint.
797     ast::Expr *typeExpr = nullptr;
798     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
799       return failure();
800     return ast::ConstraintRef(
801         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
802   }
803   case Token::kw_Op: {
804     consumeToken(Token::kw_Op);
805 
806     // Parse an optional operation name. If the name isn't provided, this refers
807     // to "any" operation.
808     FailureOr<ast::OpNameDecl *> opName =
809         parseWrappedOperationName(/*allowEmptyName=*/true);
810     if (failed(opName))
811       return failure();
812 
813     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
814                               loc);
815   }
816   case Token::kw_Type:
817     consumeToken(Token::kw_Type);
818     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
819   case Token::kw_TypeRange:
820     consumeToken(Token::kw_TypeRange);
821     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
822                               loc);
823   case Token::kw_Value: {
824     consumeToken(Token::kw_Value);
825 
826     // Check for a type constraint.
827     ast::Expr *typeExpr = nullptr;
828     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
829       return failure();
830 
831     return ast::ConstraintRef(
832         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
833   }
834   case Token::kw_ValueRange: {
835     consumeToken(Token::kw_ValueRange);
836 
837     // Check for a type constraint.
838     ast::Expr *typeExpr = nullptr;
839     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
840       return failure();
841 
842     return ast::ConstraintRef(
843         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
844   }
845   case Token::identifier: {
846     StringRef constraintName = curToken.getSpelling();
847     consumeToken(Token::identifier);
848 
849     // Lookup the referenced constraint.
850     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
851     if (!cstDecl) {
852       return emitError(loc, "unknown reference to constraint `" +
853                                 constraintName + "`");
854     }
855 
856     // Handle a reference to a proper constraint.
857     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
858       return ast::ConstraintRef(cst, loc);
859 
860     return emitErrorAndNote(
861         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
862         "see the definition of `" + constraintName + "` here");
863   }
864   default:
865     break;
866   }
867   return emitError(loc, "expected identifier constraint");
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // Exprs
872 
873 FailureOr<ast::Expr *> Parser::parseExpr() {
874   if (curToken.is(Token::underscore))
875     return parseUnderscoreExpr();
876 
877   // Parse the LHS expression.
878   FailureOr<ast::Expr *> lhsExpr;
879   switch (curToken.getKind()) {
880   case Token::kw_attr:
881     lhsExpr = parseAttributeExpr();
882     break;
883   case Token::identifier:
884     lhsExpr = parseIdentifierExpr();
885     break;
886   case Token::kw_op:
887     lhsExpr = parseOperationExpr();
888     break;
889   case Token::kw_type:
890     lhsExpr = parseTypeExpr();
891     break;
892   case Token::l_paren:
893     lhsExpr = parseTupleExpr();
894     break;
895   default:
896     return emitError("expected expression");
897   }
898   if (failed(lhsExpr))
899     return failure();
900 
901   // Check for an operator expression.
902   while (true) {
903     switch (curToken.getKind()) {
904     case Token::dot:
905       lhsExpr = parseMemberAccessExpr(*lhsExpr);
906       break;
907     default:
908       return lhsExpr;
909     }
910     if (failed(lhsExpr))
911       return failure();
912   }
913 }
914 
915 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
916   SMRange loc = curToken.getLoc();
917   consumeToken(Token::kw_attr);
918 
919   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
920   // identifier.
921   if (!consumeIf(Token::less)) {
922     resetToken(loc);
923     return parseIdentifierExpr();
924   }
925 
926   if (!curToken.isString())
927     return emitError("expected string literal containing MLIR attribute");
928   std::string attrExpr = curToken.getStringValue();
929   consumeToken();
930 
931   if (failed(
932           parseToken(Token::greater, "expected `>` after attribute literal")))
933     return failure();
934   return ast::AttributeExpr::create(ctx, loc, attrExpr);
935 }
936 
937 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
938                                                 SMRange loc) {
939   ast::Decl *decl = curDeclScope->lookup(name);
940   if (!decl)
941     return emitError(loc, "undefined reference to `" + name + "`");
942 
943   return createDeclRefExpr(loc, decl);
944 }
945 
946 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
947   StringRef name = curToken.getSpelling();
948   SMRange nameLoc = curToken.getLoc();
949   consumeToken();
950 
951   // Check to see if this is a decl ref expression that defines a variable
952   // inline.
953   if (consumeIf(Token::colon)) {
954     SmallVector<ast::ConstraintRef> constraints;
955     if (failed(parseVariableDeclConstraintList(constraints)))
956       return failure();
957     ast::Type type;
958     if (failed(validateVariableConstraints(constraints, type)))
959       return failure();
960     return createInlineVariableExpr(type, name, nameLoc, constraints);
961   }
962 
963   return parseDeclRefExpr(name, nameLoc);
964 }
965 
966 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
967   SMRange loc = curToken.getLoc();
968   consumeToken(Token::dot);
969 
970   // Parse the member name.
971   Token memberNameTok = curToken;
972   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
973       !memberNameTok.isKeyword())
974     return emitError(loc, "expected identifier or numeric member name");
975   StringRef memberName = memberNameTok.getSpelling();
976   consumeToken();
977 
978   return createMemberAccessExpr(parentExpr, memberName, loc);
979 }
980 
981 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
982   SMRange loc = curToken.getLoc();
983 
984   // Handle the case of an no operation name.
985   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
986     if (allowEmptyName)
987       return ast::OpNameDecl::create(ctx, SMRange());
988     return emitError("expected dialect namespace");
989   }
990   StringRef name = curToken.getSpelling();
991   consumeToken();
992 
993   // Otherwise, this is a literal operation name.
994   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
995     return failure();
996 
997   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
998     return emitError("expected operation name after dialect namespace");
999 
1000   name = StringRef(name.data(), name.size() + 1);
1001   do {
1002     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1003     loc.End = curToken.getEndLoc();
1004     consumeToken();
1005   } while (curToken.isAny(Token::identifier, Token::dot) ||
1006            curToken.isKeyword());
1007   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1008 }
1009 
1010 FailureOr<ast::OpNameDecl *>
1011 Parser::parseWrappedOperationName(bool allowEmptyName) {
1012   if (!consumeIf(Token::less))
1013     return ast::OpNameDecl::create(ctx, SMRange());
1014 
1015   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1016   if (failed(opNameDecl))
1017     return failure();
1018 
1019   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1020     return failure();
1021   return opNameDecl;
1022 }
1023 
1024 FailureOr<ast::Expr *> Parser::parseOperationExpr() {
1025   SMRange loc = curToken.getLoc();
1026   consumeToken(Token::kw_op);
1027 
1028   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1029   // identifier.
1030   if (curToken.isNot(Token::less)) {
1031     resetToken(loc);
1032     return parseIdentifierExpr();
1033   }
1034 
1035   // Parse the operation name. The name may be elided, in which case the
1036   // operation refers to "any" operation(i.e. a difference between `MyOp` and
1037   // `Operation*`). Operation names within a rewrite context must be named.
1038   bool allowEmptyName = parserContext != ParserContext::Rewrite;
1039   FailureOr<ast::OpNameDecl *> opNameDecl =
1040       parseWrappedOperationName(allowEmptyName);
1041   if (failed(opNameDecl))
1042     return failure();
1043 
1044   // Check for the optional list of operands.
1045   SmallVector<ast::Expr *> operands;
1046   if (consumeIf(Token::l_paren)) {
1047     do {
1048       FailureOr<ast::Expr *> operand = parseExpr();
1049       if (failed(operand))
1050         return failure();
1051       operands.push_back(*operand);
1052     } while (consumeIf(Token::comma));
1053 
1054     if (failed(parseToken(Token::r_paren,
1055                           "expected `)` after operation operand list")))
1056       return failure();
1057   }
1058 
1059   // Check for the optional list of attributes.
1060   SmallVector<ast::NamedAttributeDecl *> attributes;
1061   if (consumeIf(Token::l_brace)) {
1062     do {
1063       FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
1064       if (failed(decl))
1065         return failure();
1066       attributes.emplace_back(*decl);
1067     } while (consumeIf(Token::comma));
1068 
1069     if (failed(parseToken(Token::r_brace,
1070                           "expected `}` after operation attribute list")))
1071       return failure();
1072   }
1073 
1074   // Check for the optional list of result types.
1075   SmallVector<ast::Expr *> resultTypes;
1076   if (consumeIf(Token::arrow)) {
1077     if (failed(parseToken(Token::l_paren,
1078                           "expected `(` before operation result type list")))
1079       return failure();
1080 
1081     do {
1082       FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
1083       if (failed(resultTypeExpr))
1084         return failure();
1085       resultTypes.push_back(*resultTypeExpr);
1086     } while (consumeIf(Token::comma));
1087 
1088     if (failed(parseToken(Token::r_paren,
1089                           "expected `)` after operation result type list")))
1090       return failure();
1091   }
1092 
1093   return createOperationExpr(loc, *opNameDecl, operands, attributes,
1094                              resultTypes);
1095 }
1096 
1097 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
1098   SMRange loc = curToken.getLoc();
1099   consumeToken(Token::l_paren);
1100 
1101   DenseMap<StringRef, SMRange> usedNames;
1102   SmallVector<StringRef> elementNames;
1103   SmallVector<ast::Expr *> elements;
1104   if (curToken.isNot(Token::r_paren)) {
1105     do {
1106       // Check for the optional element name assignment before the value.
1107       StringRef elementName;
1108       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1109         Token elementNameTok = curToken;
1110         consumeToken();
1111 
1112         // The element name is only present if followed by an `=`.
1113         if (consumeIf(Token::equal)) {
1114           elementName = elementNameTok.getSpelling();
1115 
1116           // Check to see if this name is already used.
1117           auto elementNameIt =
1118               usedNames.try_emplace(elementName, elementNameTok.getLoc());
1119           if (!elementNameIt.second) {
1120             return emitErrorAndNote(
1121                 elementNameTok.getLoc(),
1122                 llvm::formatv("duplicate tuple element label `{0}`",
1123                               elementName),
1124                 elementNameIt.first->getSecond(),
1125                 "see previous label use here");
1126           }
1127         } else {
1128           // Otherwise, we treat this as part of an expression so reset the
1129           // lexer.
1130           resetToken(elementNameTok.getLoc());
1131         }
1132       }
1133       elementNames.push_back(elementName);
1134 
1135       // Parse the tuple element value.
1136       FailureOr<ast::Expr *> element = parseExpr();
1137       if (failed(element))
1138         return failure();
1139       elements.push_back(*element);
1140     } while (consumeIf(Token::comma));
1141   }
1142   loc.End = curToken.getEndLoc();
1143   if (failed(
1144           parseToken(Token::r_paren, "expected `)` after tuple element list")))
1145     return failure();
1146   return createTupleExpr(loc, elements, elementNames);
1147 }
1148 
1149 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
1150   SMRange loc = curToken.getLoc();
1151   consumeToken(Token::kw_type);
1152 
1153   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
1154   // identifier.
1155   if (!consumeIf(Token::less)) {
1156     resetToken(loc);
1157     return parseIdentifierExpr();
1158   }
1159 
1160   if (!curToken.isString())
1161     return emitError("expected string literal containing MLIR type");
1162   std::string attrExpr = curToken.getStringValue();
1163   consumeToken();
1164 
1165   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
1166     return failure();
1167   return ast::TypeExpr::create(ctx, loc, attrExpr);
1168 }
1169 
1170 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
1171   StringRef name = curToken.getSpelling();
1172   SMRange nameLoc = curToken.getLoc();
1173   consumeToken(Token::underscore);
1174 
1175   // Underscore expressions require a constraint list.
1176   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
1177     return failure();
1178 
1179   // Parse the constraints for the expression.
1180   SmallVector<ast::ConstraintRef> constraints;
1181   if (failed(parseVariableDeclConstraintList(constraints)))
1182     return failure();
1183 
1184   ast::Type type;
1185   if (failed(validateVariableConstraints(constraints, type)))
1186     return failure();
1187   return createInlineVariableExpr(type, name, nameLoc, constraints);
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // Stmts
1192 
1193 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
1194   FailureOr<ast::Stmt *> stmt;
1195   switch (curToken.getKind()) {
1196   case Token::kw_erase:
1197     stmt = parseEraseStmt();
1198     break;
1199   case Token::kw_let:
1200     stmt = parseLetStmt();
1201     break;
1202   case Token::kw_replace:
1203     stmt = parseReplaceStmt();
1204     break;
1205   case Token::kw_rewrite:
1206     stmt = parseRewriteStmt();
1207     break;
1208   default:
1209     stmt = parseExpr();
1210     break;
1211   }
1212   if (failed(stmt) ||
1213       (expectTerminalSemicolon &&
1214        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
1215     return failure();
1216   return stmt;
1217 }
1218 
1219 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
1220   SMLoc startLoc = curToken.getStartLoc();
1221   consumeToken(Token::l_brace);
1222 
1223   // Push a new block scope and parse any nested statements.
1224   pushDeclScope();
1225   SmallVector<ast::Stmt *> statements;
1226   while (curToken.isNot(Token::r_brace)) {
1227     FailureOr<ast::Stmt *> statement = parseStmt();
1228     if (failed(statement))
1229       return popDeclScope(), failure();
1230     statements.push_back(*statement);
1231   }
1232   popDeclScope();
1233 
1234   // Consume the end brace.
1235   SMRange location(startLoc, curToken.getEndLoc());
1236   consumeToken(Token::r_brace);
1237 
1238   return ast::CompoundStmt::create(ctx, location, statements);
1239 }
1240 
1241 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
1242   SMRange loc = curToken.getLoc();
1243   consumeToken(Token::kw_erase);
1244 
1245   // Parse the root operation expression.
1246   FailureOr<ast::Expr *> rootOp = parseExpr();
1247   if (failed(rootOp))
1248     return failure();
1249 
1250   return createEraseStmt(loc, *rootOp);
1251 }
1252 
1253 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
1254   SMRange loc = curToken.getLoc();
1255   consumeToken(Token::kw_let);
1256 
1257   // Parse the name of the new variable.
1258   SMRange varLoc = curToken.getLoc();
1259   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
1260     // `_` is a reserved variable name.
1261     if (curToken.is(Token::underscore)) {
1262       return emitError(varLoc,
1263                        "`_` may only be used to define \"inline\" variables");
1264     }
1265     return emitError(varLoc,
1266                      "expected identifier after `let` to name a new variable");
1267   }
1268   StringRef varName = curToken.getSpelling();
1269   consumeToken();
1270 
1271   // Parse the optional set of constraints.
1272   SmallVector<ast::ConstraintRef> constraints;
1273   if (consumeIf(Token::colon) &&
1274       failed(parseVariableDeclConstraintList(constraints)))
1275     return failure();
1276 
1277   // Parse the optional initializer expression.
1278   ast::Expr *initializer = nullptr;
1279   if (consumeIf(Token::equal)) {
1280     FailureOr<ast::Expr *> initOrFailure = parseExpr();
1281     if (failed(initOrFailure))
1282       return failure();
1283     initializer = *initOrFailure;
1284 
1285     // Check that the constraints are compatible with having an initializer,
1286     // e.g. type constraints cannot be used with initializers.
1287     for (ast::ConstraintRef constraint : constraints) {
1288       LogicalResult result =
1289           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
1290               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
1291                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
1292                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
1293                   return this->emitError(
1294                       constraint.referenceLoc,
1295                       "type constraints are not permitted on variables with "
1296                       "initializers");
1297                 }
1298                 return success();
1299               })
1300               .Default(success());
1301       if (failed(result))
1302         return failure();
1303     }
1304   }
1305 
1306   FailureOr<ast::VariableDecl *> varDecl =
1307       createVariableDecl(varName, varLoc, initializer, constraints);
1308   if (failed(varDecl))
1309     return failure();
1310   return ast::LetStmt::create(ctx, loc, *varDecl);
1311 }
1312 
1313 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
1314   SMRange loc = curToken.getLoc();
1315   consumeToken(Token::kw_replace);
1316 
1317   // Parse the root operation expression.
1318   FailureOr<ast::Expr *> rootOp = parseExpr();
1319   if (failed(rootOp))
1320     return failure();
1321 
1322   if (failed(
1323           parseToken(Token::kw_with, "expected `with` after root operation")))
1324     return failure();
1325 
1326   // The replacement portion of this statement is within a rewrite context.
1327   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1328                                               ParserContext::Rewrite);
1329 
1330   // Parse the replacement values.
1331   SmallVector<ast::Expr *> replValues;
1332   if (consumeIf(Token::l_paren)) {
1333     if (consumeIf(Token::r_paren)) {
1334       return emitError(
1335           loc, "expected at least one replacement value, consider using "
1336                "`erase` if no replacement values are desired");
1337     }
1338 
1339     do {
1340       FailureOr<ast::Expr *> replExpr = parseExpr();
1341       if (failed(replExpr))
1342         return failure();
1343       replValues.emplace_back(*replExpr);
1344     } while (consumeIf(Token::comma));
1345 
1346     if (failed(parseToken(Token::r_paren,
1347                           "expected `)` after replacement values")))
1348       return failure();
1349   } else {
1350     FailureOr<ast::Expr *> replExpr = parseExpr();
1351     if (failed(replExpr))
1352       return failure();
1353     replValues.emplace_back(*replExpr);
1354   }
1355 
1356   return createReplaceStmt(loc, *rootOp, replValues);
1357 }
1358 
1359 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
1360   SMRange loc = curToken.getLoc();
1361   consumeToken(Token::kw_rewrite);
1362 
1363   // Parse the root operation.
1364   FailureOr<ast::Expr *> rootOp = parseExpr();
1365   if (failed(rootOp))
1366     return failure();
1367 
1368   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
1369     return failure();
1370 
1371   if (curToken.isNot(Token::l_brace))
1372     return emitError("expected `{` to start rewrite body");
1373 
1374   // The rewrite body of this statement is within a rewrite context.
1375   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1376                                               ParserContext::Rewrite);
1377 
1378   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
1379   if (failed(rewriteBody))
1380     return failure();
1381 
1382   return createRewriteStmt(loc, *rootOp, *rewriteBody);
1383 }
1384 
1385 //===----------------------------------------------------------------------===//
1386 // Creation+Analysis
1387 //===----------------------------------------------------------------------===//
1388 
1389 //===----------------------------------------------------------------------===//
1390 // Decls
1391 
1392 FailureOr<ast::PatternDecl *>
1393 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
1394                           const ParsedPatternMetadata &metadata,
1395                           ast::CompoundStmt *body) {
1396   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
1397                                   metadata.hasBoundedRecursion, body);
1398 }
1399 
1400 FailureOr<ast::VariableDecl *>
1401 Parser::createVariableDecl(StringRef name, SMRange loc,
1402                            ast::Expr *initializer,
1403                            ArrayRef<ast::ConstraintRef> constraints) {
1404   // The type of the variable, which is expected to be inferred by either a
1405   // constraint or an initializer expression.
1406   ast::Type type;
1407   if (failed(validateVariableConstraints(constraints, type)))
1408     return failure();
1409 
1410   if (initializer) {
1411     // Update the variable type based on the initializer, or try to convert the
1412     // initializer to the existing type.
1413     if (!type)
1414       type = initializer->getType();
1415     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
1416       type = mergedType;
1417     else if (failed(convertExpressionTo(initializer, type)))
1418       return failure();
1419 
1420     // Otherwise, if there is no initializer check that the type has already
1421     // been resolved from the constraint list.
1422   } else if (!type) {
1423     return emitErrorAndNote(
1424         loc, "unable to infer type for variable `" + name + "`", loc,
1425         "the type of a variable must be inferable from the constraint "
1426         "list or the initializer");
1427   }
1428 
1429   // Try to define a variable with the given name.
1430   FailureOr<ast::VariableDecl *> varDecl =
1431       defineVariableDecl(name, loc, type, initializer, constraints);
1432   if (failed(varDecl))
1433     return failure();
1434 
1435   return *varDecl;
1436 }
1437 
1438 LogicalResult
1439 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
1440                                     ast::Type &inferredType) {
1441   for (const ast::ConstraintRef &ref : constraints)
1442     if (failed(validateVariableConstraint(ref, inferredType)))
1443       return failure();
1444   return success();
1445 }
1446 
1447 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
1448                                                  ast::Type &inferredType) {
1449   ast::Type constraintType;
1450   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
1451     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
1452       if (failed(validateTypeConstraintExpr(typeExpr)))
1453         return failure();
1454     }
1455     constraintType = ast::AttributeType::get(ctx);
1456   } else if (const auto *cst =
1457                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
1458     constraintType = ast::OperationType::get(ctx, cst->getName());
1459   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
1460     constraintType = typeTy;
1461   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
1462     constraintType = typeRangeTy;
1463   } else if (const auto *cst =
1464                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
1465     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
1466       if (failed(validateTypeConstraintExpr(typeExpr)))
1467         return failure();
1468     }
1469     constraintType = valueTy;
1470   } else if (const auto *cst =
1471                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
1472     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
1473       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
1474         return failure();
1475     }
1476     constraintType = valueRangeTy;
1477   } else {
1478     llvm_unreachable("unknown constraint type");
1479   }
1480 
1481   // Check that the constraint type is compatible with the current inferred
1482   // type.
1483   if (!inferredType) {
1484     inferredType = constraintType;
1485   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
1486     inferredType = mergedTy;
1487   } else {
1488     return emitError(ref.referenceLoc,
1489                      llvm::formatv("constraint type `{0}` is incompatible "
1490                                    "with the previously inferred type `{1}`",
1491                                    constraintType, inferredType));
1492   }
1493   return success();
1494 }
1495 
1496 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
1497   ast::Type typeExprType = typeExpr->getType();
1498   if (typeExprType != typeTy) {
1499     return emitError(typeExpr->getLoc(),
1500                      "expected expression of `Type` in type constraint");
1501   }
1502   return success();
1503 }
1504 
1505 LogicalResult
1506 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
1507   ast::Type typeExprType = typeExpr->getType();
1508   if (typeExprType != typeRangeTy) {
1509     return emitError(typeExpr->getLoc(),
1510                      "expected expression of `TypeRange` in type constraint");
1511   }
1512   return success();
1513 }
1514 
1515 //===----------------------------------------------------------------------===//
1516 // Exprs
1517 
1518 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
1519                                                         ast::Decl *decl) {
1520   // Check the type of decl being referenced.
1521   ast::Type declType;
1522   if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
1523     declType = varDecl->getType();
1524   else
1525     return emitError(loc, "invalid reference to `" +
1526                               decl->getName()->getName() + "`");
1527 
1528   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
1529 }
1530 
1531 FailureOr<ast::DeclRefExpr *>
1532 Parser::createInlineVariableExpr(ast::Type type, StringRef name,
1533                                  SMRange loc,
1534                                  ArrayRef<ast::ConstraintRef> constraints) {
1535   FailureOr<ast::VariableDecl *> decl =
1536       defineVariableDecl(name, loc, type, constraints);
1537   if (failed(decl))
1538     return failure();
1539   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
1540 }
1541 
1542 FailureOr<ast::MemberAccessExpr *>
1543 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
1544                                SMRange loc) {
1545   // Validate the member name for the given parent expression.
1546   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
1547   if (failed(memberType))
1548     return failure();
1549 
1550   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
1551 }
1552 
1553 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
1554                                                   StringRef name,
1555                                                   SMRange loc) {
1556   ast::Type parentType = parentExpr->getType();
1557   if (parentType.isa<ast::OperationType>()) {
1558     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
1559       return valueRangeTy;
1560   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
1561     // Handle indexed results.
1562     unsigned index = 0;
1563     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
1564         index < tupleType.size()) {
1565       return tupleType.getElementTypes()[index];
1566     }
1567 
1568     // Handle named results.
1569     auto elementNames = tupleType.getElementNames();
1570     const auto *it = llvm::find(elementNames, name);
1571     if (it != elementNames.end())
1572       return tupleType.getElementTypes()[it - elementNames.begin()];
1573   }
1574   return emitError(
1575       loc,
1576       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
1577                     name, parentType));
1578 }
1579 
1580 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
1581     SMRange loc, const ast::OpNameDecl *name,
1582     MutableArrayRef<ast::Expr *> operands,
1583     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
1584     MutableArrayRef<ast::Expr *> results) {
1585   Optional<StringRef> opNameRef = name->getName();
1586 
1587   // Verify the inputs operands.
1588   if (failed(validateOperationOperands(loc, opNameRef, operands)))
1589     return failure();
1590 
1591   // Verify the attribute list.
1592   for (ast::NamedAttributeDecl *attr : attributes) {
1593     // Check for an attribute type, or a type awaiting resolution.
1594     ast::Type attrType = attr->getValue()->getType();
1595     if (!attrType.isa<ast::AttributeType>()) {
1596       return emitError(
1597           attr->getValue()->getLoc(),
1598           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
1599     }
1600   }
1601 
1602   // Verify the result types.
1603   if (failed(validateOperationResults(loc, opNameRef, results)))
1604     return failure();
1605 
1606   return ast::OperationExpr::create(ctx, loc, name, operands, results,
1607                                     attributes);
1608 }
1609 
1610 LogicalResult
1611 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
1612                                   MutableArrayRef<ast::Expr *> operands) {
1613   return validateOperationOperandsOrResults(loc, name, operands, valueTy,
1614                                             valueRangeTy);
1615 }
1616 
1617 LogicalResult
1618 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
1619                                  MutableArrayRef<ast::Expr *> results) {
1620   return validateOperationOperandsOrResults(loc, name, results, typeTy,
1621                                             typeRangeTy);
1622 }
1623 
1624 LogicalResult Parser::validateOperationOperandsOrResults(
1625     SMRange loc, Optional<StringRef> name,
1626     MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
1627     ast::Type rangeTy) {
1628   // All operation types accept a single range parameter.
1629   if (values.size() == 1) {
1630     if (failed(convertExpressionTo(values[0], rangeTy)))
1631       return failure();
1632     return success();
1633   }
1634 
1635   // Otherwise, accept the value groups as they have been defined and just
1636   // ensure they are one of the expected types.
1637   for (ast::Expr *&valueExpr : values) {
1638     ast::Type valueExprType = valueExpr->getType();
1639 
1640     // Check if this is one of the expected types.
1641     if (valueExprType == rangeTy || valueExprType == singleTy)
1642       continue;
1643 
1644     // If the operand is an Operation, allow converting to a Value or
1645     // ValueRange. This situations arises quite often with nested operation
1646     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
1647     if (singleTy == valueTy) {
1648       if (valueExprType.isa<ast::OperationType>()) {
1649         valueExpr = convertOpToValue(valueExpr);
1650         continue;
1651       }
1652     }
1653 
1654     return emitError(
1655         valueExpr->getLoc(),
1656         llvm::formatv(
1657             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
1658             singleTy, rangeTy, valueExprType));
1659   }
1660   return success();
1661 }
1662 
1663 FailureOr<ast::TupleExpr *>
1664 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
1665                         ArrayRef<StringRef> elementNames) {
1666   for (const ast::Expr *element : elements) {
1667     ast::Type eleTy = element->getType();
1668     if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
1669       return emitError(
1670           element->getLoc(),
1671           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
1672     }
1673   }
1674   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
1675 }
1676 
1677 //===----------------------------------------------------------------------===//
1678 // Stmts
1679 
1680 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
1681                                                     ast::Expr *rootOp) {
1682   // Check that root is an Operation.
1683   ast::Type rootType = rootOp->getType();
1684   if (!rootType.isa<ast::OperationType>())
1685     return emitError(rootOp->getLoc(), "expected `Op` expression");
1686 
1687   return ast::EraseStmt::create(ctx, loc, rootOp);
1688 }
1689 
1690 FailureOr<ast::ReplaceStmt *>
1691 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
1692                           MutableArrayRef<ast::Expr *> replValues) {
1693   // Check that root is an Operation.
1694   ast::Type rootType = rootOp->getType();
1695   if (!rootType.isa<ast::OperationType>()) {
1696     return emitError(
1697         rootOp->getLoc(),
1698         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
1699   }
1700 
1701   // If there are multiple replacement values, we implicitly convert any Op
1702   // expressions to the value form.
1703   bool shouldConvertOpToValues = replValues.size() > 1;
1704   for (ast::Expr *&replExpr : replValues) {
1705     ast::Type replType = replExpr->getType();
1706 
1707     // Check that replExpr is an Operation, Value, or ValueRange.
1708     if (replType.isa<ast::OperationType>()) {
1709       if (shouldConvertOpToValues)
1710         replExpr = convertOpToValue(replExpr);
1711       continue;
1712     }
1713 
1714     if (replType != valueTy && replType != valueRangeTy) {
1715       return emitError(replExpr->getLoc(),
1716                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
1717                                      "expression, but got `{0}`",
1718                                      replType));
1719     }
1720   }
1721 
1722   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
1723 }
1724 
1725 FailureOr<ast::RewriteStmt *>
1726 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
1727                           ast::CompoundStmt *rewriteBody) {
1728   // Check that root is an Operation.
1729   ast::Type rootType = rootOp->getType();
1730   if (!rootType.isa<ast::OperationType>()) {
1731     return emitError(
1732         rootOp->getLoc(),
1733         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
1734   }
1735 
1736   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
1737 }
1738 
1739 //===----------------------------------------------------------------------===//
1740 // Parser
1741 //===----------------------------------------------------------------------===//
1742 
1743 FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
1744                                                  llvm::SourceMgr &sourceMgr) {
1745   Parser parser(ctx, sourceMgr);
1746   return parser.parseModule();
1747 }
1748