1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/IndentedOstream.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/Argument.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Constraint.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/Tools/PDLL/AST/Context.h"
19 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
20 #include "mlir/Tools/PDLL/AST/Nodes.h"
21 #include "mlir/Tools/PDLL/AST/Types.h"
22 #include "mlir/Tools/PDLL/ODS/Constraint.h"
23 #include "mlir/Tools/PDLL/ODS/Context.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
Parser(ast::Context & ctx,llvm::SourceMgr & sourceMgr,bool enableDocumentation,CodeCompleteContext * codeCompleteContext)46   Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47          bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48       : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49         curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50         valueTy(ast::ValueType::get(ctx)),
51         valueRangeTy(ast::ValueRangeType::get(ctx)),
52         typeTy(ast::TypeType::get(ctx)),
53         typeRangeTy(ast::TypeRangeType::get(ctx)),
54         attrTy(ast::AttributeType::get(ctx)),
55         codeCompleteContext(codeCompleteContext) {}
56 
57   /// Try to parse a new module. Returns nullptr in the case of failure.
58   FailureOr<ast::Module *> parseModule();
59 
60 private:
61   /// The current context of the parser. It allows for the parser to know a bit
62   /// about the construct it is nested within during parsing. This is used
63   /// specifically to provide additional verification during parsing, e.g. to
64   /// prevent using rewrites within a match context, matcher constraints within
65   /// a rewrite section, etc.
66   enum class ParserContext {
67     /// The parser is in the global context.
68     Global,
69     /// The parser is currently within a Constraint, which disallows all types
70     /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71     Constraint,
72     /// The parser is currently within the matcher portion of a Pattern, which
73     /// is allows a terminal operation rewrite statement but no other rewrite
74     /// transformations.
75     PatternMatch,
76     /// The parser is currently within a Rewrite, which disallows calls to
77     /// constraints, requires operation expressions to have names, etc.
78     Rewrite,
79   };
80 
81   /// The current specification context of an operations result type. This
82   /// indicates how the result types of an operation may be inferred.
83   enum class OpResultTypeContext {
84     /// The result types of the operation are not known to be inferred.
85     Explicit,
86     /// The result types of the operation are inferred from the root input of a
87     /// `replace` statement.
88     Replacement,
89     /// The result types of the operation are inferred by using the
90     /// `InferTypeOpInterface` interface provided by the operation.
91     Interface,
92   };
93 
94   //===--------------------------------------------------------------------===//
95   // Parsing
96   //===--------------------------------------------------------------------===//
97 
98   /// Push a new decl scope onto the lexer.
pushDeclScope()99   ast::DeclScope *pushDeclScope() {
100     ast::DeclScope *newScope =
101         new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102     return (curDeclScope = newScope);
103   }
pushDeclScope(ast::DeclScope * scope)104   void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105 
106   /// Pop the last decl scope from the lexer.
popDeclScope()107   void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108 
109   /// Parse the body of an AST module.
110   LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111 
112   /// Try to convert the given expression to `type`. Returns failure and emits
113   /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114   /// invoked to attach notes to the emitted error diagnostic. On success,
115   /// `expr` is updated to the expression used to convert to `type`.
116   LogicalResult convertExpressionTo(
117       ast::Expr *&expr, ast::Type type,
118       function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
119 
120   /// Given an operation expression, convert it to a Value or ValueRange
121   /// typed expression.
122   ast::Expr *convertOpToValue(const ast::Expr *opExpr);
123 
124   /// Lookup ODS information for the given operation, returns nullptr if no
125   /// information is found.
lookupODSOperation(Optional<StringRef> opName)126   const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
127     return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
128   }
129 
130   /// Process the given documentation string, or return an empty string if
131   /// documentation isn't enabled.
processDoc(StringRef doc)132   StringRef processDoc(StringRef doc) {
133     return enableDocumentation ? doc : StringRef();
134   }
135 
136   /// Process the given documentation string and format it, or return an empty
137   /// string if documentation isn't enabled.
processAndFormatDoc(const Twine & doc)138   std::string processAndFormatDoc(const Twine &doc) {
139     if (!enableDocumentation)
140       return "";
141     std::string docStr;
142     {
143       llvm::raw_string_ostream docOS(docStr);
144       std::string tmpDocStr = doc.str();
145       raw_indented_ostream(docOS).printReindented(
146           StringRef(tmpDocStr).rtrim(" \t"));
147     }
148     return docStr;
149   }
150 
151   //===--------------------------------------------------------------------===//
152   // Directives
153 
154   LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
155   LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
156   LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
157                                SmallVectorImpl<ast::Decl *> &decls);
158 
159   /// Process the records of a parsed tablegen include file.
160   void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
161                                SmallVectorImpl<ast::Decl *> &decls);
162 
163   /// Create a user defined native constraint for a constraint imported from
164   /// ODS.
165   template <typename ConstraintT>
166   ast::Decl *
167   createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
168                                     SMRange loc, ast::Type type,
169                                     StringRef nativeType, StringRef docString);
170   template <typename ConstraintT>
171   ast::Decl *
172   createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
173                                     SMRange loc, ast::Type type,
174                                     StringRef nativeType);
175 
176   //===--------------------------------------------------------------------===//
177   // Decls
178 
179   /// This structure contains the set of pattern metadata that may be parsed.
180   struct ParsedPatternMetadata {
181     Optional<uint16_t> benefit;
182     bool hasBoundedRecursion = false;
183   };
184 
185   FailureOr<ast::Decl *> parseTopLevelDecl();
186   FailureOr<ast::NamedAttributeDecl *>
187   parseNamedAttributeDecl(Optional<StringRef> parentOpName);
188 
189   /// Parse an argument variable as part of the signature of a
190   /// UserConstraintDecl or UserRewriteDecl.
191   FailureOr<ast::VariableDecl *> parseArgumentDecl();
192 
193   /// Parse a result variable as part of the signature of a UserConstraintDecl
194   /// or UserRewriteDecl.
195   FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
196 
197   /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
198   /// defined in a non-global context.
199   FailureOr<ast::UserConstraintDecl *>
200   parseUserConstraintDecl(bool isInline = false);
201 
202   /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
203   /// non-global context, such as within a Pattern/Constraint/etc.
204   FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
205 
206   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
207   /// PDLL constructs.
208   FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
209       const ast::Name &name, bool isInline,
210       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
211       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
212 
213   /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
214   /// defined in a non-global context.
215   FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
216 
217   /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
218   /// non-global context, such as within a Pattern/Rewrite/etc.
219   FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
220 
221   /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
222   /// PDLL constructs.
223   FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
224       const ast::Name &name, bool isInline,
225       ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
226       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
227 
228   /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
229   /// effectively the same syntax, and only differ on slight semantics (given
230   /// the different parsing contexts).
231   template <typename T, typename ParseUserPDLLDeclFnT>
232   FailureOr<T *> parseUserConstraintOrRewriteDecl(
233       ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
234       StringRef anonymousNamePrefix, bool isInline);
235 
236   /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
237   /// These decls have effectively the same syntax.
238   template <typename T>
239   FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
240       const ast::Name &name, bool isInline,
241       ArrayRef<ast::VariableDecl *> arguments,
242       ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
243 
244   /// Parse the functional signature (i.e. the arguments and results) of a
245   /// UserConstraintDecl or UserRewriteDecl.
246   LogicalResult parseUserConstraintOrRewriteSignature(
247       SmallVectorImpl<ast::VariableDecl *> &arguments,
248       SmallVectorImpl<ast::VariableDecl *> &results,
249       ast::DeclScope *&argumentScope, ast::Type &resultType);
250 
251   /// Validate the return (which if present is specified by bodyIt) of a
252   /// UserConstraintDecl or UserRewriteDecl.
253   LogicalResult validateUserConstraintOrRewriteReturn(
254       StringRef declType, ast::CompoundStmt *body,
255       ArrayRef<ast::Stmt *>::iterator bodyIt,
256       ArrayRef<ast::Stmt *>::iterator bodyE,
257       ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
258 
259   FailureOr<ast::CompoundStmt *>
260   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
261                   bool expectTerminalSemicolon = true);
262   FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
263   FailureOr<ast::Decl *> parsePatternDecl();
264   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
265 
266   /// Check to see if a decl has already been defined with the given name, if
267   /// one has emit and error and return failure. Returns success otherwise.
268   LogicalResult checkDefineNamedDecl(const ast::Name &name);
269 
270   /// Try to define a variable decl with the given components, returns the
271   /// variable on success.
272   FailureOr<ast::VariableDecl *>
273   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
274                      ast::Expr *initExpr,
275                      ArrayRef<ast::ConstraintRef> constraints);
276   FailureOr<ast::VariableDecl *>
277   defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
278                      ArrayRef<ast::ConstraintRef> constraints);
279 
280   /// Parse the constraint reference list for a variable decl.
281   LogicalResult parseVariableDeclConstraintList(
282       SmallVectorImpl<ast::ConstraintRef> &constraints);
283 
284   /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
285   FailureOr<ast::Expr *> parseTypeConstraintExpr();
286 
287   /// Try to parse a single reference to a constraint. `typeConstraint` is the
288   /// location of a previously parsed type constraint for the entity that will
289   /// be constrained by the parsed constraint. `existingConstraints` are any
290   /// existing constraints that have already been parsed for the same entity
291   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
292   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
293   /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
294   /// constraints) may be used with the variable.
295   FailureOr<ast::ConstraintRef>
296   parseConstraint(Optional<SMRange> &typeConstraint,
297                   ArrayRef<ast::ConstraintRef> existingConstraints,
298                   bool allowInlineTypeConstraints,
299                   bool allowNonCoreConstraints);
300 
301   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
302   /// argument or result variable. The constraints for these variables do not
303   /// allow inline type constraints, and only permit a single constraint.
304   FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
305 
306   //===--------------------------------------------------------------------===//
307   // Exprs
308 
309   FailureOr<ast::Expr *> parseExpr();
310 
311   /// Identifier expressions.
312   FailureOr<ast::Expr *> parseAttributeExpr();
313   FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
314   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
315   FailureOr<ast::Expr *> parseIdentifierExpr();
316   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
317   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
318   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
319   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
320   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
321   FailureOr<ast::Expr *>
322   parseOperationExpr(OpResultTypeContext inputResultTypeContext =
323                          OpResultTypeContext::Explicit);
324   FailureOr<ast::Expr *> parseTupleExpr();
325   FailureOr<ast::Expr *> parseTypeExpr();
326   FailureOr<ast::Expr *> parseUnderscoreExpr();
327 
328   //===--------------------------------------------------------------------===//
329   // Stmts
330 
331   FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
332   FailureOr<ast::CompoundStmt *> parseCompoundStmt();
333   FailureOr<ast::EraseStmt *> parseEraseStmt();
334   FailureOr<ast::LetStmt *> parseLetStmt();
335   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
336   FailureOr<ast::ReturnStmt *> parseReturnStmt();
337   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
338 
339   //===--------------------------------------------------------------------===//
340   // Creation+Analysis
341   //===--------------------------------------------------------------------===//
342 
343   //===--------------------------------------------------------------------===//
344   // Decls
345 
346   /// Try to extract a callable from the given AST node. Returns nullptr on
347   /// failure.
348   ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
349 
350   /// Try to create a pattern decl with the given components, returning the
351   /// Pattern on success.
352   FailureOr<ast::PatternDecl *>
353   createPatternDecl(SMRange loc, const ast::Name *name,
354                     const ParsedPatternMetadata &metadata,
355                     ast::CompoundStmt *body);
356 
357   /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
358   /// of results, defined as part of the signature.
359   ast::Type
360   createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
361 
362   /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
363   template <typename T>
364   FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
365       const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
366       ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
367       ast::CompoundStmt *body);
368 
369   /// Try to create a variable decl with the given components, returning the
370   /// Variable on success.
371   FailureOr<ast::VariableDecl *>
372   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
373                      ArrayRef<ast::ConstraintRef> constraints);
374 
375   /// Create a variable for an argument or result defined as part of the
376   /// signature of a UserConstraintDecl/UserRewriteDecl.
377   FailureOr<ast::VariableDecl *>
378   createArgOrResultVariableDecl(StringRef name, SMRange loc,
379                                 const ast::ConstraintRef &constraint);
380 
381   /// Validate the constraints used to constraint a variable decl.
382   /// `inferredType` is the type of the variable inferred by the constraints
383   /// within the list, and is updated to the most refined type as determined by
384   /// the constraints. Returns success if the constraint list is valid, failure
385   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
386   /// defined constraints) may be used with the variable.
387   LogicalResult
388   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
389                               ast::Type &inferredType,
390                               bool allowNonCoreConstraints = true);
391   /// Validate a single reference to a constraint. `inferredType` contains the
392   /// currently inferred variabled type and is refined within the type defined
393   /// by the constraint. Returns success if the constraint is valid, failure
394   /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
395   /// defined constraints) may be used with the variable.
396   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
397                                            ast::Type &inferredType,
398                                            bool allowNonCoreConstraints = true);
399   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
400   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
401 
402   //===--------------------------------------------------------------------===//
403   // Exprs
404 
405   FailureOr<ast::CallExpr *>
406   createCallExpr(SMRange loc, ast::Expr *parentExpr,
407                  MutableArrayRef<ast::Expr *> arguments);
408   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
409   FailureOr<ast::DeclRefExpr *>
410   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
411                            ArrayRef<ast::ConstraintRef> constraints);
412   FailureOr<ast::MemberAccessExpr *>
413   createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
414 
415   /// Validate the member access `name` into the given parent expression. On
416   /// success, this also returns the type of the member accessed.
417   FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
418                                             StringRef name, SMRange loc);
419   FailureOr<ast::OperationExpr *>
420   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
421                       OpResultTypeContext resultTypeContext,
422                       MutableArrayRef<ast::Expr *> operands,
423                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
424                       MutableArrayRef<ast::Expr *> results);
425   LogicalResult
426   validateOperationOperands(SMRange loc, Optional<StringRef> name,
427                             const ods::Operation *odsOp,
428                             MutableArrayRef<ast::Expr *> operands);
429   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
430                                          const ods::Operation *odsOp,
431                                          MutableArrayRef<ast::Expr *> results);
432   void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
433                                           const ods::Operation *odsOp);
434   LogicalResult validateOperationOperandsOrResults(
435       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
436       Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
437       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
438       ast::Type rangeTy);
439   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
440                                               ArrayRef<ast::Expr *> elements,
441                                               ArrayRef<StringRef> elementNames);
442 
443   //===--------------------------------------------------------------------===//
444   // Stmts
445 
446   FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
447   FailureOr<ast::ReplaceStmt *>
448   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
449                     MutableArrayRef<ast::Expr *> replValues);
450   FailureOr<ast::RewriteStmt *>
451   createRewriteStmt(SMRange loc, ast::Expr *rootOp,
452                     ast::CompoundStmt *rewriteBody);
453 
454   //===--------------------------------------------------------------------===//
455   // Code Completion
456   //===--------------------------------------------------------------------===//
457 
458   /// The set of various code completion methods. Every completion method
459   /// returns `failure` to stop the parsing process after providing completion
460   /// results.
461 
462   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
463   LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
464   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
465                                            bool allowNonCoreConstraints,
466                                            bool allowInlineTypeConstraints);
467   LogicalResult codeCompleteDialectName();
468   LogicalResult codeCompleteOperationName(StringRef dialectName);
469   LogicalResult codeCompletePatternMetadata();
470   LogicalResult codeCompleteIncludeFilename(StringRef curPath);
471 
472   void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
473   void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
474                                               unsigned currentNumOperands);
475   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
476                                              unsigned currentNumResults);
477 
478   //===--------------------------------------------------------------------===//
479   // Lexer Utilities
480   //===--------------------------------------------------------------------===//
481 
482   /// If the current token has the specified kind, consume it and return true.
483   /// If not, return false.
consumeIf(Token::Kind kind)484   bool consumeIf(Token::Kind kind) {
485     if (curToken.isNot(kind))
486       return false;
487     consumeToken(kind);
488     return true;
489   }
490 
491   /// Advance the current lexer onto the next token.
consumeToken()492   void consumeToken() {
493     assert(curToken.isNot(Token::eof, Token::error) &&
494            "shouldn't advance past EOF or errors");
495     curToken = lexer.lexToken();
496   }
497 
498   /// Advance the current lexer onto the next token, asserting what the expected
499   /// current token is. This is preferred to the above method because it leads
500   /// to more self-documenting code with better checking.
consumeToken(Token::Kind kind)501   void consumeToken(Token::Kind kind) {
502     assert(curToken.is(kind) && "consumed an unexpected token");
503     consumeToken();
504   }
505 
506   /// Reset the lexer to the location at the given position.
resetToken(SMRange tokLoc)507   void resetToken(SMRange tokLoc) {
508     lexer.resetPointer(tokLoc.Start.getPointer());
509     curToken = lexer.lexToken();
510   }
511 
512   /// Consume the specified token if present and return success. On failure,
513   /// output a diagnostic and return failure.
parseToken(Token::Kind kind,const Twine & msg)514   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
515     if (curToken.getKind() != kind)
516       return emitError(curToken.getLoc(), msg);
517     consumeToken();
518     return success();
519   }
emitError(SMRange loc,const Twine & msg)520   LogicalResult emitError(SMRange loc, const Twine &msg) {
521     lexer.emitError(loc, msg);
522     return failure();
523   }
emitError(const Twine & msg)524   LogicalResult emitError(const Twine &msg) {
525     return emitError(curToken.getLoc(), msg);
526   }
emitErrorAndNote(SMRange loc,const Twine & msg,SMRange noteLoc,const Twine & note)527   LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
528                                  const Twine &note) {
529     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
530     return failure();
531   }
532 
533   //===--------------------------------------------------------------------===//
534   // Fields
535   //===--------------------------------------------------------------------===//
536 
537   /// The owning AST context.
538   ast::Context &ctx;
539 
540   /// The lexer of this parser.
541   Lexer lexer;
542 
543   /// The current token within the lexer.
544   Token curToken;
545 
546   /// A flag indicating if the parser should add documentation to AST nodes when
547   /// viable.
548   bool enableDocumentation;
549 
550   /// The most recently defined decl scope.
551   ast::DeclScope *curDeclScope = nullptr;
552   llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
553 
554   /// The current context of the parser.
555   ParserContext parserContext = ParserContext::Global;
556 
557   /// Cached types to simplify verification and expression creation.
558   ast::Type valueTy, valueRangeTy;
559   ast::Type typeTy, typeRangeTy;
560   ast::Type attrTy;
561 
562   /// A counter used when naming anonymous constraints and rewrites.
563   unsigned anonymousDeclNameCounter = 0;
564 
565   /// The optional code completion context.
566   CodeCompleteContext *codeCompleteContext;
567 };
568 } // namespace
569 
parseModule()570 FailureOr<ast::Module *> Parser::parseModule() {
571   SMLoc moduleLoc = curToken.getStartLoc();
572   pushDeclScope();
573 
574   // Parse the top-level decls of the module.
575   SmallVector<ast::Decl *> decls;
576   if (failed(parseModuleBody(decls)))
577     return popDeclScope(), failure();
578 
579   popDeclScope();
580   return ast::Module::create(ctx, moduleLoc, decls);
581 }
582 
parseModuleBody(SmallVectorImpl<ast::Decl * > & decls)583 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
584   while (curToken.isNot(Token::eof)) {
585     if (curToken.is(Token::directive)) {
586       if (failed(parseDirective(decls)))
587         return failure();
588       continue;
589     }
590 
591     FailureOr<ast::Decl *> decl = parseTopLevelDecl();
592     if (failed(decl))
593       return failure();
594     decls.push_back(*decl);
595   }
596   return success();
597 }
598 
convertOpToValue(const ast::Expr * opExpr)599 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
600   return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
601                                                  valueRangeTy);
602 }
603 
convertExpressionTo(ast::Expr * & expr,ast::Type type,function_ref<void (ast::Diagnostic & diag)> noteAttachFn)604 LogicalResult Parser::convertExpressionTo(
605     ast::Expr *&expr, ast::Type type,
606     function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
607   ast::Type exprType = expr->getType();
608   if (exprType == type)
609     return success();
610 
611   auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
612     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
613         expr->getLoc(), llvm::formatv("unable to convert expression of type "
614                                       "`{0}` to the expected type of "
615                                       "`{1}`",
616                                       exprType, type));
617     if (noteAttachFn)
618       noteAttachFn(*diag);
619     return diag;
620   };
621 
622   if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
623     // Two operation types are compatible if they have the same name, or if the
624     // expected type is more general.
625     if (auto opType = type.dyn_cast<ast::OperationType>()) {
626       if (opType.getName())
627         return emitConvertError();
628       return success();
629     }
630 
631     // An operation can always convert to a ValueRange.
632     if (type == valueRangeTy) {
633       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
634                                                      valueRangeTy);
635       return success();
636     }
637 
638     // Allow conversion to a single value by constraining the result range.
639     if (type == valueTy) {
640       // If the operation is registered, we can verify if it can ever have a
641       // single result.
642       if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
643         if (odsOp->getResults().empty()) {
644           return emitConvertError()->attachNote(
645               llvm::formatv("see the definition of `{0}`, which was defined "
646                             "with zero results",
647                             odsOp->getName()),
648               odsOp->getLoc());
649         }
650 
651         unsigned numSingleResults = llvm::count_if(
652             odsOp->getResults(), [](const ods::OperandOrResult &result) {
653               return result.getVariableLengthKind() ==
654                      ods::VariableLengthKind::Single;
655             });
656         if (numSingleResults > 1) {
657           return emitConvertError()->attachNote(
658               llvm::formatv("see the definition of `{0}`, which was defined "
659                             "with at least {1} results",
660                             odsOp->getName(), numSingleResults),
661               odsOp->getLoc());
662         }
663       }
664 
665       expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
666                                                      valueTy);
667       return success();
668     }
669     return emitConvertError();
670   }
671 
672   // FIXME: Decide how to allow/support converting a single result to multiple,
673   // and multiple to a single result. For now, we just allow Single->Range,
674   // but this isn't something really supported in the PDL dialect. We should
675   // figure out some way to support both.
676   if ((exprType == valueTy || exprType == valueRangeTy) &&
677       (type == valueTy || type == valueRangeTy))
678     return success();
679   if ((exprType == typeTy || exprType == typeRangeTy) &&
680       (type == typeTy || type == typeRangeTy))
681     return success();
682 
683   // Handle tuple types.
684   if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
685     auto tupleType = type.dyn_cast<ast::TupleType>();
686     if (!tupleType || tupleType.size() != exprTupleType.size())
687       return emitConvertError();
688 
689     // Build a new tuple expression using each of the elements of the current
690     // tuple.
691     SmallVector<ast::Expr *> newExprs;
692     for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
693       newExprs.push_back(ast::MemberAccessExpr::create(
694           ctx, expr->getLoc(), expr, llvm::to_string(i),
695           exprTupleType.getElementTypes()[i]));
696 
697       auto diagFn = [&](ast::Diagnostic &diag) {
698         diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
699                                       i, exprTupleType));
700         if (noteAttachFn)
701           noteAttachFn(diag);
702       };
703       if (failed(convertExpressionTo(newExprs.back(),
704                                      tupleType.getElementTypes()[i], diagFn)))
705         return failure();
706     }
707     expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
708                                   tupleType.getElementNames());
709     return success();
710   }
711 
712   return emitConvertError();
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Directives
717 
parseDirective(SmallVectorImpl<ast::Decl * > & decls)718 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
719   StringRef directive = curToken.getSpelling();
720   if (directive == "#include")
721     return parseInclude(decls);
722 
723   return emitError("unknown directive `" + directive + "`");
724 }
725 
parseInclude(SmallVectorImpl<ast::Decl * > & decls)726 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
727   SMRange loc = curToken.getLoc();
728   consumeToken(Token::directive);
729 
730   // Handle code completion of the include file path.
731   if (curToken.is(Token::code_complete_string))
732     return codeCompleteIncludeFilename(curToken.getStringValue());
733 
734   // Parse the file being included.
735   if (!curToken.isString())
736     return emitError(loc,
737                      "expected string file name after `include` directive");
738   SMRange fileLoc = curToken.getLoc();
739   std::string filenameStr = curToken.getStringValue();
740   StringRef filename = filenameStr;
741   consumeToken();
742 
743   // Check the type of include. If ending with `.pdll`, this is another pdl file
744   // to be parsed along with the current module.
745   if (filename.endswith(".pdll")) {
746     if (failed(lexer.pushInclude(filename, fileLoc)))
747       return emitError(fileLoc,
748                        "unable to open include file `" + filename + "`");
749 
750     // If we added the include successfully, parse it into the current module.
751     // Make sure to update to the next token after we finish parsing the nested
752     // file.
753     curToken = lexer.lexToken();
754     LogicalResult result = parseModuleBody(decls);
755     curToken = lexer.lexToken();
756     return result;
757   }
758 
759   // Otherwise, this must be a `.td` include.
760   if (filename.endswith(".td"))
761     return parseTdInclude(filename, fileLoc, decls);
762 
763   return emitError(fileLoc,
764                    "expected include filename to end with `.pdll` or `.td`");
765 }
766 
parseTdInclude(StringRef filename,llvm::SMRange fileLoc,SmallVectorImpl<ast::Decl * > & decls)767 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
768                                      SmallVectorImpl<ast::Decl *> &decls) {
769   llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
770 
771   // Use the source manager to open the file, but don't yet add it.
772   std::string includedFile;
773   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
774       parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
775   if (!includeBuffer)
776     return emitError(fileLoc, "unable to open include file `" + filename + "`");
777 
778   // Setup the source manager for parsing the tablegen file.
779   llvm::SourceMgr tdSrcMgr;
780   tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
781   tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
782 
783   // This class provides a context argument for the llvm::SourceMgr diagnostic
784   // handler.
785   struct DiagHandlerContext {
786     Parser &parser;
787     StringRef filename;
788     llvm::SMRange loc;
789   } handlerContext{*this, filename, fileLoc};
790 
791   // Set the diagnostic handler for the tablegen source manager.
792   tdSrcMgr.setDiagHandler(
793       [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
794         auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
795         (void)ctx->parser.emitError(
796             ctx->loc,
797             llvm::formatv("error while processing include file `{0}`: {1}",
798                           ctx->filename, diag.getMessage()));
799       },
800       &handlerContext);
801 
802   // Parse the tablegen file.
803   llvm::RecordKeeper tdRecords;
804   if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
805     return failure();
806 
807   // Process the parsed records.
808   processTdIncludeRecords(tdRecords, decls);
809 
810   // After we are done processing, move all of the tablegen source buffers to
811   // the main parser source mgr. This allows for directly using source locations
812   // from the .td files without needing to remap them.
813   parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
814   return success();
815 }
816 
processTdIncludeRecords(llvm::RecordKeeper & tdRecords,SmallVectorImpl<ast::Decl * > & decls)817 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
818                                      SmallVectorImpl<ast::Decl *> &decls) {
819   // Return the length kind of the given value.
820   auto getLengthKind = [](const auto &value) {
821     if (value.isOptional())
822       return ods::VariableLengthKind::Optional;
823     return value.isVariadic() ? ods::VariableLengthKind::Variadic
824                               : ods::VariableLengthKind::Single;
825   };
826 
827   // Insert a type constraint into the ODS context.
828   ods::Context &odsContext = ctx.getODSContext();
829   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
830       -> const ods::TypeConstraint & {
831     return odsContext.insertTypeConstraint(
832         cst.constraint.getUniqueDefName(),
833         processDoc(cst.constraint.getSummary()),
834         cst.constraint.getCPPClassName());
835   };
836   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
837     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
838   };
839 
840   // Process the parsed tablegen records to build ODS information.
841   /// Operations.
842   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
843     tblgen::Operator op(def);
844 
845     // Check to see if this operation is known to support type inferrence.
846     bool supportsResultTypeInferrence =
847         op.getTrait("::mlir::InferTypeOpInterface::Trait");
848 
849     bool inserted = false;
850     ods::Operation *odsOp = nullptr;
851     std::tie(odsOp, inserted) = odsContext.insertOperation(
852         op.getOperationName(), processDoc(op.getSummary()),
853         processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
854         supportsResultTypeInferrence, op.getLoc().front());
855 
856     // Ignore operations that have already been added.
857     if (!inserted)
858       continue;
859 
860     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
861       odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
862                              odsContext.insertAttributeConstraint(
863                                  attr.attr.getUniqueDefName(),
864                                  processDoc(attr.attr.getSummary()),
865                                  attr.attr.getStorageType()));
866     }
867     for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
868       odsOp->appendOperand(operand.name, getLengthKind(operand),
869                            addTypeConstraint(operand));
870     }
871     for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
872       odsOp->appendResult(result.name, getLengthKind(result),
873                           addTypeConstraint(result));
874     }
875   }
876 
877   auto shouldBeSkipped = [this](llvm::Record *def) {
878     return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
879            def->isSubClassOf("DeclareInterfaceMethods");
880   };
881 
882   /// Attr constraints.
883   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
884     if (shouldBeSkipped(def))
885       continue;
886 
887     tblgen::Attribute constraint(def);
888     decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
889         constraint, convertLocToRange(def->getLoc().front()), attrTy,
890         constraint.getStorageType()));
891   }
892   /// Type constraints.
893   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
894     if (shouldBeSkipped(def))
895       continue;
896 
897     tblgen::TypeConstraint constraint(def);
898     decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
899         constraint, convertLocToRange(def->getLoc().front()), typeTy,
900         constraint.getCPPClassName()));
901   }
902   /// OpInterfaces.
903   ast::Type opTy = ast::OperationType::get(ctx);
904   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
905     if (shouldBeSkipped(def))
906       continue;
907 
908     SMRange loc = convertLocToRange(def->getLoc().front());
909 
910     std::string cppClassName =
911         llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
912                       def->getValueAsString("cppInterfaceName"))
913             .str();
914     std::string codeBlock =
915         llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
916                       cppClassName)
917             .str();
918 
919     std::string desc =
920         processAndFormatDoc(def->getValueAsString("description"));
921     decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
922         def->getName(), codeBlock, loc, opTy, cppClassName, desc));
923   }
924 }
925 
926 template <typename ConstraintT>
createODSNativePDLLConstraintDecl(StringRef name,StringRef codeBlock,SMRange loc,ast::Type type,StringRef nativeType,StringRef docString)927 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
928     StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
929     StringRef nativeType, StringRef docString) {
930   // Build the single input parameter.
931   ast::DeclScope *argScope = pushDeclScope();
932   auto *paramVar = ast::VariableDecl::create(
933       ctx, ast::Name::create(ctx, "self", loc), type,
934       /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
935   argScope->add(paramVar);
936   popDeclScope();
937 
938   // Build the native constraint.
939   auto *constraintDecl = ast::UserConstraintDecl::createNative(
940       ctx, ast::Name::create(ctx, name, loc), paramVar,
941       /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
942   constraintDecl->setDocComment(ctx, docString);
943   curDeclScope->add(constraintDecl);
944   return constraintDecl;
945 }
946 
947 template <typename ConstraintT>
948 ast::Decl *
createODSNativePDLLConstraintDecl(const tblgen::Constraint & constraint,SMRange loc,ast::Type type,StringRef nativeType)949 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
950                                           SMRange loc, ast::Type type,
951                                           StringRef nativeType) {
952   // Format the condition template.
953   tblgen::FmtContext fmtContext;
954   fmtContext.withSelf("self");
955   std::string codeBlock = tblgen::tgfmt(
956       "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
957       &fmtContext);
958 
959   // If documentation was enabled, build the doc string for the generated
960   // constraint. It would be nice to do this lazily, but TableGen information is
961   // destroyed after we finish parsing the file.
962   std::string docString;
963   if (enableDocumentation) {
964     StringRef desc = constraint.getDescription();
965     docString = processAndFormatDoc(
966         constraint.getSummary() +
967         (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
968   }
969 
970   return createODSNativePDLLConstraintDecl<ConstraintT>(
971       constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
972       docString);
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // Decls
977 
parseTopLevelDecl()978 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
979   FailureOr<ast::Decl *> decl;
980   switch (curToken.getKind()) {
981   case Token::kw_Constraint:
982     decl = parseUserConstraintDecl();
983     break;
984   case Token::kw_Pattern:
985     decl = parsePatternDecl();
986     break;
987   case Token::kw_Rewrite:
988     decl = parseUserRewriteDecl();
989     break;
990   default:
991     return emitError("expected top-level declaration, such as a `Pattern`");
992   }
993   if (failed(decl))
994     return failure();
995 
996   // If the decl has a name, add it to the current scope.
997   if (const ast::Name *name = (*decl)->getName()) {
998     if (failed(checkDefineNamedDecl(*name)))
999       return failure();
1000     curDeclScope->add(*decl);
1001   }
1002   return decl;
1003 }
1004 
1005 FailureOr<ast::NamedAttributeDecl *>
parseNamedAttributeDecl(Optional<StringRef> parentOpName)1006 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
1007   // Check for name code completion.
1008   if (curToken.is(Token::code_complete))
1009     return codeCompleteAttributeName(parentOpName);
1010 
1011   std::string attrNameStr;
1012   if (curToken.isString())
1013     attrNameStr = curToken.getStringValue();
1014   else if (curToken.is(Token::identifier) || curToken.isKeyword())
1015     attrNameStr = curToken.getSpelling().str();
1016   else
1017     return emitError("expected identifier or string attribute name");
1018   const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1019   consumeToken();
1020 
1021   // Check for a value of the attribute.
1022   ast::Expr *attrValue = nullptr;
1023   if (consumeIf(Token::equal)) {
1024     FailureOr<ast::Expr *> attrExpr = parseExpr();
1025     if (failed(attrExpr))
1026       return failure();
1027     attrValue = *attrExpr;
1028   } else {
1029     // If there isn't a concrete value, create an expression representing a
1030     // UnitAttr.
1031     attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1032   }
1033 
1034   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1035 }
1036 
parseLambdaBody(function_ref<LogicalResult (ast::Stmt * &)> processStatementFn,bool expectTerminalSemicolon)1037 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1038     function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1039     bool expectTerminalSemicolon) {
1040   consumeToken(Token::equal_arrow);
1041 
1042   // Parse the single statement of the lambda body.
1043   SMLoc bodyStartLoc = curToken.getStartLoc();
1044   pushDeclScope();
1045   FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1046   bool failedToParse =
1047       failed(singleStatement) || failed(processStatementFn(*singleStatement));
1048   popDeclScope();
1049   if (failedToParse)
1050     return failure();
1051 
1052   SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1053   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1054 }
1055 
parseArgumentDecl()1056 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1057   // Ensure that the argument is named.
1058   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1059     return emitError("expected identifier argument name");
1060 
1061   // Parse the argument similarly to a normal variable.
1062   StringRef name = curToken.getSpelling();
1063   SMRange nameLoc = curToken.getLoc();
1064   consumeToken();
1065 
1066   if (failed(
1067           parseToken(Token::colon, "expected `:` before argument constraint")))
1068     return failure();
1069 
1070   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1071   if (failed(cst))
1072     return failure();
1073 
1074   return createArgOrResultVariableDecl(name, nameLoc, *cst);
1075 }
1076 
parseResultDecl(unsigned resultNum)1077 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1078   // Check to see if this result is named.
1079   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1080     // Check to see if this name actually refers to a Constraint.
1081     ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1082     if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1083       // If yes, and this is a Rewrite, give a nice error message as non-Core
1084       // constraints are not supported on Rewrite results.
1085       if (parserContext == ParserContext::Rewrite) {
1086         return emitError(
1087             "`Rewrite` results are only permitted to use core constraints, "
1088             "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1089       }
1090 
1091       // Otherwise, parse this as an unnamed result variable.
1092     } else {
1093       // If it wasn't a constraint, parse the result similarly to a variable. If
1094       // there is already an existing decl, we will emit an error when defining
1095       // this variable later.
1096       StringRef name = curToken.getSpelling();
1097       SMRange nameLoc = curToken.getLoc();
1098       consumeToken();
1099 
1100       if (failed(parseToken(Token::colon,
1101                             "expected `:` before result constraint")))
1102         return failure();
1103 
1104       FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1105       if (failed(cst))
1106         return failure();
1107 
1108       return createArgOrResultVariableDecl(name, nameLoc, *cst);
1109     }
1110   }
1111 
1112   // If it isn't named, we parse the constraint directly and create an unnamed
1113   // result variable.
1114   FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1115   if (failed(cst))
1116     return failure();
1117 
1118   return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1119 }
1120 
1121 FailureOr<ast::UserConstraintDecl *>
parseUserConstraintDecl(bool isInline)1122 Parser::parseUserConstraintDecl(bool isInline) {
1123   // Constraints and rewrites have very similar formats, dispatch to a shared
1124   // interface for parsing.
1125   return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1126       [&](auto &&...args) {
1127         return this->parseUserPDLLConstraintDecl(args...);
1128       },
1129       ParserContext::Constraint, "constraint", isInline);
1130 }
1131 
parseInlineUserConstraintDecl()1132 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1133   FailureOr<ast::UserConstraintDecl *> decl =
1134       parseUserConstraintDecl(/*isInline=*/true);
1135   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1136     return failure();
1137 
1138   curDeclScope->add(*decl);
1139   return decl;
1140 }
1141 
parseUserPDLLConstraintDecl(const ast::Name & name,bool isInline,ArrayRef<ast::VariableDecl * > arguments,ast::DeclScope * argumentScope,ArrayRef<ast::VariableDecl * > results,ast::Type resultType)1142 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1143     const ast::Name &name, bool isInline,
1144     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1145     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1146   // Push the argument scope back onto the list, so that the body can
1147   // reference arguments.
1148   pushDeclScope(argumentScope);
1149 
1150   // Parse the body of the constraint. The body is either defined as a compound
1151   // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1152   ast::CompoundStmt *body;
1153   if (curToken.is(Token::equal_arrow)) {
1154     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1155         [&](ast::Stmt *&stmt) -> LogicalResult {
1156           ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1157           if (!stmtExpr) {
1158             return emitError(stmt->getLoc(),
1159                              "expected `Constraint` lambda body to contain a "
1160                              "single expression");
1161           }
1162           stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1163           return success();
1164         },
1165         /*expectTerminalSemicolon=*/!isInline);
1166     if (failed(bodyResult))
1167       return failure();
1168     body = *bodyResult;
1169   } else {
1170     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1171     if (failed(bodyResult))
1172       return failure();
1173     body = *bodyResult;
1174 
1175     // Verify the structure of the body.
1176     auto bodyIt = body->begin(), bodyE = body->end();
1177     for (; bodyIt != bodyE; ++bodyIt)
1178       if (isa<ast::ReturnStmt>(*bodyIt))
1179         break;
1180     if (failed(validateUserConstraintOrRewriteReturn(
1181             "Constraint", body, bodyIt, bodyE, results, resultType)))
1182       return failure();
1183   }
1184   popDeclScope();
1185 
1186   return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1187       name, arguments, results, resultType, body);
1188 }
1189 
parseUserRewriteDecl(bool isInline)1190 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1191   // Constraints and rewrites have very similar formats, dispatch to a shared
1192   // interface for parsing.
1193   return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1194       [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1195       ParserContext::Rewrite, "rewrite", isInline);
1196 }
1197 
parseInlineUserRewriteDecl()1198 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1199   FailureOr<ast::UserRewriteDecl *> decl =
1200       parseUserRewriteDecl(/*isInline=*/true);
1201   if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1202     return failure();
1203 
1204   curDeclScope->add(*decl);
1205   return decl;
1206 }
1207 
parseUserPDLLRewriteDecl(const ast::Name & name,bool isInline,ArrayRef<ast::VariableDecl * > arguments,ast::DeclScope * argumentScope,ArrayRef<ast::VariableDecl * > results,ast::Type resultType)1208 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1209     const ast::Name &name, bool isInline,
1210     ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1211     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1212   // Push the argument scope back onto the list, so that the body can
1213   // reference arguments.
1214   curDeclScope = argumentScope;
1215   ast::CompoundStmt *body;
1216   if (curToken.is(Token::equal_arrow)) {
1217     FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1218         [&](ast::Stmt *&statement) -> LogicalResult {
1219           if (isa<ast::OpRewriteStmt>(statement))
1220             return success();
1221 
1222           ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1223           if (!statementExpr) {
1224             return emitError(
1225                 statement->getLoc(),
1226                 "expected `Rewrite` lambda body to contain a single expression "
1227                 "or an operation rewrite statement; such as `erase`, "
1228                 "`replace`, or `rewrite`");
1229           }
1230           statement =
1231               ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1232           return success();
1233         },
1234         /*expectTerminalSemicolon=*/!isInline);
1235     if (failed(bodyResult))
1236       return failure();
1237     body = *bodyResult;
1238   } else {
1239     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1240     if (failed(bodyResult))
1241       return failure();
1242     body = *bodyResult;
1243   }
1244   popDeclScope();
1245 
1246   // Verify the structure of the body.
1247   auto bodyIt = body->begin(), bodyE = body->end();
1248   for (; bodyIt != bodyE; ++bodyIt)
1249     if (isa<ast::ReturnStmt>(*bodyIt))
1250       break;
1251   if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1252                                                    bodyE, results, resultType)))
1253     return failure();
1254   return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1255       name, arguments, results, resultType, body);
1256 }
1257 
1258 template <typename T, typename ParseUserPDLLDeclFnT>
parseUserConstraintOrRewriteDecl(ParseUserPDLLDeclFnT && parseUserPDLLFn,ParserContext declContext,StringRef anonymousNamePrefix,bool isInline)1259 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1260     ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1261     StringRef anonymousNamePrefix, bool isInline) {
1262   SMRange loc = curToken.getLoc();
1263   consumeToken();
1264   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1265 
1266   // Parse the name of the decl.
1267   const ast::Name *name = nullptr;
1268   if (curToken.isNot(Token::identifier)) {
1269     // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1270     // in C++, so being unnamed is fine.
1271     if (!isInline)
1272       return emitError("expected identifier name");
1273 
1274     // Create a unique anonymous name to use, as the name for this decl is not
1275     // important.
1276     std::string anonName =
1277         llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1278                       anonymousDeclNameCounter++)
1279             .str();
1280     name = &ast::Name::create(ctx, anonName, loc);
1281   } else {
1282     // If a name was provided, we can use it directly.
1283     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1284     consumeToken(Token::identifier);
1285   }
1286 
1287   // Parse the functional signature of the decl.
1288   SmallVector<ast::VariableDecl *> arguments, results;
1289   ast::DeclScope *argumentScope;
1290   ast::Type resultType;
1291   if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1292                                                    argumentScope, resultType)))
1293     return failure();
1294 
1295   // Check to see which type of constraint this is. If the constraint contains a
1296   // compound body, this is a PDLL decl.
1297   if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1298     return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1299                            resultType);
1300 
1301   // Otherwise, this is a native decl.
1302   return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1303                                                    results, resultType);
1304 }
1305 
1306 template <typename T>
parseUserNativeConstraintOrRewriteDecl(const ast::Name & name,bool isInline,ArrayRef<ast::VariableDecl * > arguments,ArrayRef<ast::VariableDecl * > results,ast::Type resultType)1307 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1308     const ast::Name &name, bool isInline,
1309     ArrayRef<ast::VariableDecl *> arguments,
1310     ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1311   // If followed by a string, the native code body has also been specified.
1312   std::string codeStrStorage;
1313   Optional<StringRef> optCodeStr;
1314   if (curToken.isString()) {
1315     codeStrStorage = curToken.getStringValue();
1316     optCodeStr = codeStrStorage;
1317     consumeToken();
1318   } else if (isInline) {
1319     return emitError(name.getLoc(),
1320                      "external declarations must be declared in global scope");
1321   } else if (curToken.is(Token::error)) {
1322     return failure();
1323   }
1324   if (failed(parseToken(Token::semicolon,
1325                         "expected `;` after native declaration")))
1326     return failure();
1327   // TODO: PDL should be able to support constraint results in certain
1328   // situations, we should revise this.
1329   if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1330     return emitError(
1331         "native Constraints currently do not support returning results");
1332   }
1333   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1334 }
1335 
parseUserConstraintOrRewriteSignature(SmallVectorImpl<ast::VariableDecl * > & arguments,SmallVectorImpl<ast::VariableDecl * > & results,ast::DeclScope * & argumentScope,ast::Type & resultType)1336 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1337     SmallVectorImpl<ast::VariableDecl *> &arguments,
1338     SmallVectorImpl<ast::VariableDecl *> &results,
1339     ast::DeclScope *&argumentScope, ast::Type &resultType) {
1340   // Parse the argument list of the decl.
1341   if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1342     return failure();
1343 
1344   argumentScope = pushDeclScope();
1345   if (curToken.isNot(Token::r_paren)) {
1346     do {
1347       FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1348       if (failed(argument))
1349         return failure();
1350       arguments.emplace_back(*argument);
1351     } while (consumeIf(Token::comma));
1352   }
1353   popDeclScope();
1354   if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1355     return failure();
1356 
1357   // Parse the results of the decl.
1358   pushDeclScope();
1359   if (consumeIf(Token::arrow)) {
1360     auto parseResultFn = [&]() -> LogicalResult {
1361       FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1362       if (failed(result))
1363         return failure();
1364       results.emplace_back(*result);
1365       return success();
1366     };
1367 
1368     // Check for a list of results.
1369     if (consumeIf(Token::l_paren)) {
1370       do {
1371         if (failed(parseResultFn()))
1372           return failure();
1373       } while (consumeIf(Token::comma));
1374       if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1375         return failure();
1376 
1377       // Otherwise, there is only one result.
1378     } else if (failed(parseResultFn())) {
1379       return failure();
1380     }
1381   }
1382   popDeclScope();
1383 
1384   // Compute the result type of the decl.
1385   resultType = createUserConstraintRewriteResultType(results);
1386 
1387   // Verify that results are only named if there are more than one.
1388   if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1389     return emitError(
1390         results.front()->getLoc(),
1391         "cannot create a single-element tuple with an element label");
1392   }
1393   return success();
1394 }
1395 
validateUserConstraintOrRewriteReturn(StringRef declType,ast::CompoundStmt * body,ArrayRef<ast::Stmt * >::iterator bodyIt,ArrayRef<ast::Stmt * >::iterator bodyE,ArrayRef<ast::VariableDecl * > results,ast::Type & resultType)1396 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1397     StringRef declType, ast::CompoundStmt *body,
1398     ArrayRef<ast::Stmt *>::iterator bodyIt,
1399     ArrayRef<ast::Stmt *>::iterator bodyE,
1400     ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1401   // Handle if a `return` was provided.
1402   if (bodyIt != bodyE) {
1403     // Emit an error if we have trailing statements after the return.
1404     if (std::next(bodyIt) != bodyE) {
1405       return emitError(
1406           (*std::next(bodyIt))->getLoc(),
1407           llvm::formatv("`return` terminated the `{0}` body, but found "
1408                         "trailing statements afterwards",
1409                         declType));
1410     }
1411 
1412     // Otherwise if a return wasn't provided, check that no results are
1413     // expected.
1414   } else if (!results.empty()) {
1415     return emitError(
1416         {body->getLoc().End, body->getLoc().End},
1417         llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1418                       declType, resultType));
1419   }
1420   return success();
1421 }
1422 
parsePatternLambdaBody()1423 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1424   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1425     if (isa<ast::OpRewriteStmt>(statement))
1426       return success();
1427     return emitError(
1428         statement->getLoc(),
1429         "expected Pattern lambda body to contain a single operation "
1430         "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1431   });
1432 }
1433 
parsePatternDecl()1434 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1435   SMRange loc = curToken.getLoc();
1436   consumeToken(Token::kw_Pattern);
1437   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1438                                               ParserContext::PatternMatch);
1439 
1440   // Check for an optional identifier for the pattern name.
1441   const ast::Name *name = nullptr;
1442   if (curToken.is(Token::identifier)) {
1443     name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1444     consumeToken(Token::identifier);
1445   }
1446 
1447   // Parse any pattern metadata.
1448   ParsedPatternMetadata metadata;
1449   if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1450     return failure();
1451 
1452   // Parse the pattern body.
1453   ast::CompoundStmt *body;
1454 
1455   // Handle a lambda body.
1456   if (curToken.is(Token::equal_arrow)) {
1457     FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1458     if (failed(bodyResult))
1459       return failure();
1460     body = *bodyResult;
1461   } else {
1462     if (curToken.isNot(Token::l_brace))
1463       return emitError("expected `{` or `=>` to start pattern body");
1464     FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1465     if (failed(bodyResult))
1466       return failure();
1467     body = *bodyResult;
1468 
1469     // Verify the body of the pattern.
1470     auto bodyIt = body->begin(), bodyE = body->end();
1471     for (; bodyIt != bodyE; ++bodyIt) {
1472       if (isa<ast::ReturnStmt>(*bodyIt)) {
1473         return emitError((*bodyIt)->getLoc(),
1474                          "`return` statements are only permitted within a "
1475                          "`Constraint` or `Rewrite` body");
1476       }
1477       // Break when we've found the rewrite statement.
1478       if (isa<ast::OpRewriteStmt>(*bodyIt))
1479         break;
1480     }
1481     if (bodyIt == bodyE) {
1482       return emitError(loc,
1483                        "expected Pattern body to terminate with an operation "
1484                        "rewrite statement, such as `erase`");
1485     }
1486     if (std::next(bodyIt) != bodyE) {
1487       return emitError((*std::next(bodyIt))->getLoc(),
1488                        "Pattern body was terminated by an operation "
1489                        "rewrite statement, but found trailing statements");
1490     }
1491   }
1492 
1493   return createPatternDecl(loc, name, metadata, body);
1494 }
1495 
1496 LogicalResult
parsePatternDeclMetadata(ParsedPatternMetadata & metadata)1497 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1498   Optional<SMRange> benefitLoc;
1499   Optional<SMRange> hasBoundedRecursionLoc;
1500 
1501   do {
1502     // Handle metadata code completion.
1503     if (curToken.is(Token::code_complete))
1504       return codeCompletePatternMetadata();
1505 
1506     if (curToken.isNot(Token::identifier))
1507       return emitError("expected pattern metadata identifier");
1508     StringRef metadataStr = curToken.getSpelling();
1509     SMRange metadataLoc = curToken.getLoc();
1510     consumeToken(Token::identifier);
1511 
1512     // Parse the benefit metadata: benefit(<integer-value>)
1513     if (metadataStr == "benefit") {
1514       if (benefitLoc) {
1515         return emitErrorAndNote(metadataLoc,
1516                                 "pattern benefit has already been specified",
1517                                 *benefitLoc, "see previous definition here");
1518       }
1519       if (failed(parseToken(Token::l_paren,
1520                             "expected `(` before pattern benefit")))
1521         return failure();
1522 
1523       uint16_t benefitValue = 0;
1524       if (curToken.isNot(Token::integer))
1525         return emitError("expected integral pattern benefit");
1526       if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1527         return emitError(
1528             "expected pattern benefit to fit within a 16-bit integer");
1529       consumeToken(Token::integer);
1530 
1531       metadata.benefit = benefitValue;
1532       benefitLoc = metadataLoc;
1533 
1534       if (failed(
1535               parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1536         return failure();
1537       continue;
1538     }
1539 
1540     // Parse the bounded recursion metadata: recursion
1541     if (metadataStr == "recursion") {
1542       if (hasBoundedRecursionLoc) {
1543         return emitErrorAndNote(
1544             metadataLoc,
1545             "pattern recursion metadata has already been specified",
1546             *hasBoundedRecursionLoc, "see previous definition here");
1547       }
1548       metadata.hasBoundedRecursion = true;
1549       hasBoundedRecursionLoc = metadataLoc;
1550       continue;
1551     }
1552 
1553     return emitError(metadataLoc, "unknown pattern metadata");
1554   } while (consumeIf(Token::comma));
1555 
1556   return success();
1557 }
1558 
parseTypeConstraintExpr()1559 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1560   consumeToken(Token::less);
1561 
1562   FailureOr<ast::Expr *> typeExpr = parseExpr();
1563   if (failed(typeExpr) ||
1564       failed(parseToken(Token::greater,
1565                         "expected `>` after variable type constraint")))
1566     return failure();
1567   return typeExpr;
1568 }
1569 
checkDefineNamedDecl(const ast::Name & name)1570 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1571   assert(curDeclScope && "defining decl outside of a decl scope");
1572   if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1573     return emitErrorAndNote(
1574         name.getLoc(), "`" + name.getName() + "` has already been defined",
1575         lastDecl->getName()->getLoc(), "see previous definition here");
1576   }
1577   return success();
1578 }
1579 
1580 FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name,SMRange nameLoc,ast::Type type,ast::Expr * initExpr,ArrayRef<ast::ConstraintRef> constraints)1581 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1582                            ast::Expr *initExpr,
1583                            ArrayRef<ast::ConstraintRef> constraints) {
1584   assert(curDeclScope && "defining variable outside of decl scope");
1585   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1586 
1587   // If the name of the variable indicates a special variable, we don't add it
1588   // to the scope. This variable is local to the definition point.
1589   if (name.empty() || name == "_") {
1590     return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1591                                      constraints);
1592   }
1593   if (failed(checkDefineNamedDecl(nameDecl)))
1594     return failure();
1595 
1596   auto *varDecl =
1597       ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1598   curDeclScope->add(varDecl);
1599   return varDecl;
1600 }
1601 
1602 FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name,SMRange nameLoc,ast::Type type,ArrayRef<ast::ConstraintRef> constraints)1603 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1604                            ArrayRef<ast::ConstraintRef> constraints) {
1605   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1606                             constraints);
1607 }
1608 
parseVariableDeclConstraintList(SmallVectorImpl<ast::ConstraintRef> & constraints)1609 LogicalResult Parser::parseVariableDeclConstraintList(
1610     SmallVectorImpl<ast::ConstraintRef> &constraints) {
1611   Optional<SMRange> typeConstraint;
1612   auto parseSingleConstraint = [&] {
1613     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1614         typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1615         /*allowNonCoreConstraints=*/true);
1616     if (failed(constraint))
1617       return failure();
1618     constraints.push_back(*constraint);
1619     return success();
1620   };
1621 
1622   // Check to see if this is a single constraint, or a list.
1623   if (!consumeIf(Token::l_square))
1624     return parseSingleConstraint();
1625 
1626   do {
1627     if (failed(parseSingleConstraint()))
1628       return failure();
1629   } while (consumeIf(Token::comma));
1630   return parseToken(Token::r_square, "expected `]` after constraint list");
1631 }
1632 
1633 FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> & typeConstraint,ArrayRef<ast::ConstraintRef> existingConstraints,bool allowInlineTypeConstraints,bool allowNonCoreConstraints)1634 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1635                         ArrayRef<ast::ConstraintRef> existingConstraints,
1636                         bool allowInlineTypeConstraints,
1637                         bool allowNonCoreConstraints) {
1638   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1639     if (!allowInlineTypeConstraints) {
1640       return emitError(
1641           curToken.getLoc(),
1642           "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1643           "permitted on arguments or results");
1644     }
1645     if (typeConstraint)
1646       return emitErrorAndNote(
1647           curToken.getLoc(),
1648           "the type of this variable has already been constrained",
1649           *typeConstraint, "see previous constraint location here");
1650     FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1651     if (failed(constraintExpr))
1652       return failure();
1653     typeExpr = *constraintExpr;
1654     typeConstraint = typeExpr->getLoc();
1655     return success();
1656   };
1657 
1658   SMRange loc = curToken.getLoc();
1659   switch (curToken.getKind()) {
1660   case Token::kw_Attr: {
1661     consumeToken(Token::kw_Attr);
1662 
1663     // Check for a type constraint.
1664     ast::Expr *typeExpr = nullptr;
1665     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1666       return failure();
1667     return ast::ConstraintRef(
1668         ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1669   }
1670   case Token::kw_Op: {
1671     consumeToken(Token::kw_Op);
1672 
1673     // Parse an optional operation name. If the name isn't provided, this refers
1674     // to "any" operation.
1675     FailureOr<ast::OpNameDecl *> opName =
1676         parseWrappedOperationName(/*allowEmptyName=*/true);
1677     if (failed(opName))
1678       return failure();
1679 
1680     return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1681                               loc);
1682   }
1683   case Token::kw_Type:
1684     consumeToken(Token::kw_Type);
1685     return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1686   case Token::kw_TypeRange:
1687     consumeToken(Token::kw_TypeRange);
1688     return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1689                               loc);
1690   case Token::kw_Value: {
1691     consumeToken(Token::kw_Value);
1692 
1693     // Check for a type constraint.
1694     ast::Expr *typeExpr = nullptr;
1695     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1696       return failure();
1697 
1698     return ast::ConstraintRef(
1699         ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1700   }
1701   case Token::kw_ValueRange: {
1702     consumeToken(Token::kw_ValueRange);
1703 
1704     // Check for a type constraint.
1705     ast::Expr *typeExpr = nullptr;
1706     if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1707       return failure();
1708 
1709     return ast::ConstraintRef(
1710         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1711   }
1712 
1713   case Token::kw_Constraint: {
1714     // Handle an inline constraint.
1715     FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1716     if (failed(decl))
1717       return failure();
1718     return ast::ConstraintRef(*decl, loc);
1719   }
1720   case Token::identifier: {
1721     StringRef constraintName = curToken.getSpelling();
1722     consumeToken(Token::identifier);
1723 
1724     // Lookup the referenced constraint.
1725     ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1726     if (!cstDecl) {
1727       return emitError(loc, "unknown reference to constraint `" +
1728                                 constraintName + "`");
1729     }
1730 
1731     // Handle a reference to a proper constraint.
1732     if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1733       return ast::ConstraintRef(cst, loc);
1734 
1735     return emitErrorAndNote(
1736         loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1737         "see the definition of `" + constraintName + "` here");
1738   }
1739     // Handle single entity constraint code completion.
1740   case Token::code_complete: {
1741     // Try to infer the current type for use by code completion.
1742     ast::Type inferredType;
1743     if (failed(validateVariableConstraints(existingConstraints, inferredType,
1744                                            allowNonCoreConstraints)))
1745       return failure();
1746 
1747     return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1748                                       allowInlineTypeConstraints);
1749   }
1750   default:
1751     break;
1752   }
1753   return emitError(loc, "expected identifier constraint");
1754 }
1755 
parseArgOrResultConstraint()1756 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1757   // Constraint arguments may apply more complex constraints via the arguments.
1758   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1759 
1760   Optional<SMRange> typeConstraint;
1761   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1762                          /*allowInlineTypeConstraints=*/false,
1763                          allowNonCoreConstraints);
1764 }
1765 
1766 //===----------------------------------------------------------------------===//
1767 // Exprs
1768 
parseExpr()1769 FailureOr<ast::Expr *> Parser::parseExpr() {
1770   if (curToken.is(Token::underscore))
1771     return parseUnderscoreExpr();
1772 
1773   // Parse the LHS expression.
1774   FailureOr<ast::Expr *> lhsExpr;
1775   switch (curToken.getKind()) {
1776   case Token::kw_attr:
1777     lhsExpr = parseAttributeExpr();
1778     break;
1779   case Token::kw_Constraint:
1780     lhsExpr = parseInlineConstraintLambdaExpr();
1781     break;
1782   case Token::identifier:
1783     lhsExpr = parseIdentifierExpr();
1784     break;
1785   case Token::kw_op:
1786     lhsExpr = parseOperationExpr();
1787     break;
1788   case Token::kw_Rewrite:
1789     lhsExpr = parseInlineRewriteLambdaExpr();
1790     break;
1791   case Token::kw_type:
1792     lhsExpr = parseTypeExpr();
1793     break;
1794   case Token::l_paren:
1795     lhsExpr = parseTupleExpr();
1796     break;
1797   default:
1798     return emitError("expected expression");
1799   }
1800   if (failed(lhsExpr))
1801     return failure();
1802 
1803   // Check for an operator expression.
1804   while (true) {
1805     switch (curToken.getKind()) {
1806     case Token::dot:
1807       lhsExpr = parseMemberAccessExpr(*lhsExpr);
1808       break;
1809     case Token::l_paren:
1810       lhsExpr = parseCallExpr(*lhsExpr);
1811       break;
1812     default:
1813       return lhsExpr;
1814     }
1815     if (failed(lhsExpr))
1816       return failure();
1817   }
1818 }
1819 
parseAttributeExpr()1820 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1821   SMRange loc = curToken.getLoc();
1822   consumeToken(Token::kw_attr);
1823 
1824   // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1825   // identifier.
1826   if (!consumeIf(Token::less)) {
1827     resetToken(loc);
1828     return parseIdentifierExpr();
1829   }
1830 
1831   if (!curToken.isString())
1832     return emitError("expected string literal containing MLIR attribute");
1833   std::string attrExpr = curToken.getStringValue();
1834   consumeToken();
1835 
1836   loc.End = curToken.getEndLoc();
1837   if (failed(
1838           parseToken(Token::greater, "expected `>` after attribute literal")))
1839     return failure();
1840   return ast::AttributeExpr::create(ctx, loc, attrExpr);
1841 }
1842 
parseCallExpr(ast::Expr * parentExpr)1843 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1844   consumeToken(Token::l_paren);
1845 
1846   // Parse the arguments of the call.
1847   SmallVector<ast::Expr *> arguments;
1848   if (curToken.isNot(Token::r_paren)) {
1849     do {
1850       // Handle code completion for the call arguments.
1851       if (curToken.is(Token::code_complete)) {
1852         codeCompleteCallSignature(parentExpr, arguments.size());
1853         return failure();
1854       }
1855 
1856       FailureOr<ast::Expr *> argument = parseExpr();
1857       if (failed(argument))
1858         return failure();
1859       arguments.push_back(*argument);
1860     } while (consumeIf(Token::comma));
1861   }
1862 
1863   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1864   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1865     return failure();
1866 
1867   return createCallExpr(loc, parentExpr, arguments);
1868 }
1869 
parseDeclRefExpr(StringRef name,SMRange loc)1870 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1871   ast::Decl *decl = curDeclScope->lookup(name);
1872   if (!decl)
1873     return emitError(loc, "undefined reference to `" + name + "`");
1874 
1875   return createDeclRefExpr(loc, decl);
1876 }
1877 
parseIdentifierExpr()1878 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1879   StringRef name = curToken.getSpelling();
1880   SMRange nameLoc = curToken.getLoc();
1881   consumeToken();
1882 
1883   // Check to see if this is a decl ref expression that defines a variable
1884   // inline.
1885   if (consumeIf(Token::colon)) {
1886     SmallVector<ast::ConstraintRef> constraints;
1887     if (failed(parseVariableDeclConstraintList(constraints)))
1888       return failure();
1889     ast::Type type;
1890     if (failed(validateVariableConstraints(constraints, type)))
1891       return failure();
1892     return createInlineVariableExpr(type, name, nameLoc, constraints);
1893   }
1894 
1895   return parseDeclRefExpr(name, nameLoc);
1896 }
1897 
parseInlineConstraintLambdaExpr()1898 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1899   FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1900   if (failed(decl))
1901     return failure();
1902 
1903   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1904                                   ast::ConstraintType::get(ctx));
1905 }
1906 
parseInlineRewriteLambdaExpr()1907 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1908   FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1909   if (failed(decl))
1910     return failure();
1911 
1912   return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1913                                   ast::RewriteType::get(ctx));
1914 }
1915 
parseMemberAccessExpr(ast::Expr * parentExpr)1916 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1917   SMRange dotLoc = curToken.getLoc();
1918   consumeToken(Token::dot);
1919 
1920   // Check for code completion of the member name.
1921   if (curToken.is(Token::code_complete))
1922     return codeCompleteMemberAccess(parentExpr);
1923 
1924   // Parse the member name.
1925   Token memberNameTok = curToken;
1926   if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1927       !memberNameTok.isKeyword())
1928     return emitError(dotLoc, "expected identifier or numeric member name");
1929   StringRef memberName = memberNameTok.getSpelling();
1930   SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1931   consumeToken();
1932 
1933   return createMemberAccessExpr(parentExpr, memberName, loc);
1934 }
1935 
parseOperationName(bool allowEmptyName)1936 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1937   SMRange loc = curToken.getLoc();
1938 
1939   // Check for code completion for the dialect name.
1940   if (curToken.is(Token::code_complete))
1941     return codeCompleteDialectName();
1942 
1943   // Handle the case of an no operation name.
1944   if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1945     if (allowEmptyName)
1946       return ast::OpNameDecl::create(ctx, SMRange());
1947     return emitError("expected dialect namespace");
1948   }
1949   StringRef name = curToken.getSpelling();
1950   consumeToken();
1951 
1952   // Otherwise, this is a literal operation name.
1953   if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1954     return failure();
1955 
1956   // Check for code completion for the operation name.
1957   if (curToken.is(Token::code_complete))
1958     return codeCompleteOperationName(name);
1959 
1960   if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1961     return emitError("expected operation name after dialect namespace");
1962 
1963   name = StringRef(name.data(), name.size() + 1);
1964   do {
1965     name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1966     loc.End = curToken.getEndLoc();
1967     consumeToken();
1968   } while (curToken.isAny(Token::identifier, Token::dot) ||
1969            curToken.isKeyword());
1970   return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1971 }
1972 
1973 FailureOr<ast::OpNameDecl *>
parseWrappedOperationName(bool allowEmptyName)1974 Parser::parseWrappedOperationName(bool allowEmptyName) {
1975   if (!consumeIf(Token::less))
1976     return ast::OpNameDecl::create(ctx, SMRange());
1977 
1978   FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1979   if (failed(opNameDecl))
1980     return failure();
1981 
1982   if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1983     return failure();
1984   return opNameDecl;
1985 }
1986 
1987 FailureOr<ast::Expr *>
parseOperationExpr(OpResultTypeContext inputResultTypeContext)1988 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
1989   SMRange loc = curToken.getLoc();
1990   consumeToken(Token::kw_op);
1991 
1992   // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1993   // identifier.
1994   if (curToken.isNot(Token::less)) {
1995     resetToken(loc);
1996     return parseIdentifierExpr();
1997   }
1998 
1999   // Parse the operation name. The name may be elided, in which case the
2000   // operation refers to "any" operation(i.e. a difference between `MyOp` and
2001   // `Operation*`). Operation names within a rewrite context must be named.
2002   bool allowEmptyName = parserContext != ParserContext::Rewrite;
2003   FailureOr<ast::OpNameDecl *> opNameDecl =
2004       parseWrappedOperationName(allowEmptyName);
2005   if (failed(opNameDecl))
2006     return failure();
2007   Optional<StringRef> opName = (*opNameDecl)->getName();
2008 
2009   // Functor used to create an implicit range variable, used for implicit "all"
2010   // operand or results variables.
2011   auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2012     FailureOr<ast::VariableDecl *> rangeVar =
2013         defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2014     assert(succeeded(rangeVar) && "expected range variable to be valid");
2015     return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2016   };
2017 
2018   // Check for the optional list of operands.
2019   SmallVector<ast::Expr *> operands;
2020   if (!consumeIf(Token::l_paren)) {
2021     // If the operand list isn't specified and we are in a match context, define
2022     // an inplace unconstrained operand range corresponding to all of the
2023     // operands of the operation. This avoids treating zero operands the same
2024     // way as "unconstrained operands".
2025     if (parserContext != ParserContext::Rewrite) {
2026       operands.push_back(createImplicitRangeVar(
2027           ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2028     }
2029   } else if (!consumeIf(Token::r_paren)) {
2030     // If the operand list was specified and non-empty, parse the operands.
2031     do {
2032       // Check for operand signature code completion.
2033       if (curToken.is(Token::code_complete)) {
2034         codeCompleteOperationOperandsSignature(opName, operands.size());
2035         return failure();
2036       }
2037 
2038       FailureOr<ast::Expr *> operand = parseExpr();
2039       if (failed(operand))
2040         return failure();
2041       operands.push_back(*operand);
2042     } while (consumeIf(Token::comma));
2043 
2044     if (failed(parseToken(Token::r_paren,
2045                           "expected `)` after operation operand list")))
2046       return failure();
2047   }
2048 
2049   // Check for the optional list of attributes.
2050   SmallVector<ast::NamedAttributeDecl *> attributes;
2051   if (consumeIf(Token::l_brace)) {
2052     do {
2053       FailureOr<ast::NamedAttributeDecl *> decl =
2054           parseNamedAttributeDecl(opName);
2055       if (failed(decl))
2056         return failure();
2057       attributes.emplace_back(*decl);
2058     } while (consumeIf(Token::comma));
2059 
2060     if (failed(parseToken(Token::r_brace,
2061                           "expected `}` after operation attribute list")))
2062       return failure();
2063   }
2064 
2065   // Handle the result types of the operation.
2066   SmallVector<ast::Expr *> resultTypes;
2067   OpResultTypeContext resultTypeContext = inputResultTypeContext;
2068 
2069   // Check for an explicit list of result types.
2070   if (consumeIf(Token::arrow)) {
2071     if (failed(parseToken(Token::l_paren,
2072                           "expected `(` before operation result type list")))
2073       return failure();
2074 
2075     // If result types are provided, initially assume that the operation does
2076     // not rely on type inferrence. We don't assert that it isn't, because we
2077     // may be inferring the value of some type/type range variables, but given
2078     // that these variables may be defined in calls we can't always discern when
2079     // this is the case.
2080     resultTypeContext = OpResultTypeContext::Explicit;
2081 
2082     // Handle the case of an empty result list.
2083     if (!consumeIf(Token::r_paren)) {
2084       do {
2085         // Check for result signature code completion.
2086         if (curToken.is(Token::code_complete)) {
2087           codeCompleteOperationResultsSignature(opName, resultTypes.size());
2088           return failure();
2089         }
2090 
2091         FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2092         if (failed(resultTypeExpr))
2093           return failure();
2094         resultTypes.push_back(*resultTypeExpr);
2095       } while (consumeIf(Token::comma));
2096 
2097       if (failed(parseToken(Token::r_paren,
2098                             "expected `)` after operation result type list")))
2099         return failure();
2100     }
2101   } else if (parserContext != ParserContext::Rewrite) {
2102     // If the result list isn't specified and we are in a match context, define
2103     // an inplace unconstrained result range corresponding to all of the results
2104     // of the operation. This avoids treating zero results the same way as
2105     // "unconstrained results".
2106     resultTypes.push_back(createImplicitRangeVar(
2107         ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2108   } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2109     // If the result list isn't specified and we are in a rewrite, try to infer
2110     // them at runtime instead.
2111     resultTypeContext = OpResultTypeContext::Interface;
2112   }
2113 
2114   return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2115                              attributes, resultTypes);
2116 }
2117 
parseTupleExpr()2118 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2119   SMRange loc = curToken.getLoc();
2120   consumeToken(Token::l_paren);
2121 
2122   DenseMap<StringRef, SMRange> usedNames;
2123   SmallVector<StringRef> elementNames;
2124   SmallVector<ast::Expr *> elements;
2125   if (curToken.isNot(Token::r_paren)) {
2126     do {
2127       // Check for the optional element name assignment before the value.
2128       StringRef elementName;
2129       if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2130         Token elementNameTok = curToken;
2131         consumeToken();
2132 
2133         // The element name is only present if followed by an `=`.
2134         if (consumeIf(Token::equal)) {
2135           elementName = elementNameTok.getSpelling();
2136 
2137           // Check to see if this name is already used.
2138           auto elementNameIt =
2139               usedNames.try_emplace(elementName, elementNameTok.getLoc());
2140           if (!elementNameIt.second) {
2141             return emitErrorAndNote(
2142                 elementNameTok.getLoc(),
2143                 llvm::formatv("duplicate tuple element label `{0}`",
2144                               elementName),
2145                 elementNameIt.first->getSecond(),
2146                 "see previous label use here");
2147           }
2148         } else {
2149           // Otherwise, we treat this as part of an expression so reset the
2150           // lexer.
2151           resetToken(elementNameTok.getLoc());
2152         }
2153       }
2154       elementNames.push_back(elementName);
2155 
2156       // Parse the tuple element value.
2157       FailureOr<ast::Expr *> element = parseExpr();
2158       if (failed(element))
2159         return failure();
2160       elements.push_back(*element);
2161     } while (consumeIf(Token::comma));
2162   }
2163   loc.End = curToken.getEndLoc();
2164   if (failed(
2165           parseToken(Token::r_paren, "expected `)` after tuple element list")))
2166     return failure();
2167   return createTupleExpr(loc, elements, elementNames);
2168 }
2169 
parseTypeExpr()2170 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2171   SMRange loc = curToken.getLoc();
2172   consumeToken(Token::kw_type);
2173 
2174   // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2175   // identifier.
2176   if (!consumeIf(Token::less)) {
2177     resetToken(loc);
2178     return parseIdentifierExpr();
2179   }
2180 
2181   if (!curToken.isString())
2182     return emitError("expected string literal containing MLIR type");
2183   std::string attrExpr = curToken.getStringValue();
2184   consumeToken();
2185 
2186   loc.End = curToken.getEndLoc();
2187   if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2188     return failure();
2189   return ast::TypeExpr::create(ctx, loc, attrExpr);
2190 }
2191 
parseUnderscoreExpr()2192 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2193   StringRef name = curToken.getSpelling();
2194   SMRange nameLoc = curToken.getLoc();
2195   consumeToken(Token::underscore);
2196 
2197   // Underscore expressions require a constraint list.
2198   if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2199     return failure();
2200 
2201   // Parse the constraints for the expression.
2202   SmallVector<ast::ConstraintRef> constraints;
2203   if (failed(parseVariableDeclConstraintList(constraints)))
2204     return failure();
2205 
2206   ast::Type type;
2207   if (failed(validateVariableConstraints(constraints, type)))
2208     return failure();
2209   return createInlineVariableExpr(type, name, nameLoc, constraints);
2210 }
2211 
2212 //===----------------------------------------------------------------------===//
2213 // Stmts
2214 
parseStmt(bool expectTerminalSemicolon)2215 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2216   FailureOr<ast::Stmt *> stmt;
2217   switch (curToken.getKind()) {
2218   case Token::kw_erase:
2219     stmt = parseEraseStmt();
2220     break;
2221   case Token::kw_let:
2222     stmt = parseLetStmt();
2223     break;
2224   case Token::kw_replace:
2225     stmt = parseReplaceStmt();
2226     break;
2227   case Token::kw_return:
2228     stmt = parseReturnStmt();
2229     break;
2230   case Token::kw_rewrite:
2231     stmt = parseRewriteStmt();
2232     break;
2233   default:
2234     stmt = parseExpr();
2235     break;
2236   }
2237   if (failed(stmt) ||
2238       (expectTerminalSemicolon &&
2239        failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2240     return failure();
2241   return stmt;
2242 }
2243 
parseCompoundStmt()2244 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2245   SMLoc startLoc = curToken.getStartLoc();
2246   consumeToken(Token::l_brace);
2247 
2248   // Push a new block scope and parse any nested statements.
2249   pushDeclScope();
2250   SmallVector<ast::Stmt *> statements;
2251   while (curToken.isNot(Token::r_brace)) {
2252     FailureOr<ast::Stmt *> statement = parseStmt();
2253     if (failed(statement))
2254       return popDeclScope(), failure();
2255     statements.push_back(*statement);
2256   }
2257   popDeclScope();
2258 
2259   // Consume the end brace.
2260   SMRange location(startLoc, curToken.getEndLoc());
2261   consumeToken(Token::r_brace);
2262 
2263   return ast::CompoundStmt::create(ctx, location, statements);
2264 }
2265 
parseEraseStmt()2266 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2267   if (parserContext == ParserContext::Constraint)
2268     return emitError("`erase` cannot be used within a Constraint");
2269   SMRange loc = curToken.getLoc();
2270   consumeToken(Token::kw_erase);
2271 
2272   // Parse the root operation expression.
2273   FailureOr<ast::Expr *> rootOp = parseExpr();
2274   if (failed(rootOp))
2275     return failure();
2276 
2277   return createEraseStmt(loc, *rootOp);
2278 }
2279 
parseLetStmt()2280 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2281   SMRange loc = curToken.getLoc();
2282   consumeToken(Token::kw_let);
2283 
2284   // Parse the name of the new variable.
2285   SMRange varLoc = curToken.getLoc();
2286   if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2287     // `_` is a reserved variable name.
2288     if (curToken.is(Token::underscore)) {
2289       return emitError(varLoc,
2290                        "`_` may only be used to define \"inline\" variables");
2291     }
2292     return emitError(varLoc,
2293                      "expected identifier after `let` to name a new variable");
2294   }
2295   StringRef varName = curToken.getSpelling();
2296   consumeToken();
2297 
2298   // Parse the optional set of constraints.
2299   SmallVector<ast::ConstraintRef> constraints;
2300   if (consumeIf(Token::colon) &&
2301       failed(parseVariableDeclConstraintList(constraints)))
2302     return failure();
2303 
2304   // Parse the optional initializer expression.
2305   ast::Expr *initializer = nullptr;
2306   if (consumeIf(Token::equal)) {
2307     FailureOr<ast::Expr *> initOrFailure = parseExpr();
2308     if (failed(initOrFailure))
2309       return failure();
2310     initializer = *initOrFailure;
2311 
2312     // Check that the constraints are compatible with having an initializer,
2313     // e.g. type constraints cannot be used with initializers.
2314     for (ast::ConstraintRef constraint : constraints) {
2315       LogicalResult result =
2316           TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2317               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2318                     ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2319                 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2320                   return this->emitError(
2321                       constraint.referenceLoc,
2322                       "type constraints are not permitted on variables with "
2323                       "initializers");
2324                 }
2325                 return success();
2326               })
2327               .Default(success());
2328       if (failed(result))
2329         return failure();
2330     }
2331   }
2332 
2333   FailureOr<ast::VariableDecl *> varDecl =
2334       createVariableDecl(varName, varLoc, initializer, constraints);
2335   if (failed(varDecl))
2336     return failure();
2337   return ast::LetStmt::create(ctx, loc, *varDecl);
2338 }
2339 
parseReplaceStmt()2340 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2341   if (parserContext == ParserContext::Constraint)
2342     return emitError("`replace` cannot be used within a Constraint");
2343   SMRange loc = curToken.getLoc();
2344   consumeToken(Token::kw_replace);
2345 
2346   // Parse the root operation expression.
2347   FailureOr<ast::Expr *> rootOp = parseExpr();
2348   if (failed(rootOp))
2349     return failure();
2350 
2351   if (failed(
2352           parseToken(Token::kw_with, "expected `with` after root operation")))
2353     return failure();
2354 
2355   // The replacement portion of this statement is within a rewrite context.
2356   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2357                                               ParserContext::Rewrite);
2358 
2359   // Parse the replacement values.
2360   SmallVector<ast::Expr *> replValues;
2361   if (consumeIf(Token::l_paren)) {
2362     if (consumeIf(Token::r_paren)) {
2363       return emitError(
2364           loc, "expected at least one replacement value, consider using "
2365                "`erase` if no replacement values are desired");
2366     }
2367 
2368     do {
2369       FailureOr<ast::Expr *> replExpr = parseExpr();
2370       if (failed(replExpr))
2371         return failure();
2372       replValues.emplace_back(*replExpr);
2373     } while (consumeIf(Token::comma));
2374 
2375     if (failed(parseToken(Token::r_paren,
2376                           "expected `)` after replacement values")))
2377       return failure();
2378   } else {
2379     // Handle replacement with an operation uniquely, as the replacement
2380     // operation supports type inferrence from the root operation.
2381     FailureOr<ast::Expr *> replExpr;
2382     if (curToken.is(Token::kw_op))
2383       replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2384     else
2385       replExpr = parseExpr();
2386     if (failed(replExpr))
2387       return failure();
2388     replValues.emplace_back(*replExpr);
2389   }
2390 
2391   return createReplaceStmt(loc, *rootOp, replValues);
2392 }
2393 
parseReturnStmt()2394 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2395   SMRange loc = curToken.getLoc();
2396   consumeToken(Token::kw_return);
2397 
2398   // Parse the result value.
2399   FailureOr<ast::Expr *> resultExpr = parseExpr();
2400   if (failed(resultExpr))
2401     return failure();
2402 
2403   return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2404 }
2405 
parseRewriteStmt()2406 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2407   if (parserContext == ParserContext::Constraint)
2408     return emitError("`rewrite` cannot be used within a Constraint");
2409   SMRange loc = curToken.getLoc();
2410   consumeToken(Token::kw_rewrite);
2411 
2412   // Parse the root operation.
2413   FailureOr<ast::Expr *> rootOp = parseExpr();
2414   if (failed(rootOp))
2415     return failure();
2416 
2417   if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2418     return failure();
2419 
2420   if (curToken.isNot(Token::l_brace))
2421     return emitError("expected `{` to start rewrite body");
2422 
2423   // The rewrite body of this statement is within a rewrite context.
2424   llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2425                                               ParserContext::Rewrite);
2426 
2427   FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2428   if (failed(rewriteBody))
2429     return failure();
2430 
2431   // Verify the rewrite body.
2432   for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2433     if (isa<ast::ReturnStmt>(stmt)) {
2434       return emitError(stmt->getLoc(),
2435                        "`return` statements are only permitted within a "
2436                        "`Constraint` or `Rewrite` body");
2437     }
2438   }
2439 
2440   return createRewriteStmt(loc, *rootOp, *rewriteBody);
2441 }
2442 
2443 //===----------------------------------------------------------------------===//
2444 // Creation+Analysis
2445 //===----------------------------------------------------------------------===//
2446 
2447 //===----------------------------------------------------------------------===//
2448 // Decls
2449 
tryExtractCallableDecl(ast::Node * node)2450 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2451   // Unwrap reference expressions.
2452   if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2453     node = init->getDecl();
2454   return dyn_cast<ast::CallableDecl>(node);
2455 }
2456 
2457 FailureOr<ast::PatternDecl *>
createPatternDecl(SMRange loc,const ast::Name * name,const ParsedPatternMetadata & metadata,ast::CompoundStmt * body)2458 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2459                           const ParsedPatternMetadata &metadata,
2460                           ast::CompoundStmt *body) {
2461   return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2462                                   metadata.hasBoundedRecursion, body);
2463 }
2464 
createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl * > results)2465 ast::Type Parser::createUserConstraintRewriteResultType(
2466     ArrayRef<ast::VariableDecl *> results) {
2467   // Single result decls use the type of the single result.
2468   if (results.size() == 1)
2469     return results[0]->getType();
2470 
2471   // Multiple results use a tuple type, with the types and names grabbed from
2472   // the result variable decls.
2473   auto resultTypes = llvm::map_range(
2474       results, [&](const auto *result) { return result->getType(); });
2475   auto resultNames = llvm::map_range(
2476       results, [&](const auto *result) { return result->getName().getName(); });
2477   return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2478                              llvm::to_vector(resultNames));
2479 }
2480 
2481 template <typename T>
createUserPDLLConstraintOrRewriteDecl(const ast::Name & name,ArrayRef<ast::VariableDecl * > arguments,ArrayRef<ast::VariableDecl * > results,ast::Type resultType,ast::CompoundStmt * body)2482 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2483     const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2484     ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2485     ast::CompoundStmt *body) {
2486   if (!body->getChildren().empty()) {
2487     if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2488       ast::Expr *resultExpr = retStmt->getResultExpr();
2489 
2490       // Process the result of the decl. If no explicit signature results
2491       // were provided, check for return type inference. Otherwise, check that
2492       // the return expression can be converted to the expected type.
2493       if (results.empty())
2494         resultType = resultExpr->getType();
2495       else if (failed(convertExpressionTo(resultExpr, resultType)))
2496         return failure();
2497       else
2498         retStmt->setResultExpr(resultExpr);
2499     }
2500   }
2501   return T::createPDLL(ctx, name, arguments, results, body, resultType);
2502 }
2503 
2504 FailureOr<ast::VariableDecl *>
createVariableDecl(StringRef name,SMRange loc,ast::Expr * initializer,ArrayRef<ast::ConstraintRef> constraints)2505 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2506                            ArrayRef<ast::ConstraintRef> constraints) {
2507   // The type of the variable, which is expected to be inferred by either a
2508   // constraint or an initializer expression.
2509   ast::Type type;
2510   if (failed(validateVariableConstraints(constraints, type)))
2511     return failure();
2512 
2513   if (initializer) {
2514     // Update the variable type based on the initializer, or try to convert the
2515     // initializer to the existing type.
2516     if (!type)
2517       type = initializer->getType();
2518     else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2519       type = mergedType;
2520     else if (failed(convertExpressionTo(initializer, type)))
2521       return failure();
2522 
2523     // Otherwise, if there is no initializer check that the type has already
2524     // been resolved from the constraint list.
2525   } else if (!type) {
2526     return emitErrorAndNote(
2527         loc, "unable to infer type for variable `" + name + "`", loc,
2528         "the type of a variable must be inferable from the constraint "
2529         "list or the initializer");
2530   }
2531 
2532   // Constraint types cannot be used when defining variables.
2533   if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2534     return emitError(
2535         loc, llvm::formatv("unable to define variable of `{0}` type", type));
2536   }
2537 
2538   // Try to define a variable with the given name.
2539   FailureOr<ast::VariableDecl *> varDecl =
2540       defineVariableDecl(name, loc, type, initializer, constraints);
2541   if (failed(varDecl))
2542     return failure();
2543 
2544   return *varDecl;
2545 }
2546 
2547 FailureOr<ast::VariableDecl *>
createArgOrResultVariableDecl(StringRef name,SMRange loc,const ast::ConstraintRef & constraint)2548 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2549                                       const ast::ConstraintRef &constraint) {
2550   // Constraint arguments may apply more complex constraints via the arguments.
2551   bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2552   ast::Type argType;
2553   if (failed(validateVariableConstraint(constraint, argType,
2554                                         allowNonCoreConstraints)))
2555     return failure();
2556   return defineVariableDecl(name, loc, argType, constraint);
2557 }
2558 
2559 LogicalResult
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,ast::Type & inferredType,bool allowNonCoreConstraints)2560 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2561                                     ast::Type &inferredType,
2562                                     bool allowNonCoreConstraints) {
2563   for (const ast::ConstraintRef &ref : constraints)
2564     if (failed(validateVariableConstraint(ref, inferredType,
2565                                           allowNonCoreConstraints)))
2566       return failure();
2567   return success();
2568 }
2569 
validateVariableConstraint(const ast::ConstraintRef & ref,ast::Type & inferredType,bool allowNonCoreConstraints)2570 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2571                                                  ast::Type &inferredType,
2572                                                  bool allowNonCoreConstraints) {
2573   ast::Type constraintType;
2574   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2575     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2576       if (failed(validateTypeConstraintExpr(typeExpr)))
2577         return failure();
2578     }
2579     constraintType = ast::AttributeType::get(ctx);
2580   } else if (const auto *cst =
2581                  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2582     constraintType = ast::OperationType::get(
2583         ctx, cst->getName(), lookupODSOperation(cst->getName()));
2584   } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2585     constraintType = typeTy;
2586   } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2587     constraintType = typeRangeTy;
2588   } else if (const auto *cst =
2589                  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2590     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2591       if (failed(validateTypeConstraintExpr(typeExpr)))
2592         return failure();
2593     }
2594     constraintType = valueTy;
2595   } else if (const auto *cst =
2596                  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2597     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2598       if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2599         return failure();
2600     }
2601     constraintType = valueRangeTy;
2602   } else if (const auto *cst =
2603                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2604     if (!allowNonCoreConstraints) {
2605       return emitError(ref.referenceLoc,
2606                        "`Rewrite` arguments and results are only permitted to "
2607                        "use core constraints, such as `Attr`, `Op`, `Type`, "
2608                        "`TypeRange`, `Value`, `ValueRange`");
2609     }
2610 
2611     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2612     if (inputs.size() != 1) {
2613       return emitErrorAndNote(ref.referenceLoc,
2614                               "`Constraint`s applied via a variable constraint "
2615                               "list must take a single input, but got " +
2616                                   Twine(inputs.size()),
2617                               cst->getLoc(),
2618                               "see definition of constraint here");
2619     }
2620     constraintType = inputs.front()->getType();
2621   } else {
2622     llvm_unreachable("unknown constraint type");
2623   }
2624 
2625   // Check that the constraint type is compatible with the current inferred
2626   // type.
2627   if (!inferredType) {
2628     inferredType = constraintType;
2629   } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2630     inferredType = mergedTy;
2631   } else {
2632     return emitError(ref.referenceLoc,
2633                      llvm::formatv("constraint type `{0}` is incompatible "
2634                                    "with the previously inferred type `{1}`",
2635                                    constraintType, inferredType));
2636   }
2637   return success();
2638 }
2639 
validateTypeConstraintExpr(const ast::Expr * typeExpr)2640 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2641   ast::Type typeExprType = typeExpr->getType();
2642   if (typeExprType != typeTy) {
2643     return emitError(typeExpr->getLoc(),
2644                      "expected expression of `Type` in type constraint");
2645   }
2646   return success();
2647 }
2648 
2649 LogicalResult
validateTypeRangeConstraintExpr(const ast::Expr * typeExpr)2650 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2651   ast::Type typeExprType = typeExpr->getType();
2652   if (typeExprType != typeRangeTy) {
2653     return emitError(typeExpr->getLoc(),
2654                      "expected expression of `TypeRange` in type constraint");
2655   }
2656   return success();
2657 }
2658 
2659 //===----------------------------------------------------------------------===//
2660 // Exprs
2661 
2662 FailureOr<ast::CallExpr *>
createCallExpr(SMRange loc,ast::Expr * parentExpr,MutableArrayRef<ast::Expr * > arguments)2663 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2664                        MutableArrayRef<ast::Expr *> arguments) {
2665   ast::Type parentType = parentExpr->getType();
2666 
2667   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2668   if (!callableDecl) {
2669     return emitError(loc,
2670                      llvm::formatv("expected a reference to a callable "
2671                                    "`Constraint` or `Rewrite`, but got: `{0}`",
2672                                    parentType));
2673   }
2674   if (parserContext == ParserContext::Rewrite) {
2675     if (isa<ast::UserConstraintDecl>(callableDecl))
2676       return emitError(
2677           loc, "unable to invoke `Constraint` within a rewrite section");
2678   } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2679     return emitError(loc, "unable to invoke `Rewrite` within a match section");
2680   }
2681 
2682   // Verify the arguments of the call.
2683   /// Handle size mismatch.
2684   ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2685   if (callArgs.size() != arguments.size()) {
2686     return emitErrorAndNote(
2687         loc,
2688         llvm::formatv("invalid number of arguments for {0} call; expected "
2689                       "{1}, but got {2}",
2690                       callableDecl->getCallableType(), callArgs.size(),
2691                       arguments.size()),
2692         callableDecl->getLoc(),
2693         llvm::formatv("see the definition of {0} here",
2694                       callableDecl->getName()->getName()));
2695   }
2696 
2697   /// Handle argument type mismatch.
2698   auto attachDiagFn = [&](ast::Diagnostic &diag) {
2699     diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2700                                   callableDecl->getName()->getName()),
2701                     callableDecl->getLoc());
2702   };
2703   for (auto it : llvm::zip(callArgs, arguments)) {
2704     if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2705                                    attachDiagFn)))
2706       return failure();
2707   }
2708 
2709   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2710                                callableDecl->getResultType());
2711 }
2712 
createDeclRefExpr(SMRange loc,ast::Decl * decl)2713 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2714                                                         ast::Decl *decl) {
2715   // Check the type of decl being referenced.
2716   ast::Type declType;
2717   if (isa<ast::ConstraintDecl>(decl))
2718     declType = ast::ConstraintType::get(ctx);
2719   else if (isa<ast::UserRewriteDecl>(decl))
2720     declType = ast::RewriteType::get(ctx);
2721   else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2722     declType = varDecl->getType();
2723   else
2724     return emitError(loc, "invalid reference to `" +
2725                               decl->getName()->getName() + "`");
2726 
2727   return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2728 }
2729 
2730 FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type,StringRef name,SMRange loc,ArrayRef<ast::ConstraintRef> constraints)2731 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2732                                  ArrayRef<ast::ConstraintRef> constraints) {
2733   FailureOr<ast::VariableDecl *> decl =
2734       defineVariableDecl(name, loc, type, constraints);
2735   if (failed(decl))
2736     return failure();
2737   return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2738 }
2739 
2740 FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr * parentExpr,StringRef name,SMRange loc)2741 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2742                                SMRange loc) {
2743   // Validate the member name for the given parent expression.
2744   FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2745   if (failed(memberType))
2746     return failure();
2747 
2748   return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2749 }
2750 
validateMemberAccess(ast::Expr * parentExpr,StringRef name,SMRange loc)2751 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2752                                                   StringRef name, SMRange loc) {
2753   ast::Type parentType = parentExpr->getType();
2754   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2755     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2756       return valueRangeTy;
2757 
2758     // Verify member access based on the operation type.
2759     if (const ods::Operation *odsOp = opType.getODSOperation()) {
2760       auto results = odsOp->getResults();
2761 
2762       // Handle indexed results.
2763       unsigned index = 0;
2764       if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2765           index < results.size()) {
2766         return results[index].isVariadic() ? valueRangeTy : valueTy;
2767       }
2768 
2769       // Handle named results.
2770       const auto *it = llvm::find_if(results, [&](const auto &result) {
2771         return result.getName() == name;
2772       });
2773       if (it != results.end())
2774         return it->isVariadic() ? valueRangeTy : valueTy;
2775     } else if (llvm::isDigit(name[0])) {
2776       // Allow unchecked numeric indexing of the results of unregistered
2777       // operations. It returns a single value.
2778       return valueTy;
2779     }
2780   } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2781     // Handle indexed results.
2782     unsigned index = 0;
2783     if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2784         index < tupleType.size()) {
2785       return tupleType.getElementTypes()[index];
2786     }
2787 
2788     // Handle named results.
2789     auto elementNames = tupleType.getElementNames();
2790     const auto *it = llvm::find(elementNames, name);
2791     if (it != elementNames.end())
2792       return tupleType.getElementTypes()[it - elementNames.begin()];
2793   }
2794   return emitError(
2795       loc,
2796       llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2797                     name, parentType));
2798 }
2799 
createOperationExpr(SMRange loc,const ast::OpNameDecl * name,OpResultTypeContext resultTypeContext,MutableArrayRef<ast::Expr * > operands,MutableArrayRef<ast::NamedAttributeDecl * > attributes,MutableArrayRef<ast::Expr * > results)2800 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2801     SMRange loc, const ast::OpNameDecl *name,
2802     OpResultTypeContext resultTypeContext,
2803     MutableArrayRef<ast::Expr *> operands,
2804     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2805     MutableArrayRef<ast::Expr *> results) {
2806   Optional<StringRef> opNameRef = name->getName();
2807   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2808 
2809   // Verify the inputs operands.
2810   if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2811     return failure();
2812 
2813   // Verify the attribute list.
2814   for (ast::NamedAttributeDecl *attr : attributes) {
2815     // Check for an attribute type, or a type awaiting resolution.
2816     ast::Type attrType = attr->getValue()->getType();
2817     if (!attrType.isa<ast::AttributeType>()) {
2818       return emitError(
2819           attr->getValue()->getLoc(),
2820           llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2821     }
2822   }
2823 
2824   assert(
2825       (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2826       "unexpected inferrence when results were explicitly specified");
2827 
2828   // If we aren't relying on type inferrence, or explicit results were provided,
2829   // validate them.
2830   if (resultTypeContext == OpResultTypeContext::Explicit) {
2831     if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2832       return failure();
2833 
2834     // Validate the use of interface based type inferrence for this operation.
2835   } else if (resultTypeContext == OpResultTypeContext::Interface) {
2836     assert(opNameRef &&
2837            "expected valid operation name when inferring operation results");
2838     checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2839   }
2840 
2841   return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2842                                     attributes);
2843 }
2844 
2845 LogicalResult
validateOperationOperands(SMRange loc,Optional<StringRef> name,const ods::Operation * odsOp,MutableArrayRef<ast::Expr * > operands)2846 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2847                                   const ods::Operation *odsOp,
2848                                   MutableArrayRef<ast::Expr *> operands) {
2849   return validateOperationOperandsOrResults(
2850       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2851       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2852       valueRangeTy);
2853 }
2854 
2855 LogicalResult
validateOperationResults(SMRange loc,Optional<StringRef> name,const ods::Operation * odsOp,MutableArrayRef<ast::Expr * > results)2856 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2857                                  const ods::Operation *odsOp,
2858                                  MutableArrayRef<ast::Expr *> results) {
2859   return validateOperationOperandsOrResults(
2860       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2861       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2862 }
2863 
checkOperationResultTypeInferrence(SMRange loc,StringRef opName,const ods::Operation * odsOp)2864 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2865                                                 const ods::Operation *odsOp) {
2866   // If the operation might not have inferrence support, emit a warning to the
2867   // user. We don't emit an error because the interface might be added to the
2868   // operation at runtime. It's rare, but it could still happen. We emit a
2869   // warning here instead.
2870 
2871   // Handle inferrence warnings for unknown operations.
2872   if (!odsOp) {
2873     ctx.getDiagEngine().emitWarning(
2874         loc, llvm::formatv(
2875                  "operation result types are marked to be inferred, but "
2876                  "`{0}` is unknown. Ensure that `{0}` supports zero "
2877                  "results or implements `InferTypeOpInterface`. Include "
2878                  "the ODS definition of this operation to remove this warning.",
2879                  opName));
2880     return;
2881   }
2882 
2883   // Handle inferrence warnings for known operations that expected at least one
2884   // result, but don't have inference support. An elided results list can mean
2885   // "zero-results", and we don't want to warn when that is the expected
2886   // behavior.
2887   bool requiresInferrence =
2888       llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2889         return !result.isVariableLength();
2890       });
2891   if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2892     ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2893         loc,
2894         llvm::formatv("operation result types are marked to be inferred, but "
2895                       "`{0}` does not provide an implementation of "
2896                       "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2897                       "`InferTypeOpInterface` at runtime, or add support to "
2898                       "the ODS definition to remove this warning.",
2899                       opName));
2900     diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2901                      odsOp->getLoc());
2902     return;
2903   }
2904 }
2905 
validateOperationOperandsOrResults(StringRef groupName,SMRange loc,Optional<SMRange> odsOpLoc,Optional<StringRef> name,MutableArrayRef<ast::Expr * > values,ArrayRef<ods::OperandOrResult> odsValues,ast::Type singleTy,ast::Type rangeTy)2906 LogicalResult Parser::validateOperationOperandsOrResults(
2907     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2908     Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2909     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2910     ast::Type rangeTy) {
2911   // All operation types accept a single range parameter.
2912   if (values.size() == 1) {
2913     if (failed(convertExpressionTo(values[0], rangeTy)))
2914       return failure();
2915     return success();
2916   }
2917 
2918   /// If the operation has ODS information, we can more accurately verify the
2919   /// values.
2920   if (odsOpLoc) {
2921     if (odsValues.size() != values.size()) {
2922       return emitErrorAndNote(
2923           loc,
2924           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2925                         "{2}, but got {3}",
2926                         groupName, *name, odsValues.size(), values.size()),
2927           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2928     }
2929     auto diagFn = [&](ast::Diagnostic &diag) {
2930       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2931                       *odsOpLoc);
2932     };
2933     for (unsigned i = 0, e = values.size(); i < e; ++i) {
2934       ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2935       if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2936         return failure();
2937     }
2938     return success();
2939   }
2940 
2941   // Otherwise, accept the value groups as they have been defined and just
2942   // ensure they are one of the expected types.
2943   for (ast::Expr *&valueExpr : values) {
2944     ast::Type valueExprType = valueExpr->getType();
2945 
2946     // Check if this is one of the expected types.
2947     if (valueExprType == rangeTy || valueExprType == singleTy)
2948       continue;
2949 
2950     // If the operand is an Operation, allow converting to a Value or
2951     // ValueRange. This situations arises quite often with nested operation
2952     // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2953     if (singleTy == valueTy) {
2954       if (valueExprType.isa<ast::OperationType>()) {
2955         valueExpr = convertOpToValue(valueExpr);
2956         continue;
2957       }
2958     }
2959 
2960     return emitError(
2961         valueExpr->getLoc(),
2962         llvm::formatv(
2963             "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2964             singleTy, rangeTy, valueExprType));
2965   }
2966   return success();
2967 }
2968 
2969 FailureOr<ast::TupleExpr *>
createTupleExpr(SMRange loc,ArrayRef<ast::Expr * > elements,ArrayRef<StringRef> elementNames)2970 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2971                         ArrayRef<StringRef> elementNames) {
2972   for (const ast::Expr *element : elements) {
2973     ast::Type eleTy = element->getType();
2974     if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2975       return emitError(
2976           element->getLoc(),
2977           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2978     }
2979   }
2980   return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2981 }
2982 
2983 //===----------------------------------------------------------------------===//
2984 // Stmts
2985 
createEraseStmt(SMRange loc,ast::Expr * rootOp)2986 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2987                                                     ast::Expr *rootOp) {
2988   // Check that root is an Operation.
2989   ast::Type rootType = rootOp->getType();
2990   if (!rootType.isa<ast::OperationType>())
2991     return emitError(rootOp->getLoc(), "expected `Op` expression");
2992 
2993   return ast::EraseStmt::create(ctx, loc, rootOp);
2994 }
2995 
2996 FailureOr<ast::ReplaceStmt *>
createReplaceStmt(SMRange loc,ast::Expr * rootOp,MutableArrayRef<ast::Expr * > replValues)2997 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2998                           MutableArrayRef<ast::Expr *> replValues) {
2999   // Check that root is an Operation.
3000   ast::Type rootType = rootOp->getType();
3001   if (!rootType.isa<ast::OperationType>()) {
3002     return emitError(
3003         rootOp->getLoc(),
3004         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3005   }
3006 
3007   // If there are multiple replacement values, we implicitly convert any Op
3008   // expressions to the value form.
3009   bool shouldConvertOpToValues = replValues.size() > 1;
3010   for (ast::Expr *&replExpr : replValues) {
3011     ast::Type replType = replExpr->getType();
3012 
3013     // Check that replExpr is an Operation, Value, or ValueRange.
3014     if (replType.isa<ast::OperationType>()) {
3015       if (shouldConvertOpToValues)
3016         replExpr = convertOpToValue(replExpr);
3017       continue;
3018     }
3019 
3020     if (replType != valueTy && replType != valueRangeTy) {
3021       return emitError(replExpr->getLoc(),
3022                        llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3023                                      "expression, but got `{0}`",
3024                                      replType));
3025     }
3026   }
3027 
3028   return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3029 }
3030 
3031 FailureOr<ast::RewriteStmt *>
createRewriteStmt(SMRange loc,ast::Expr * rootOp,ast::CompoundStmt * rewriteBody)3032 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3033                           ast::CompoundStmt *rewriteBody) {
3034   // Check that root is an Operation.
3035   ast::Type rootType = rootOp->getType();
3036   if (!rootType.isa<ast::OperationType>()) {
3037     return emitError(
3038         rootOp->getLoc(),
3039         llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3040   }
3041 
3042   return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3043 }
3044 
3045 //===----------------------------------------------------------------------===//
3046 // Code Completion
3047 //===----------------------------------------------------------------------===//
3048 
codeCompleteMemberAccess(ast::Expr * parentExpr)3049 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3050   ast::Type parentType = parentExpr->getType();
3051   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3052     codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3053   else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3054     codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3055   return failure();
3056 }
3057 
codeCompleteAttributeName(Optional<StringRef> opName)3058 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
3059   if (opName)
3060     codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3061   return failure();
3062 }
3063 
3064 LogicalResult
codeCompleteConstraintName(ast::Type inferredType,bool allowNonCoreConstraints,bool allowInlineTypeConstraints)3065 Parser::codeCompleteConstraintName(ast::Type inferredType,
3066                                    bool allowNonCoreConstraints,
3067                                    bool allowInlineTypeConstraints) {
3068   codeCompleteContext->codeCompleteConstraintName(
3069       inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
3070       curDeclScope);
3071   return failure();
3072 }
3073 
codeCompleteDialectName()3074 LogicalResult Parser::codeCompleteDialectName() {
3075   codeCompleteContext->codeCompleteDialectName();
3076   return failure();
3077 }
3078 
codeCompleteOperationName(StringRef dialectName)3079 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3080   codeCompleteContext->codeCompleteOperationName(dialectName);
3081   return failure();
3082 }
3083 
codeCompletePatternMetadata()3084 LogicalResult Parser::codeCompletePatternMetadata() {
3085   codeCompleteContext->codeCompletePatternMetadata();
3086   return failure();
3087 }
3088 
codeCompleteIncludeFilename(StringRef curPath)3089 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3090   codeCompleteContext->codeCompleteIncludeFilename(curPath);
3091   return failure();
3092 }
3093 
codeCompleteCallSignature(ast::Node * parent,unsigned currentNumArgs)3094 void Parser::codeCompleteCallSignature(ast::Node *parent,
3095                                        unsigned currentNumArgs) {
3096   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3097   if (!callableDecl)
3098     return;
3099 
3100   codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3101 }
3102 
codeCompleteOperationOperandsSignature(Optional<StringRef> opName,unsigned currentNumOperands)3103 void Parser::codeCompleteOperationOperandsSignature(
3104     Optional<StringRef> opName, unsigned currentNumOperands) {
3105   codeCompleteContext->codeCompleteOperationOperandsSignature(
3106       opName, currentNumOperands);
3107 }
3108 
codeCompleteOperationResultsSignature(Optional<StringRef> opName,unsigned currentNumResults)3109 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
3110                                                    unsigned currentNumResults) {
3111   codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3112                                                              currentNumResults);
3113 }
3114 
3115 //===----------------------------------------------------------------------===//
3116 // Parser
3117 //===----------------------------------------------------------------------===//
3118 
3119 FailureOr<ast::Module *>
parsePDLLAST(ast::Context & ctx,llvm::SourceMgr & sourceMgr,bool enableDocumentation,CodeCompleteContext * codeCompleteContext)3120 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3121                          bool enableDocumentation,
3122                          CodeCompleteContext *codeCompleteContext) {
3123   Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3124   return parser.parseModule();
3125 }
3126