//===- Parser.cpp ---------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/Parser/Parser.h" #include "Lexer.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Diagnostic.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" #include using namespace mlir; using namespace mlir::pdll; //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// namespace { class Parser { public: Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr) : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()), curToken(lexer.lexToken()), curDeclScope(nullptr), valueTy(ast::ValueType::get(ctx)), valueRangeTy(ast::ValueRangeType::get(ctx)), typeTy(ast::TypeType::get(ctx)), typeRangeTy(ast::TypeRangeType::get(ctx)) {} /// Try to parse a new module. Returns nullptr in the case of failure. FailureOr parseModule(); private: /// The current context of the parser. It allows for the parser to know a bit /// about the construct it is nested within during parsing. This is used /// specifically to provide additional verification during parsing, e.g. to /// prevent using rewrites within a match context, matcher constraints within /// a rewrite section, etc. enum class ParserContext { /// The parser is in the global context. Global, /// The parser is currently within the matcher portion of a Pattern, which /// is allows a terminal operation rewrite statement but no other rewrite /// transformations. PatternMatch, /// The parser is currently within a Rewrite, which disallows calls to /// constraints, requires operation expressions to have names, etc. Rewrite, }; //===--------------------------------------------------------------------===// // Parsing //===--------------------------------------------------------------------===// /// Push a new decl scope onto the lexer. ast::DeclScope *pushDeclScope() { ast::DeclScope *newScope = new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); return (curDeclScope = newScope); } void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } /// Pop the last decl scope from the lexer. void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } /// Parse the body of an AST module. LogicalResult parseModuleBody(SmallVector &decls); /// Try to convert the given expression to `type`. Returns failure and emits /// an error if a conversion is not viable. On failure, `noteAttachFn` is /// invoked to attach notes to the emitted error diagnostic. On success, /// `expr` is updated to the expression used to convert to `type`. LogicalResult convertExpressionTo( ast::Expr *&expr, ast::Type type, function_ref noteAttachFn = {}); /// Given an operation expression, convert it to a Value or ValueRange /// typed expression. ast::Expr *convertOpToValue(const ast::Expr *opExpr); //===--------------------------------------------------------------------===// // Directives LogicalResult parseDirective(SmallVector &decls); LogicalResult parseInclude(SmallVector &decls); //===--------------------------------------------------------------------===// // Decls /// This structure contains the set of pattern metadata that may be parsed. struct ParsedPatternMetadata { Optional benefit; bool hasBoundedRecursion = false; }; FailureOr parseTopLevelDecl(); FailureOr parseNamedAttributeDecl(); FailureOr parsePatternDecl(); LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); /// Check to see if a decl has already been defined with the given name, if /// one has emit and error and return failure. Returns success otherwise. LogicalResult checkDefineNamedDecl(const ast::Name &name); /// Try to define a variable decl with the given components, returns the /// variable on success. FailureOr defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, ast::Expr *initExpr, ArrayRef constraints); FailureOr defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, ArrayRef constraints); /// Parse the constraint reference list for a variable decl. LogicalResult parseVariableDeclConstraintList( SmallVectorImpl &constraints); /// Parse the expression used within a type constraint, e.g. Attr. FailureOr parseTypeConstraintExpr(); /// Try to parse a single reference to a constraint. `typeConstraint` is the /// location of a previously parsed type constraint for the entity that will /// be constrained by the parsed constraint. `existingConstraints` are any /// existing constraints that have already been parsed for the same entity /// that will be constrained by this constraint. FailureOr parseConstraint(Optional &typeConstraint, ArrayRef existingConstraints); //===--------------------------------------------------------------------===// // Exprs FailureOr parseExpr(); /// Identifier expressions. FailureOr parseAttributeExpr(); FailureOr parseDeclRefExpr(StringRef name, llvm::SMRange loc); FailureOr parseIdentifierExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(bool allowEmptyName); FailureOr parseOperationExpr(); FailureOr parseTupleExpr(); FailureOr parseTypeExpr(); FailureOr parseUnderscoreExpr(); //===--------------------------------------------------------------------===// // Stmts FailureOr parseStmt(bool expectTerminalSemicolon = true); FailureOr parseCompoundStmt(); FailureOr parseEraseStmt(); FailureOr parseLetStmt(); FailureOr parseReplaceStmt(); FailureOr parseRewriteStmt(); //===--------------------------------------------------------------------===// // Creation+Analysis //===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===// // Decls /// Try to create a pattern decl with the given components, returning the /// Pattern on success. FailureOr createPatternDecl(llvm::SMRange loc, const ast::Name *name, const ParsedPatternMetadata &metadata, ast::CompoundStmt *body); /// Try to create a variable decl with the given components, returning the /// Variable on success. FailureOr createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer, ArrayRef constraints); /// Validate the constraints used to constraint a variable decl. /// `inferredType` is the type of the variable inferred by the constraints /// within the list, and is updated to the most refined type as determined by /// the constraints. Returns success if the constraint list is valid, failure /// otherwise. LogicalResult validateVariableConstraints(ArrayRef constraints, ast::Type &inferredType); /// Validate a single reference to a constraint. `inferredType` contains the /// currently inferred variabled type and is refined within the type defined /// by the constraint. Returns success if the constraint is valid, failure /// otherwise. LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, ast::Type &inferredType); LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); //===--------------------------------------------------------------------===// // Exprs FailureOr createDeclRefExpr(llvm::SMRange loc, ast::Decl *decl); FailureOr createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc, ArrayRef constraints); FailureOr createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, llvm::SMRange loc); /// Validate the member access `name` into the given parent expression. On /// success, this also returns the type of the member accessed. FailureOr validateMemberAccess(ast::Expr *parentExpr, StringRef name, llvm::SMRange loc); FailureOr createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name, MutableArrayRef operands, MutableArrayRef attributes, MutableArrayRef results); LogicalResult validateOperationOperands(llvm::SMRange loc, Optional name, MutableArrayRef operands); LogicalResult validateOperationResults(llvm::SMRange loc, Optional name, MutableArrayRef results); LogicalResult validateOperationOperandsOrResults(llvm::SMRange loc, Optional name, MutableArrayRef values, ast::Type singleTy, ast::Type rangeTy); FailureOr createTupleExpr(llvm::SMRange loc, ArrayRef elements, ArrayRef elementNames); //===--------------------------------------------------------------------===// // Stmts FailureOr createEraseStmt(llvm::SMRange loc, ast::Expr *rootOp); FailureOr createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, MutableArrayRef replValues); FailureOr createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp, ast::CompoundStmt *rewriteBody); //===--------------------------------------------------------------------===// // Lexer Utilities //===--------------------------------------------------------------------===// /// If the current token has the specified kind, consume it and return true. /// If not, return false. bool consumeIf(Token::Kind kind) { if (curToken.isNot(kind)) return false; consumeToken(kind); return true; } /// Advance the current lexer onto the next token. void consumeToken() { assert(curToken.isNot(Token::eof, Token::error) && "shouldn't advance past EOF or errors"); curToken = lexer.lexToken(); } /// Advance the current lexer onto the next token, asserting what the expected /// current token is. This is preferred to the above method because it leads /// to more self-documenting code with better checking. void consumeToken(Token::Kind kind) { assert(curToken.is(kind) && "consumed an unexpected token"); consumeToken(); } /// Reset the lexer to the location at the given position. void resetToken(llvm::SMRange tokLoc) { lexer.resetPointer(tokLoc.Start.getPointer()); curToken = lexer.lexToken(); } /// Consume the specified token if present and return success. On failure, /// output a diagnostic and return failure. LogicalResult parseToken(Token::Kind kind, const Twine &msg) { if (curToken.getKind() != kind) return emitError(curToken.getLoc(), msg); consumeToken(); return success(); } LogicalResult emitError(llvm::SMRange loc, const Twine &msg) { lexer.emitError(loc, msg); return failure(); } LogicalResult emitError(const Twine &msg) { return emitError(curToken.getLoc(), msg); } LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg, llvm::SMRange noteLoc, const Twine ¬e) { lexer.emitErrorAndNote(loc, msg, noteLoc, note); return failure(); } //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// /// The owning AST context. ast::Context &ctx; /// The lexer of this parser. Lexer lexer; /// The current token within the lexer. Token curToken; /// The most recently defined decl scope. ast::DeclScope *curDeclScope; llvm::SpecificBumpPtrAllocator scopeAllocator; /// The current context of the parser. ParserContext parserContext = ParserContext::Global; /// Cached types to simplify verification and expression creation. ast::Type valueTy, valueRangeTy; ast::Type typeTy, typeRangeTy; }; } // namespace FailureOr Parser::parseModule() { llvm::SMLoc moduleLoc = curToken.getStartLoc(); pushDeclScope(); // Parse the top-level decls of the module. SmallVector decls; if (failed(parseModuleBody(decls))) return popDeclScope(), failure(); popDeclScope(); return ast::Module::create(ctx, moduleLoc, decls); } LogicalResult Parser::parseModuleBody(SmallVector &decls) { while (curToken.isNot(Token::eof)) { if (curToken.is(Token::directive)) { if (failed(parseDirective(decls))) return failure(); continue; } FailureOr decl = parseTopLevelDecl(); if (failed(decl)) return failure(); decls.push_back(*decl); } return success(); } ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr, valueRangeTy); } LogicalResult Parser::convertExpressionTo( ast::Expr *&expr, ast::Type type, function_ref noteAttachFn) { ast::Type exprType = expr->getType(); if (exprType == type) return success(); auto emitConvertError = [&]() -> ast::InFlightDiagnostic { ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( expr->getLoc(), llvm::formatv("unable to convert expression of type " "`{0}` to the expected type of " "`{1}`", exprType, type)); if (noteAttachFn) noteAttachFn(*diag); return diag; }; if (auto exprOpType = exprType.dyn_cast()) { // Two operation types are compatible if they have the same name, or if the // expected type is more general. if (auto opType = type.dyn_cast()) { if (opType.getName()) return emitConvertError(); return success(); } // An operation can always convert to a ValueRange. if (type == valueRangeTy) { expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, valueRangeTy); return success(); } // Allow conversion to a single value by constraining the result range. if (type == valueTy) { expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, valueTy); return success(); } return emitConvertError(); } // FIXME: Decide how to allow/support converting a single result to multiple, // and multiple to a single result. For now, we just allow Single->Range, // but this isn't something really supported in the PDL dialect. We should // figure out some way to support both. if ((exprType == valueTy || exprType == valueRangeTy) && (type == valueTy || type == valueRangeTy)) return success(); if ((exprType == typeTy || exprType == typeRangeTy) && (type == typeTy || type == typeRangeTy)) return success(); // Handle tuple types. if (auto exprTupleType = exprType.dyn_cast()) { auto tupleType = type.dyn_cast(); if (!tupleType || tupleType.size() != exprTupleType.size()) return emitConvertError(); // Build a new tuple expression using each of the elements of the current // tuple. SmallVector newExprs; for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { newExprs.push_back(ast::MemberAccessExpr::create( ctx, expr->getLoc(), expr, llvm::to_string(i), exprTupleType.getElementTypes()[i])); auto diagFn = [&](ast::Diagnostic &diag) { diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", i, exprTupleType)); if (noteAttachFn) noteAttachFn(diag); }; if (failed(convertExpressionTo(newExprs.back(), tupleType.getElementTypes()[i], diagFn))) return failure(); } expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, tupleType.getElementNames()); return success(); } return emitConvertError(); } //===----------------------------------------------------------------------===// // Directives LogicalResult Parser::parseDirective(SmallVector &decls) { StringRef directive = curToken.getSpelling(); if (directive == "#include") return parseInclude(decls); return emitError("unknown directive `" + directive + "`"); } LogicalResult Parser::parseInclude(SmallVector &decls) { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::directive); // Parse the file being included. if (!curToken.isString()) return emitError(loc, "expected string file name after `include` directive"); llvm::SMRange fileLoc = curToken.getLoc(); std::string filenameStr = curToken.getStringValue(); StringRef filename = filenameStr; consumeToken(); // Check the type of include. If ending with `.pdll`, this is another pdl file // to be parsed along with the current module. if (filename.endswith(".pdll")) { if (failed(lexer.pushInclude(filename))) return emitError(fileLoc, "unable to open include file `" + filename + "`"); // If we added the include successfully, parse it into the current module. // Make sure to save the current token so that we can restore it when we // finish parsing the nested file. Token oldToken = curToken; curToken = lexer.lexToken(); LogicalResult result = parseModuleBody(decls); curToken = oldToken; return result; } return emitError(fileLoc, "expected include filename to end with `.pdll`"); } //===----------------------------------------------------------------------===// // Decls FailureOr Parser::parseTopLevelDecl() { FailureOr decl; switch (curToken.getKind()) { case Token::kw_Pattern: decl = parsePatternDecl(); break; default: return emitError("expected top-level declaration, such as a `Pattern`"); } if (failed(decl)) return failure(); // If the decl has a name, add it to the current scope. if (const ast::Name *name = (*decl)->getName()) { if (failed(checkDefineNamedDecl(*name))) return failure(); curDeclScope->add(*decl); } return decl; } FailureOr Parser::parseNamedAttributeDecl() { std::string attrNameStr; if (curToken.isString()) attrNameStr = curToken.getStringValue(); else if (curToken.is(Token::identifier) || curToken.isKeyword()) attrNameStr = curToken.getSpelling().str(); else return emitError("expected identifier or string attribute name"); const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc()); consumeToken(); // Check for a value of the attribute. ast::Expr *attrValue = nullptr; if (consumeIf(Token::equal)) { FailureOr attrExpr = parseExpr(); if (failed(attrExpr)) return failure(); attrValue = *attrExpr; } else { // If there isn't a concrete value, create an expression representing a // UnitAttr. attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit"); } return ast::NamedAttributeDecl::create(ctx, name, attrValue); } FailureOr Parser::parsePatternDecl() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_Pattern); llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch); // Check for an optional identifier for the pattern name. const ast::Name *name = nullptr; if (curToken.is(Token::identifier)) { name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); consumeToken(Token::identifier); } // Parse any pattern metadata. ParsedPatternMetadata metadata; if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) return failure(); // Parse the pattern body. ast::CompoundStmt *body; if (curToken.isNot(Token::l_brace)) return emitError("expected `{` to start pattern body"); FailureOr bodyResult = parseCompoundStmt(); if (failed(bodyResult)) return failure(); body = *bodyResult; // Verify the body of the pattern. auto bodyIt = body->begin(), bodyE = body->end(); for (; bodyIt != bodyE; ++bodyIt) { // Break when we've found the rewrite statement. if (isa(*bodyIt)) break; } if (bodyIt == bodyE) { return emitError(loc, "expected Pattern body to terminate with an operation " "rewrite statement, such as `erase`"); } if (std::next(bodyIt) != bodyE) { return emitError((*std::next(bodyIt))->getLoc(), "Pattern body was terminated by an operation " "rewrite statement, but found trailing statements"); } return createPatternDecl(loc, name, metadata, body); } LogicalResult Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { Optional benefitLoc; Optional hasBoundedRecursionLoc; do { if (curToken.isNot(Token::identifier)) return emitError("expected pattern metadata identifier"); StringRef metadataStr = curToken.getSpelling(); llvm::SMRange metadataLoc = curToken.getLoc(); consumeToken(Token::identifier); // Parse the benefit metadata: benefit() if (metadataStr == "benefit") { if (benefitLoc) { return emitErrorAndNote(metadataLoc, "pattern benefit has already been specified", *benefitLoc, "see previous definition here"); } if (failed(parseToken(Token::l_paren, "expected `(` before pattern benefit"))) return failure(); uint16_t benefitValue = 0; if (curToken.isNot(Token::integer)) return emitError("expected integral pattern benefit"); if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) return emitError( "expected pattern benefit to fit within a 16-bit integer"); consumeToken(Token::integer); metadata.benefit = benefitValue; benefitLoc = metadataLoc; if (failed( parseToken(Token::r_paren, "expected `)` after pattern benefit"))) return failure(); continue; } // Parse the bounded recursion metadata: recursion if (metadataStr == "recursion") { if (hasBoundedRecursionLoc) { return emitErrorAndNote( metadataLoc, "pattern recursion metadata has already been specified", *hasBoundedRecursionLoc, "see previous definition here"); } metadata.hasBoundedRecursion = true; hasBoundedRecursionLoc = metadataLoc; continue; } return emitError(metadataLoc, "unknown pattern metadata"); } while (consumeIf(Token::comma)); return success(); } FailureOr Parser::parseTypeConstraintExpr() { consumeToken(Token::less); FailureOr typeExpr = parseExpr(); if (failed(typeExpr) || failed(parseToken(Token::greater, "expected `>` after variable type constraint"))) return failure(); return typeExpr; } LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { assert(curDeclScope && "defining decl outside of a decl scope"); if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) { return emitErrorAndNote( name.getLoc(), "`" + name.getName() + "` has already been defined", lastDecl->getName()->getLoc(), "see previous definition here"); } return success(); } FailureOr Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, ast::Expr *initExpr, ArrayRef constraints) { assert(curDeclScope && "defining variable outside of decl scope"); const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); // If the name of the variable indicates a special variable, we don't add it // to the scope. This variable is local to the definition point. if (name.empty() || name == "_") { return ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); } if (failed(checkDefineNamedDecl(nameDecl))) return failure(); auto *varDecl = ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints); curDeclScope->add(varDecl); return varDecl; } FailureOr Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, ArrayRef constraints) { return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, constraints); } LogicalResult Parser::parseVariableDeclConstraintList( SmallVectorImpl &constraints) { Optional typeConstraint; auto parseSingleConstraint = [&] { FailureOr constraint = parseConstraint(typeConstraint, constraints); if (failed(constraint)) return failure(); constraints.push_back(*constraint); return success(); }; // Check to see if this is a single constraint, or a list. if (!consumeIf(Token::l_square)) return parseSingleConstraint(); do { if (failed(parseSingleConstraint())) return failure(); } while (consumeIf(Token::comma)); return parseToken(Token::r_square, "expected `]` after constraint list"); } FailureOr Parser::parseConstraint(Optional &typeConstraint, ArrayRef existingConstraints) { auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { if (typeConstraint) return emitErrorAndNote( curToken.getLoc(), "the type of this variable has already been constrained", *typeConstraint, "see previous constraint location here"); FailureOr constraintExpr = parseTypeConstraintExpr(); if (failed(constraintExpr)) return failure(); typeExpr = *constraintExpr; typeConstraint = typeExpr->getLoc(); return success(); }; llvm::SMRange loc = curToken.getLoc(); switch (curToken.getKind()) { case Token::kw_Attr: { consumeToken(Token::kw_Attr); // Check for a type constraint. ast::Expr *typeExpr = nullptr; if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) return failure(); return ast::ConstraintRef( ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); } case Token::kw_Op: { consumeToken(Token::kw_Op); // Parse an optional operation name. If the name isn't provided, this refers // to "any" operation. FailureOr opName = parseWrappedOperationName(/*allowEmptyName=*/true); if (failed(opName)) return failure(); return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName), loc); } case Token::kw_Type: consumeToken(Token::kw_Type); return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); case Token::kw_TypeRange: consumeToken(Token::kw_TypeRange); return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), loc); case Token::kw_Value: { consumeToken(Token::kw_Value); // Check for a type constraint. ast::Expr *typeExpr = nullptr; if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) return failure(); return ast::ConstraintRef( ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); } case Token::kw_ValueRange: { consumeToken(Token::kw_ValueRange); // Check for a type constraint. ast::Expr *typeExpr = nullptr; if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) return failure(); return ast::ConstraintRef( ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); } case Token::identifier: { StringRef constraintName = curToken.getSpelling(); consumeToken(Token::identifier); // Lookup the referenced constraint. ast::Decl *cstDecl = curDeclScope->lookup(constraintName); if (!cstDecl) { return emitError(loc, "unknown reference to constraint `" + constraintName + "`"); } // Handle a reference to a proper constraint. if (auto *cst = dyn_cast(cstDecl)) return ast::ConstraintRef(cst, loc); return emitErrorAndNote( loc, "invalid reference to non-constraint", cstDecl->getLoc(), "see the definition of `" + constraintName + "` here"); } default: break; } return emitError(loc, "expected identifier constraint"); } //===----------------------------------------------------------------------===// // Exprs FailureOr Parser::parseExpr() { if (curToken.is(Token::underscore)) return parseUnderscoreExpr(); // Parse the LHS expression. FailureOr lhsExpr; switch (curToken.getKind()) { case Token::kw_attr: lhsExpr = parseAttributeExpr(); break; case Token::identifier: lhsExpr = parseIdentifierExpr(); break; case Token::kw_op: lhsExpr = parseOperationExpr(); break; case Token::kw_type: lhsExpr = parseTypeExpr(); break; case Token::l_paren: lhsExpr = parseTupleExpr(); break; default: return emitError("expected expression"); } if (failed(lhsExpr)) return failure(); // Check for an operator expression. while (true) { switch (curToken.getKind()) { case Token::dot: lhsExpr = parseMemberAccessExpr(*lhsExpr); break; default: return lhsExpr; } if (failed(lhsExpr)) return failure(); } } FailureOr Parser::parseAttributeExpr() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_attr); // If we aren't followed by a `<`, the `attr` keyword is treated as a normal // identifier. if (!consumeIf(Token::less)) { resetToken(loc); return parseIdentifierExpr(); } if (!curToken.isString()) return emitError("expected string literal containing MLIR attribute"); std::string attrExpr = curToken.getStringValue(); consumeToken(); if (failed( parseToken(Token::greater, "expected `>` after attribute literal"))) return failure(); return ast::AttributeExpr::create(ctx, loc, attrExpr); } FailureOr Parser::parseDeclRefExpr(StringRef name, llvm::SMRange loc) { ast::Decl *decl = curDeclScope->lookup(name); if (!decl) return emitError(loc, "undefined reference to `" + name + "`"); return createDeclRefExpr(loc, decl); } FailureOr Parser::parseIdentifierExpr() { StringRef name = curToken.getSpelling(); llvm::SMRange nameLoc = curToken.getLoc(); consumeToken(); // Check to see if this is a decl ref expression that defines a variable // inline. if (consumeIf(Token::colon)) { SmallVector constraints; if (failed(parseVariableDeclConstraintList(constraints))) return failure(); ast::Type type; if (failed(validateVariableConstraints(constraints, type))) return failure(); return createInlineVariableExpr(type, name, nameLoc, constraints); } return parseDeclRefExpr(name, nameLoc); } FailureOr Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::dot); // Parse the member name. Token memberNameTok = curToken; if (memberNameTok.isNot(Token::identifier, Token::integer) && !memberNameTok.isKeyword()) return emitError(loc, "expected identifier or numeric member name"); StringRef memberName = memberNameTok.getSpelling(); consumeToken(); return createMemberAccessExpr(parentExpr, memberName, loc); } FailureOr Parser::parseOperationName(bool allowEmptyName) { llvm::SMRange loc = curToken.getLoc(); // Handle the case of an no operation name. if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { if (allowEmptyName) return ast::OpNameDecl::create(ctx, llvm::SMRange()); return emitError("expected dialect namespace"); } StringRef name = curToken.getSpelling(); consumeToken(); // Otherwise, this is a literal operation name. if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) return failure(); if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) return emitError("expected operation name after dialect namespace"); name = StringRef(name.data(), name.size() + 1); do { name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); loc.End = curToken.getEndLoc(); consumeToken(); } while (curToken.isAny(Token::identifier, Token::dot) || curToken.isKeyword()); return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); } FailureOr Parser::parseWrappedOperationName(bool allowEmptyName) { if (!consumeIf(Token::less)) return ast::OpNameDecl::create(ctx, llvm::SMRange()); FailureOr opNameDecl = parseOperationName(allowEmptyName); if (failed(opNameDecl)) return failure(); if (failed(parseToken(Token::greater, "expected `>` after operation name"))) return failure(); return opNameDecl; } FailureOr Parser::parseOperationExpr() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_op); // If it isn't followed by a `<`, the `op` keyword is treated as a normal // identifier. if (curToken.isNot(Token::less)) { resetToken(loc); return parseIdentifierExpr(); } // Parse the operation name. The name may be elided, in which case the // operation refers to "any" operation(i.e. a difference between `MyOp` and // `Operation*`). Operation names within a rewrite context must be named. bool allowEmptyName = parserContext != ParserContext::Rewrite; FailureOr opNameDecl = parseWrappedOperationName(allowEmptyName); if (failed(opNameDecl)) return failure(); // Check for the optional list of operands. SmallVector operands; if (consumeIf(Token::l_paren)) { do { FailureOr operand = parseExpr(); if (failed(operand)) return failure(); operands.push_back(*operand); } while (consumeIf(Token::comma)); if (failed(parseToken(Token::r_paren, "expected `)` after operation operand list"))) return failure(); } // Check for the optional list of attributes. SmallVector attributes; if (consumeIf(Token::l_brace)) { do { FailureOr decl = parseNamedAttributeDecl(); if (failed(decl)) return failure(); attributes.emplace_back(*decl); } while (consumeIf(Token::comma)); if (failed(parseToken(Token::r_brace, "expected `}` after operation attribute list"))) return failure(); } // Check for the optional list of result types. SmallVector resultTypes; if (consumeIf(Token::arrow)) { if (failed(parseToken(Token::l_paren, "expected `(` before operation result type list"))) return failure(); do { FailureOr resultTypeExpr = parseExpr(); if (failed(resultTypeExpr)) return failure(); resultTypes.push_back(*resultTypeExpr); } while (consumeIf(Token::comma)); if (failed(parseToken(Token::r_paren, "expected `)` after operation result type list"))) return failure(); } return createOperationExpr(loc, *opNameDecl, operands, attributes, resultTypes); } FailureOr Parser::parseTupleExpr() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::l_paren); DenseMap usedNames; SmallVector elementNames; SmallVector elements; if (curToken.isNot(Token::r_paren)) { do { // Check for the optional element name assignment before the value. StringRef elementName; if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { Token elementNameTok = curToken; consumeToken(); // The element name is only present if followed by an `=`. if (consumeIf(Token::equal)) { elementName = elementNameTok.getSpelling(); // Check to see if this name is already used. auto elementNameIt = usedNames.try_emplace(elementName, elementNameTok.getLoc()); if (!elementNameIt.second) { return emitErrorAndNote( elementNameTok.getLoc(), llvm::formatv("duplicate tuple element label `{0}`", elementName), elementNameIt.first->getSecond(), "see previous label use here"); } } else { // Otherwise, we treat this as part of an expression so reset the // lexer. resetToken(elementNameTok.getLoc()); } } elementNames.push_back(elementName); // Parse the tuple element value. FailureOr element = parseExpr(); if (failed(element)) return failure(); elements.push_back(*element); } while (consumeIf(Token::comma)); } loc.End = curToken.getEndLoc(); if (failed( parseToken(Token::r_paren, "expected `)` after tuple element list"))) return failure(); return createTupleExpr(loc, elements, elementNames); } FailureOr Parser::parseTypeExpr() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_type); // If we aren't followed by a `<`, the `type` keyword is treated as a normal // identifier. if (!consumeIf(Token::less)) { resetToken(loc); return parseIdentifierExpr(); } if (!curToken.isString()) return emitError("expected string literal containing MLIR type"); std::string attrExpr = curToken.getStringValue(); consumeToken(); if (failed(parseToken(Token::greater, "expected `>` after type literal"))) return failure(); return ast::TypeExpr::create(ctx, loc, attrExpr); } FailureOr Parser::parseUnderscoreExpr() { StringRef name = curToken.getSpelling(); llvm::SMRange nameLoc = curToken.getLoc(); consumeToken(Token::underscore); // Underscore expressions require a constraint list. if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) return failure(); // Parse the constraints for the expression. SmallVector constraints; if (failed(parseVariableDeclConstraintList(constraints))) return failure(); ast::Type type; if (failed(validateVariableConstraints(constraints, type))) return failure(); return createInlineVariableExpr(type, name, nameLoc, constraints); } //===----------------------------------------------------------------------===// // Stmts FailureOr Parser::parseStmt(bool expectTerminalSemicolon) { FailureOr stmt; switch (curToken.getKind()) { case Token::kw_erase: stmt = parseEraseStmt(); break; case Token::kw_let: stmt = parseLetStmt(); break; case Token::kw_replace: stmt = parseReplaceStmt(); break; case Token::kw_rewrite: stmt = parseRewriteStmt(); break; default: stmt = parseExpr(); break; } if (failed(stmt) || (expectTerminalSemicolon && failed(parseToken(Token::semicolon, "expected `;` after statement")))) return failure(); return stmt; } FailureOr Parser::parseCompoundStmt() { llvm::SMLoc startLoc = curToken.getStartLoc(); consumeToken(Token::l_brace); // Push a new block scope and parse any nested statements. pushDeclScope(); SmallVector statements; while (curToken.isNot(Token::r_brace)) { FailureOr statement = parseStmt(); if (failed(statement)) return popDeclScope(), failure(); statements.push_back(*statement); } popDeclScope(); // Consume the end brace. llvm::SMRange location(startLoc, curToken.getEndLoc()); consumeToken(Token::r_brace); return ast::CompoundStmt::create(ctx, location, statements); } FailureOr Parser::parseEraseStmt() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_erase); // Parse the root operation expression. FailureOr rootOp = parseExpr(); if (failed(rootOp)) return failure(); return createEraseStmt(loc, *rootOp); } FailureOr Parser::parseLetStmt() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_let); // Parse the name of the new variable. llvm::SMRange varLoc = curToken.getLoc(); if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { // `_` is a reserved variable name. if (curToken.is(Token::underscore)) { return emitError(varLoc, "`_` may only be used to define \"inline\" variables"); } return emitError(varLoc, "expected identifier after `let` to name a new variable"); } StringRef varName = curToken.getSpelling(); consumeToken(); // Parse the optional set of constraints. SmallVector constraints; if (consumeIf(Token::colon) && failed(parseVariableDeclConstraintList(constraints))) return failure(); // Parse the optional initializer expression. ast::Expr *initializer = nullptr; if (consumeIf(Token::equal)) { FailureOr initOrFailure = parseExpr(); if (failed(initOrFailure)) return failure(); initializer = *initOrFailure; // Check that the constraints are compatible with having an initializer, // e.g. type constraints cannot be used with initializers. for (ast::ConstraintRef constraint : constraints) { LogicalResult result = TypeSwitch(constraint.constraint) .Case([&](const auto *cst) { if (auto *typeConstraintExpr = cst->getTypeExpr()) { return this->emitError( constraint.referenceLoc, "type constraints are not permitted on variables with " "initializers"); } return success(); }) .Default(success()); if (failed(result)) return failure(); } } FailureOr varDecl = createVariableDecl(varName, varLoc, initializer, constraints); if (failed(varDecl)) return failure(); return ast::LetStmt::create(ctx, loc, *varDecl); } FailureOr Parser::parseReplaceStmt() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_replace); // Parse the root operation expression. FailureOr rootOp = parseExpr(); if (failed(rootOp)) return failure(); if (failed( parseToken(Token::kw_with, "expected `with` after root operation"))) return failure(); // The replacement portion of this statement is within a rewrite context. llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); // Parse the replacement values. SmallVector replValues; if (consumeIf(Token::l_paren)) { if (consumeIf(Token::r_paren)) { return emitError( loc, "expected at least one replacement value, consider using " "`erase` if no replacement values are desired"); } do { FailureOr replExpr = parseExpr(); if (failed(replExpr)) return failure(); replValues.emplace_back(*replExpr); } while (consumeIf(Token::comma)); if (failed(parseToken(Token::r_paren, "expected `)` after replacement values"))) return failure(); } else { FailureOr replExpr = parseExpr(); if (failed(replExpr)) return failure(); replValues.emplace_back(*replExpr); } return createReplaceStmt(loc, *rootOp, replValues); } FailureOr Parser::parseRewriteStmt() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_rewrite); // Parse the root operation. FailureOr rootOp = parseExpr(); if (failed(rootOp)) return failure(); if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) return failure(); if (curToken.isNot(Token::l_brace)) return emitError("expected `{` to start rewrite body"); // The rewrite body of this statement is within a rewrite context. llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); FailureOr rewriteBody = parseCompoundStmt(); if (failed(rewriteBody)) return failure(); return createRewriteStmt(loc, *rootOp, *rewriteBody); } //===----------------------------------------------------------------------===// // Creation+Analysis //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Decls FailureOr Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name, const ParsedPatternMetadata &metadata, ast::CompoundStmt *body) { return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, metadata.hasBoundedRecursion, body); } FailureOr Parser::createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer, ArrayRef constraints) { // The type of the variable, which is expected to be inferred by either a // constraint or an initializer expression. ast::Type type; if (failed(validateVariableConstraints(constraints, type))) return failure(); if (initializer) { // Update the variable type based on the initializer, or try to convert the // initializer to the existing type. if (!type) type = initializer->getType(); else if (ast::Type mergedType = type.refineWith(initializer->getType())) type = mergedType; else if (failed(convertExpressionTo(initializer, type))) return failure(); // Otherwise, if there is no initializer check that the type has already // been resolved from the constraint list. } else if (!type) { return emitErrorAndNote( loc, "unable to infer type for variable `" + name + "`", loc, "the type of a variable must be inferable from the constraint " "list or the initializer"); } // Try to define a variable with the given name. FailureOr varDecl = defineVariableDecl(name, loc, type, initializer, constraints); if (failed(varDecl)) return failure(); return *varDecl; } LogicalResult Parser::validateVariableConstraints(ArrayRef constraints, ast::Type &inferredType) { for (const ast::ConstraintRef &ref : constraints) if (failed(validateVariableConstraint(ref, inferredType))) return failure(); return success(); } LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, ast::Type &inferredType) { ast::Type constraintType; if (const auto *cst = dyn_cast(ref.constraint)) { if (const ast::Expr *typeExpr = cst->getTypeExpr()) { if (failed(validateTypeConstraintExpr(typeExpr))) return failure(); } constraintType = ast::AttributeType::get(ctx); } else if (const auto *cst = dyn_cast(ref.constraint)) { constraintType = ast::OperationType::get(ctx, cst->getName()); } else if (isa(ref.constraint)) { constraintType = typeTy; } else if (isa(ref.constraint)) { constraintType = typeRangeTy; } else if (const auto *cst = dyn_cast(ref.constraint)) { if (const ast::Expr *typeExpr = cst->getTypeExpr()) { if (failed(validateTypeConstraintExpr(typeExpr))) return failure(); } constraintType = valueTy; } else if (const auto *cst = dyn_cast(ref.constraint)) { if (const ast::Expr *typeExpr = cst->getTypeExpr()) { if (failed(validateTypeRangeConstraintExpr(typeExpr))) return failure(); } constraintType = valueRangeTy; } else { llvm_unreachable("unknown constraint type"); } // Check that the constraint type is compatible with the current inferred // type. if (!inferredType) { inferredType = constraintType; } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { inferredType = mergedTy; } else { return emitError(ref.referenceLoc, llvm::formatv("constraint type `{0}` is incompatible " "with the previously inferred type `{1}`", constraintType, inferredType)); } return success(); } LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { ast::Type typeExprType = typeExpr->getType(); if (typeExprType != typeTy) { return emitError(typeExpr->getLoc(), "expected expression of `Type` in type constraint"); } return success(); } LogicalResult Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { ast::Type typeExprType = typeExpr->getType(); if (typeExprType != typeRangeTy) { return emitError(typeExpr->getLoc(), "expected expression of `TypeRange` in type constraint"); } return success(); } //===----------------------------------------------------------------------===// // Exprs FailureOr Parser::createDeclRefExpr(llvm::SMRange loc, ast::Decl *decl) { // Check the type of decl being referenced. ast::Type declType; if (auto *varDecl = dyn_cast(decl)) declType = varDecl->getType(); else return emitError(loc, "invalid reference to `" + decl->getName()->getName() + "`"); return ast::DeclRefExpr::create(ctx, loc, decl, declType); } FailureOr Parser::createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc, ArrayRef constraints) { FailureOr decl = defineVariableDecl(name, loc, type, constraints); if (failed(decl)) return failure(); return ast::DeclRefExpr::create(ctx, loc, *decl, type); } FailureOr Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, llvm::SMRange loc) { // Validate the member name for the given parent expression. FailureOr memberType = validateMemberAccess(parentExpr, name, loc); if (failed(memberType)) return failure(); return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); } FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, StringRef name, llvm::SMRange loc) { ast::Type parentType = parentExpr->getType(); if (parentType.isa()) { if (name == ast::AllResultsMemberAccessExpr::getMemberName()) return valueRangeTy; } else if (auto tupleType = parentType.dyn_cast()) { // Handle indexed results. unsigned index = 0; if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && index < tupleType.size()) { return tupleType.getElementTypes()[index]; } // Handle named results. auto elementNames = tupleType.getElementNames(); const auto *it = llvm::find(elementNames, name); if (it != elementNames.end()) return tupleType.getElementTypes()[it - elementNames.begin()]; } return emitError( loc, llvm::formatv("invalid member access `{0}` on expression of type `{1}`", name, parentType)); } FailureOr Parser::createOperationExpr( llvm::SMRange loc, const ast::OpNameDecl *name, MutableArrayRef operands, MutableArrayRef attributes, MutableArrayRef results) { Optional opNameRef = name->getName(); // Verify the inputs operands. if (failed(validateOperationOperands(loc, opNameRef, operands))) return failure(); // Verify the attribute list. for (ast::NamedAttributeDecl *attr : attributes) { // Check for an attribute type, or a type awaiting resolution. ast::Type attrType = attr->getValue()->getType(); if (!attrType.isa()) { return emitError( attr->getValue()->getLoc(), llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); } } // Verify the result types. if (failed(validateOperationResults(loc, opNameRef, results))) return failure(); return ast::OperationExpr::create(ctx, loc, name, operands, results, attributes); } LogicalResult Parser::validateOperationOperands(llvm::SMRange loc, Optional name, MutableArrayRef operands) { return validateOperationOperandsOrResults(loc, name, operands, valueTy, valueRangeTy); } LogicalResult Parser::validateOperationResults(llvm::SMRange loc, Optional name, MutableArrayRef results) { return validateOperationOperandsOrResults(loc, name, results, typeTy, typeRangeTy); } LogicalResult Parser::validateOperationOperandsOrResults( llvm::SMRange loc, Optional name, MutableArrayRef values, ast::Type singleTy, ast::Type rangeTy) { // All operation types accept a single range parameter. if (values.size() == 1) { if (failed(convertExpressionTo(values[0], rangeTy))) return failure(); return success(); } // Otherwise, accept the value groups as they have been defined and just // ensure they are one of the expected types. for (ast::Expr *&valueExpr : values) { ast::Type valueExprType = valueExpr->getType(); // Check if this is one of the expected types. if (valueExprType == rangeTy || valueExprType == singleTy) continue; // If the operand is an Operation, allow converting to a Value or // ValueRange. This situations arises quite often with nested operation // expressions: `op(op)` if (singleTy == valueTy) { if (valueExprType.isa()) { valueExpr = convertOpToValue(valueExpr); continue; } } return emitError( valueExpr->getLoc(), llvm::formatv( "expected `{0}` or `{1}` convertible expression, but got `{2}`", singleTy, rangeTy, valueExprType)); } return success(); } FailureOr Parser::createTupleExpr(llvm::SMRange loc, ArrayRef elements, ArrayRef elementNames) { for (const ast::Expr *element : elements) { ast::Type eleTy = element->getType(); if (eleTy.isa()) { return emitError( element->getLoc(), llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); } } return ast::TupleExpr::create(ctx, loc, elements, elementNames); } //===----------------------------------------------------------------------===// // Stmts FailureOr Parser::createEraseStmt(llvm::SMRange loc, ast::Expr *rootOp) { // Check that root is an Operation. ast::Type rootType = rootOp->getType(); if (!rootType.isa()) return emitError(rootOp->getLoc(), "expected `Op` expression"); return ast::EraseStmt::create(ctx, loc, rootOp); } FailureOr Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, MutableArrayRef replValues) { // Check that root is an Operation. ast::Type rootType = rootOp->getType(); if (!rootType.isa()) { return emitError( rootOp->getLoc(), llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); } // If there are multiple replacement values, we implicitly convert any Op // expressions to the value form. bool shouldConvertOpToValues = replValues.size() > 1; for (ast::Expr *&replExpr : replValues) { ast::Type replType = replExpr->getType(); // Check that replExpr is an Operation, Value, or ValueRange. if (replType.isa()) { if (shouldConvertOpToValues) replExpr = convertOpToValue(replExpr); continue; } if (replType != valueTy && replType != valueRangeTy) { return emitError(replExpr->getLoc(), llvm::formatv("expected `Op`, `Value` or `ValueRange` " "expression, but got `{0}`", replType)); } } return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); } FailureOr Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp, ast::CompoundStmt *rewriteBody) { // Check that root is an Operation. ast::Type rootType = rootOp->getType(); if (!rootType.isa()) { return emitError( rootOp->getLoc(), llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); } return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); } //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// FailureOr mlir::pdll::parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr) { Parser parser(ctx, sourceMgr); return parser.parseModule(); }