//===- Parser.cpp ---------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include <string>

using namespace mlir;
using namespace mlir::pdll;

//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//

namespace {
class Parser {
public:
  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
      : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
        curToken(lexer.lexToken()), curDeclScope(nullptr),
        valueTy(ast::ValueType::get(ctx)),
        valueRangeTy(ast::ValueRangeType::get(ctx)),
        typeTy(ast::TypeType::get(ctx)),
        typeRangeTy(ast::TypeRangeType::get(ctx)) {}

  /// Try to parse a new module. Returns nullptr in the case of failure.
  FailureOr<ast::Module *> parseModule();

private:
  /// The current context of the parser. It allows for the parser to know a bit
  /// about the construct it is nested within during parsing. This is used
  /// specifically to provide additional verification during parsing, e.g. to
  /// prevent using rewrites within a match context, matcher constraints within
  /// a rewrite section, etc.
  enum class ParserContext {
    /// The parser is in the global context.
    Global,
    /// The parser is currently within the matcher portion of a Pattern, which
    /// is allows a terminal operation rewrite statement but no other rewrite
    /// transformations.
    PatternMatch,
    /// The parser is currently within a Rewrite, which disallows calls to
    /// constraints, requires operation expressions to have names, etc.
    Rewrite,
  };

  //===--------------------------------------------------------------------===//
  // Parsing
  //===--------------------------------------------------------------------===//

  /// Push a new decl scope onto the lexer.
  ast::DeclScope *pushDeclScope() {
    ast::DeclScope *newScope =
        new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
    return (curDeclScope = newScope);
  }
  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }

  /// Pop the last decl scope from the lexer.
  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }

  /// Parse the body of an AST module.
  LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);

  /// Try to convert the given expression to `type`. Returns failure and emits
  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
  /// invoked to attach notes to the emitted error diagnostic. On success,
  /// `expr` is updated to the expression used to convert to `type`.
  LogicalResult convertExpressionTo(
      ast::Expr *&expr, ast::Type type,
      function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});

  /// Given an operation expression, convert it to a Value or ValueRange
  /// typed expression.
  ast::Expr *convertOpToValue(const ast::Expr *opExpr);

  //===--------------------------------------------------------------------===//
  // Directives

  LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
  LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);

  //===--------------------------------------------------------------------===//
  // Decls

  /// This structure contains the set of pattern metadata that may be parsed.
  struct ParsedPatternMetadata {
    Optional<uint16_t> benefit;
    bool hasBoundedRecursion = false;
  };

  FailureOr<ast::Decl *> parseTopLevelDecl();
  FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
  FailureOr<ast::Decl *> parsePatternDecl();
  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);

  /// Check to see if a decl has already been defined with the given name, if
  /// one has emit and error and return failure. Returns success otherwise.
  LogicalResult checkDefineNamedDecl(const ast::Name &name);

  /// Try to define a variable decl with the given components, returns the
  /// variable on success.
  FailureOr<ast::VariableDecl *>
  defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
                     ast::Expr *initExpr,
                     ArrayRef<ast::ConstraintRef> constraints);
  FailureOr<ast::VariableDecl *>
  defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
                     ArrayRef<ast::ConstraintRef> constraints);

  /// Parse the constraint reference list for a variable decl.
  LogicalResult parseVariableDeclConstraintList(
      SmallVectorImpl<ast::ConstraintRef> &constraints);

  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
  FailureOr<ast::Expr *> parseTypeConstraintExpr();

  /// Try to parse a single reference to a constraint. `typeConstraint` is the
  /// location of a previously parsed type constraint for the entity that will
  /// be constrained by the parsed constraint. `existingConstraints` are any
  /// existing constraints that have already been parsed for the same entity
  /// that will be constrained by this constraint.
  FailureOr<ast::ConstraintRef>
  parseConstraint(Optional<llvm::SMRange> &typeConstraint,
                  ArrayRef<ast::ConstraintRef> existingConstraints);

  //===--------------------------------------------------------------------===//
  // Exprs

  FailureOr<ast::Expr *> parseExpr();

  /// Identifier expressions.
  FailureOr<ast::Expr *> parseAttributeExpr();
  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, llvm::SMRange loc);
  FailureOr<ast::Expr *> parseIdentifierExpr();
  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
  FailureOr<ast::Expr *> parseOperationExpr();
  FailureOr<ast::Expr *> parseTupleExpr();
  FailureOr<ast::Expr *> parseTypeExpr();
  FailureOr<ast::Expr *> parseUnderscoreExpr();

  //===--------------------------------------------------------------------===//
  // Stmts

  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
  FailureOr<ast::EraseStmt *> parseEraseStmt();
  FailureOr<ast::LetStmt *> parseLetStmt();
  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
  FailureOr<ast::RewriteStmt *> parseRewriteStmt();

  //===--------------------------------------------------------------------===//
  // Creation+Analysis
  //===--------------------------------------------------------------------===//

  //===--------------------------------------------------------------------===//
  // Decls

  /// Try to create a pattern decl with the given components, returning the
  /// Pattern on success.
  FailureOr<ast::PatternDecl *>
  createPatternDecl(llvm::SMRange loc, const ast::Name *name,
                    const ParsedPatternMetadata &metadata,
                    ast::CompoundStmt *body);

  /// Try to create a variable decl with the given components, returning the
  /// Variable on success.
  FailureOr<ast::VariableDecl *>
  createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer,
                     ArrayRef<ast::ConstraintRef> constraints);

  /// Validate the constraints used to constraint a variable decl.
  /// `inferredType` is the type of the variable inferred by the constraints
  /// within the list, and is updated to the most refined type as determined by
  /// the constraints. Returns success if the constraint list is valid, failure
  /// otherwise.
  LogicalResult
  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
                              ast::Type &inferredType);
  /// Validate a single reference to a constraint. `inferredType` contains the
  /// currently inferred variabled type and is refined within the type defined
  /// by the constraint. Returns success if the constraint is valid, failure
  /// otherwise.
  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
                                           ast::Type &inferredType);
  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);

  //===--------------------------------------------------------------------===//
  // Exprs

  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(llvm::SMRange loc,
                                                  ast::Decl *decl);
  FailureOr<ast::DeclRefExpr *>
  createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc,
                           ArrayRef<ast::ConstraintRef> constraints);
  FailureOr<ast::MemberAccessExpr *>
  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
                         llvm::SMRange loc);

  /// Validate the member access `name` into the given parent expression. On
  /// success, this also returns the type of the member accessed.
  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
                                            StringRef name, llvm::SMRange loc);
  FailureOr<ast::OperationExpr *>
  createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name,
                      MutableArrayRef<ast::Expr *> operands,
                      MutableArrayRef<ast::NamedAttributeDecl *> attributes,
                      MutableArrayRef<ast::Expr *> results);
  LogicalResult
  validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
                            MutableArrayRef<ast::Expr *> operands);
  LogicalResult validateOperationResults(llvm::SMRange loc,
                                         Optional<StringRef> name,
                                         MutableArrayRef<ast::Expr *> results);
  LogicalResult
  validateOperationOperandsOrResults(llvm::SMRange loc,
                                     Optional<StringRef> name,
                                     MutableArrayRef<ast::Expr *> values,
                                     ast::Type singleTy, ast::Type rangeTy);
  FailureOr<ast::TupleExpr *> createTupleExpr(llvm::SMRange loc,
                                              ArrayRef<ast::Expr *> elements,
                                              ArrayRef<StringRef> elementNames);

  //===--------------------------------------------------------------------===//
  // Stmts

  FailureOr<ast::EraseStmt *> createEraseStmt(llvm::SMRange loc,
                                              ast::Expr *rootOp);
  FailureOr<ast::ReplaceStmt *>
  createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
                    MutableArrayRef<ast::Expr *> replValues);
  FailureOr<ast::RewriteStmt *>
  createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
                    ast::CompoundStmt *rewriteBody);

  //===--------------------------------------------------------------------===//
  // Lexer Utilities
  //===--------------------------------------------------------------------===//

  /// If the current token has the specified kind, consume it and return true.
  /// If not, return false.
  bool consumeIf(Token::Kind kind) {
    if (curToken.isNot(kind))
      return false;
    consumeToken(kind);
    return true;
  }

  /// Advance the current lexer onto the next token.
  void consumeToken() {
    assert(curToken.isNot(Token::eof, Token::error) &&
           "shouldn't advance past EOF or errors");
    curToken = lexer.lexToken();
  }

  /// Advance the current lexer onto the next token, asserting what the expected
  /// current token is. This is preferred to the above method because it leads
  /// to more self-documenting code with better checking.
  void consumeToken(Token::Kind kind) {
    assert(curToken.is(kind) && "consumed an unexpected token");
    consumeToken();
  }

  /// Reset the lexer to the location at the given position.
  void resetToken(llvm::SMRange tokLoc) {
    lexer.resetPointer(tokLoc.Start.getPointer());
    curToken = lexer.lexToken();
  }

  /// Consume the specified token if present and return success. On failure,
  /// output a diagnostic and return failure.
  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
    if (curToken.getKind() != kind)
      return emitError(curToken.getLoc(), msg);
    consumeToken();
    return success();
  }
  LogicalResult emitError(llvm::SMRange loc, const Twine &msg) {
    lexer.emitError(loc, msg);
    return failure();
  }
  LogicalResult emitError(const Twine &msg) {
    return emitError(curToken.getLoc(), msg);
  }
  LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
                                 llvm::SMRange noteLoc, const Twine &note) {
    lexer.emitErrorAndNote(loc, msg, noteLoc, note);
    return failure();
  }

  //===--------------------------------------------------------------------===//
  // Fields
  //===--------------------------------------------------------------------===//

  /// The owning AST context.
  ast::Context &ctx;

  /// The lexer of this parser.
  Lexer lexer;

  /// The current token within the lexer.
  Token curToken;

  /// The most recently defined decl scope.
  ast::DeclScope *curDeclScope;
  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;

  /// The current context of the parser.
  ParserContext parserContext = ParserContext::Global;

  /// Cached types to simplify verification and expression creation.
  ast::Type valueTy, valueRangeTy;
  ast::Type typeTy, typeRangeTy;
};
} // namespace

