1 //===- AsmParserImpl.h - MLIR AsmParserImpl Class ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_LIB_ASMPARSER_ASMPARSERIMPL_H
10 #define MLIR_LIB_ASMPARSER_ASMPARSERIMPL_H
11 
12 #include "Parser.h"
13 #include "mlir/AsmParser/AsmParserState.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/OpImplementation.h"
16 
17 namespace mlir {
18 namespace detail {
19 //===----------------------------------------------------------------------===//
20 // AsmParserImpl
21 //===----------------------------------------------------------------------===//
22 
23 /// This class provides the implementation of the generic parser methods within
24 /// AsmParser.
25 template <typename BaseT>
26 class AsmParserImpl : public BaseT {
27 public:
AsmParserImpl(SMLoc nameLoc,Parser & parser)28   AsmParserImpl(SMLoc nameLoc, Parser &parser)
29       : nameLoc(nameLoc), parser(parser) {}
30   ~AsmParserImpl() override = default;
31 
32   /// Return the location of the original name token.
getNameLoc()33   SMLoc getNameLoc() const override { return nameLoc; }
34 
35   //===--------------------------------------------------------------------===//
36   // Utilities
37   //===--------------------------------------------------------------------===//
38 
39   /// Return if any errors were emitted during parsing.
didEmitError()40   bool didEmitError() const { return emittedError; }
41 
42   /// Emit a diagnostic at the specified location and return failure.
emitError(SMLoc loc,const Twine & message)43   InFlightDiagnostic emitError(SMLoc loc, const Twine &message) override {
44     emittedError = true;
45     return parser.emitError(loc, message);
46   }
47 
48   /// Return a builder which provides useful access to MLIRContext, global
49   /// objects like types and attributes.
getBuilder()50   Builder &getBuilder() const override { return parser.builder; }
51 
52   /// Get the location of the next token and store it into the argument.  This
53   /// always succeeds.
getCurrentLocation()54   SMLoc getCurrentLocation() override { return parser.getToken().getLoc(); }
55 
56   /// Re-encode the given source location as an MLIR location and return it.
getEncodedSourceLoc(SMLoc loc)57   Location getEncodedSourceLoc(SMLoc loc) override {
58     return parser.getEncodedSourceLocation(loc);
59   }
60 
61   //===--------------------------------------------------------------------===//
62   // Token Parsing
63   //===--------------------------------------------------------------------===//
64 
65   using Delimiter = AsmParser::Delimiter;
66 
67   /// Parse a `->` token.
parseArrow()68   ParseResult parseArrow() override {
69     return parser.parseToken(Token::arrow, "expected '->'");
70   }
71 
72   /// Parses a `->` if present.
parseOptionalArrow()73   ParseResult parseOptionalArrow() override {
74     return success(parser.consumeIf(Token::arrow));
75   }
76 
77   /// Parse a '{' token.
parseLBrace()78   ParseResult parseLBrace() override {
79     return parser.parseToken(Token::l_brace, "expected '{'");
80   }
81 
82   /// Parse a '{' token if present
parseOptionalLBrace()83   ParseResult parseOptionalLBrace() override {
84     return success(parser.consumeIf(Token::l_brace));
85   }
86 
87   /// Parse a `}` token.
parseRBrace()88   ParseResult parseRBrace() override {
89     return parser.parseToken(Token::r_brace, "expected '}'");
90   }
91 
92   /// Parse a `}` token if present
parseOptionalRBrace()93   ParseResult parseOptionalRBrace() override {
94     return success(parser.consumeIf(Token::r_brace));
95   }
96 
97   /// Parse a `:` token.
parseColon()98   ParseResult parseColon() override {
99     return parser.parseToken(Token::colon, "expected ':'");
100   }
101 
102   /// Parse a `:` token if present.
parseOptionalColon()103   ParseResult parseOptionalColon() override {
104     return success(parser.consumeIf(Token::colon));
105   }
106 
107   /// Parse a `,` token.
parseComma()108   ParseResult parseComma() override {
109     return parser.parseToken(Token::comma, "expected ','");
110   }
111 
112   /// Parse a `,` token if present.
parseOptionalComma()113   ParseResult parseOptionalComma() override {
114     return success(parser.consumeIf(Token::comma));
115   }
116 
117   /// Parses a `...` if present.
parseOptionalEllipsis()118   ParseResult parseOptionalEllipsis() override {
119     return success(parser.consumeIf(Token::ellipsis));
120   }
121 
122   /// Parse a `=` token.
parseEqual()123   ParseResult parseEqual() override {
124     return parser.parseToken(Token::equal, "expected '='");
125   }
126 
127   /// Parse a `=` token if present.
parseOptionalEqual()128   ParseResult parseOptionalEqual() override {
129     return success(parser.consumeIf(Token::equal));
130   }
131 
132   /// Parse a '<' token.
parseLess()133   ParseResult parseLess() override {
134     return parser.parseToken(Token::less, "expected '<'");
135   }
136 
137   /// Parse a `<` token if present.
parseOptionalLess()138   ParseResult parseOptionalLess() override {
139     return success(parser.consumeIf(Token::less));
140   }
141 
142   /// Parse a '>' token.
parseGreater()143   ParseResult parseGreater() override {
144     return parser.parseToken(Token::greater, "expected '>'");
145   }
146 
147   /// Parse a `>` token if present.
parseOptionalGreater()148   ParseResult parseOptionalGreater() override {
149     return success(parser.consumeIf(Token::greater));
150   }
151 
152   /// Parse a `(` token.
parseLParen()153   ParseResult parseLParen() override {
154     return parser.parseToken(Token::l_paren, "expected '('");
155   }
156 
157   /// Parses a '(' if present.
parseOptionalLParen()158   ParseResult parseOptionalLParen() override {
159     return success(parser.consumeIf(Token::l_paren));
160   }
161 
162   /// Parse a `)` token.
parseRParen()163   ParseResult parseRParen() override {
164     return parser.parseToken(Token::r_paren, "expected ')'");
165   }
166 
167   /// Parses a ')' if present.
parseOptionalRParen()168   ParseResult parseOptionalRParen() override {
169     return success(parser.consumeIf(Token::r_paren));
170   }
171 
172   /// Parse a `[` token.
parseLSquare()173   ParseResult parseLSquare() override {
174     return parser.parseToken(Token::l_square, "expected '['");
175   }
176 
177   /// Parses a '[' if present.
parseOptionalLSquare()178   ParseResult parseOptionalLSquare() override {
179     return success(parser.consumeIf(Token::l_square));
180   }
181 
182   /// Parse a `]` token.
parseRSquare()183   ParseResult parseRSquare() override {
184     return parser.parseToken(Token::r_square, "expected ']'");
185   }
186 
187   /// Parses a ']' if present.
parseOptionalRSquare()188   ParseResult parseOptionalRSquare() override {
189     return success(parser.consumeIf(Token::r_square));
190   }
191 
192   /// Parses a '?' token.
parseQuestion()193   ParseResult parseQuestion() override {
194     return parser.parseToken(Token::question, "expected '?'");
195   }
196 
197   /// Parses a '?' if present.
parseOptionalQuestion()198   ParseResult parseOptionalQuestion() override {
199     return success(parser.consumeIf(Token::question));
200   }
201 
202   /// Parses a '*' token.
parseStar()203   ParseResult parseStar() override {
204     return parser.parseToken(Token::star, "expected '*'");
205   }
206 
207   /// Parses a '*' if present.
parseOptionalStar()208   ParseResult parseOptionalStar() override {
209     return success(parser.consumeIf(Token::star));
210   }
211 
212   /// Parses a '+' token.
parsePlus()213   ParseResult parsePlus() override {
214     return parser.parseToken(Token::plus, "expected '+'");
215   }
216 
217   /// Parses a '+' token if present.
parseOptionalPlus()218   ParseResult parseOptionalPlus() override {
219     return success(parser.consumeIf(Token::plus));
220   }
221 
222   /// Parse a '|' token.
parseVerticalBar()223   ParseResult parseVerticalBar() override {
224     return parser.parseToken(Token::vertical_bar, "expected '|'");
225   }
226 
227   /// Parse a '|' token if present.
parseOptionalVerticalBar()228   ParseResult parseOptionalVerticalBar() override {
229     return success(parser.consumeIf(Token::vertical_bar));
230   }
231 
232   /// Parses a quoted string token if present.
parseOptionalString(std::string * string)233   ParseResult parseOptionalString(std::string *string) override {
234     if (!parser.getToken().is(Token::string))
235       return failure();
236 
237     if (string)
238       *string = parser.getToken().getStringValue();
239     parser.consumeToken();
240     return success();
241   }
242 
243   /// Parse a floating point value from the stream.
parseFloat(double & result)244   ParseResult parseFloat(double &result) override {
245     bool isNegative = parser.consumeIf(Token::minus);
246     Token curTok = parser.getToken();
247     SMLoc loc = curTok.getLoc();
248 
249     // Check for a floating point value.
250     if (curTok.is(Token::floatliteral)) {
251       auto val = curTok.getFloatingPointValue();
252       if (!val)
253         return emitError(loc, "floating point value too large");
254       parser.consumeToken(Token::floatliteral);
255       result = isNegative ? -*val : *val;
256       return success();
257     }
258 
259     // Check for a hexadecimal float value.
260     if (curTok.is(Token::integer)) {
261       Optional<APFloat> apResult;
262       if (failed(parser.parseFloatFromIntegerLiteral(
263               apResult, curTok, isNegative, APFloat::IEEEdouble(),
264               /*typeSizeInBits=*/64)))
265         return failure();
266 
267       parser.consumeToken(Token::integer);
268       result = apResult->convertToDouble();
269       return success();
270     }
271 
272     return emitError(loc, "expected floating point literal");
273   }
274 
275   /// Parse an optional integer value from the stream.
parseOptionalInteger(APInt & result)276   OptionalParseResult parseOptionalInteger(APInt &result) override {
277     return parser.parseOptionalInteger(result);
278   }
279 
280   /// Parse a list of comma-separated items with an optional delimiter.  If a
281   /// delimiter is provided, then an empty list is allowed.  If not, then at
282   /// least one element will be parsed.
parseCommaSeparatedList(Delimiter delimiter,function_ref<ParseResult ()> parseElt,StringRef contextMessage)283   ParseResult parseCommaSeparatedList(Delimiter delimiter,
284                                       function_ref<ParseResult()> parseElt,
285                                       StringRef contextMessage) override {
286     return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
287   }
288 
289   //===--------------------------------------------------------------------===//
290   // Keyword Parsing
291   //===--------------------------------------------------------------------===//
292 
parseKeyword(StringRef keyword,const Twine & msg)293   ParseResult parseKeyword(StringRef keyword, const Twine &msg) override {
294     if (parser.getToken().isCodeCompletion())
295       return parser.codeCompleteExpectedTokens(keyword);
296 
297     auto loc = getCurrentLocation();
298     if (parseOptionalKeyword(keyword))
299       return emitError(loc, "expected '") << keyword << "'" << msg;
300     return success();
301   }
302   using AsmParser::parseKeyword;
303 
304   /// Parse the given keyword if present.
parseOptionalKeyword(StringRef keyword)305   ParseResult parseOptionalKeyword(StringRef keyword) override {
306     if (parser.getToken().isCodeCompletion())
307       return parser.codeCompleteOptionalTokens(keyword);
308 
309     // Check that the current token has the same spelling.
310     if (!parser.isCurrentTokenAKeyword() ||
311         parser.getTokenSpelling() != keyword)
312       return failure();
313     parser.consumeToken();
314     return success();
315   }
316 
317   /// Parse a keyword, if present, into 'keyword'.
parseOptionalKeyword(StringRef * keyword)318   ParseResult parseOptionalKeyword(StringRef *keyword) override {
319     // Check that the current token is a keyword.
320     if (!parser.isCurrentTokenAKeyword())
321       return failure();
322 
323     *keyword = parser.getTokenSpelling();
324     parser.consumeToken();
325     return success();
326   }
327 
328   /// Parse a keyword if it is one of the 'allowedKeywords'.
329   ParseResult
parseOptionalKeyword(StringRef * keyword,ArrayRef<StringRef> allowedKeywords)330   parseOptionalKeyword(StringRef *keyword,
331                        ArrayRef<StringRef> allowedKeywords) override {
332     if (parser.getToken().isCodeCompletion())
333       return parser.codeCompleteOptionalTokens(allowedKeywords);
334 
335     // Check that the current token is a keyword.
336     if (!parser.isCurrentTokenAKeyword())
337       return failure();
338 
339     StringRef currentKeyword = parser.getTokenSpelling();
340     if (llvm::is_contained(allowedKeywords, currentKeyword)) {
341       *keyword = currentKeyword;
342       parser.consumeToken();
343       return success();
344     }
345 
346     return failure();
347   }
348 
349   /// Parse an optional keyword or string and set instance into 'result'.`
parseOptionalKeywordOrString(std::string * result)350   ParseResult parseOptionalKeywordOrString(std::string *result) override {
351     StringRef keyword;
352     if (succeeded(parseOptionalKeyword(&keyword))) {
353       *result = keyword.str();
354       return success();
355     }
356 
357     return parseOptionalString(result);
358   }
359 
360   //===--------------------------------------------------------------------===//
361   // Attribute Parsing
362   //===--------------------------------------------------------------------===//
363 
364   /// Parse an arbitrary attribute and return it in result.
parseAttribute(Attribute & result,Type type)365   ParseResult parseAttribute(Attribute &result, Type type) override {
366     result = parser.parseAttribute(type);
367     return success(static_cast<bool>(result));
368   }
369 
370   /// Parse a custom attribute with the provided callback, unless the next
371   /// token is `#`, in which case the generic parser is invoked.
parseCustomAttributeWithFallback(Attribute & result,Type type,function_ref<ParseResult (Attribute & result,Type type)> parseAttribute)372   ParseResult parseCustomAttributeWithFallback(
373       Attribute &result, Type type,
374       function_ref<ParseResult(Attribute &result, Type type)> parseAttribute)
375       override {
376     if (parser.getToken().isNot(Token::hash_identifier))
377       return parseAttribute(result, type);
378     result = parser.parseAttribute(type);
379     return success(static_cast<bool>(result));
380   }
381 
382   /// Parse a custom attribute with the provided callback, unless the next
383   /// token is `#`, in which case the generic parser is invoked.
parseCustomTypeWithFallback(Type & result,function_ref<ParseResult (Type & result)> parseType)384   ParseResult parseCustomTypeWithFallback(
385       Type &result,
386       function_ref<ParseResult(Type &result)> parseType) override {
387     if (parser.getToken().isNot(Token::exclamation_identifier))
388       return parseType(result);
389     result = parser.parseType();
390     return success(static_cast<bool>(result));
391   }
392 
parseOptionalAttribute(Attribute & result,Type type)393   OptionalParseResult parseOptionalAttribute(Attribute &result,
394                                              Type type) override {
395     return parser.parseOptionalAttribute(result, type);
396   }
parseOptionalAttribute(ArrayAttr & result,Type type)397   OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
398                                              Type type) override {
399     return parser.parseOptionalAttribute(result, type);
400   }
parseOptionalAttribute(StringAttr & result,Type type)401   OptionalParseResult parseOptionalAttribute(StringAttr &result,
402                                              Type type) override {
403     return parser.parseOptionalAttribute(result, type);
404   }
405 
406   /// Parse a named dictionary into 'result' if it is present.
parseOptionalAttrDict(NamedAttrList & result)407   ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
408     if (parser.getToken().isNot(Token::l_brace))
409       return success();
410     return parser.parseAttributeDict(result);
411   }
412 
413   /// Parse a named dictionary into 'result' if the `attributes` keyword is
414   /// present.
parseOptionalAttrDictWithKeyword(NamedAttrList & result)415   ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
416     if (failed(parseOptionalKeyword("attributes")))
417       return success();
418     return parser.parseAttributeDict(result);
419   }
420 
421   /// Parse an affine map instance into 'map'.
parseAffineMap(AffineMap & map)422   ParseResult parseAffineMap(AffineMap &map) override {
423     return parser.parseAffineMapReference(map);
424   }
425 
426   /// Parse an integer set instance into 'set'.
printIntegerSet(IntegerSet & set)427   ParseResult printIntegerSet(IntegerSet &set) override {
428     return parser.parseIntegerSetReference(set);
429   }
430 
431   //===--------------------------------------------------------------------===//
432   // Identifier Parsing
433   //===--------------------------------------------------------------------===//
434 
435   /// Parse an optional @-identifier and store it (without the '@' symbol) in a
436   /// string attribute named 'attrName'.
parseOptionalSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)437   ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
438                                       NamedAttrList &attrs) override {
439     Token atToken = parser.getToken();
440     if (atToken.isNot(Token::at_identifier))
441       return failure();
442 
443     result = getBuilder().getStringAttr(atToken.getSymbolReference());
444     attrs.push_back(getBuilder().getNamedAttr(attrName, result));
445     parser.consumeToken();
446 
447     // If we are populating the assembly parser state, record this as a symbol
448     // reference.
449     if (parser.getState().asmState) {
450       parser.getState().asmState->addUses(SymbolRefAttr::get(result),
451                                           atToken.getLocRange());
452     }
453     return success();
454   }
455 
456   //===--------------------------------------------------------------------===//
457   // Resource Parsing
458   //===--------------------------------------------------------------------===//
459 
460   /// Parse a handle to a resource within the assembly format.
461   FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect * dialect)462   parseResourceHandle(Dialect *dialect) override {
463     const auto *interface = dyn_cast_or_null<OpAsmDialectInterface>(dialect);
464     if (!interface) {
465       return parser.emitError() << "dialect '" << dialect->getNamespace()
466                                 << "' does not expect resource handles";
467     }
468     StringRef resourceName;
469     return parser.parseResourceHandle(interface, resourceName);
470   }
471 
472   //===--------------------------------------------------------------------===//
473   // Type Parsing
474   //===--------------------------------------------------------------------===//
475 
476   /// Parse a type.
parseType(Type & result)477   ParseResult parseType(Type &result) override {
478     return failure(!(result = parser.parseType()));
479   }
480 
481   /// Parse an optional type.
parseOptionalType(Type & result)482   OptionalParseResult parseOptionalType(Type &result) override {
483     return parser.parseOptionalType(result);
484   }
485 
486   /// Parse an arrow followed by a type list.
parseArrowTypeList(SmallVectorImpl<Type> & result)487   ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
488     if (parseArrow() || parser.parseFunctionResultTypes(result))
489       return failure();
490     return success();
491   }
492 
493   /// Parse an optional arrow followed by a type list.
494   ParseResult
parseOptionalArrowTypeList(SmallVectorImpl<Type> & result)495   parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
496     if (!parser.consumeIf(Token::arrow))
497       return success();
498     return parser.parseFunctionResultTypes(result);
499   }
500 
501   /// Parse a colon followed by a type.
parseColonType(Type & result)502   ParseResult parseColonType(Type &result) override {
503     return failure(parser.parseToken(Token::colon, "expected ':'") ||
504                    !(result = parser.parseType()));
505   }
506 
507   /// Parse a colon followed by a type list, which must have at least one type.
parseColonTypeList(SmallVectorImpl<Type> & result)508   ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
509     if (parser.parseToken(Token::colon, "expected ':'"))
510       return failure();
511     return parser.parseTypeListNoParens(result);
512   }
513 
514   /// Parse an optional colon followed by a type list, which if present must
515   /// have at least one type.
516   ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> & result)517   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
518     if (!parser.consumeIf(Token::colon))
519       return success();
520     return parser.parseTypeListNoParens(result);
521   }
522 
parseDimensionList(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic,bool withTrailingX)523   ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
524                                  bool allowDynamic,
525                                  bool withTrailingX) override {
526     return parser.parseDimensionListRanked(dimensions, allowDynamic,
527                                            withTrailingX);
528   }
529 
parseXInDimensionList()530   ParseResult parseXInDimensionList() override {
531     return parser.parseXInDimensionList();
532   }
533 
534   //===--------------------------------------------------------------------===//
535   // Code Completion
536   //===--------------------------------------------------------------------===//
537 
538   /// Parse a keyword, or an empty string if the current location signals a code
539   /// completion.
parseKeywordOrCompletion(StringRef * keyword)540   ParseResult parseKeywordOrCompletion(StringRef *keyword) override {
541     Token tok = parser.getToken();
542     if (tok.isCodeCompletion() && tok.getSpelling().empty()) {
543       *keyword = "";
544       return success();
545     }
546     return parseKeyword(keyword);
547   }
548 
549   /// Signal the code completion of a set of expected tokens.
codeCompleteExpectedTokens(ArrayRef<StringRef> tokens)550   void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) override {
551     Token tok = parser.getToken();
552     if (tok.isCodeCompletion() && tok.getSpelling().empty())
553       (void)parser.codeCompleteExpectedTokens(tokens);
554   }
555 
556 protected:
557   /// The source location of the dialect symbol.
558   SMLoc nameLoc;
559 
560   /// The main parser.
561   Parser &parser;
562 
563   /// A flag that indicates if any errors were emitted during parsing.
564   bool emittedError = false;
565 };
566 } // namespace detail
567 } // namespace mlir
568 
569 #endif // MLIR_LIB_ASMPARSER_ASMPARSERIMPL_H
570