1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/IndentedOstream.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/Argument.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Constraint.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/Tools/PDLL/AST/Context.h"
19 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
20 #include "mlir/Tools/PDLL/AST/Nodes.h"
21 #include "mlir/Tools/PDLL/AST/Types.h"
22 #include "mlir/Tools/PDLL/ODS/Constraint.h"
23 #include "mlir/Tools/PDLL/ODS/Context.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35
36 using namespace mlir;
37 using namespace mlir::pdll;
38
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42
43 namespace {
44 class Parser {
45 public:
Parser(ast::Context & ctx,llvm::SourceMgr & sourceMgr,bool enableDocumentation,CodeCompleteContext * codeCompleteContext)46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50 valueTy(ast::ValueType::get(ctx)),
51 valueRangeTy(ast::ValueRangeType::get(ctx)),
52 typeTy(ast::TypeType::get(ctx)),
53 typeRangeTy(ast::TypeRangeType::get(ctx)),
54 attrTy(ast::AttributeType::get(ctx)),
55 codeCompleteContext(codeCompleteContext) {}
56
57 /// Try to parse a new module. Returns nullptr in the case of failure.
58 FailureOr<ast::Module *> parseModule();
59
60 private:
61 /// The current context of the parser. It allows for the parser to know a bit
62 /// about the construct it is nested within during parsing. This is used
63 /// specifically to provide additional verification during parsing, e.g. to
64 /// prevent using rewrites within a match context, matcher constraints within
65 /// a rewrite section, etc.
66 enum class ParserContext {
67 /// The parser is in the global context.
68 Global,
69 /// The parser is currently within a Constraint, which disallows all types
70 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71 Constraint,
72 /// The parser is currently within the matcher portion of a Pattern, which
73 /// is allows a terminal operation rewrite statement but no other rewrite
74 /// transformations.
75 PatternMatch,
76 /// The parser is currently within a Rewrite, which disallows calls to
77 /// constraints, requires operation expressions to have names, etc.
78 Rewrite,
79 };
80
81 /// The current specification context of an operations result type. This
82 /// indicates how the result types of an operation may be inferred.
83 enum class OpResultTypeContext {
84 /// The result types of the operation are not known to be inferred.
85 Explicit,
86 /// The result types of the operation are inferred from the root input of a
87 /// `replace` statement.
88 Replacement,
89 /// The result types of the operation are inferred by using the
90 /// `InferTypeOpInterface` interface provided by the operation.
91 Interface,
92 };
93
94 //===--------------------------------------------------------------------===//
95 // Parsing
96 //===--------------------------------------------------------------------===//
97
98 /// Push a new decl scope onto the lexer.
pushDeclScope()99 ast::DeclScope *pushDeclScope() {
100 ast::DeclScope *newScope =
101 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102 return (curDeclScope = newScope);
103 }
pushDeclScope(ast::DeclScope * scope)104 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105
106 /// Pop the last decl scope from the lexer.
popDeclScope()107 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108
109 /// Parse the body of an AST module.
110 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111
112 /// Try to convert the given expression to `type`. Returns failure and emits
113 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114 /// invoked to attach notes to the emitted error diagnostic. On success,
115 /// `expr` is updated to the expression used to convert to `type`.
116 LogicalResult convertExpressionTo(
117 ast::Expr *&expr, ast::Type type,
118 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
119
120 /// Given an operation expression, convert it to a Value or ValueRange
121 /// typed expression.
122 ast::Expr *convertOpToValue(const ast::Expr *opExpr);
123
124 /// Lookup ODS information for the given operation, returns nullptr if no
125 /// information is found.
lookupODSOperation(Optional<StringRef> opName)126 const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
127 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
128 }
129
130 /// Process the given documentation string, or return an empty string if
131 /// documentation isn't enabled.
processDoc(StringRef doc)132 StringRef processDoc(StringRef doc) {
133 return enableDocumentation ? doc : StringRef();
134 }
135
136 /// Process the given documentation string and format it, or return an empty
137 /// string if documentation isn't enabled.
processAndFormatDoc(const Twine & doc)138 std::string processAndFormatDoc(const Twine &doc) {
139 if (!enableDocumentation)
140 return "";
141 std::string docStr;
142 {
143 llvm::raw_string_ostream docOS(docStr);
144 std::string tmpDocStr = doc.str();
145 raw_indented_ostream(docOS).printReindented(
146 StringRef(tmpDocStr).rtrim(" \t"));
147 }
148 return docStr;
149 }
150
151 //===--------------------------------------------------------------------===//
152 // Directives
153
154 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
155 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
156 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
157 SmallVectorImpl<ast::Decl *> &decls);
158
159 /// Process the records of a parsed tablegen include file.
160 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
161 SmallVectorImpl<ast::Decl *> &decls);
162
163 /// Create a user defined native constraint for a constraint imported from
164 /// ODS.
165 template <typename ConstraintT>
166 ast::Decl *
167 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
168 SMRange loc, ast::Type type,
169 StringRef nativeType, StringRef docString);
170 template <typename ConstraintT>
171 ast::Decl *
172 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
173 SMRange loc, ast::Type type,
174 StringRef nativeType);
175
176 //===--------------------------------------------------------------------===//
177 // Decls
178
179 /// This structure contains the set of pattern metadata that may be parsed.
180 struct ParsedPatternMetadata {
181 Optional<uint16_t> benefit;
182 bool hasBoundedRecursion = false;
183 };
184
185 FailureOr<ast::Decl *> parseTopLevelDecl();
186 FailureOr<ast::NamedAttributeDecl *>
187 parseNamedAttributeDecl(Optional<StringRef> parentOpName);
188
189 /// Parse an argument variable as part of the signature of a
190 /// UserConstraintDecl or UserRewriteDecl.
191 FailureOr<ast::VariableDecl *> parseArgumentDecl();
192
193 /// Parse a result variable as part of the signature of a UserConstraintDecl
194 /// or UserRewriteDecl.
195 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
196
197 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
198 /// defined in a non-global context.
199 FailureOr<ast::UserConstraintDecl *>
200 parseUserConstraintDecl(bool isInline = false);
201
202 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
203 /// non-global context, such as within a Pattern/Constraint/etc.
204 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
205
206 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
207 /// PDLL constructs.
208 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
209 const ast::Name &name, bool isInline,
210 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
211 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
212
213 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
214 /// defined in a non-global context.
215 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
216
217 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
218 /// non-global context, such as within a Pattern/Rewrite/etc.
219 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
220
221 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
222 /// PDLL constructs.
223 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
224 const ast::Name &name, bool isInline,
225 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
226 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
227
228 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
229 /// effectively the same syntax, and only differ on slight semantics (given
230 /// the different parsing contexts).
231 template <typename T, typename ParseUserPDLLDeclFnT>
232 FailureOr<T *> parseUserConstraintOrRewriteDecl(
233 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
234 StringRef anonymousNamePrefix, bool isInline);
235
236 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
237 /// These decls have effectively the same syntax.
238 template <typename T>
239 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
240 const ast::Name &name, bool isInline,
241 ArrayRef<ast::VariableDecl *> arguments,
242 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
243
244 /// Parse the functional signature (i.e. the arguments and results) of a
245 /// UserConstraintDecl or UserRewriteDecl.
246 LogicalResult parseUserConstraintOrRewriteSignature(
247 SmallVectorImpl<ast::VariableDecl *> &arguments,
248 SmallVectorImpl<ast::VariableDecl *> &results,
249 ast::DeclScope *&argumentScope, ast::Type &resultType);
250
251 /// Validate the return (which if present is specified by bodyIt) of a
252 /// UserConstraintDecl or UserRewriteDecl.
253 LogicalResult validateUserConstraintOrRewriteReturn(
254 StringRef declType, ast::CompoundStmt *body,
255 ArrayRef<ast::Stmt *>::iterator bodyIt,
256 ArrayRef<ast::Stmt *>::iterator bodyE,
257 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
258
259 FailureOr<ast::CompoundStmt *>
260 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
261 bool expectTerminalSemicolon = true);
262 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
263 FailureOr<ast::Decl *> parsePatternDecl();
264 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
265
266 /// Check to see if a decl has already been defined with the given name, if
267 /// one has emit and error and return failure. Returns success otherwise.
268 LogicalResult checkDefineNamedDecl(const ast::Name &name);
269
270 /// Try to define a variable decl with the given components, returns the
271 /// variable on success.
272 FailureOr<ast::VariableDecl *>
273 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
274 ast::Expr *initExpr,
275 ArrayRef<ast::ConstraintRef> constraints);
276 FailureOr<ast::VariableDecl *>
277 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
278 ArrayRef<ast::ConstraintRef> constraints);
279
280 /// Parse the constraint reference list for a variable decl.
281 LogicalResult parseVariableDeclConstraintList(
282 SmallVectorImpl<ast::ConstraintRef> &constraints);
283
284 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
285 FailureOr<ast::Expr *> parseTypeConstraintExpr();
286
287 /// Try to parse a single reference to a constraint. `typeConstraint` is the
288 /// location of a previously parsed type constraint for the entity that will
289 /// be constrained by the parsed constraint. `existingConstraints` are any
290 /// existing constraints that have already been parsed for the same entity
291 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
292 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
293 /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
294 /// constraints) may be used with the variable.
295 FailureOr<ast::ConstraintRef>
296 parseConstraint(Optional<SMRange> &typeConstraint,
297 ArrayRef<ast::ConstraintRef> existingConstraints,
298 bool allowInlineTypeConstraints,
299 bool allowNonCoreConstraints);
300
301 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
302 /// argument or result variable. The constraints for these variables do not
303 /// allow inline type constraints, and only permit a single constraint.
304 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
305
306 //===--------------------------------------------------------------------===//
307 // Exprs
308
309 FailureOr<ast::Expr *> parseExpr();
310
311 /// Identifier expressions.
312 FailureOr<ast::Expr *> parseAttributeExpr();
313 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
314 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
315 FailureOr<ast::Expr *> parseIdentifierExpr();
316 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
317 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
318 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
319 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
320 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
321 FailureOr<ast::Expr *>
322 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
323 OpResultTypeContext::Explicit);
324 FailureOr<ast::Expr *> parseTupleExpr();
325 FailureOr<ast::Expr *> parseTypeExpr();
326 FailureOr<ast::Expr *> parseUnderscoreExpr();
327
328 //===--------------------------------------------------------------------===//
329 // Stmts
330
331 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
332 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
333 FailureOr<ast::EraseStmt *> parseEraseStmt();
334 FailureOr<ast::LetStmt *> parseLetStmt();
335 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
336 FailureOr<ast::ReturnStmt *> parseReturnStmt();
337 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
338
339 //===--------------------------------------------------------------------===//
340 // Creation+Analysis
341 //===--------------------------------------------------------------------===//
342
343 //===--------------------------------------------------------------------===//
344 // Decls
345
346 /// Try to extract a callable from the given AST node. Returns nullptr on
347 /// failure.
348 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
349
350 /// Try to create a pattern decl with the given components, returning the
351 /// Pattern on success.
352 FailureOr<ast::PatternDecl *>
353 createPatternDecl(SMRange loc, const ast::Name *name,
354 const ParsedPatternMetadata &metadata,
355 ast::CompoundStmt *body);
356
357 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
358 /// of results, defined as part of the signature.
359 ast::Type
360 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
361
362 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
363 template <typename T>
364 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
365 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
366 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
367 ast::CompoundStmt *body);
368
369 /// Try to create a variable decl with the given components, returning the
370 /// Variable on success.
371 FailureOr<ast::VariableDecl *>
372 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
373 ArrayRef<ast::ConstraintRef> constraints);
374
375 /// Create a variable for an argument or result defined as part of the
376 /// signature of a UserConstraintDecl/UserRewriteDecl.
377 FailureOr<ast::VariableDecl *>
378 createArgOrResultVariableDecl(StringRef name, SMRange loc,
379 const ast::ConstraintRef &constraint);
380
381 /// Validate the constraints used to constraint a variable decl.
382 /// `inferredType` is the type of the variable inferred by the constraints
383 /// within the list, and is updated to the most refined type as determined by
384 /// the constraints. Returns success if the constraint list is valid, failure
385 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
386 /// defined constraints) may be used with the variable.
387 LogicalResult
388 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
389 ast::Type &inferredType,
390 bool allowNonCoreConstraints = true);
391 /// Validate a single reference to a constraint. `inferredType` contains the
392 /// currently inferred variabled type and is refined within the type defined
393 /// by the constraint. Returns success if the constraint is valid, failure
394 /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
395 /// defined constraints) may be used with the variable.
396 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
397 ast::Type &inferredType,
398 bool allowNonCoreConstraints = true);
399 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
400 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
401
402 //===--------------------------------------------------------------------===//
403 // Exprs
404
405 FailureOr<ast::CallExpr *>
406 createCallExpr(SMRange loc, ast::Expr *parentExpr,
407 MutableArrayRef<ast::Expr *> arguments);
408 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
409 FailureOr<ast::DeclRefExpr *>
410 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
411 ArrayRef<ast::ConstraintRef> constraints);
412 FailureOr<ast::MemberAccessExpr *>
413 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
414
415 /// Validate the member access `name` into the given parent expression. On
416 /// success, this also returns the type of the member accessed.
417 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
418 StringRef name, SMRange loc);
419 FailureOr<ast::OperationExpr *>
420 createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
421 OpResultTypeContext resultTypeContext,
422 MutableArrayRef<ast::Expr *> operands,
423 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
424 MutableArrayRef<ast::Expr *> results);
425 LogicalResult
426 validateOperationOperands(SMRange loc, Optional<StringRef> name,
427 const ods::Operation *odsOp,
428 MutableArrayRef<ast::Expr *> operands);
429 LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
430 const ods::Operation *odsOp,
431 MutableArrayRef<ast::Expr *> results);
432 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
433 const ods::Operation *odsOp);
434 LogicalResult validateOperationOperandsOrResults(
435 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
436 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
437 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
438 ast::Type rangeTy);
439 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
440 ArrayRef<ast::Expr *> elements,
441 ArrayRef<StringRef> elementNames);
442
443 //===--------------------------------------------------------------------===//
444 // Stmts
445
446 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
447 FailureOr<ast::ReplaceStmt *>
448 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
449 MutableArrayRef<ast::Expr *> replValues);
450 FailureOr<ast::RewriteStmt *>
451 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
452 ast::CompoundStmt *rewriteBody);
453
454 //===--------------------------------------------------------------------===//
455 // Code Completion
456 //===--------------------------------------------------------------------===//
457
458 /// The set of various code completion methods. Every completion method
459 /// returns `failure` to stop the parsing process after providing completion
460 /// results.
461
462 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
463 LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
464 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
465 bool allowNonCoreConstraints,
466 bool allowInlineTypeConstraints);
467 LogicalResult codeCompleteDialectName();
468 LogicalResult codeCompleteOperationName(StringRef dialectName);
469 LogicalResult codeCompletePatternMetadata();
470 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
471
472 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
473 void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
474 unsigned currentNumOperands);
475 void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
476 unsigned currentNumResults);
477
478 //===--------------------------------------------------------------------===//
479 // Lexer Utilities
480 //===--------------------------------------------------------------------===//
481
482 /// If the current token has the specified kind, consume it and return true.
483 /// If not, return false.
consumeIf(Token::Kind kind)484 bool consumeIf(Token::Kind kind) {
485 if (curToken.isNot(kind))
486 return false;
487 consumeToken(kind);
488 return true;
489 }
490
491 /// Advance the current lexer onto the next token.
consumeToken()492 void consumeToken() {
493 assert(curToken.isNot(Token::eof, Token::error) &&
494 "shouldn't advance past EOF or errors");
495 curToken = lexer.lexToken();
496 }
497
498 /// Advance the current lexer onto the next token, asserting what the expected
499 /// current token is. This is preferred to the above method because it leads
500 /// to more self-documenting code with better checking.
consumeToken(Token::Kind kind)501 void consumeToken(Token::Kind kind) {
502 assert(curToken.is(kind) && "consumed an unexpected token");
503 consumeToken();
504 }
505
506 /// Reset the lexer to the location at the given position.
resetToken(SMRange tokLoc)507 void resetToken(SMRange tokLoc) {
508 lexer.resetPointer(tokLoc.Start.getPointer());
509 curToken = lexer.lexToken();
510 }
511
512 /// Consume the specified token if present and return success. On failure,
513 /// output a diagnostic and return failure.
parseToken(Token::Kind kind,const Twine & msg)514 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
515 if (curToken.getKind() != kind)
516 return emitError(curToken.getLoc(), msg);
517 consumeToken();
518 return success();
519 }
emitError(SMRange loc,const Twine & msg)520 LogicalResult emitError(SMRange loc, const Twine &msg) {
521 lexer.emitError(loc, msg);
522 return failure();
523 }
emitError(const Twine & msg)524 LogicalResult emitError(const Twine &msg) {
525 return emitError(curToken.getLoc(), msg);
526 }
emitErrorAndNote(SMRange loc,const Twine & msg,SMRange noteLoc,const Twine & note)527 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
528 const Twine ¬e) {
529 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
530 return failure();
531 }
532
533 //===--------------------------------------------------------------------===//
534 // Fields
535 //===--------------------------------------------------------------------===//
536
537 /// The owning AST context.
538 ast::Context &ctx;
539
540 /// The lexer of this parser.
541 Lexer lexer;
542
543 /// The current token within the lexer.
544 Token curToken;
545
546 /// A flag indicating if the parser should add documentation to AST nodes when
547 /// viable.
548 bool enableDocumentation;
549
550 /// The most recently defined decl scope.
551 ast::DeclScope *curDeclScope = nullptr;
552 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
553
554 /// The current context of the parser.
555 ParserContext parserContext = ParserContext::Global;
556
557 /// Cached types to simplify verification and expression creation.
558 ast::Type valueTy, valueRangeTy;
559 ast::Type typeTy, typeRangeTy;
560 ast::Type attrTy;
561
562 /// A counter used when naming anonymous constraints and rewrites.
563 unsigned anonymousDeclNameCounter = 0;
564
565 /// The optional code completion context.
566 CodeCompleteContext *codeCompleteContext;
567 };
568 } // namespace
569
parseModule()570 FailureOr<ast::Module *> Parser::parseModule() {
571 SMLoc moduleLoc = curToken.getStartLoc();
572 pushDeclScope();
573
574 // Parse the top-level decls of the module.
575 SmallVector<ast::Decl *> decls;
576 if (failed(parseModuleBody(decls)))
577 return popDeclScope(), failure();
578
579 popDeclScope();
580 return ast::Module::create(ctx, moduleLoc, decls);
581 }
582
parseModuleBody(SmallVectorImpl<ast::Decl * > & decls)583 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
584 while (curToken.isNot(Token::eof)) {
585 if (curToken.is(Token::directive)) {
586 if (failed(parseDirective(decls)))
587 return failure();
588 continue;
589 }
590
591 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
592 if (failed(decl))
593 return failure();
594 decls.push_back(*decl);
595 }
596 return success();
597 }
598
convertOpToValue(const ast::Expr * opExpr)599 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
600 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
601 valueRangeTy);
602 }
603
convertExpressionTo(ast::Expr * & expr,ast::Type type,function_ref<void (ast::Diagnostic & diag)> noteAttachFn)604 LogicalResult Parser::convertExpressionTo(
605 ast::Expr *&expr, ast::Type type,
606 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
607 ast::Type exprType = expr->getType();
608 if (exprType == type)
609 return success();
610
611 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
612 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
613 expr->getLoc(), llvm::formatv("unable to convert expression of type "
614 "`{0}` to the expected type of "
615 "`{1}`",
616 exprType, type));
617 if (noteAttachFn)
618 noteAttachFn(*diag);
619 return diag;
620 };
621
622 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
623 // Two operation types are compatible if they have the same name, or if the
624 // expected type is more general.
625 if (auto opType = type.dyn_cast<ast::OperationType>()) {
626 if (opType.getName())
627 return emitConvertError();
628 return success();
629 }
630
631 // An operation can always convert to a ValueRange.
632 if (type == valueRangeTy) {
633 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
634 valueRangeTy);
635 return success();
636 }
637
638 // Allow conversion to a single value by constraining the result range.
639 if (type == valueTy) {
640 // If the operation is registered, we can verify if it can ever have a
641 // single result.
642 if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
643 if (odsOp->getResults().empty()) {
644 return emitConvertError()->attachNote(
645 llvm::formatv("see the definition of `{0}`, which was defined "
646 "with zero results",
647 odsOp->getName()),
648 odsOp->getLoc());
649 }
650
651 unsigned numSingleResults = llvm::count_if(
652 odsOp->getResults(), [](const ods::OperandOrResult &result) {
653 return result.getVariableLengthKind() ==
654 ods::VariableLengthKind::Single;
655 });
656 if (numSingleResults > 1) {
657 return emitConvertError()->attachNote(
658 llvm::formatv("see the definition of `{0}`, which was defined "
659 "with at least {1} results",
660 odsOp->getName(), numSingleResults),
661 odsOp->getLoc());
662 }
663 }
664
665 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
666 valueTy);
667 return success();
668 }
669 return emitConvertError();
670 }
671
672 // FIXME: Decide how to allow/support converting a single result to multiple,
673 // and multiple to a single result. For now, we just allow Single->Range,
674 // but this isn't something really supported in the PDL dialect. We should
675 // figure out some way to support both.
676 if ((exprType == valueTy || exprType == valueRangeTy) &&
677 (type == valueTy || type == valueRangeTy))
678 return success();
679 if ((exprType == typeTy || exprType == typeRangeTy) &&
680 (type == typeTy || type == typeRangeTy))
681 return success();
682
683 // Handle tuple types.
684 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
685 auto tupleType = type.dyn_cast<ast::TupleType>();
686 if (!tupleType || tupleType.size() != exprTupleType.size())
687 return emitConvertError();
688
689 // Build a new tuple expression using each of the elements of the current
690 // tuple.
691 SmallVector<ast::Expr *> newExprs;
692 for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
693 newExprs.push_back(ast::MemberAccessExpr::create(
694 ctx, expr->getLoc(), expr, llvm::to_string(i),
695 exprTupleType.getElementTypes()[i]));
696
697 auto diagFn = [&](ast::Diagnostic &diag) {
698 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
699 i, exprTupleType));
700 if (noteAttachFn)
701 noteAttachFn(diag);
702 };
703 if (failed(convertExpressionTo(newExprs.back(),
704 tupleType.getElementTypes()[i], diagFn)))
705 return failure();
706 }
707 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
708 tupleType.getElementNames());
709 return success();
710 }
711
712 return emitConvertError();
713 }
714
715 //===----------------------------------------------------------------------===//
716 // Directives
717
parseDirective(SmallVectorImpl<ast::Decl * > & decls)718 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
719 StringRef directive = curToken.getSpelling();
720 if (directive == "#include")
721 return parseInclude(decls);
722
723 return emitError("unknown directive `" + directive + "`");
724 }
725
parseInclude(SmallVectorImpl<ast::Decl * > & decls)726 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
727 SMRange loc = curToken.getLoc();
728 consumeToken(Token::directive);
729
730 // Handle code completion of the include file path.
731 if (curToken.is(Token::code_complete_string))
732 return codeCompleteIncludeFilename(curToken.getStringValue());
733
734 // Parse the file being included.
735 if (!curToken.isString())
736 return emitError(loc,
737 "expected string file name after `include` directive");
738 SMRange fileLoc = curToken.getLoc();
739 std::string filenameStr = curToken.getStringValue();
740 StringRef filename = filenameStr;
741 consumeToken();
742
743 // Check the type of include. If ending with `.pdll`, this is another pdl file
744 // to be parsed along with the current module.
745 if (filename.endswith(".pdll")) {
746 if (failed(lexer.pushInclude(filename, fileLoc)))
747 return emitError(fileLoc,
748 "unable to open include file `" + filename + "`");
749
750 // If we added the include successfully, parse it into the current module.
751 // Make sure to update to the next token after we finish parsing the nested
752 // file.
753 curToken = lexer.lexToken();
754 LogicalResult result = parseModuleBody(decls);
755 curToken = lexer.lexToken();
756 return result;
757 }
758
759 // Otherwise, this must be a `.td` include.
760 if (filename.endswith(".td"))
761 return parseTdInclude(filename, fileLoc, decls);
762
763 return emitError(fileLoc,
764 "expected include filename to end with `.pdll` or `.td`");
765 }
766
parseTdInclude(StringRef filename,llvm::SMRange fileLoc,SmallVectorImpl<ast::Decl * > & decls)767 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
768 SmallVectorImpl<ast::Decl *> &decls) {
769 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
770
771 // Use the source manager to open the file, but don't yet add it.
772 std::string includedFile;
773 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
774 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
775 if (!includeBuffer)
776 return emitError(fileLoc, "unable to open include file `" + filename + "`");
777
778 // Setup the source manager for parsing the tablegen file.
779 llvm::SourceMgr tdSrcMgr;
780 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
781 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
782
783 // This class provides a context argument for the llvm::SourceMgr diagnostic
784 // handler.
785 struct DiagHandlerContext {
786 Parser &parser;
787 StringRef filename;
788 llvm::SMRange loc;
789 } handlerContext{*this, filename, fileLoc};
790
791 // Set the diagnostic handler for the tablegen source manager.
792 tdSrcMgr.setDiagHandler(
793 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
794 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
795 (void)ctx->parser.emitError(
796 ctx->loc,
797 llvm::formatv("error while processing include file `{0}`: {1}",
798 ctx->filename, diag.getMessage()));
799 },
800 &handlerContext);
801
802 // Parse the tablegen file.
803 llvm::RecordKeeper tdRecords;
804 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
805 return failure();
806
807 // Process the parsed records.
808 processTdIncludeRecords(tdRecords, decls);
809
810 // After we are done processing, move all of the tablegen source buffers to
811 // the main parser source mgr. This allows for directly using source locations
812 // from the .td files without needing to remap them.
813 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
814 return success();
815 }
816
processTdIncludeRecords(llvm::RecordKeeper & tdRecords,SmallVectorImpl<ast::Decl * > & decls)817 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
818 SmallVectorImpl<ast::Decl *> &decls) {
819 // Return the length kind of the given value.
820 auto getLengthKind = [](const auto &value) {
821 if (value.isOptional())
822 return ods::VariableLengthKind::Optional;
823 return value.isVariadic() ? ods::VariableLengthKind::Variadic
824 : ods::VariableLengthKind::Single;
825 };
826
827 // Insert a type constraint into the ODS context.
828 ods::Context &odsContext = ctx.getODSContext();
829 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
830 -> const ods::TypeConstraint & {
831 return odsContext.insertTypeConstraint(
832 cst.constraint.getUniqueDefName(),
833 processDoc(cst.constraint.getSummary()),
834 cst.constraint.getCPPClassName());
835 };
836 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
837 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
838 };
839
840 // Process the parsed tablegen records to build ODS information.
841 /// Operations.
842 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
843 tblgen::Operator op(def);
844
845 // Check to see if this operation is known to support type inferrence.
846 bool supportsResultTypeInferrence =
847 op.getTrait("::mlir::InferTypeOpInterface::Trait");
848
849 bool inserted = false;
850 ods::Operation *odsOp = nullptr;
851 std::tie(odsOp, inserted) = odsContext.insertOperation(
852 op.getOperationName(), processDoc(op.getSummary()),
853 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
854 supportsResultTypeInferrence, op.getLoc().front());
855
856 // Ignore operations that have already been added.
857 if (!inserted)
858 continue;
859
860 for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
861 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
862 odsContext.insertAttributeConstraint(
863 attr.attr.getUniqueDefName(),
864 processDoc(attr.attr.getSummary()),
865 attr.attr.getStorageType()));
866 }
867 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
868 odsOp->appendOperand(operand.name, getLengthKind(operand),
869 addTypeConstraint(operand));
870 }
871 for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
872 odsOp->appendResult(result.name, getLengthKind(result),
873 addTypeConstraint(result));
874 }
875 }
876
877 auto shouldBeSkipped = [this](llvm::Record *def) {
878 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
879 def->isSubClassOf("DeclareInterfaceMethods");
880 };
881
882 /// Attr constraints.
883 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
884 if (shouldBeSkipped(def))
885 continue;
886
887 tblgen::Attribute constraint(def);
888 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
889 constraint, convertLocToRange(def->getLoc().front()), attrTy,
890 constraint.getStorageType()));
891 }
892 /// Type constraints.
893 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
894 if (shouldBeSkipped(def))
895 continue;
896
897 tblgen::TypeConstraint constraint(def);
898 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
899 constraint, convertLocToRange(def->getLoc().front()), typeTy,
900 constraint.getCPPClassName()));
901 }
902 /// OpInterfaces.
903 ast::Type opTy = ast::OperationType::get(ctx);
904 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
905 if (shouldBeSkipped(def))
906 continue;
907
908 SMRange loc = convertLocToRange(def->getLoc().front());
909
910 std::string cppClassName =
911 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
912 def->getValueAsString("cppInterfaceName"))
913 .str();
914 std::string codeBlock =
915 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
916 cppClassName)
917 .str();
918
919 std::string desc =
920 processAndFormatDoc(def->getValueAsString("description"));
921 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
922 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
923 }
924 }
925
926 template <typename ConstraintT>
createODSNativePDLLConstraintDecl(StringRef name,StringRef codeBlock,SMRange loc,ast::Type type,StringRef nativeType,StringRef docString)927 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
928 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
929 StringRef nativeType, StringRef docString) {
930 // Build the single input parameter.
931 ast::DeclScope *argScope = pushDeclScope();
932 auto *paramVar = ast::VariableDecl::create(
933 ctx, ast::Name::create(ctx, "self", loc), type,
934 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
935 argScope->add(paramVar);
936 popDeclScope();
937
938 // Build the native constraint.
939 auto *constraintDecl = ast::UserConstraintDecl::createNative(
940 ctx, ast::Name::create(ctx, name, loc), paramVar,
941 /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
942 constraintDecl->setDocComment(ctx, docString);
943 curDeclScope->add(constraintDecl);
944 return constraintDecl;
945 }
946
947 template <typename ConstraintT>
948 ast::Decl *
createODSNativePDLLConstraintDecl(const tblgen::Constraint & constraint,SMRange loc,ast::Type type,StringRef nativeType)949 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
950 SMRange loc, ast::Type type,
951 StringRef nativeType) {
952 // Format the condition template.
953 tblgen::FmtContext fmtContext;
954 fmtContext.withSelf("self");
955 std::string codeBlock = tblgen::tgfmt(
956 "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
957 &fmtContext);
958
959 // If documentation was enabled, build the doc string for the generated
960 // constraint. It would be nice to do this lazily, but TableGen information is
961 // destroyed after we finish parsing the file.
962 std::string docString;
963 if (enableDocumentation) {
964 StringRef desc = constraint.getDescription();
965 docString = processAndFormatDoc(
966 constraint.getSummary() +
967 (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
968 }
969
970 return createODSNativePDLLConstraintDecl<ConstraintT>(
971 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
972 docString);
973 }
974
975 //===----------------------------------------------------------------------===//
976 // Decls
977
parseTopLevelDecl()978 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
979 FailureOr<ast::Decl *> decl;
980 switch (curToken.getKind()) {
981 case Token::kw_Constraint:
982 decl = parseUserConstraintDecl();
983 break;
984 case Token::kw_Pattern:
985 decl = parsePatternDecl();
986 break;
987 case Token::kw_Rewrite:
988 decl = parseUserRewriteDecl();
989 break;
990 default:
991 return emitError("expected top-level declaration, such as a `Pattern`");
992 }
993 if (failed(decl))
994 return failure();
995
996 // If the decl has a name, add it to the current scope.
997 if (const ast::Name *name = (*decl)->getName()) {
998 if (failed(checkDefineNamedDecl(*name)))
999 return failure();
1000 curDeclScope->add(*decl);
1001 }
1002 return decl;
1003 }
1004
1005 FailureOr<ast::NamedAttributeDecl *>
parseNamedAttributeDecl(Optional<StringRef> parentOpName)1006 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
1007 // Check for name code completion.
1008 if (curToken.is(Token::code_complete))
1009 return codeCompleteAttributeName(parentOpName);
1010
1011 std::string attrNameStr;
1012 if (curToken.isString())
1013 attrNameStr = curToken.getStringValue();
1014 else if (curToken.is(Token::identifier) || curToken.isKeyword())
1015 attrNameStr = curToken.getSpelling().str();
1016 else
1017 return emitError("expected identifier or string attribute name");
1018 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1019 consumeToken();
1020
1021 // Check for a value of the attribute.
1022 ast::Expr *attrValue = nullptr;
1023 if (consumeIf(Token::equal)) {
1024 FailureOr<ast::Expr *> attrExpr = parseExpr();
1025 if (failed(attrExpr))
1026 return failure();
1027 attrValue = *attrExpr;
1028 } else {
1029 // If there isn't a concrete value, create an expression representing a
1030 // UnitAttr.
1031 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1032 }
1033
1034 return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1035 }
1036
parseLambdaBody(function_ref<LogicalResult (ast::Stmt * &)> processStatementFn,bool expectTerminalSemicolon)1037 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1038 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1039 bool expectTerminalSemicolon) {
1040 consumeToken(Token::equal_arrow);
1041
1042 // Parse the single statement of the lambda body.
1043 SMLoc bodyStartLoc = curToken.getStartLoc();
1044 pushDeclScope();
1045 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1046 bool failedToParse =
1047 failed(singleStatement) || failed(processStatementFn(*singleStatement));
1048 popDeclScope();
1049 if (failedToParse)
1050 return failure();
1051
1052 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1053 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1054 }
1055
parseArgumentDecl()1056 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1057 // Ensure that the argument is named.
1058 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1059 return emitError("expected identifier argument name");
1060
1061 // Parse the argument similarly to a normal variable.
1062 StringRef name = curToken.getSpelling();
1063 SMRange nameLoc = curToken.getLoc();
1064 consumeToken();
1065
1066 if (failed(
1067 parseToken(Token::colon, "expected `:` before argument constraint")))
1068 return failure();
1069
1070 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1071 if (failed(cst))
1072 return failure();
1073
1074 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1075 }
1076
parseResultDecl(unsigned resultNum)1077 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1078 // Check to see if this result is named.
1079 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1080 // Check to see if this name actually refers to a Constraint.
1081 ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1082 if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1083 // If yes, and this is a Rewrite, give a nice error message as non-Core
1084 // constraints are not supported on Rewrite results.
1085 if (parserContext == ParserContext::Rewrite) {
1086 return emitError(
1087 "`Rewrite` results are only permitted to use core constraints, "
1088 "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1089 }
1090
1091 // Otherwise, parse this as an unnamed result variable.
1092 } else {
1093 // If it wasn't a constraint, parse the result similarly to a variable. If
1094 // there is already an existing decl, we will emit an error when defining
1095 // this variable later.
1096 StringRef name = curToken.getSpelling();
1097 SMRange nameLoc = curToken.getLoc();
1098 consumeToken();
1099
1100 if (failed(parseToken(Token::colon,
1101 "expected `:` before result constraint")))
1102 return failure();
1103
1104 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1105 if (failed(cst))
1106 return failure();
1107
1108 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1109 }
1110 }
1111
1112 // If it isn't named, we parse the constraint directly and create an unnamed
1113 // result variable.
1114 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1115 if (failed(cst))
1116 return failure();
1117
1118 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1119 }
1120
1121 FailureOr<ast::UserConstraintDecl *>
parseUserConstraintDecl(bool isInline)1122 Parser::parseUserConstraintDecl(bool isInline) {
1123 // Constraints and rewrites have very similar formats, dispatch to a shared
1124 // interface for parsing.
1125 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1126 [&](auto &&...args) {
1127 return this->parseUserPDLLConstraintDecl(args...);
1128 },
1129 ParserContext::Constraint, "constraint", isInline);
1130 }
1131
parseInlineUserConstraintDecl()1132 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1133 FailureOr<ast::UserConstraintDecl *> decl =
1134 parseUserConstraintDecl(/*isInline=*/true);
1135 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1136 return failure();
1137
1138 curDeclScope->add(*decl);
1139 return decl;
1140 }
1141
parseUserPDLLConstraintDecl(const ast::Name & name,bool isInline,ArrayRef<ast::VariableDecl * > arguments,ast::DeclScope * argumentScope,ArrayRef<ast::VariableDecl * > results,ast::Type resultType)1142 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1143 const ast::Name &name, bool isInline,
1144 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1145 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1146 // Push the argument scope back onto the list, so that the body can
1147 // reference arguments.
1148 pushDeclScope(argumentScope);
1149
1150 // Parse the body of the constraint. The body is either defined as a compound
1151 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1152 ast::CompoundStmt *body;
1153 if (curToken.is(Token::equal_arrow)) {
1154 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1155 [&](ast::Stmt *&stmt) -> LogicalResult {
1156 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1157 if (!stmtExpr) {
1158 return emitError(stmt->getLoc(),
1159 "expected `Constraint` lambda body to contain a "
1160 "single expression");
1161 }
1162 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1163 return success();
1164 },
1165 /*expectTerminalSemicolon=*/!isInline);
1166 if (failed(bodyResult))
1167 return failure();
1168 body = *bodyResult;
1169 } else {
1170 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1171 if (failed(bodyResult))
1172 return failure();
1173 body = *bodyResult;
1174
1175 // Verify the structure of the body.
1176 auto bodyIt = body->begin(), bodyE = body->end();
1177 for (; bodyIt != bodyE; ++bodyIt)
1178 if (isa<ast::ReturnStmt>(*bodyIt))
1179 break;
1180 if (failed(validateUserConstraintOrRewriteReturn(
1181 "Constraint", body, bodyIt, bodyE, results, resultType)))
1182 return failure();
1183 }
1184 popDeclScope();
1185
1186 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1187 name, arguments, results, resultType, body);
1188 }
1189
parseUserRewriteDecl(bool isInline)1190 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1191 // Constraints and rewrites have very similar formats, dispatch to a shared
1192 // interface for parsing.
1193 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1194 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1195 ParserContext::Rewrite, "rewrite", isInline);
1196 }
1197
parseInlineUserRewriteDecl()1198 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1199 FailureOr<ast::UserRewriteDecl *> decl =
1200 parseUserRewriteDecl(/*isInline=*/true);
1201 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1202 return failure();
1203
1204 curDeclScope->add(*decl);
1205 return decl;
1206 }
1207
parseUserPDLLRewriteDecl(const ast::Name & name,bool isInline,ArrayRef<ast::VariableDecl * > arguments,ast::DeclScope * argumentScope,ArrayRef<ast::VariableDecl * > results,ast::Type resultType)1208 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1209 const ast::Name &name, bool isInline,
1210 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1211 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1212 // Push the argument scope back onto the list, so that the body can
1213 // reference arguments.
1214 curDeclScope = argumentScope;
1215 ast::CompoundStmt *body;
1216 if (curToken.is(Token::equal_arrow)) {
1217 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1218 [&](ast::Stmt *&statement) -> LogicalResult {
1219 if (isa<ast::OpRewriteStmt>(statement))
1220 return success();
1221
1222 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1223 if (!statementExpr) {
1224 return emitError(
1225 statement->getLoc(),
1226 "expected `Rewrite` lambda body to contain a single expression "
1227 "or an operation rewrite statement; such as `erase`, "
1228 "`replace`, or `rewrite`");
1229 }
1230 statement =
1231 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1232 return success();
1233 },
1234 /*expectTerminalSemicolon=*/!isInline);
1235 if (failed(bodyResult))
1236 return failure();
1237 body = *bodyResult;
1238 } else {
1239 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1240 if (failed(bodyResult))
1241 return failure();
1242 body = *bodyResult;
1243 }
1244 popDeclScope();
1245
1246 // Verify the structure of the body.
1247 auto bodyIt = body->begin(), bodyE = body->end();
1248 for (; bodyIt != bodyE; ++bodyIt)
1249 if (isa<ast::ReturnStmt>(*bodyIt))
1250 break;
1251 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1252 bodyE, results, resultType)))
1253 return failure();
1254 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1255 name, arguments, results, resultType, body);
1256 }
1257
1258 template <typename T, typename ParseUserPDLLDeclFnT>
parseUserConstraintOrRewriteDecl(ParseUserPDLLDeclFnT && parseUserPDLLFn,ParserContext declContext,StringRef anonymousNamePrefix,bool isInline)1259 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1260 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1261 StringRef anonymousNamePrefix, bool isInline) {
1262 SMRange loc = curToken.getLoc();
1263 consumeToken();
1264 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1265
1266 // Parse the name of the decl.
1267 const ast::Name *name = nullptr;
1268 if (curToken.isNot(Token::identifier)) {
1269 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1270 // in C++, so being unnamed is fine.
1271 if (!isInline)
1272 return emitError("expected identifier name");
1273
1274 // Create a unique anonymous name to use, as the name for this decl is not
1275 // important.
1276 std::string anonName =
1277 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1278 anonymousDeclNameCounter++)
1279 .str();
1280 name = &ast::Name::create(ctx, anonName, loc);
1281 } else {
1282 // If a name was provided, we can use it directly.
1283 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1284 consumeToken(Token::identifier);
1285 }
1286
1287 // Parse the functional signature of the decl.
1288 SmallVector<ast::VariableDecl *> arguments, results;
1289 ast::DeclScope *argumentScope;
1290 ast::Type resultType;
1291 if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1292 argumentScope, resultType)))
1293 return failure();
1294
1295 // Check to see which type of constraint this is. If the constraint contains a
1296 // compound body, this is a PDLL decl.
1297 if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1298 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1299 resultType);
1300
1301 // Otherwise, this is a native decl.
1302 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1303 results, resultType);
1304 }
1305
1306 template <typename T>
parseUserNativeConstraintOrRewriteDecl(const ast::Name & name,bool isInline,ArrayRef<ast::VariableDecl * > arguments,ArrayRef<ast::VariableDecl * > results,ast::Type resultType)1307 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1308 const ast::Name &name, bool isInline,
1309 ArrayRef<ast::VariableDecl *> arguments,
1310 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1311 // If followed by a string, the native code body has also been specified.
1312 std::string codeStrStorage;
1313 Optional<StringRef> optCodeStr;
1314 if (curToken.isString()) {
1315 codeStrStorage = curToken.getStringValue();
1316 optCodeStr = codeStrStorage;
1317 consumeToken();
1318 } else if (isInline) {
1319 return emitError(name.getLoc(),
1320 "external declarations must be declared in global scope");
1321 } else if (curToken.is(Token::error)) {
1322 return failure();
1323 }
1324 if (failed(parseToken(Token::semicolon,
1325 "expected `;` after native declaration")))
1326 return failure();
1327 // TODO: PDL should be able to support constraint results in certain
1328 // situations, we should revise this.
1329 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1330 return emitError(
1331 "native Constraints currently do not support returning results");
1332 }
1333 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1334 }
1335
parseUserConstraintOrRewriteSignature(SmallVectorImpl<ast::VariableDecl * > & arguments,SmallVectorImpl<ast::VariableDecl * > & results,ast::DeclScope * & argumentScope,ast::Type & resultType)1336 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1337 SmallVectorImpl<ast::VariableDecl *> &arguments,
1338 SmallVectorImpl<ast::VariableDecl *> &results,
1339 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1340 // Parse the argument list of the decl.
1341 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1342 return failure();
1343
1344 argumentScope = pushDeclScope();
1345 if (curToken.isNot(Token::r_paren)) {
1346 do {
1347 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1348 if (failed(argument))
1349 return failure();
1350 arguments.emplace_back(*argument);
1351 } while (consumeIf(Token::comma));
1352 }
1353 popDeclScope();
1354 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1355 return failure();
1356
1357 // Parse the results of the decl.
1358 pushDeclScope();
1359 if (consumeIf(Token::arrow)) {
1360 auto parseResultFn = [&]() -> LogicalResult {
1361 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1362 if (failed(result))
1363 return failure();
1364 results.emplace_back(*result);
1365 return success();
1366 };
1367
1368 // Check for a list of results.
1369 if (consumeIf(Token::l_paren)) {
1370 do {
1371 if (failed(parseResultFn()))
1372 return failure();
1373 } while (consumeIf(Token::comma));
1374 if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1375 return failure();
1376
1377 // Otherwise, there is only one result.
1378 } else if (failed(parseResultFn())) {
1379 return failure();
1380 }
1381 }
1382 popDeclScope();
1383
1384 // Compute the result type of the decl.
1385 resultType = createUserConstraintRewriteResultType(results);
1386
1387 // Verify that results are only named if there are more than one.
1388 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1389 return emitError(
1390 results.front()->getLoc(),
1391 "cannot create a single-element tuple with an element label");
1392 }
1393 return success();
1394 }
1395
validateUserConstraintOrRewriteReturn(StringRef declType,ast::CompoundStmt * body,ArrayRef<ast::Stmt * >::iterator bodyIt,ArrayRef<ast::Stmt * >::iterator bodyE,ArrayRef<ast::VariableDecl * > results,ast::Type & resultType)1396 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1397 StringRef declType, ast::CompoundStmt *body,
1398 ArrayRef<ast::Stmt *>::iterator bodyIt,
1399 ArrayRef<ast::Stmt *>::iterator bodyE,
1400 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1401 // Handle if a `return` was provided.
1402 if (bodyIt != bodyE) {
1403 // Emit an error if we have trailing statements after the return.
1404 if (std::next(bodyIt) != bodyE) {
1405 return emitError(
1406 (*std::next(bodyIt))->getLoc(),
1407 llvm::formatv("`return` terminated the `{0}` body, but found "
1408 "trailing statements afterwards",
1409 declType));
1410 }
1411
1412 // Otherwise if a return wasn't provided, check that no results are
1413 // expected.
1414 } else if (!results.empty()) {
1415 return emitError(
1416 {body->getLoc().End, body->getLoc().End},
1417 llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1418 declType, resultType));
1419 }
1420 return success();
1421 }
1422
parsePatternLambdaBody()1423 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1424 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1425 if (isa<ast::OpRewriteStmt>(statement))
1426 return success();
1427 return emitError(
1428 statement->getLoc(),
1429 "expected Pattern lambda body to contain a single operation "
1430 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1431 });
1432 }
1433
parsePatternDecl()1434 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1435 SMRange loc = curToken.getLoc();
1436 consumeToken(Token::kw_Pattern);
1437 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1438 ParserContext::PatternMatch);
1439
1440 // Check for an optional identifier for the pattern name.
1441 const ast::Name *name = nullptr;
1442 if (curToken.is(Token::identifier)) {
1443 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1444 consumeToken(Token::identifier);
1445 }
1446
1447 // Parse any pattern metadata.
1448 ParsedPatternMetadata metadata;
1449 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1450 return failure();
1451
1452 // Parse the pattern body.
1453 ast::CompoundStmt *body;
1454
1455 // Handle a lambda body.
1456 if (curToken.is(Token::equal_arrow)) {
1457 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1458 if (failed(bodyResult))
1459 return failure();
1460 body = *bodyResult;
1461 } else {
1462 if (curToken.isNot(Token::l_brace))
1463 return emitError("expected `{` or `=>` to start pattern body");
1464 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1465 if (failed(bodyResult))
1466 return failure();
1467 body = *bodyResult;
1468
1469 // Verify the body of the pattern.
1470 auto bodyIt = body->begin(), bodyE = body->end();
1471 for (; bodyIt != bodyE; ++bodyIt) {
1472 if (isa<ast::ReturnStmt>(*bodyIt)) {
1473 return emitError((*bodyIt)->getLoc(),
1474 "`return` statements are only permitted within a "
1475 "`Constraint` or `Rewrite` body");
1476 }
1477 // Break when we've found the rewrite statement.
1478 if (isa<ast::OpRewriteStmt>(*bodyIt))
1479 break;
1480 }
1481 if (bodyIt == bodyE) {
1482 return emitError(loc,
1483 "expected Pattern body to terminate with an operation "
1484 "rewrite statement, such as `erase`");
1485 }
1486 if (std::next(bodyIt) != bodyE) {
1487 return emitError((*std::next(bodyIt))->getLoc(),
1488 "Pattern body was terminated by an operation "
1489 "rewrite statement, but found trailing statements");
1490 }
1491 }
1492
1493 return createPatternDecl(loc, name, metadata, body);
1494 }
1495
1496 LogicalResult
parsePatternDeclMetadata(ParsedPatternMetadata & metadata)1497 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1498 Optional<SMRange> benefitLoc;
1499 Optional<SMRange> hasBoundedRecursionLoc;
1500
1501 do {
1502 // Handle metadata code completion.
1503 if (curToken.is(Token::code_complete))
1504 return codeCompletePatternMetadata();
1505
1506 if (curToken.isNot(Token::identifier))
1507 return emitError("expected pattern metadata identifier");
1508 StringRef metadataStr = curToken.getSpelling();
1509 SMRange metadataLoc = curToken.getLoc();
1510 consumeToken(Token::identifier);
1511
1512 // Parse the benefit metadata: benefit(<integer-value>)
1513 if (metadataStr == "benefit") {
1514 if (benefitLoc) {
1515 return emitErrorAndNote(metadataLoc,
1516 "pattern benefit has already been specified",
1517 *benefitLoc, "see previous definition here");
1518 }
1519 if (failed(parseToken(Token::l_paren,
1520 "expected `(` before pattern benefit")))
1521 return failure();
1522
1523 uint16_t benefitValue = 0;
1524 if (curToken.isNot(Token::integer))
1525 return emitError("expected integral pattern benefit");
1526 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1527 return emitError(
1528 "expected pattern benefit to fit within a 16-bit integer");
1529 consumeToken(Token::integer);
1530
1531 metadata.benefit = benefitValue;
1532 benefitLoc = metadataLoc;
1533
1534 if (failed(
1535 parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1536 return failure();
1537 continue;
1538 }
1539
1540 // Parse the bounded recursion metadata: recursion
1541 if (metadataStr == "recursion") {
1542 if (hasBoundedRecursionLoc) {
1543 return emitErrorAndNote(
1544 metadataLoc,
1545 "pattern recursion metadata has already been specified",
1546 *hasBoundedRecursionLoc, "see previous definition here");
1547 }
1548 metadata.hasBoundedRecursion = true;
1549 hasBoundedRecursionLoc = metadataLoc;
1550 continue;
1551 }
1552
1553 return emitError(metadataLoc, "unknown pattern metadata");
1554 } while (consumeIf(Token::comma));
1555
1556 return success();
1557 }
1558
parseTypeConstraintExpr()1559 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1560 consumeToken(Token::less);
1561
1562 FailureOr<ast::Expr *> typeExpr = parseExpr();
1563 if (failed(typeExpr) ||
1564 failed(parseToken(Token::greater,
1565 "expected `>` after variable type constraint")))
1566 return failure();
1567 return typeExpr;
1568 }
1569
checkDefineNamedDecl(const ast::Name & name)1570 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1571 assert(curDeclScope && "defining decl outside of a decl scope");
1572 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1573 return emitErrorAndNote(
1574 name.getLoc(), "`" + name.getName() + "` has already been defined",
1575 lastDecl->getName()->getLoc(), "see previous definition here");
1576 }
1577 return success();
1578 }
1579
1580 FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name,SMRange nameLoc,ast::Type type,ast::Expr * initExpr,ArrayRef<ast::ConstraintRef> constraints)1581 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1582 ast::Expr *initExpr,
1583 ArrayRef<ast::ConstraintRef> constraints) {
1584 assert(curDeclScope && "defining variable outside of decl scope");
1585 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1586
1587 // If the name of the variable indicates a special variable, we don't add it
1588 // to the scope. This variable is local to the definition point.
1589 if (name.empty() || name == "_") {
1590 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1591 constraints);
1592 }
1593 if (failed(checkDefineNamedDecl(nameDecl)))
1594 return failure();
1595
1596 auto *varDecl =
1597 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1598 curDeclScope->add(varDecl);
1599 return varDecl;
1600 }
1601
1602 FailureOr<ast::VariableDecl *>
defineVariableDecl(StringRef name,SMRange nameLoc,ast::Type type,ArrayRef<ast::ConstraintRef> constraints)1603 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1604 ArrayRef<ast::ConstraintRef> constraints) {
1605 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1606 constraints);
1607 }
1608
parseVariableDeclConstraintList(SmallVectorImpl<ast::ConstraintRef> & constraints)1609 LogicalResult Parser::parseVariableDeclConstraintList(
1610 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1611 Optional<SMRange> typeConstraint;
1612 auto parseSingleConstraint = [&] {
1613 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1614 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1615 /*allowNonCoreConstraints=*/true);
1616 if (failed(constraint))
1617 return failure();
1618 constraints.push_back(*constraint);
1619 return success();
1620 };
1621
1622 // Check to see if this is a single constraint, or a list.
1623 if (!consumeIf(Token::l_square))
1624 return parseSingleConstraint();
1625
1626 do {
1627 if (failed(parseSingleConstraint()))
1628 return failure();
1629 } while (consumeIf(Token::comma));
1630 return parseToken(Token::r_square, "expected `]` after constraint list");
1631 }
1632
1633 FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> & typeConstraint,ArrayRef<ast::ConstraintRef> existingConstraints,bool allowInlineTypeConstraints,bool allowNonCoreConstraints)1634 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1635 ArrayRef<ast::ConstraintRef> existingConstraints,
1636 bool allowInlineTypeConstraints,
1637 bool allowNonCoreConstraints) {
1638 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1639 if (!allowInlineTypeConstraints) {
1640 return emitError(
1641 curToken.getLoc(),
1642 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1643 "permitted on arguments or results");
1644 }
1645 if (typeConstraint)
1646 return emitErrorAndNote(
1647 curToken.getLoc(),
1648 "the type of this variable has already been constrained",
1649 *typeConstraint, "see previous constraint location here");
1650 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1651 if (failed(constraintExpr))
1652 return failure();
1653 typeExpr = *constraintExpr;
1654 typeConstraint = typeExpr->getLoc();
1655 return success();
1656 };
1657
1658 SMRange loc = curToken.getLoc();
1659 switch (curToken.getKind()) {
1660 case Token::kw_Attr: {
1661 consumeToken(Token::kw_Attr);
1662
1663 // Check for a type constraint.
1664 ast::Expr *typeExpr = nullptr;
1665 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1666 return failure();
1667 return ast::ConstraintRef(
1668 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1669 }
1670 case Token::kw_Op: {
1671 consumeToken(Token::kw_Op);
1672
1673 // Parse an optional operation name. If the name isn't provided, this refers
1674 // to "any" operation.
1675 FailureOr<ast::OpNameDecl *> opName =
1676 parseWrappedOperationName(/*allowEmptyName=*/true);
1677 if (failed(opName))
1678 return failure();
1679
1680 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1681 loc);
1682 }
1683 case Token::kw_Type:
1684 consumeToken(Token::kw_Type);
1685 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1686 case Token::kw_TypeRange:
1687 consumeToken(Token::kw_TypeRange);
1688 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1689 loc);
1690 case Token::kw_Value: {
1691 consumeToken(Token::kw_Value);
1692
1693 // Check for a type constraint.
1694 ast::Expr *typeExpr = nullptr;
1695 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1696 return failure();
1697
1698 return ast::ConstraintRef(
1699 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1700 }
1701 case Token::kw_ValueRange: {
1702 consumeToken(Token::kw_ValueRange);
1703
1704 // Check for a type constraint.
1705 ast::Expr *typeExpr = nullptr;
1706 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1707 return failure();
1708
1709 return ast::ConstraintRef(
1710 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1711 }
1712
1713 case Token::kw_Constraint: {
1714 // Handle an inline constraint.
1715 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1716 if (failed(decl))
1717 return failure();
1718 return ast::ConstraintRef(*decl, loc);
1719 }
1720 case Token::identifier: {
1721 StringRef constraintName = curToken.getSpelling();
1722 consumeToken(Token::identifier);
1723
1724 // Lookup the referenced constraint.
1725 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1726 if (!cstDecl) {
1727 return emitError(loc, "unknown reference to constraint `" +
1728 constraintName + "`");
1729 }
1730
1731 // Handle a reference to a proper constraint.
1732 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1733 return ast::ConstraintRef(cst, loc);
1734
1735 return emitErrorAndNote(
1736 loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1737 "see the definition of `" + constraintName + "` here");
1738 }
1739 // Handle single entity constraint code completion.
1740 case Token::code_complete: {
1741 // Try to infer the current type for use by code completion.
1742 ast::Type inferredType;
1743 if (failed(validateVariableConstraints(existingConstraints, inferredType,
1744 allowNonCoreConstraints)))
1745 return failure();
1746
1747 return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1748 allowInlineTypeConstraints);
1749 }
1750 default:
1751 break;
1752 }
1753 return emitError(loc, "expected identifier constraint");
1754 }
1755
parseArgOrResultConstraint()1756 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1757 // Constraint arguments may apply more complex constraints via the arguments.
1758 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1759
1760 Optional<SMRange> typeConstraint;
1761 return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1762 /*allowInlineTypeConstraints=*/false,
1763 allowNonCoreConstraints);
1764 }
1765
1766 //===----------------------------------------------------------------------===//
1767 // Exprs
1768
parseExpr()1769 FailureOr<ast::Expr *> Parser::parseExpr() {
1770 if (curToken.is(Token::underscore))
1771 return parseUnderscoreExpr();
1772
1773 // Parse the LHS expression.
1774 FailureOr<ast::Expr *> lhsExpr;
1775 switch (curToken.getKind()) {
1776 case Token::kw_attr:
1777 lhsExpr = parseAttributeExpr();
1778 break;
1779 case Token::kw_Constraint:
1780 lhsExpr = parseInlineConstraintLambdaExpr();
1781 break;
1782 case Token::identifier:
1783 lhsExpr = parseIdentifierExpr();
1784 break;
1785 case Token::kw_op:
1786 lhsExpr = parseOperationExpr();
1787 break;
1788 case Token::kw_Rewrite:
1789 lhsExpr = parseInlineRewriteLambdaExpr();
1790 break;
1791 case Token::kw_type:
1792 lhsExpr = parseTypeExpr();
1793 break;
1794 case Token::l_paren:
1795 lhsExpr = parseTupleExpr();
1796 break;
1797 default:
1798 return emitError("expected expression");
1799 }
1800 if (failed(lhsExpr))
1801 return failure();
1802
1803 // Check for an operator expression.
1804 while (true) {
1805 switch (curToken.getKind()) {
1806 case Token::dot:
1807 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1808 break;
1809 case Token::l_paren:
1810 lhsExpr = parseCallExpr(*lhsExpr);
1811 break;
1812 default:
1813 return lhsExpr;
1814 }
1815 if (failed(lhsExpr))
1816 return failure();
1817 }
1818 }
1819
parseAttributeExpr()1820 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1821 SMRange loc = curToken.getLoc();
1822 consumeToken(Token::kw_attr);
1823
1824 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1825 // identifier.
1826 if (!consumeIf(Token::less)) {
1827 resetToken(loc);
1828 return parseIdentifierExpr();
1829 }
1830
1831 if (!curToken.isString())
1832 return emitError("expected string literal containing MLIR attribute");
1833 std::string attrExpr = curToken.getStringValue();
1834 consumeToken();
1835
1836 loc.End = curToken.getEndLoc();
1837 if (failed(
1838 parseToken(Token::greater, "expected `>` after attribute literal")))
1839 return failure();
1840 return ast::AttributeExpr::create(ctx, loc, attrExpr);
1841 }
1842
parseCallExpr(ast::Expr * parentExpr)1843 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1844 consumeToken(Token::l_paren);
1845
1846 // Parse the arguments of the call.
1847 SmallVector<ast::Expr *> arguments;
1848 if (curToken.isNot(Token::r_paren)) {
1849 do {
1850 // Handle code completion for the call arguments.
1851 if (curToken.is(Token::code_complete)) {
1852 codeCompleteCallSignature(parentExpr, arguments.size());
1853 return failure();
1854 }
1855
1856 FailureOr<ast::Expr *> argument = parseExpr();
1857 if (failed(argument))
1858 return failure();
1859 arguments.push_back(*argument);
1860 } while (consumeIf(Token::comma));
1861 }
1862
1863 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1864 if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1865 return failure();
1866
1867 return createCallExpr(loc, parentExpr, arguments);
1868 }
1869
parseDeclRefExpr(StringRef name,SMRange loc)1870 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1871 ast::Decl *decl = curDeclScope->lookup(name);
1872 if (!decl)
1873 return emitError(loc, "undefined reference to `" + name + "`");
1874
1875 return createDeclRefExpr(loc, decl);
1876 }
1877
parseIdentifierExpr()1878 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1879 StringRef name = curToken.getSpelling();
1880 SMRange nameLoc = curToken.getLoc();
1881 consumeToken();
1882
1883 // Check to see if this is a decl ref expression that defines a variable
1884 // inline.
1885 if (consumeIf(Token::colon)) {
1886 SmallVector<ast::ConstraintRef> constraints;
1887 if (failed(parseVariableDeclConstraintList(constraints)))
1888 return failure();
1889 ast::Type type;
1890 if (failed(validateVariableConstraints(constraints, type)))
1891 return failure();
1892 return createInlineVariableExpr(type, name, nameLoc, constraints);
1893 }
1894
1895 return parseDeclRefExpr(name, nameLoc);
1896 }
1897
parseInlineConstraintLambdaExpr()1898 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1899 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1900 if (failed(decl))
1901 return failure();
1902
1903 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1904 ast::ConstraintType::get(ctx));
1905 }
1906
parseInlineRewriteLambdaExpr()1907 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1908 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1909 if (failed(decl))
1910 return failure();
1911
1912 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1913 ast::RewriteType::get(ctx));
1914 }
1915
parseMemberAccessExpr(ast::Expr * parentExpr)1916 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1917 SMRange dotLoc = curToken.getLoc();
1918 consumeToken(Token::dot);
1919
1920 // Check for code completion of the member name.
1921 if (curToken.is(Token::code_complete))
1922 return codeCompleteMemberAccess(parentExpr);
1923
1924 // Parse the member name.
1925 Token memberNameTok = curToken;
1926 if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1927 !memberNameTok.isKeyword())
1928 return emitError(dotLoc, "expected identifier or numeric member name");
1929 StringRef memberName = memberNameTok.getSpelling();
1930 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1931 consumeToken();
1932
1933 return createMemberAccessExpr(parentExpr, memberName, loc);
1934 }
1935
parseOperationName(bool allowEmptyName)1936 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1937 SMRange loc = curToken.getLoc();
1938
1939 // Check for code completion for the dialect name.
1940 if (curToken.is(Token::code_complete))
1941 return codeCompleteDialectName();
1942
1943 // Handle the case of an no operation name.
1944 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1945 if (allowEmptyName)
1946 return ast::OpNameDecl::create(ctx, SMRange());
1947 return emitError("expected dialect namespace");
1948 }
1949 StringRef name = curToken.getSpelling();
1950 consumeToken();
1951
1952 // Otherwise, this is a literal operation name.
1953 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1954 return failure();
1955
1956 // Check for code completion for the operation name.
1957 if (curToken.is(Token::code_complete))
1958 return codeCompleteOperationName(name);
1959
1960 if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1961 return emitError("expected operation name after dialect namespace");
1962
1963 name = StringRef(name.data(), name.size() + 1);
1964 do {
1965 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1966 loc.End = curToken.getEndLoc();
1967 consumeToken();
1968 } while (curToken.isAny(Token::identifier, Token::dot) ||
1969 curToken.isKeyword());
1970 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1971 }
1972
1973 FailureOr<ast::OpNameDecl *>
parseWrappedOperationName(bool allowEmptyName)1974 Parser::parseWrappedOperationName(bool allowEmptyName) {
1975 if (!consumeIf(Token::less))
1976 return ast::OpNameDecl::create(ctx, SMRange());
1977
1978 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1979 if (failed(opNameDecl))
1980 return failure();
1981
1982 if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1983 return failure();
1984 return opNameDecl;
1985 }
1986
1987 FailureOr<ast::Expr *>
parseOperationExpr(OpResultTypeContext inputResultTypeContext)1988 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
1989 SMRange loc = curToken.getLoc();
1990 consumeToken(Token::kw_op);
1991
1992 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1993 // identifier.
1994 if (curToken.isNot(Token::less)) {
1995 resetToken(loc);
1996 return parseIdentifierExpr();
1997 }
1998
1999 // Parse the operation name. The name may be elided, in which case the
2000 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2001 // `Operation*`). Operation names within a rewrite context must be named.
2002 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2003 FailureOr<ast::OpNameDecl *> opNameDecl =
2004 parseWrappedOperationName(allowEmptyName);
2005 if (failed(opNameDecl))
2006 return failure();
2007 Optional<StringRef> opName = (*opNameDecl)->getName();
2008
2009 // Functor used to create an implicit range variable, used for implicit "all"
2010 // operand or results variables.
2011 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2012 FailureOr<ast::VariableDecl *> rangeVar =
2013 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2014 assert(succeeded(rangeVar) && "expected range variable to be valid");
2015 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2016 };
2017
2018 // Check for the optional list of operands.
2019 SmallVector<ast::Expr *> operands;
2020 if (!consumeIf(Token::l_paren)) {
2021 // If the operand list isn't specified and we are in a match context, define
2022 // an inplace unconstrained operand range corresponding to all of the
2023 // operands of the operation. This avoids treating zero operands the same
2024 // way as "unconstrained operands".
2025 if (parserContext != ParserContext::Rewrite) {
2026 operands.push_back(createImplicitRangeVar(
2027 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2028 }
2029 } else if (!consumeIf(Token::r_paren)) {
2030 // If the operand list was specified and non-empty, parse the operands.
2031 do {
2032 // Check for operand signature code completion.
2033 if (curToken.is(Token::code_complete)) {
2034 codeCompleteOperationOperandsSignature(opName, operands.size());
2035 return failure();
2036 }
2037
2038 FailureOr<ast::Expr *> operand = parseExpr();
2039 if (failed(operand))
2040 return failure();
2041 operands.push_back(*operand);
2042 } while (consumeIf(Token::comma));
2043
2044 if (failed(parseToken(Token::r_paren,
2045 "expected `)` after operation operand list")))
2046 return failure();
2047 }
2048
2049 // Check for the optional list of attributes.
2050 SmallVector<ast::NamedAttributeDecl *> attributes;
2051 if (consumeIf(Token::l_brace)) {
2052 do {
2053 FailureOr<ast::NamedAttributeDecl *> decl =
2054 parseNamedAttributeDecl(opName);
2055 if (failed(decl))
2056 return failure();
2057 attributes.emplace_back(*decl);
2058 } while (consumeIf(Token::comma));
2059
2060 if (failed(parseToken(Token::r_brace,
2061 "expected `}` after operation attribute list")))
2062 return failure();
2063 }
2064
2065 // Handle the result types of the operation.
2066 SmallVector<ast::Expr *> resultTypes;
2067 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2068
2069 // Check for an explicit list of result types.
2070 if (consumeIf(Token::arrow)) {
2071 if (failed(parseToken(Token::l_paren,
2072 "expected `(` before operation result type list")))
2073 return failure();
2074
2075 // If result types are provided, initially assume that the operation does
2076 // not rely on type inferrence. We don't assert that it isn't, because we
2077 // may be inferring the value of some type/type range variables, but given
2078 // that these variables may be defined in calls we can't always discern when
2079 // this is the case.
2080 resultTypeContext = OpResultTypeContext::Explicit;
2081
2082 // Handle the case of an empty result list.
2083 if (!consumeIf(Token::r_paren)) {
2084 do {
2085 // Check for result signature code completion.
2086 if (curToken.is(Token::code_complete)) {
2087 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2088 return failure();
2089 }
2090
2091 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2092 if (failed(resultTypeExpr))
2093 return failure();
2094 resultTypes.push_back(*resultTypeExpr);
2095 } while (consumeIf(Token::comma));
2096
2097 if (failed(parseToken(Token::r_paren,
2098 "expected `)` after operation result type list")))
2099 return failure();
2100 }
2101 } else if (parserContext != ParserContext::Rewrite) {
2102 // If the result list isn't specified and we are in a match context, define
2103 // an inplace unconstrained result range corresponding to all of the results
2104 // of the operation. This avoids treating zero results the same way as
2105 // "unconstrained results".
2106 resultTypes.push_back(createImplicitRangeVar(
2107 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2108 } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2109 // If the result list isn't specified and we are in a rewrite, try to infer
2110 // them at runtime instead.
2111 resultTypeContext = OpResultTypeContext::Interface;
2112 }
2113
2114 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2115 attributes, resultTypes);
2116 }
2117
parseTupleExpr()2118 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2119 SMRange loc = curToken.getLoc();
2120 consumeToken(Token::l_paren);
2121
2122 DenseMap<StringRef, SMRange> usedNames;
2123 SmallVector<StringRef> elementNames;
2124 SmallVector<ast::Expr *> elements;
2125 if (curToken.isNot(Token::r_paren)) {
2126 do {
2127 // Check for the optional element name assignment before the value.
2128 StringRef elementName;
2129 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2130 Token elementNameTok = curToken;
2131 consumeToken();
2132
2133 // The element name is only present if followed by an `=`.
2134 if (consumeIf(Token::equal)) {
2135 elementName = elementNameTok.getSpelling();
2136
2137 // Check to see if this name is already used.
2138 auto elementNameIt =
2139 usedNames.try_emplace(elementName, elementNameTok.getLoc());
2140 if (!elementNameIt.second) {
2141 return emitErrorAndNote(
2142 elementNameTok.getLoc(),
2143 llvm::formatv("duplicate tuple element label `{0}`",
2144 elementName),
2145 elementNameIt.first->getSecond(),
2146 "see previous label use here");
2147 }
2148 } else {
2149 // Otherwise, we treat this as part of an expression so reset the
2150 // lexer.
2151 resetToken(elementNameTok.getLoc());
2152 }
2153 }
2154 elementNames.push_back(elementName);
2155
2156 // Parse the tuple element value.
2157 FailureOr<ast::Expr *> element = parseExpr();
2158 if (failed(element))
2159 return failure();
2160 elements.push_back(*element);
2161 } while (consumeIf(Token::comma));
2162 }
2163 loc.End = curToken.getEndLoc();
2164 if (failed(
2165 parseToken(Token::r_paren, "expected `)` after tuple element list")))
2166 return failure();
2167 return createTupleExpr(loc, elements, elementNames);
2168 }
2169
parseTypeExpr()2170 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2171 SMRange loc = curToken.getLoc();
2172 consumeToken(Token::kw_type);
2173
2174 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2175 // identifier.
2176 if (!consumeIf(Token::less)) {
2177 resetToken(loc);
2178 return parseIdentifierExpr();
2179 }
2180
2181 if (!curToken.isString())
2182 return emitError("expected string literal containing MLIR type");
2183 std::string attrExpr = curToken.getStringValue();
2184 consumeToken();
2185
2186 loc.End = curToken.getEndLoc();
2187 if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2188 return failure();
2189 return ast::TypeExpr::create(ctx, loc, attrExpr);
2190 }
2191
parseUnderscoreExpr()2192 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2193 StringRef name = curToken.getSpelling();
2194 SMRange nameLoc = curToken.getLoc();
2195 consumeToken(Token::underscore);
2196
2197 // Underscore expressions require a constraint list.
2198 if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2199 return failure();
2200
2201 // Parse the constraints for the expression.
2202 SmallVector<ast::ConstraintRef> constraints;
2203 if (failed(parseVariableDeclConstraintList(constraints)))
2204 return failure();
2205
2206 ast::Type type;
2207 if (failed(validateVariableConstraints(constraints, type)))
2208 return failure();
2209 return createInlineVariableExpr(type, name, nameLoc, constraints);
2210 }
2211
2212 //===----------------------------------------------------------------------===//
2213 // Stmts
2214
parseStmt(bool expectTerminalSemicolon)2215 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2216 FailureOr<ast::Stmt *> stmt;
2217 switch (curToken.getKind()) {
2218 case Token::kw_erase:
2219 stmt = parseEraseStmt();
2220 break;
2221 case Token::kw_let:
2222 stmt = parseLetStmt();
2223 break;
2224 case Token::kw_replace:
2225 stmt = parseReplaceStmt();
2226 break;
2227 case Token::kw_return:
2228 stmt = parseReturnStmt();
2229 break;
2230 case Token::kw_rewrite:
2231 stmt = parseRewriteStmt();
2232 break;
2233 default:
2234 stmt = parseExpr();
2235 break;
2236 }
2237 if (failed(stmt) ||
2238 (expectTerminalSemicolon &&
2239 failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2240 return failure();
2241 return stmt;
2242 }
2243
parseCompoundStmt()2244 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2245 SMLoc startLoc = curToken.getStartLoc();
2246 consumeToken(Token::l_brace);
2247
2248 // Push a new block scope and parse any nested statements.
2249 pushDeclScope();
2250 SmallVector<ast::Stmt *> statements;
2251 while (curToken.isNot(Token::r_brace)) {
2252 FailureOr<ast::Stmt *> statement = parseStmt();
2253 if (failed(statement))
2254 return popDeclScope(), failure();
2255 statements.push_back(*statement);
2256 }
2257 popDeclScope();
2258
2259 // Consume the end brace.
2260 SMRange location(startLoc, curToken.getEndLoc());
2261 consumeToken(Token::r_brace);
2262
2263 return ast::CompoundStmt::create(ctx, location, statements);
2264 }
2265
parseEraseStmt()2266 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2267 if (parserContext == ParserContext::Constraint)
2268 return emitError("`erase` cannot be used within a Constraint");
2269 SMRange loc = curToken.getLoc();
2270 consumeToken(Token::kw_erase);
2271
2272 // Parse the root operation expression.
2273 FailureOr<ast::Expr *> rootOp = parseExpr();
2274 if (failed(rootOp))
2275 return failure();
2276
2277 return createEraseStmt(loc, *rootOp);
2278 }
2279
parseLetStmt()2280 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2281 SMRange loc = curToken.getLoc();
2282 consumeToken(Token::kw_let);
2283
2284 // Parse the name of the new variable.
2285 SMRange varLoc = curToken.getLoc();
2286 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2287 // `_` is a reserved variable name.
2288 if (curToken.is(Token::underscore)) {
2289 return emitError(varLoc,
2290 "`_` may only be used to define \"inline\" variables");
2291 }
2292 return emitError(varLoc,
2293 "expected identifier after `let` to name a new variable");
2294 }
2295 StringRef varName = curToken.getSpelling();
2296 consumeToken();
2297
2298 // Parse the optional set of constraints.
2299 SmallVector<ast::ConstraintRef> constraints;
2300 if (consumeIf(Token::colon) &&
2301 failed(parseVariableDeclConstraintList(constraints)))
2302 return failure();
2303
2304 // Parse the optional initializer expression.
2305 ast::Expr *initializer = nullptr;
2306 if (consumeIf(Token::equal)) {
2307 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2308 if (failed(initOrFailure))
2309 return failure();
2310 initializer = *initOrFailure;
2311
2312 // Check that the constraints are compatible with having an initializer,
2313 // e.g. type constraints cannot be used with initializers.
2314 for (ast::ConstraintRef constraint : constraints) {
2315 LogicalResult result =
2316 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2317 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2318 ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2319 if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2320 return this->emitError(
2321 constraint.referenceLoc,
2322 "type constraints are not permitted on variables with "
2323 "initializers");
2324 }
2325 return success();
2326 })
2327 .Default(success());
2328 if (failed(result))
2329 return failure();
2330 }
2331 }
2332
2333 FailureOr<ast::VariableDecl *> varDecl =
2334 createVariableDecl(varName, varLoc, initializer, constraints);
2335 if (failed(varDecl))
2336 return failure();
2337 return ast::LetStmt::create(ctx, loc, *varDecl);
2338 }
2339
parseReplaceStmt()2340 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2341 if (parserContext == ParserContext::Constraint)
2342 return emitError("`replace` cannot be used within a Constraint");
2343 SMRange loc = curToken.getLoc();
2344 consumeToken(Token::kw_replace);
2345
2346 // Parse the root operation expression.
2347 FailureOr<ast::Expr *> rootOp = parseExpr();
2348 if (failed(rootOp))
2349 return failure();
2350
2351 if (failed(
2352 parseToken(Token::kw_with, "expected `with` after root operation")))
2353 return failure();
2354
2355 // The replacement portion of this statement is within a rewrite context.
2356 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2357 ParserContext::Rewrite);
2358
2359 // Parse the replacement values.
2360 SmallVector<ast::Expr *> replValues;
2361 if (consumeIf(Token::l_paren)) {
2362 if (consumeIf(Token::r_paren)) {
2363 return emitError(
2364 loc, "expected at least one replacement value, consider using "
2365 "`erase` if no replacement values are desired");
2366 }
2367
2368 do {
2369 FailureOr<ast::Expr *> replExpr = parseExpr();
2370 if (failed(replExpr))
2371 return failure();
2372 replValues.emplace_back(*replExpr);
2373 } while (consumeIf(Token::comma));
2374
2375 if (failed(parseToken(Token::r_paren,
2376 "expected `)` after replacement values")))
2377 return failure();
2378 } else {
2379 // Handle replacement with an operation uniquely, as the replacement
2380 // operation supports type inferrence from the root operation.
2381 FailureOr<ast::Expr *> replExpr;
2382 if (curToken.is(Token::kw_op))
2383 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2384 else
2385 replExpr = parseExpr();
2386 if (failed(replExpr))
2387 return failure();
2388 replValues.emplace_back(*replExpr);
2389 }
2390
2391 return createReplaceStmt(loc, *rootOp, replValues);
2392 }
2393
parseReturnStmt()2394 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2395 SMRange loc = curToken.getLoc();
2396 consumeToken(Token::kw_return);
2397
2398 // Parse the result value.
2399 FailureOr<ast::Expr *> resultExpr = parseExpr();
2400 if (failed(resultExpr))
2401 return failure();
2402
2403 return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2404 }
2405
parseRewriteStmt()2406 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2407 if (parserContext == ParserContext::Constraint)
2408 return emitError("`rewrite` cannot be used within a Constraint");
2409 SMRange loc = curToken.getLoc();
2410 consumeToken(Token::kw_rewrite);
2411
2412 // Parse the root operation.
2413 FailureOr<ast::Expr *> rootOp = parseExpr();
2414 if (failed(rootOp))
2415 return failure();
2416
2417 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2418 return failure();
2419
2420 if (curToken.isNot(Token::l_brace))
2421 return emitError("expected `{` to start rewrite body");
2422
2423 // The rewrite body of this statement is within a rewrite context.
2424 llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2425 ParserContext::Rewrite);
2426
2427 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2428 if (failed(rewriteBody))
2429 return failure();
2430
2431 // Verify the rewrite body.
2432 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2433 if (isa<ast::ReturnStmt>(stmt)) {
2434 return emitError(stmt->getLoc(),
2435 "`return` statements are only permitted within a "
2436 "`Constraint` or `Rewrite` body");
2437 }
2438 }
2439
2440 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2441 }
2442
2443 //===----------------------------------------------------------------------===//
2444 // Creation+Analysis
2445 //===----------------------------------------------------------------------===//
2446
2447 //===----------------------------------------------------------------------===//
2448 // Decls
2449
tryExtractCallableDecl(ast::Node * node)2450 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2451 // Unwrap reference expressions.
2452 if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2453 node = init->getDecl();
2454 return dyn_cast<ast::CallableDecl>(node);
2455 }
2456
2457 FailureOr<ast::PatternDecl *>
createPatternDecl(SMRange loc,const ast::Name * name,const ParsedPatternMetadata & metadata,ast::CompoundStmt * body)2458 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2459 const ParsedPatternMetadata &metadata,
2460 ast::CompoundStmt *body) {
2461 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2462 metadata.hasBoundedRecursion, body);
2463 }
2464
createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl * > results)2465 ast::Type Parser::createUserConstraintRewriteResultType(
2466 ArrayRef<ast::VariableDecl *> results) {
2467 // Single result decls use the type of the single result.
2468 if (results.size() == 1)
2469 return results[0]->getType();
2470
2471 // Multiple results use a tuple type, with the types and names grabbed from
2472 // the result variable decls.
2473 auto resultTypes = llvm::map_range(
2474 results, [&](const auto *result) { return result->getType(); });
2475 auto resultNames = llvm::map_range(
2476 results, [&](const auto *result) { return result->getName().getName(); });
2477 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2478 llvm::to_vector(resultNames));
2479 }
2480
2481 template <typename T>
createUserPDLLConstraintOrRewriteDecl(const ast::Name & name,ArrayRef<ast::VariableDecl * > arguments,ArrayRef<ast::VariableDecl * > results,ast::Type resultType,ast::CompoundStmt * body)2482 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2483 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2484 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2485 ast::CompoundStmt *body) {
2486 if (!body->getChildren().empty()) {
2487 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2488 ast::Expr *resultExpr = retStmt->getResultExpr();
2489
2490 // Process the result of the decl. If no explicit signature results
2491 // were provided, check for return type inference. Otherwise, check that
2492 // the return expression can be converted to the expected type.
2493 if (results.empty())
2494 resultType = resultExpr->getType();
2495 else if (failed(convertExpressionTo(resultExpr, resultType)))
2496 return failure();
2497 else
2498 retStmt->setResultExpr(resultExpr);
2499 }
2500 }
2501 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2502 }
2503
2504 FailureOr<ast::VariableDecl *>
createVariableDecl(StringRef name,SMRange loc,ast::Expr * initializer,ArrayRef<ast::ConstraintRef> constraints)2505 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2506 ArrayRef<ast::ConstraintRef> constraints) {
2507 // The type of the variable, which is expected to be inferred by either a
2508 // constraint or an initializer expression.
2509 ast::Type type;
2510 if (failed(validateVariableConstraints(constraints, type)))
2511 return failure();
2512
2513 if (initializer) {
2514 // Update the variable type based on the initializer, or try to convert the
2515 // initializer to the existing type.
2516 if (!type)
2517 type = initializer->getType();
2518 else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2519 type = mergedType;
2520 else if (failed(convertExpressionTo(initializer, type)))
2521 return failure();
2522
2523 // Otherwise, if there is no initializer check that the type has already
2524 // been resolved from the constraint list.
2525 } else if (!type) {
2526 return emitErrorAndNote(
2527 loc, "unable to infer type for variable `" + name + "`", loc,
2528 "the type of a variable must be inferable from the constraint "
2529 "list or the initializer");
2530 }
2531
2532 // Constraint types cannot be used when defining variables.
2533 if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2534 return emitError(
2535 loc, llvm::formatv("unable to define variable of `{0}` type", type));
2536 }
2537
2538 // Try to define a variable with the given name.
2539 FailureOr<ast::VariableDecl *> varDecl =
2540 defineVariableDecl(name, loc, type, initializer, constraints);
2541 if (failed(varDecl))
2542 return failure();
2543
2544 return *varDecl;
2545 }
2546
2547 FailureOr<ast::VariableDecl *>
createArgOrResultVariableDecl(StringRef name,SMRange loc,const ast::ConstraintRef & constraint)2548 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2549 const ast::ConstraintRef &constraint) {
2550 // Constraint arguments may apply more complex constraints via the arguments.
2551 bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2552 ast::Type argType;
2553 if (failed(validateVariableConstraint(constraint, argType,
2554 allowNonCoreConstraints)))
2555 return failure();
2556 return defineVariableDecl(name, loc, argType, constraint);
2557 }
2558
2559 LogicalResult
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,ast::Type & inferredType,bool allowNonCoreConstraints)2560 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2561 ast::Type &inferredType,
2562 bool allowNonCoreConstraints) {
2563 for (const ast::ConstraintRef &ref : constraints)
2564 if (failed(validateVariableConstraint(ref, inferredType,
2565 allowNonCoreConstraints)))
2566 return failure();
2567 return success();
2568 }
2569
validateVariableConstraint(const ast::ConstraintRef & ref,ast::Type & inferredType,bool allowNonCoreConstraints)2570 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2571 ast::Type &inferredType,
2572 bool allowNonCoreConstraints) {
2573 ast::Type constraintType;
2574 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2575 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2576 if (failed(validateTypeConstraintExpr(typeExpr)))
2577 return failure();
2578 }
2579 constraintType = ast::AttributeType::get(ctx);
2580 } else if (const auto *cst =
2581 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2582 constraintType = ast::OperationType::get(
2583 ctx, cst->getName(), lookupODSOperation(cst->getName()));
2584 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2585 constraintType = typeTy;
2586 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2587 constraintType = typeRangeTy;
2588 } else if (const auto *cst =
2589 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2590 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2591 if (failed(validateTypeConstraintExpr(typeExpr)))
2592 return failure();
2593 }
2594 constraintType = valueTy;
2595 } else if (const auto *cst =
2596 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2597 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2598 if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2599 return failure();
2600 }
2601 constraintType = valueRangeTy;
2602 } else if (const auto *cst =
2603 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2604 if (!allowNonCoreConstraints) {
2605 return emitError(ref.referenceLoc,
2606 "`Rewrite` arguments and results are only permitted to "
2607 "use core constraints, such as `Attr`, `Op`, `Type`, "
2608 "`TypeRange`, `Value`, `ValueRange`");
2609 }
2610
2611 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2612 if (inputs.size() != 1) {
2613 return emitErrorAndNote(ref.referenceLoc,
2614 "`Constraint`s applied via a variable constraint "
2615 "list must take a single input, but got " +
2616 Twine(inputs.size()),
2617 cst->getLoc(),
2618 "see definition of constraint here");
2619 }
2620 constraintType = inputs.front()->getType();
2621 } else {
2622 llvm_unreachable("unknown constraint type");
2623 }
2624
2625 // Check that the constraint type is compatible with the current inferred
2626 // type.
2627 if (!inferredType) {
2628 inferredType = constraintType;
2629 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2630 inferredType = mergedTy;
2631 } else {
2632 return emitError(ref.referenceLoc,
2633 llvm::formatv("constraint type `{0}` is incompatible "
2634 "with the previously inferred type `{1}`",
2635 constraintType, inferredType));
2636 }
2637 return success();
2638 }
2639
validateTypeConstraintExpr(const ast::Expr * typeExpr)2640 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2641 ast::Type typeExprType = typeExpr->getType();
2642 if (typeExprType != typeTy) {
2643 return emitError(typeExpr->getLoc(),
2644 "expected expression of `Type` in type constraint");
2645 }
2646 return success();
2647 }
2648
2649 LogicalResult
validateTypeRangeConstraintExpr(const ast::Expr * typeExpr)2650 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2651 ast::Type typeExprType = typeExpr->getType();
2652 if (typeExprType != typeRangeTy) {
2653 return emitError(typeExpr->getLoc(),
2654 "expected expression of `TypeRange` in type constraint");
2655 }
2656 return success();
2657 }
2658
2659 //===----------------------------------------------------------------------===//
2660 // Exprs
2661
2662 FailureOr<ast::CallExpr *>
createCallExpr(SMRange loc,ast::Expr * parentExpr,MutableArrayRef<ast::Expr * > arguments)2663 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2664 MutableArrayRef<ast::Expr *> arguments) {
2665 ast::Type parentType = parentExpr->getType();
2666
2667 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2668 if (!callableDecl) {
2669 return emitError(loc,
2670 llvm::formatv("expected a reference to a callable "
2671 "`Constraint` or `Rewrite`, but got: `{0}`",
2672 parentType));
2673 }
2674 if (parserContext == ParserContext::Rewrite) {
2675 if (isa<ast::UserConstraintDecl>(callableDecl))
2676 return emitError(
2677 loc, "unable to invoke `Constraint` within a rewrite section");
2678 } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2679 return emitError(loc, "unable to invoke `Rewrite` within a match section");
2680 }
2681
2682 // Verify the arguments of the call.
2683 /// Handle size mismatch.
2684 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2685 if (callArgs.size() != arguments.size()) {
2686 return emitErrorAndNote(
2687 loc,
2688 llvm::formatv("invalid number of arguments for {0} call; expected "
2689 "{1}, but got {2}",
2690 callableDecl->getCallableType(), callArgs.size(),
2691 arguments.size()),
2692 callableDecl->getLoc(),
2693 llvm::formatv("see the definition of {0} here",
2694 callableDecl->getName()->getName()));
2695 }
2696
2697 /// Handle argument type mismatch.
2698 auto attachDiagFn = [&](ast::Diagnostic &diag) {
2699 diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2700 callableDecl->getName()->getName()),
2701 callableDecl->getLoc());
2702 };
2703 for (auto it : llvm::zip(callArgs, arguments)) {
2704 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2705 attachDiagFn)))
2706 return failure();
2707 }
2708
2709 return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2710 callableDecl->getResultType());
2711 }
2712
createDeclRefExpr(SMRange loc,ast::Decl * decl)2713 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2714 ast::Decl *decl) {
2715 // Check the type of decl being referenced.
2716 ast::Type declType;
2717 if (isa<ast::ConstraintDecl>(decl))
2718 declType = ast::ConstraintType::get(ctx);
2719 else if (isa<ast::UserRewriteDecl>(decl))
2720 declType = ast::RewriteType::get(ctx);
2721 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2722 declType = varDecl->getType();
2723 else
2724 return emitError(loc, "invalid reference to `" +
2725 decl->getName()->getName() + "`");
2726
2727 return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2728 }
2729
2730 FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type,StringRef name,SMRange loc,ArrayRef<ast::ConstraintRef> constraints)2731 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2732 ArrayRef<ast::ConstraintRef> constraints) {
2733 FailureOr<ast::VariableDecl *> decl =
2734 defineVariableDecl(name, loc, type, constraints);
2735 if (failed(decl))
2736 return failure();
2737 return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2738 }
2739
2740 FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr * parentExpr,StringRef name,SMRange loc)2741 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2742 SMRange loc) {
2743 // Validate the member name for the given parent expression.
2744 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2745 if (failed(memberType))
2746 return failure();
2747
2748 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2749 }
2750
validateMemberAccess(ast::Expr * parentExpr,StringRef name,SMRange loc)2751 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2752 StringRef name, SMRange loc) {
2753 ast::Type parentType = parentExpr->getType();
2754 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2755 if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2756 return valueRangeTy;
2757
2758 // Verify member access based on the operation type.
2759 if (const ods::Operation *odsOp = opType.getODSOperation()) {
2760 auto results = odsOp->getResults();
2761
2762 // Handle indexed results.
2763 unsigned index = 0;
2764 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2765 index < results.size()) {
2766 return results[index].isVariadic() ? valueRangeTy : valueTy;
2767 }
2768
2769 // Handle named results.
2770 const auto *it = llvm::find_if(results, [&](const auto &result) {
2771 return result.getName() == name;
2772 });
2773 if (it != results.end())
2774 return it->isVariadic() ? valueRangeTy : valueTy;
2775 } else if (llvm::isDigit(name[0])) {
2776 // Allow unchecked numeric indexing of the results of unregistered
2777 // operations. It returns a single value.
2778 return valueTy;
2779 }
2780 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2781 // Handle indexed results.
2782 unsigned index = 0;
2783 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2784 index < tupleType.size()) {
2785 return tupleType.getElementTypes()[index];
2786 }
2787
2788 // Handle named results.
2789 auto elementNames = tupleType.getElementNames();
2790 const auto *it = llvm::find(elementNames, name);
2791 if (it != elementNames.end())
2792 return tupleType.getElementTypes()[it - elementNames.begin()];
2793 }
2794 return emitError(
2795 loc,
2796 llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2797 name, parentType));
2798 }
2799
createOperationExpr(SMRange loc,const ast::OpNameDecl * name,OpResultTypeContext resultTypeContext,MutableArrayRef<ast::Expr * > operands,MutableArrayRef<ast::NamedAttributeDecl * > attributes,MutableArrayRef<ast::Expr * > results)2800 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2801 SMRange loc, const ast::OpNameDecl *name,
2802 OpResultTypeContext resultTypeContext,
2803 MutableArrayRef<ast::Expr *> operands,
2804 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2805 MutableArrayRef<ast::Expr *> results) {
2806 Optional<StringRef> opNameRef = name->getName();
2807 const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2808
2809 // Verify the inputs operands.
2810 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2811 return failure();
2812
2813 // Verify the attribute list.
2814 for (ast::NamedAttributeDecl *attr : attributes) {
2815 // Check for an attribute type, or a type awaiting resolution.
2816 ast::Type attrType = attr->getValue()->getType();
2817 if (!attrType.isa<ast::AttributeType>()) {
2818 return emitError(
2819 attr->getValue()->getLoc(),
2820 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2821 }
2822 }
2823
2824 assert(
2825 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2826 "unexpected inferrence when results were explicitly specified");
2827
2828 // If we aren't relying on type inferrence, or explicit results were provided,
2829 // validate them.
2830 if (resultTypeContext == OpResultTypeContext::Explicit) {
2831 if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2832 return failure();
2833
2834 // Validate the use of interface based type inferrence for this operation.
2835 } else if (resultTypeContext == OpResultTypeContext::Interface) {
2836 assert(opNameRef &&
2837 "expected valid operation name when inferring operation results");
2838 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2839 }
2840
2841 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2842 attributes);
2843 }
2844
2845 LogicalResult
validateOperationOperands(SMRange loc,Optional<StringRef> name,const ods::Operation * odsOp,MutableArrayRef<ast::Expr * > operands)2846 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2847 const ods::Operation *odsOp,
2848 MutableArrayRef<ast::Expr *> operands) {
2849 return validateOperationOperandsOrResults(
2850 "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2851 operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2852 valueRangeTy);
2853 }
2854
2855 LogicalResult
validateOperationResults(SMRange loc,Optional<StringRef> name,const ods::Operation * odsOp,MutableArrayRef<ast::Expr * > results)2856 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2857 const ods::Operation *odsOp,
2858 MutableArrayRef<ast::Expr *> results) {
2859 return validateOperationOperandsOrResults(
2860 "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2861 results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2862 }
2863
checkOperationResultTypeInferrence(SMRange loc,StringRef opName,const ods::Operation * odsOp)2864 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2865 const ods::Operation *odsOp) {
2866 // If the operation might not have inferrence support, emit a warning to the
2867 // user. We don't emit an error because the interface might be added to the
2868 // operation at runtime. It's rare, but it could still happen. We emit a
2869 // warning here instead.
2870
2871 // Handle inferrence warnings for unknown operations.
2872 if (!odsOp) {
2873 ctx.getDiagEngine().emitWarning(
2874 loc, llvm::formatv(
2875 "operation result types are marked to be inferred, but "
2876 "`{0}` is unknown. Ensure that `{0}` supports zero "
2877 "results or implements `InferTypeOpInterface`. Include "
2878 "the ODS definition of this operation to remove this warning.",
2879 opName));
2880 return;
2881 }
2882
2883 // Handle inferrence warnings for known operations that expected at least one
2884 // result, but don't have inference support. An elided results list can mean
2885 // "zero-results", and we don't want to warn when that is the expected
2886 // behavior.
2887 bool requiresInferrence =
2888 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2889 return !result.isVariableLength();
2890 });
2891 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2892 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2893 loc,
2894 llvm::formatv("operation result types are marked to be inferred, but "
2895 "`{0}` does not provide an implementation of "
2896 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2897 "`InferTypeOpInterface` at runtime, or add support to "
2898 "the ODS definition to remove this warning.",
2899 opName));
2900 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2901 odsOp->getLoc());
2902 return;
2903 }
2904 }
2905
validateOperationOperandsOrResults(StringRef groupName,SMRange loc,Optional<SMRange> odsOpLoc,Optional<StringRef> name,MutableArrayRef<ast::Expr * > values,ArrayRef<ods::OperandOrResult> odsValues,ast::Type singleTy,ast::Type rangeTy)2906 LogicalResult Parser::validateOperationOperandsOrResults(
2907 StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2908 Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
2909 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2910 ast::Type rangeTy) {
2911 // All operation types accept a single range parameter.
2912 if (values.size() == 1) {
2913 if (failed(convertExpressionTo(values[0], rangeTy)))
2914 return failure();
2915 return success();
2916 }
2917
2918 /// If the operation has ODS information, we can more accurately verify the
2919 /// values.
2920 if (odsOpLoc) {
2921 if (odsValues.size() != values.size()) {
2922 return emitErrorAndNote(
2923 loc,
2924 llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2925 "{2}, but got {3}",
2926 groupName, *name, odsValues.size(), values.size()),
2927 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2928 }
2929 auto diagFn = [&](ast::Diagnostic &diag) {
2930 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2931 *odsOpLoc);
2932 };
2933 for (unsigned i = 0, e = values.size(); i < e; ++i) {
2934 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2935 if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2936 return failure();
2937 }
2938 return success();
2939 }
2940
2941 // Otherwise, accept the value groups as they have been defined and just
2942 // ensure they are one of the expected types.
2943 for (ast::Expr *&valueExpr : values) {
2944 ast::Type valueExprType = valueExpr->getType();
2945
2946 // Check if this is one of the expected types.
2947 if (valueExprType == rangeTy || valueExprType == singleTy)
2948 continue;
2949
2950 // If the operand is an Operation, allow converting to a Value or
2951 // ValueRange. This situations arises quite often with nested operation
2952 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2953 if (singleTy == valueTy) {
2954 if (valueExprType.isa<ast::OperationType>()) {
2955 valueExpr = convertOpToValue(valueExpr);
2956 continue;
2957 }
2958 }
2959
2960 return emitError(
2961 valueExpr->getLoc(),
2962 llvm::formatv(
2963 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2964 singleTy, rangeTy, valueExprType));
2965 }
2966 return success();
2967 }
2968
2969 FailureOr<ast::TupleExpr *>
createTupleExpr(SMRange loc,ArrayRef<ast::Expr * > elements,ArrayRef<StringRef> elementNames)2970 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2971 ArrayRef<StringRef> elementNames) {
2972 for (const ast::Expr *element : elements) {
2973 ast::Type eleTy = element->getType();
2974 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
2975 return emitError(
2976 element->getLoc(),
2977 llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2978 }
2979 }
2980 return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2981 }
2982
2983 //===----------------------------------------------------------------------===//
2984 // Stmts
2985
createEraseStmt(SMRange loc,ast::Expr * rootOp)2986 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2987 ast::Expr *rootOp) {
2988 // Check that root is an Operation.
2989 ast::Type rootType = rootOp->getType();
2990 if (!rootType.isa<ast::OperationType>())
2991 return emitError(rootOp->getLoc(), "expected `Op` expression");
2992
2993 return ast::EraseStmt::create(ctx, loc, rootOp);
2994 }
2995
2996 FailureOr<ast::ReplaceStmt *>
createReplaceStmt(SMRange loc,ast::Expr * rootOp,MutableArrayRef<ast::Expr * > replValues)2997 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2998 MutableArrayRef<ast::Expr *> replValues) {
2999 // Check that root is an Operation.
3000 ast::Type rootType = rootOp->getType();
3001 if (!rootType.isa<ast::OperationType>()) {
3002 return emitError(
3003 rootOp->getLoc(),
3004 llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3005 }
3006
3007 // If there are multiple replacement values, we implicitly convert any Op
3008 // expressions to the value form.
3009 bool shouldConvertOpToValues = replValues.size() > 1;
3010 for (ast::Expr *&replExpr : replValues) {
3011 ast::Type replType = replExpr->getType();
3012
3013 // Check that replExpr is an Operation, Value, or ValueRange.
3014 if (replType.isa<ast::OperationType>()) {
3015 if (shouldConvertOpToValues)
3016 replExpr = convertOpToValue(replExpr);
3017 continue;
3018 }
3019
3020 if (replType != valueTy && replType != valueRangeTy) {
3021 return emitError(replExpr->getLoc(),
3022 llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3023 "expression, but got `{0}`",
3024 replType));
3025 }
3026 }
3027
3028 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3029 }
3030
3031 FailureOr<ast::RewriteStmt *>
createRewriteStmt(SMRange loc,ast::Expr * rootOp,ast::CompoundStmt * rewriteBody)3032 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3033 ast::CompoundStmt *rewriteBody) {
3034 // Check that root is an Operation.
3035 ast::Type rootType = rootOp->getType();
3036 if (!rootType.isa<ast::OperationType>()) {
3037 return emitError(
3038 rootOp->getLoc(),
3039 llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3040 }
3041
3042 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3043 }
3044
3045 //===----------------------------------------------------------------------===//
3046 // Code Completion
3047 //===----------------------------------------------------------------------===//
3048
codeCompleteMemberAccess(ast::Expr * parentExpr)3049 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3050 ast::Type parentType = parentExpr->getType();
3051 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3052 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3053 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3054 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3055 return failure();
3056 }
3057
codeCompleteAttributeName(Optional<StringRef> opName)3058 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
3059 if (opName)
3060 codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3061 return failure();
3062 }
3063
3064 LogicalResult
codeCompleteConstraintName(ast::Type inferredType,bool allowNonCoreConstraints,bool allowInlineTypeConstraints)3065 Parser::codeCompleteConstraintName(ast::Type inferredType,
3066 bool allowNonCoreConstraints,
3067 bool allowInlineTypeConstraints) {
3068 codeCompleteContext->codeCompleteConstraintName(
3069 inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
3070 curDeclScope);
3071 return failure();
3072 }
3073
codeCompleteDialectName()3074 LogicalResult Parser::codeCompleteDialectName() {
3075 codeCompleteContext->codeCompleteDialectName();
3076 return failure();
3077 }
3078
codeCompleteOperationName(StringRef dialectName)3079 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3080 codeCompleteContext->codeCompleteOperationName(dialectName);
3081 return failure();
3082 }
3083
codeCompletePatternMetadata()3084 LogicalResult Parser::codeCompletePatternMetadata() {
3085 codeCompleteContext->codeCompletePatternMetadata();
3086 return failure();
3087 }
3088
codeCompleteIncludeFilename(StringRef curPath)3089 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3090 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3091 return failure();
3092 }
3093
codeCompleteCallSignature(ast::Node * parent,unsigned currentNumArgs)3094 void Parser::codeCompleteCallSignature(ast::Node *parent,
3095 unsigned currentNumArgs) {
3096 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3097 if (!callableDecl)
3098 return;
3099
3100 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3101 }
3102
codeCompleteOperationOperandsSignature(Optional<StringRef> opName,unsigned currentNumOperands)3103 void Parser::codeCompleteOperationOperandsSignature(
3104 Optional<StringRef> opName, unsigned currentNumOperands) {
3105 codeCompleteContext->codeCompleteOperationOperandsSignature(
3106 opName, currentNumOperands);
3107 }
3108
codeCompleteOperationResultsSignature(Optional<StringRef> opName,unsigned currentNumResults)3109 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
3110 unsigned currentNumResults) {
3111 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3112 currentNumResults);
3113 }
3114
3115 //===----------------------------------------------------------------------===//
3116 // Parser
3117 //===----------------------------------------------------------------------===//
3118
3119 FailureOr<ast::Module *>
parsePDLLAST(ast::Context & ctx,llvm::SourceMgr & sourceMgr,bool enableDocumentation,CodeCompleteContext * codeCompleteContext)3120 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3121 bool enableDocumentation,
3122 CodeCompleteContext *codeCompleteContext) {
3123 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3124 return parser.parseModule();
3125 }
3126