FailureOr<ast::Module *> Parser::parseModule() {
  llvm::SMLoc moduleLoc = curToken.getStartLoc();
  pushDeclScope();

  // Parse the top-level decls of the module.
  SmallVector<ast::Decl *> decls;
  if (failed(parseModuleBody(decls)))
    return popDeclScope(), failure();

  popDeclScope();
  return ast::Module::create(ctx, moduleLoc, decls);
}

LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
  while (curToken.isNot(Token::eof)) {
    if (curToken.is(Token::directive)) {
      if (failed(parseDirective(decls)))
        return failure();
      continue;
    }

    FailureOr<ast::Decl *> decl = parseTopLevelDecl();
    if (failed(decl))
      return failure();
    decls.push_back(*decl);
  }
  return success();
}

ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
                                                 valueRangeTy);
}

LogicalResult Parser::convertExpressionTo(
    ast::Expr *&expr, ast::Type type,
    function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
  ast::Type exprType = expr->getType();
  if (exprType == type)
    return success();

  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
    ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
        expr->getLoc(), llvm::formatv("unable to convert expression of type "
                                      "`{0}` to the expected type of "
                                      "`{1}`",
                                      exprType, type));
    if (noteAttachFn)
      noteAttachFn(*diag);
    return diag;
  };

  if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
    // Two operation types are compatible if they have the same name, or if the
    // expected type is more general.
    if (auto opType = type.dyn_cast<ast::OperationType>()) {
      if (opType.getName())
        return emitConvertError();
      return success();
    }

    // An operation can always convert to a ValueRange.
    if (type == valueRangeTy) {
      expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
                                                     valueRangeTy);
      return success();
    }

    // Allow conversion to a single value by constraining the result range.
    if (type == valueTy) {
      expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
                                                     valueTy);
      return success();
    }
    return emitConvertError();
  }

  // FIXME: Decide how to allow/support converting a single result to multiple,
  // and multiple to a single result. For now, we just allow Single->Range,
  // but this isn't something really supported in the PDL dialect. We should
  // figure out some way to support both.
  if ((exprType == valueTy || exprType == valueRangeTy) &&
      (type == valueTy || type == valueRangeTy))
    return success();
  if ((exprType == typeTy || exprType == typeRangeTy) &&
      (type == typeTy || type == typeRangeTy))
    return success();

  // Handle tuple types.
  if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
    auto tupleType = type.dyn_cast<ast::TupleType>();
    if (!tupleType || tupleType.size() != exprTupleType.size())
      return emitConvertError();

    // Build a new tuple expression using each of the elements of the current
    // tuple.
    SmallVector<ast::Expr *> newExprs;
    for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
      newExprs.push_back(ast::MemberAccessExpr::create(
          ctx, expr->getLoc(), expr, llvm::to_string(i),
          exprTupleType.getElementTypes()[i]));

      auto diagFn = [&](ast::Diagnostic &diag) {
        diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
                                      i, exprTupleType));
        if (noteAttachFn)
          noteAttachFn(diag);
      };
      if (failed(convertExpressionTo(newExprs.back(),
                                     tupleType.getElementTypes()[i], diagFn)))
        return failure();
    }
    expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
                                  tupleType.getElementNames());
    return success();
  }

  return emitConvertError();
}

//===----------------------------------------------------------------------===//
// Directives

LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
  StringRef directive = curToken.getSpelling();
  if (directive == "#include")
    return parseInclude(decls);

  return emitError("unknown directive `" + directive + "`");
}

LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::directive);

  // Parse the file being included.
  if (!curToken.isString())
    return emitError(loc,
                     "expected string file name after `include` directive");
  llvm::SMRange fileLoc = curToken.getLoc();
  std::string filenameStr = curToken.getStringValue();
  StringRef filename = filenameStr;
  consumeToken();

  // Check the type of include. If ending with `.pdll`, this is another pdl file
  // to be parsed along with the current module.
  if (filename.endswith(".pdll")) {
    if (failed(lexer.pushInclude(filename)))
      return emitError(fileLoc,
                       "unable to open include file `" + filename + "`");

    // If we added the include successfully, parse it into the current module.
    // Make sure to save the current token so that we can restore it when we
    // finish parsing the nested file.
    Token oldToken = curToken;
    curToken = lexer.lexToken();
    LogicalResult result = parseModuleBody(decls);
    curToken = oldToken;
    return result;
  }

  return emitError(fileLoc, "expected include filename to end with `.pdll`");
}

//===----------------------------------------------------------------------===//
// Decls

FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
  FailureOr<ast::Decl *> decl;
  switch (curToken.getKind()) {
  case Token::kw_Pattern:
    decl = parsePatternDecl();
    break;
  default:
    return emitError("expected top-level declaration, such as a `Pattern`");
  }
  if (failed(decl))
    return failure();

  // If the decl has a name, add it to the current scope.
  if (const ast::Name *name = (*decl)->getName()) {
    if (failed(checkDefineNamedDecl(*name)))
      return failure();
    curDeclScope->add(*decl);
  }
  return decl;
}

FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
  std::string attrNameStr;
  if (curToken.isString())
    attrNameStr = curToken.getStringValue();
  else if (curToken.is(Token::identifier) || curToken.isKeyword())
    attrNameStr = curToken.getSpelling().str();
  else
    return emitError("expected identifier or string attribute name");
  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
  consumeToken();

  // Check for a value of the attribute.
  ast::Expr *attrValue = nullptr;
  if (consumeIf(Token::equal)) {
    FailureOr<ast::Expr *> attrExpr = parseExpr();
    if (failed(attrExpr))
      return failure();
    attrValue = *attrExpr;
  } else {
    // If there isn't a concrete value, create an expression representing a
    // UnitAttr.
    attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
  }

  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
}

FailureOr<ast::Decl *> Parser::parsePatternDecl() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_Pattern);
  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
                                              ParserContext::PatternMatch);

  // Check for an optional identifier for the pattern name.
  const ast::Name *name = nullptr;
  if (curToken.is(Token::identifier)) {
    name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
    consumeToken(Token::identifier);
  }

  // Parse any pattern metadata.
  ParsedPatternMetadata metadata;
  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
    return failure();

  // Parse the pattern body.
  ast::CompoundStmt *body;

  if (curToken.isNot(Token::l_brace))
    return emitError("expected `{` to start pattern body");
  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
  if (failed(bodyResult))
    return failure();
  body = *bodyResult;

  // Verify the body of the pattern.
  auto bodyIt = body->begin(), bodyE = body->end();
  for (; bodyIt != bodyE; ++bodyIt) {
    // Break when we've found the rewrite statement.
    if (isa<ast::OpRewriteStmt>(*bodyIt))
      break;
  }
  if (bodyIt == bodyE) {
    return emitError(loc,
                     "expected Pattern body to terminate with an operation "
                     "rewrite statement, such as `erase`");
  }
  if (std::next(bodyIt) != bodyE) {
    return emitError((*std::next(bodyIt))->getLoc(),
                     "Pattern body was terminated by an operation "
                     "rewrite statement, but found trailing statements");
  }

  return createPatternDecl(loc, name, metadata, body);
}

LogicalResult
Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
  Optional<llvm::SMRange> benefitLoc;
  Optional<llvm::SMRange> hasBoundedRecursionLoc;

  do {
    if (curToken.isNot(Token::identifier))
      return emitError("expected pattern metadata identifier");
    StringRef metadataStr = curToken.getSpelling();
    llvm::SMRange metadataLoc = curToken.getLoc();
    consumeToken(Token::identifier);

    // Parse the benefit metadata: benefit(<integer-value>)
    if (metadataStr == "benefit") {
      if (benefitLoc) {
        return emitErrorAndNote(metadataLoc,
                                "pattern benefit has already been specified",
                                *benefitLoc, "see previous definition here");
      }
      if (failed(parseToken(Token::l_paren,
                            "expected `(` before pattern benefit")))
        return failure();

      uint16_t benefitValue = 0;
      if (curToken.isNot(Token::integer))
        return emitError("expected integral pattern benefit");
      if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
        return emitError(
            "expected pattern benefit to fit within a 16-bit integer");
      consumeToken(Token::integer);

      metadata.benefit = benefitValue;
      benefitLoc = metadataLoc;

      if (failed(
              parseToken(Token::r_paren, "expected `)` after pattern benefit")))
        return failure();
      continue;
    }

    // Parse the bounded recursion metadata: recursion
    if (metadataStr == "recursion") {
      if (hasBoundedRecursionLoc) {
        return emitErrorAndNote(
            metadataLoc,
            "pattern recursion metadata has already been specified",
            *hasBoundedRecursionLoc, "see previous definition here");
      }
      metadata.hasBoundedRecursion = true;
      hasBoundedRecursionLoc = metadataLoc;
      continue;
    }

    return emitError(metadataLoc, "unknown pattern metadata");
  } while (consumeIf(Token::comma));

  return success();
}

FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
  consumeToken(Token::less);

  FailureOr<ast::Expr *> typeExpr = parseExpr();
  if (failed(typeExpr) ||
      failed(parseToken(Token::greater,
                        "expected `>` after variable type constraint")))
    return failure();
  return typeExpr;
}

LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
  assert(curDeclScope && "defining decl outside of a decl scope");
  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
    return emitErrorAndNote(
        name.getLoc(), "`" + name.getName() + "` has already been defined",
        lastDecl->getName()->getLoc(), "see previous definition here");
  }
  return success();
}

FailureOr<ast::VariableDecl *>
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
                           ast::Type type, ast::Expr *initExpr,
                           ArrayRef<ast::ConstraintRef> constraints) {
  assert(curDeclScope && "defining variable outside of decl scope");
  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);

  // If the name of the variable indicates a special variable, we don't add it
  // to the scope. This variable is local to the definition point.
  if (name.empty() || name == "_") {
    return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
                                     constraints);
  }
  if (failed(checkDefineNamedDecl(nameDecl)))
    return failure();

  auto *varDecl =
      ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
  curDeclScope->add(varDecl);
  return varDecl;
}

FailureOr<ast::VariableDecl *>
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
                           ast::Type type,
                           ArrayRef<ast::ConstraintRef> constraints) {
  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
                            constraints);
}

LogicalResult Parser::parseVariableDeclConstraintList(
    SmallVectorImpl<ast::ConstraintRef> &constraints) {
  Optional<llvm::SMRange> typeConstraint;
  auto parseSingleConstraint = [&] {
    FailureOr<ast::ConstraintRef> constraint =
        parseConstraint(typeConstraint, constraints);
    if (failed(constraint))
      return failure();
    constraints.push_back(*constraint);
    return success();
  };

  // Check to see if this is a single constraint, or a list.
  if (!consumeIf(Token::l_square))
    return parseSingleConstraint();

  do {
    if (failed(parseSingleConstraint()))
      return failure();
  } while (consumeIf(Token::comma));
  return parseToken(Token::r_square, "expected `]` after constraint list");
}

FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<llvm::SMRange> &typeConstraint,
                        ArrayRef<ast::ConstraintRef> existingConstraints) {
  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
    if (typeConstraint)
      return emitErrorAndNote(
          curToken.getLoc(),
          "the type of this variable has already been constrained",
          *typeConstraint, "see previous constraint location here");
    FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
    if (failed(constraintExpr))
      return failure();
    typeExpr = *constraintExpr;
    typeConstraint = typeExpr->getLoc();
    return success();
  };

  llvm::SMRange loc = curToken.getLoc();
  switch (curToken.getKind()) {
  case Token::kw_Attr: {
    consumeToken(Token::kw_Attr);

    // Check for a type constraint.
    ast::Expr *typeExpr = nullptr;
    if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
      return failure();
    return ast::ConstraintRef(
        ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
  }
  case Token::kw_Op: {
    consumeToken(Token::kw_Op);

    // Parse an optional operation name. If the name isn't provided, this refers
    // to "any" operation.
    FailureOr<ast::OpNameDecl *> opName =
        parseWrappedOperationName(/*allowEmptyName=*/true);
    if (failed(opName))
      return failure();

    return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
                              loc);
  }
  case Token::kw_Type:
    consumeToken(Token::kw_Type);
    return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
  case Token::kw_TypeRange:
    consumeToken(Token::kw_TypeRange);
    return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
                              loc);
  case Token::kw_Value: {
    consumeToken(Token::kw_Value);

    // Check for a type constraint.
    ast::Expr *typeExpr = nullptr;
    if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
      return failure();

    return ast::ConstraintRef(
        ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
  }
  case Token::kw_ValueRange: {
    consumeToken(Token::kw_ValueRange);

    // Check for a type constraint.
    ast::Expr *typeExpr = nullptr;
    if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
      return failure();

    return ast::ConstraintRef(
        ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
  }
  case Token::identifier: {
    StringRef constraintName = curToken.getSpelling();
    consumeToken(Token::identifier);

    // Lookup the referenced constraint.
    ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
    if (!cstDecl) {
      return emitError(loc, "unknown reference to constraint `" +
                                constraintName + "`");
    }

    // Handle a reference to a proper constraint.
    if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
      return ast::ConstraintRef(cst, loc);

    return emitErrorAndNote(
        loc, "invalid reference to non-constraint", cstDecl->getLoc(),
        "see the definition of `" + constraintName + "` here");
  }
  default:
    break;
  }
  return emitError(loc, "expected identifier constraint");
}

//===----------------------------------------------------------------------===//
// Exprs

FailureOr<ast::Expr *> Parser::parseExpr() {
  if (curToken.is(Token::underscore))
    return parseUnderscoreExpr();

  // Parse the LHS expression.
  FailureOr<ast::Expr *> lhsExpr;
  switch (curToken.getKind()) {
  case Token::kw_attr:
    lhsExpr = parseAttributeExpr();
    break;
  case Token::identifier:
    lhsExpr = parseIdentifierExpr();
    break;
  case Token::kw_op:
    lhsExpr = parseOperationExpr();
    break;
  case Token::kw_type:
    lhsExpr = parseTypeExpr();
    break;
  case Token::l_paren:
    lhsExpr = parseTupleExpr();
    break;
  default:
    return emitError("expected expression");
  }
  if (failed(lhsExpr))
    return failure();

  // Check for an operator expression.
  while (true) {
    switch (curToken.getKind()) {
    case Token::dot:
      lhsExpr = parseMemberAccessExpr(*lhsExpr);
      break;
    default:
      return lhsExpr;
    }
    if (failed(lhsExpr))
      return failure();
  }
}

FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_attr);

  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
  // identifier.
  if (!consumeIf(Token::less)) {
    resetToken(loc);
    return parseIdentifierExpr();
  }

  if (!curToken.isString())
    return emitError("expected string literal containing MLIR attribute");
  std::string attrExpr = curToken.getStringValue();
  consumeToken();

  if (failed(
          parseToken(Token::greater, "expected `>` after attribute literal")))
    return failure();
  return ast::AttributeExpr::create(ctx, loc, attrExpr);
}

FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
                                                llvm::SMRange loc) {
  ast::Decl *decl = curDeclScope->lookup(name);
  if (!decl)
    return emitError(loc, "undefined reference to `" + name + "`");

  return createDeclRefExpr(loc, decl);
}

FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
  StringRef name = curToken.getSpelling();
  llvm::SMRange nameLoc = curToken.getLoc();
  consumeToken();

  // Check to see if this is a decl ref expression that defines a variable
  // inline.
  if (consumeIf(Token::colon)) {
    SmallVector<ast::ConstraintRef> constraints;
    if (failed(parseVariableDeclConstraintList(constraints)))
      return failure();
    ast::Type type;
    if (failed(validateVariableConstraints(constraints, type)))
      return failure();
    return createInlineVariableExpr(type, name, nameLoc, constraints);
  }

  return parseDeclRefExpr(name, nameLoc);
}

FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::dot);

  // Parse the member name.
  Token memberNameTok = curToken;
  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
      !memberNameTok.isKeyword())
    return emitError(loc, "expected identifier or numeric member name");
  StringRef memberName = memberNameTok.getSpelling();
  consumeToken();

  return createMemberAccessExpr(parentExpr, memberName, loc);
}

FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
  llvm::SMRange loc = curToken.getLoc();

  // Handle the case of an no operation name.
  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
    if (allowEmptyName)
      return ast::OpNameDecl::create(ctx, llvm::SMRange());
    return emitError("expected dialect namespace");
  }
  StringRef name = curToken.getSpelling();
  consumeToken();

  // Otherwise, this is a literal operation name.
  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
    return failure();

  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
    return emitError("expected operation name after dialect namespace");

  name = StringRef(name.data(), name.size() + 1);
  do {
    name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
    loc.End = curToken.getEndLoc();
    consumeToken();
  } while (curToken.isAny(Token::identifier, Token::dot) ||
           curToken.isKeyword());
  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
}

FailureOr<ast::OpNameDecl *>
Parser::parseWrappedOperationName(bool allowEmptyName) {
  if (!consumeIf(Token::less))
    return ast::OpNameDecl::create(ctx, llvm::SMRange());

  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
  if (failed(opNameDecl))
    return failure();

  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
    return failure();
  return opNameDecl;
}

FailureOr<ast::Expr *> Parser::parseOperationExpr() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_op);

  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
  // identifier.
  if (curToken.isNot(Token::less)) {
    resetToken(loc);
    return parseIdentifierExpr();
  }

  // Parse the operation name. The name may be elided, in which case the
  // operation refers to "any" operation(i.e. a difference between `MyOp` and
  // `Operation*`). Operation names within a rewrite context must be named.
  bool allowEmptyName = parserContext != ParserContext::Rewrite;
  FailureOr<ast::OpNameDecl *> opNameDecl =
      parseWrappedOperationName(allowEmptyName);
  if (failed(opNameDecl))
    return failure();

  // Check for the optional list of operands.
  SmallVector<ast::Expr *> operands;
  if (consumeIf(Token::l_paren)) {
    do {
      FailureOr<ast::Expr *> operand = parseExpr();
      if (failed(operand))
        return failure();
      operands.push_back(*operand);
    } while (consumeIf(Token::comma));

    if (failed(parseToken(Token::r_paren,
                          "expected `)` after operation operand list")))
      return failure();
  }

  // Check for the optional list of attributes.
  SmallVector<ast::NamedAttributeDecl *> attributes;
  if (consumeIf(Token::l_brace)) {
    do {
      FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
      if (failed(decl))
        return failure();
      attributes.emplace_back(*decl);
    } while (consumeIf(Token::comma));

    if (failed(parseToken(Token::r_brace,
                          "expected `}` after operation attribute list")))
      return failure();
  }

  // Check for the optional list of result types.
  SmallVector<ast::Expr *> resultTypes;
  if (consumeIf(Token::arrow)) {
    if (failed(parseToken(Token::l_paren,
                          "expected `(` before operation result type list")))
      return failure();

    do {
      FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
      if (failed(resultTypeExpr))
        return failure();
      resultTypes.push_back(*resultTypeExpr);
    } while (consumeIf(Token::comma));

    if (failed(parseToken(Token::r_paren,
                          "expected `)` after operation result type list")))
      return failure();
  }

  return createOperationExpr(loc, *opNameDecl, operands, attributes,
                             resultTypes);
}

FailureOr<ast::Expr *> Parser::parseTupleExpr() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::l_paren);

  DenseMap<StringRef, llvm::SMRange> usedNames;
  SmallVector<StringRef> elementNames;
  SmallVector<ast::Expr *> elements;
  if (curToken.isNot(Token::r_paren)) {
    do {
      // Check for the optional element name assignment before the value.
      StringRef elementName;
      if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
        Token elementNameTok = curToken;
        consumeToken();

        // The element name is only present if followed by an `=`.
        if (consumeIf(Token::equal)) {
          elementName = elementNameTok.getSpelling();

          // Check to see if this name is already used.
          auto elementNameIt =
              usedNames.try_emplace(elementName, elementNameTok.getLoc());
          if (!elementNameIt.second) {
            return emitErrorAndNote(
                elementNameTok.getLoc(),
                llvm::formatv("duplicate tuple element label `{0}`",
                              elementName),
                elementNameIt.first->getSecond(),
                "see previous label use here");
          }
        } else {
          // Otherwise, we treat this as part of an expression so reset the
          // lexer.
          resetToken(elementNameTok.getLoc());
        }
      }
      elementNames.push_back(elementName);

      // Parse the tuple element value.
      FailureOr<ast::Expr *> element = parseExpr();
      if (failed(element))
        return failure();
      elements.push_back(*element);
    } while (consumeIf(Token::comma));
  }
  loc.End = curToken.getEndLoc();
  if (failed(
          parseToken(Token::r_paren, "expected `)` after tuple element list")))
    return failure();
  return createTupleExpr(loc, elements, elementNames);
}

FailureOr<ast::Expr *> Parser::parseTypeExpr() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_type);

  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
  // identifier.
  if (!consumeIf(Token::less)) {
    resetToken(loc);
    return parseIdentifierExpr();
  }

  if (!curToken.isString())
    return emitError("expected string literal containing MLIR type");
  std::string attrExpr = curToken.getStringValue();
  consumeToken();

  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
    return failure();
  return ast::TypeExpr::create(ctx, loc, attrExpr);
}

FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
  StringRef name = curToken.getSpelling();
  llvm::SMRange nameLoc = curToken.getLoc();
  consumeToken(Token::underscore);

  // Underscore expressions require a constraint list.
  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
    return failure();

  // Parse the constraints for the expression.
  SmallVector<ast::ConstraintRef> constraints;
  if (failed(parseVariableDeclConstraintList(constraints)))
    return failure();

  ast::Type type;
  if (failed(validateVariableConstraints(constraints, type)))
    return failure();
  return createInlineVariableExpr(type, name, nameLoc, constraints);
}

//===----------------------------------------------------------------------===//
// Stmts

FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
  FailureOr<ast::Stmt *> stmt;
  switch (curToken.getKind()) {
  case Token::kw_erase:
    stmt = parseEraseStmt();
    break;
  case Token::kw_let:
    stmt = parseLetStmt();
    break;
  case Token::kw_replace:
    stmt = parseReplaceStmt();
    break;
  case Token::kw_rewrite:
    stmt = parseRewriteStmt();
    break;
  default:
    stmt = parseExpr();
    break;
  }
  if (failed(stmt) ||
      (expectTerminalSemicolon &&
       failed(parseToken(Token::semicolon, "expected `;` after statement"))))
    return failure();
  return stmt;
}

FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
  llvm::SMLoc startLoc = curToken.getStartLoc();
  consumeToken(Token::l_brace);

  // Push a new block scope and parse any nested statements.
  pushDeclScope();
  SmallVector<ast::Stmt *> statements;
  while (curToken.isNot(Token::r_brace)) {
    FailureOr<ast::Stmt *> statement = parseStmt();
    if (failed(statement))
      return popDeclScope(), failure();
    statements.push_back(*statement);
  }
  popDeclScope();

  // Consume the end brace.
  llvm::SMRange location(startLoc, curToken.getEndLoc());
  consumeToken(Token::r_brace);

  return ast::CompoundStmt::create(ctx, location, statements);
}

FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_erase);

  // Parse the root operation expression.
  FailureOr<ast::Expr *> rootOp = parseExpr();
  if (failed(rootOp))
    return failure();

  return createEraseStmt(loc, *rootOp);
}

FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_let);

  // Parse the name of the new variable.
  llvm::SMRange varLoc = curToken.getLoc();
  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
    // `_` is a reserved variable name.
    if (curToken.is(Token::underscore)) {
      return emitError(varLoc,
                       "`_` may only be used to define \"inline\" variables");
    }
    return emitError(varLoc,
                     "expected identifier after `let` to name a new variable");
  }
  StringRef varName = curToken.getSpelling();
  consumeToken();

  // Parse the optional set of constraints.
  SmallVector<ast::ConstraintRef> constraints;
  if (consumeIf(Token::colon) &&
      failed(parseVariableDeclConstraintList(constraints)))
    return failure();

  // Parse the optional initializer expression.
  ast::Expr *initializer = nullptr;
  if (consumeIf(Token::equal)) {
    FailureOr<ast::Expr *> initOrFailure = parseExpr();
    if (failed(initOrFailure))
      return failure();
    initializer = *initOrFailure;

    // Check that the constraints are compatible with having an initializer,
    // e.g. type constraints cannot be used with initializers.
    for (ast::ConstraintRef constraint : constraints) {
      LogicalResult result =
          TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
              .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
                    ast::ValueRangeConstraintDecl>([&](const auto *cst) {
                if (auto *typeConstraintExpr = cst->getTypeExpr()) {
                  return this->emitError(
                      constraint.referenceLoc,
                      "type constraints are not permitted on variables with "
                      "initializers");
                }
                return success();
              })
              .Default(success());
      if (failed(result))
        return failure();
    }
  }

  FailureOr<ast::VariableDecl *> varDecl =
      createVariableDecl(varName, varLoc, initializer, constraints);
  if (failed(varDecl))
    return failure();
  return ast::LetStmt::create(ctx, loc, *varDecl);
}

FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_replace);

  // Parse the root operation expression.
  FailureOr<ast::Expr *> rootOp = parseExpr();
  if (failed(rootOp))
    return failure();

  if (failed(
          parseToken(Token::kw_with, "expected `with` after root operation")))
    return failure();

  // The replacement portion of this statement is within a rewrite context.
  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
                                              ParserContext::Rewrite);

  // Parse the replacement values.
  SmallVector<ast::Expr *> replValues;
  if (consumeIf(Token::l_paren)) {
    if (consumeIf(Token::r_paren)) {
      return emitError(
          loc, "expected at least one replacement value, consider using "
               "`erase` if no replacement values are desired");
    }

    do {
      FailureOr<ast::Expr *> replExpr = parseExpr();
      if (failed(replExpr))
        return failure();
      replValues.emplace_back(*replExpr);
    } while (consumeIf(Token::comma));

    if (failed(parseToken(Token::r_paren,
                          "expected `)` after replacement values")))
      return failure();
  } else {
    FailureOr<ast::Expr *> replExpr = parseExpr();
    if (failed(replExpr))
      return failure();
    replValues.emplace_back(*replExpr);
  }

  return createReplaceStmt(loc, *rootOp, replValues);
}

FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
  llvm::SMRange loc = curToken.getLoc();
  consumeToken(Token::kw_rewrite);

  // Parse the root operation.
  FailureOr<ast::Expr *> rootOp = parseExpr();
  if (failed(rootOp))
    return failure();

  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
    return failure();

  if (curToken.isNot(Token::l_brace))
    return emitError("expected `{` to start rewrite body");

  // The rewrite body of this statement is within a rewrite context.
  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
                                              ParserContext::Rewrite);

  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
  if (failed(rewriteBody))
    return failure();

  return createRewriteStmt(loc, *rootOp, *rewriteBody);
}

//===----------------------------------------------------------------------===//
// Creation+Analysis
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// Decls

FailureOr<ast::PatternDecl *>
Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name,
                          const ParsedPatternMetadata &metadata,
                          ast::CompoundStmt *body) {
  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
                                  metadata.hasBoundedRecursion, body);
}

FailureOr<ast::VariableDecl *>
Parser::createVariableDecl(StringRef name, llvm::SMRange loc,
                           ast::Expr *initializer,
                           ArrayRef<ast::ConstraintRef> constraints) {
  // The type of the variable, which is expected to be inferred by either a
  // constraint or an initializer expression.
  ast::Type type;
  if (failed(validateVariableConstraints(constraints, type)))
    return failure();

  if (initializer) {
    // Update the variable type based on the initializer, or try to convert the
    // initializer to the existing type.
    if (!type)
      type = initializer->getType();
    else if (ast::Type mergedType = type.refineWith(initializer->getType()))
      type = mergedType;
    else if (failed(convertExpressionTo(initializer, type)))
      return failure();

    // Otherwise, if there is no initializer check that the type has already
    // been resolved from the constraint list.
  } else if (!type) {
    return emitErrorAndNote(
        loc, "unable to infer type for variable `" + name + "`", loc,
        "the type of a variable must be inferable from the constraint "
        "list or the initializer");
  }

  // Try to define a variable with the given name.
  FailureOr<ast::VariableDecl *> varDecl =
      defineVariableDecl(name, loc, type, initializer, constraints);
  if (failed(varDecl))
    return failure();

  return *varDecl;
}

LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
                                    ast::Type &inferredType) {
  for (const ast::ConstraintRef &ref : constraints)
    if (failed(validateVariableConstraint(ref, inferredType)))
      return failure();
  return success();
}

LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
                                                 ast::Type &inferredType) {
  ast::Type constraintType;
  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
    if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
      if (failed(validateTypeConstraintExpr(typeExpr)))
        return failure();
    }
    constraintType = ast::AttributeType::get(ctx);
  } else if (const auto *cst =
                 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
    constraintType = ast::OperationType::get(ctx, cst->getName());
  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
    constraintType = typeTy;
  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
    constraintType = typeRangeTy;
  } else if (const auto *cst =
                 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
    if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
      if (failed(validateTypeConstraintExpr(typeExpr)))
        return failure();
    }
    constraintType = valueTy;
  } else if (const auto *cst =
                 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
    if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
      if (failed(validateTypeRangeConstraintExpr(typeExpr)))
        return failure();
    }
    constraintType = valueRangeTy;
  } else {
    llvm_unreachable("unknown constraint type");
  }

  // Check that the constraint type is compatible with the current inferred
  // type.
  if (!inferredType) {
    inferredType = constraintType;
  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
    inferredType = mergedTy;
  } else {
    return emitError(ref.referenceLoc,
                     llvm::formatv("constraint type `{0}` is incompatible "
                                   "with the previously inferred type `{1}`",
                                   constraintType, inferredType));
  }
  return success();
}

LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
  ast::Type typeExprType = typeExpr->getType();
  if (typeExprType != typeTy) {
    return emitError(typeExpr->getLoc(),
                     "expected expression of `Type` in type constraint");
  }
  return success();
}

LogicalResult
Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
  ast::Type typeExprType = typeExpr->getType();
  if (typeExprType != typeRangeTy) {
    return emitError(typeExpr->getLoc(),
                     "expected expression of `TypeRange` in type constraint");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// Exprs

FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(llvm::SMRange loc,
                                                        ast::Decl *decl) {
  // Check the type of decl being referenced.
  ast::Type declType;
  if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
    declType = varDecl->getType();
  else
    return emitError(loc, "invalid reference to `" +
                              decl->getName()->getName() + "`");

  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
}

FailureOr<ast::DeclRefExpr *>
Parser::createInlineVariableExpr(ast::Type type, StringRef name,
                                 llvm::SMRange loc,
                                 ArrayRef<ast::ConstraintRef> constraints) {
  FailureOr<ast::VariableDecl *> decl =
      defineVariableDecl(name, loc, type, constraints);
  if (failed(decl))
    return failure();
  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
}

FailureOr<ast::MemberAccessExpr *>
Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
                               llvm::SMRange loc) {
  // Validate the member name for the given parent expression.
  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
  if (failed(memberType))
    return failure();

  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
}

FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
                                                  StringRef name,
                                                  llvm::SMRange loc) {
  ast::Type parentType = parentExpr->getType();
  if (parentType.isa<ast::OperationType>()) {
    if (name == ast::AllResultsMemberAccessExpr::getMemberName())
      return valueRangeTy;
  } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
    // Handle indexed results.
    unsigned index = 0;
    if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
        index < tupleType.size()) {
      return tupleType.getElementTypes()[index];
    }

    // Handle named results.
    auto elementNames = tupleType.getElementNames();
    const auto *it = llvm::find(elementNames, name);
    if (it != elementNames.end())
      return tupleType.getElementTypes()[it - elementNames.begin()];
  }
  return emitError(
      loc,
      llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
                    name, parentType));
}

FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
    llvm::SMRange loc, const ast::OpNameDecl *name,
    MutableArrayRef<ast::Expr *> operands,
    MutableArrayRef<ast::NamedAttributeDecl *> attributes,
    MutableArrayRef<ast::Expr *> results) {
  Optional<StringRef> opNameRef = name->getName();

  // Verify the inputs operands.
  if (failed(validateOperationOperands(loc, opNameRef, operands)))
    return failure();

  // Verify the attribute list.
  for (ast::NamedAttributeDecl *attr : attributes) {
    // Check for an attribute type, or a type awaiting resolution.
    ast::Type attrType = attr->getValue()->getType();
    if (!attrType.isa<ast::AttributeType>()) {
      return emitError(
          attr->getValue()->getLoc(),
          llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
    }
  }

  // Verify the result types.
  if (failed(validateOperationResults(loc, opNameRef, results)))
    return failure();

  return ast::OperationExpr::create(ctx, loc, name, operands, results,
                                    attributes);
}

LogicalResult
Parser::validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
                                  MutableArrayRef<ast::Expr *> operands) {
  return validateOperationOperandsOrResults(loc, name, operands, valueTy,
                                            valueRangeTy);
}

LogicalResult
Parser::validateOperationResults(llvm::SMRange loc, Optional<StringRef> name,
                                 MutableArrayRef<ast::Expr *> results) {
  return validateOperationOperandsOrResults(loc, name, results, typeTy,
                                            typeRangeTy);
}

LogicalResult Parser::validateOperationOperandsOrResults(
    llvm::SMRange loc, Optional<StringRef> name,
    MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
    ast::Type rangeTy) {
  // All operation types accept a single range parameter.
  if (values.size() == 1) {
    if (failed(convertExpressionTo(values[0], rangeTy)))
      return failure();
    return success();
  }

  // Otherwise, accept the value groups as they have been defined and just
  // ensure they are one of the expected types.
  for (ast::Expr *&valueExpr : values) {
    ast::Type valueExprType = valueExpr->getType();

    // Check if this is one of the expected types.
    if (valueExprType == rangeTy || valueExprType == singleTy)
      continue;

    // If the operand is an Operation, allow converting to a Value or
    // ValueRange. This situations arises quite often with nested operation
    // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
    if (singleTy == valueTy) {
      if (valueExprType.isa<ast::OperationType>()) {
        valueExpr = convertOpToValue(valueExpr);
        continue;
      }
    }

    return emitError(
        valueExpr->getLoc(),
        llvm::formatv(
            "expected `{0}` or `{1}` convertible expression, but got `{2}`",
            singleTy, rangeTy, valueExprType));
  }
  return success();
}

FailureOr<ast::TupleExpr *>
Parser::createTupleExpr(llvm::SMRange loc, ArrayRef<ast::Expr *> elements,
                        ArrayRef<StringRef> elementNames) {
  for (const ast::Expr *element : elements) {
    ast::Type eleTy = element->getType();
    if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
      return emitError(
          element->getLoc(),
          llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
    }
  }
  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
}

//===----------------------------------------------------------------------===//
// Stmts

FailureOr<ast::EraseStmt *> Parser::createEraseStmt(llvm::SMRange loc,
                                                    ast::Expr *rootOp) {
  // Check that root is an Operation.
  ast::Type rootType = rootOp->getType();
  if (!rootType.isa<ast::OperationType>())
    return emitError(rootOp->getLoc(), "expected `Op` expression");

  return ast::EraseStmt::create(ctx, loc, rootOp);
}

FailureOr<ast::ReplaceStmt *>
Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
                          MutableArrayRef<ast::Expr *> replValues) {
  // Check that root is an Operation.
  ast::Type rootType = rootOp->getType();
  if (!rootType.isa<ast::OperationType>()) {
    return emitError(
        rootOp->getLoc(),
        llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
  }

  // If there are multiple replacement values, we implicitly convert any Op
  // expressions to the value form.
  bool shouldConvertOpToValues = replValues.size() > 1;
  for (ast::Expr *&replExpr : replValues) {
    ast::Type replType = replExpr->getType();

    // Check that replExpr is an Operation, Value, or ValueRange.
    if (replType.isa<ast::OperationType>()) {
      if (shouldConvertOpToValues)
        replExpr = convertOpToValue(replExpr);
      continue;
    }

    if (replType != valueTy && replType != valueRangeTy) {
      return emitError(replExpr->getLoc(),
                       llvm::formatv("expected `Op`, `Value` or `ValueRange` "
                                     "expression, but got `{0}`",
                                     replType));
    }
  }

  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
}

FailureOr<ast::RewriteStmt *>
Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
                          ast::CompoundStmt *rewriteBody) {
  // Check that root is an Operation.
  ast::Type rootType = rootOp->getType();
  if (!rootType.isa<ast::OperationType>()) {
    return emitError(
        rootOp->getLoc(),
        llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
  }

  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
}

//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//

FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
                                                 llvm::SourceMgr &sourceMgr) {
  Parser parser(ctx, sourceMgr);
  return parser.parseModule();
}
