1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 
15 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/Endian.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 /// Parse an arbitrary attribute.
29 ///
30 ///  attribute-value ::= `unit`
31 ///                    | bool-literal
32 ///                    | integer-literal (`:` (index-type | integer-type))?
33 ///                    | float-literal (`:` float-type)?
34 ///                    | string-literal (`:` type)?
35 ///                    | type
36 ///                    | `[` `:` (integer-type | float-type) tensor-literal `]`
37 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
38 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
39 ///                    | symbol-ref-id (`::` symbol-ref-id)*
40 ///                    | `dense` `<` tensor-literal `>` `:`
41 ///                      (tensor-type | vector-type)
42 ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
43 ///                      `:` (tensor-type | vector-type)
44 ///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
45 ///                      `>` `:` (tensor-type | vector-type)
46 ///                    | extended-attribute
47 ///
parseAttribute(Type type)48 Attribute Parser::parseAttribute(Type type) {
49   switch (getToken().getKind()) {
50   // Parse an AffineMap or IntegerSet attribute.
51   case Token::kw_affine_map: {
52     consumeToken(Token::kw_affine_map);
53 
54     AffineMap map;
55     if (parseToken(Token::less, "expected '<' in affine map") ||
56         parseAffineMapReference(map) ||
57         parseToken(Token::greater, "expected '>' in affine map"))
58       return Attribute();
59     return AffineMapAttr::get(map);
60   }
61   case Token::kw_affine_set: {
62     consumeToken(Token::kw_affine_set);
63 
64     IntegerSet set;
65     if (parseToken(Token::less, "expected '<' in integer set") ||
66         parseIntegerSetReference(set) ||
67         parseToken(Token::greater, "expected '>' in integer set"))
68       return Attribute();
69     return IntegerSetAttr::get(set);
70   }
71 
72   // Parse an array attribute.
73   case Token::l_square: {
74     consumeToken(Token::l_square);
75     if (consumeIf(Token::colon))
76       return parseDenseArrayAttr();
77     SmallVector<Attribute, 4> elements;
78     auto parseElt = [&]() -> ParseResult {
79       elements.push_back(parseAttribute());
80       return elements.back() ? success() : failure();
81     };
82 
83     if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
84       return nullptr;
85     return builder.getArrayAttr(elements);
86   }
87 
88   // Parse a boolean attribute.
89   case Token::kw_false:
90     consumeToken(Token::kw_false);
91     return builder.getBoolAttr(false);
92   case Token::kw_true:
93     consumeToken(Token::kw_true);
94     return builder.getBoolAttr(true);
95 
96   // Parse a dense elements attribute.
97   case Token::kw_dense:
98     return parseDenseElementsAttr(type);
99 
100   // Parse a dictionary attribute.
101   case Token::l_brace: {
102     NamedAttrList elements;
103     if (parseAttributeDict(elements))
104       return nullptr;
105     return elements.getDictionary(getContext());
106   }
107 
108   // Parse an extended attribute, i.e. alias or dialect attribute.
109   case Token::hash_identifier:
110     return parseExtendedAttr(type);
111 
112   // Parse floating point and integer attributes.
113   case Token::floatliteral:
114     return parseFloatAttr(type, /*isNegative=*/false);
115   case Token::integer:
116     return parseDecOrHexAttr(type, /*isNegative=*/false);
117   case Token::minus: {
118     consumeToken(Token::minus);
119     if (getToken().is(Token::integer))
120       return parseDecOrHexAttr(type, /*isNegative=*/true);
121     if (getToken().is(Token::floatliteral))
122       return parseFloatAttr(type, /*isNegative=*/true);
123 
124     return (emitWrongTokenError(
125                 "expected constant integer or floating point value"),
126             nullptr);
127   }
128 
129   // Parse a location attribute.
130   case Token::kw_loc: {
131     consumeToken(Token::kw_loc);
132 
133     LocationAttr locAttr;
134     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
135         parseLocationInstance(locAttr) ||
136         parseToken(Token::r_paren, "expected ')' in inline location"))
137       return Attribute();
138     return locAttr;
139   }
140 
141   // Parse an opaque elements attribute.
142   case Token::kw_opaque:
143     return parseOpaqueElementsAttr(type);
144 
145   // Parse a sparse elements attribute.
146   case Token::kw_sparse:
147     return parseSparseElementsAttr(type);
148 
149   // Parse a string attribute.
150   case Token::string: {
151     auto val = getToken().getStringValue();
152     consumeToken(Token::string);
153     // Parse the optional trailing colon type if one wasn't explicitly provided.
154     if (!type && consumeIf(Token::colon) && !(type = parseType()))
155       return Attribute();
156 
157     return type ? StringAttr::get(val, type)
158                 : StringAttr::get(getContext(), val);
159   }
160 
161   // Parse a symbol reference attribute.
162   case Token::at_identifier: {
163     // When populating the parser state, this is a list of locations for all of
164     // the nested references.
165     SmallVector<SMRange> referenceLocations;
166     if (state.asmState)
167       referenceLocations.push_back(getToken().getLocRange());
168 
169     // Parse the top-level reference.
170     std::string nameStr = getToken().getSymbolReference();
171     consumeToken(Token::at_identifier);
172 
173     // Parse any nested references.
174     std::vector<FlatSymbolRefAttr> nestedRefs;
175     while (getToken().is(Token::colon)) {
176       // Check for the '::' prefix.
177       const char *curPointer = getToken().getLoc().getPointer();
178       consumeToken(Token::colon);
179       if (!consumeIf(Token::colon)) {
180         if (getToken().isNot(Token::eof, Token::error)) {
181           state.lex.resetPointer(curPointer);
182           consumeToken();
183         }
184         break;
185       }
186       // Parse the reference itself.
187       auto curLoc = getToken().getLoc();
188       if (getToken().isNot(Token::at_identifier)) {
189         emitError(curLoc, "expected nested symbol reference identifier");
190         return Attribute();
191       }
192 
193       // If we are populating the assembly state, add the location for this
194       // reference.
195       if (state.asmState)
196         referenceLocations.push_back(getToken().getLocRange());
197 
198       std::string nameStr = getToken().getSymbolReference();
199       consumeToken(Token::at_identifier);
200       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
201     }
202     SymbolRefAttr symbolRefAttr =
203         SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
204 
205     // If we are populating the assembly state, record this symbol reference.
206     if (state.asmState)
207       state.asmState->addUses(symbolRefAttr, referenceLocations);
208     return symbolRefAttr;
209   }
210 
211   // Parse a 'unit' attribute.
212   case Token::kw_unit:
213     consumeToken(Token::kw_unit);
214     return builder.getUnitAttr();
215 
216     // Handle completion of an attribute.
217   case Token::code_complete:
218     if (getToken().isCodeCompletionFor(Token::hash_identifier))
219       return parseExtendedAttr(type);
220     return codeCompleteAttribute();
221 
222   default:
223     // Parse a type attribute. We parse `Optional` here to allow for providing a
224     // better error message.
225     Type type;
226     OptionalParseResult result = parseOptionalType(type);
227     if (!result.hasValue())
228       return emitWrongTokenError("expected attribute value"), Attribute();
229     return failed(*result) ? Attribute() : TypeAttr::get(type);
230   }
231 }
232 
233 /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)234 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
235                                                    Type type) {
236   switch (getToken().getKind()) {
237   case Token::at_identifier:
238   case Token::floatliteral:
239   case Token::integer:
240   case Token::hash_identifier:
241   case Token::kw_affine_map:
242   case Token::kw_affine_set:
243   case Token::kw_dense:
244   case Token::kw_false:
245   case Token::kw_loc:
246   case Token::kw_opaque:
247   case Token::kw_sparse:
248   case Token::kw_true:
249   case Token::kw_unit:
250   case Token::l_brace:
251   case Token::l_square:
252   case Token::minus:
253   case Token::string:
254     attribute = parseAttribute(type);
255     return success(attribute != nullptr);
256 
257   default:
258     // Parse an optional type attribute.
259     Type type;
260     OptionalParseResult result = parseOptionalType(type);
261     if (result.hasValue() && succeeded(*result))
262       attribute = TypeAttr::get(type);
263     return result;
264   }
265 }
parseOptionalAttribute(ArrayAttr & attribute,Type type)266 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
267                                                    Type type) {
268   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
269 }
parseOptionalAttribute(StringAttr & attribute,Type type)270 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
271                                                    Type type) {
272   return parseOptionalAttributeWithToken(Token::string, attribute, type);
273 }
274 
275 /// Attribute dictionary.
276 ///
277 ///   attribute-dict ::= `{` `}`
278 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
279 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
280 ///
parseAttributeDict(NamedAttrList & attributes)281 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
282   llvm::SmallDenseSet<StringAttr> seenKeys;
283   auto parseElt = [&]() -> ParseResult {
284     // The name of an attribute can either be a bare identifier, or a string.
285     Optional<StringAttr> nameId;
286     if (getToken().is(Token::string))
287       nameId = builder.getStringAttr(getToken().getStringValue());
288     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
289              getToken().isKeyword())
290       nameId = builder.getStringAttr(getTokenSpelling());
291     else
292       return emitWrongTokenError("expected attribute name");
293 
294     if (nameId->size() == 0)
295       return emitError("expected valid attribute name");
296 
297     if (!seenKeys.insert(*nameId).second)
298       return emitError("duplicate key '")
299              << nameId->getValue() << "' in dictionary attribute";
300     consumeToken();
301 
302     // Lazy load a dialect in the context if there is a possible namespace.
303     auto splitName = nameId->strref().split('.');
304     if (!splitName.second.empty())
305       getContext()->getOrLoadDialect(splitName.first);
306 
307     // Try to parse the '=' for the attribute value.
308     if (!consumeIf(Token::equal)) {
309       // If there is no '=', we treat this as a unit attribute.
310       attributes.push_back({*nameId, builder.getUnitAttr()});
311       return success();
312     }
313 
314     auto attr = parseAttribute();
315     if (!attr)
316       return failure();
317     attributes.push_back({*nameId, attr});
318     return success();
319   };
320 
321   return parseCommaSeparatedList(Delimiter::Braces, parseElt,
322                                  " in attribute dictionary");
323 }
324 
325 /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)326 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
327   auto val = getToken().getFloatingPointValue();
328   if (!val)
329     return (emitError("floating point value too large for attribute"), nullptr);
330   consumeToken(Token::floatliteral);
331   if (!type) {
332     // Default to F64 when no type is specified.
333     if (!consumeIf(Token::colon))
334       type = builder.getF64Type();
335     else if (!(type = parseType()))
336       return nullptr;
337   }
338   if (!type.isa<FloatType>())
339     return (emitError("floating point value not valid for specified type"),
340             nullptr);
341   return FloatAttr::get(type, isNegative ? -*val : *val);
342 }
343 
344 /// Construct an APint from a parsed value, a known attribute type and
345 /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)346 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
347                                            StringRef spelling) {
348   // Parse the integer value into an APInt that is big enough to hold the value.
349   APInt result;
350   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
351   if (spelling.getAsInteger(isHex ? 0 : 10, result))
352     return llvm::None;
353 
354   // Extend or truncate the bitwidth to the right size.
355   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
356                                   : type.getIntOrFloatBitWidth();
357 
358   if (width > result.getBitWidth()) {
359     result = result.zext(width);
360   } else if (width < result.getBitWidth()) {
361     // The parser can return an unnecessarily wide result with leading zeros.
362     // This isn't a problem, but truncating off bits is bad.
363     if (result.countLeadingZeros() < result.getBitWidth() - width)
364       return llvm::None;
365 
366     result = result.trunc(width);
367   }
368 
369   if (width == 0) {
370     // 0 bit integers cannot be negative and manipulation of their sign bit will
371     // assert, so short-cut validation here.
372     if (isNegative)
373       return llvm::None;
374   } else if (isNegative) {
375     // The value is negative, we have an overflow if the sign bit is not set
376     // in the negated apInt.
377     result.negate();
378     if (!result.isSignBitSet())
379       return llvm::None;
380   } else if ((type.isSignedInteger() || type.isIndex()) &&
381              result.isSignBitSet()) {
382     // The value is a positive signed integer or index,
383     // we have an overflow if the sign bit is set.
384     return llvm::None;
385   }
386 
387   return result;
388 }
389 
390 /// Parse a decimal or a hexadecimal literal, which can be either an integer
391 /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)392 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
393   Token tok = getToken();
394   StringRef spelling = tok.getSpelling();
395   SMLoc loc = tok.getLoc();
396 
397   consumeToken(Token::integer);
398   if (!type) {
399     // Default to i64 if not type is specified.
400     if (!consumeIf(Token::colon))
401       type = builder.getIntegerType(64);
402     else if (!(type = parseType()))
403       return nullptr;
404   }
405 
406   if (auto floatType = type.dyn_cast<FloatType>()) {
407     Optional<APFloat> result;
408     if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
409                                             floatType.getFloatSemantics(),
410                                             floatType.getWidth())))
411       return Attribute();
412     return FloatAttr::get(floatType, *result);
413   }
414 
415   if (!type.isa<IntegerType, IndexType>())
416     return emitError(loc, "integer literal not valid for specified type"),
417            nullptr;
418 
419   if (isNegative && type.isUnsignedInteger()) {
420     emitError(loc,
421               "negative integer literal not valid for unsigned integer type");
422     return nullptr;
423   }
424 
425   Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
426   if (!apInt)
427     return emitError(loc, "integer constant out of range for attribute"),
428            nullptr;
429   return builder.getIntegerAttr(type, *apInt);
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // TensorLiteralParser
434 //===----------------------------------------------------------------------===//
435 
436 /// Parse elements values stored within a hex string. On success, the values are
437 /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)438 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
439                                              std::string &result) {
440   if (Optional<std::string> value = tok.getHexStringValue()) {
441     result = std::move(*value);
442     return success();
443   }
444   return parser.emitError(
445       tok.getLoc(), "expected string containing hex digits starting with `0x`");
446 }
447 
448 namespace {
449 /// This class implements a parser for TensorLiterals. A tensor literal is
450 /// either a single element (e.g, 5) or a multi-dimensional list of elements
451 /// (e.g., [[5, 5]]).
452 class TensorLiteralParser {
453 public:
TensorLiteralParser(Parser & p)454   TensorLiteralParser(Parser &p) : p(p) {}
455 
456   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
457   /// may also parse a tensor literal that is store as a hex string.
458   ParseResult parse(bool allowHex);
459 
460   /// Build a dense attribute instance with the parsed elements and the given
461   /// shaped type.
462   DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
463 
getShape() const464   ArrayRef<int64_t> getShape() const { return shape; }
465 
466 private:
467   /// Get the parsed elements for an integer attribute.
468   ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
469                                  std::vector<APInt> &intValues);
470 
471   /// Get the parsed elements for a float attribute.
472   ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
473                                    std::vector<APFloat> &floatValues);
474 
475   /// Build a Dense String attribute for the given type.
476   DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
477 
478   /// Build a Dense attribute with hex data for the given type.
479   DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
480 
481   /// Parse a single element, returning failure if it isn't a valid element
482   /// literal. For example:
483   /// parseElement(1) -> Success, 1
484   /// parseElement([1]) -> Failure
485   ParseResult parseElement();
486 
487   /// Parse a list of either lists or elements, returning the dimensions of the
488   /// parsed sub-tensors in dims. For example:
489   ///   parseList([1, 2, 3]) -> Success, [3]
490   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
491   ///   parseList([[1, 2], 3]) -> Failure
492   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
493   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
494 
495   /// Parse a literal that was printed as a hex string.
496   ParseResult parseHexElements();
497 
498   Parser &p;
499 
500   /// The shape inferred from the parsed elements.
501   SmallVector<int64_t, 4> shape;
502 
503   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
504   std::vector<std::pair<bool, Token>> storage;
505 
506   /// Storage used when parsing elements that were stored as hex values.
507   Optional<Token> hexStorage;
508 };
509 } // namespace
510 
511 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
512 /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)513 ParseResult TensorLiteralParser::parse(bool allowHex) {
514   // If hex is allowed, check for a string literal.
515   if (allowHex && p.getToken().is(Token::string)) {
516     hexStorage = p.getToken();
517     p.consumeToken(Token::string);
518     return success();
519   }
520   // Otherwise, parse a list or an individual element.
521   if (p.getToken().is(Token::l_square))
522     return parseList(shape);
523   return parseElement();
524 }
525 
526 /// Build a dense attribute instance with the parsed elements and the given
527 /// shaped type.
getAttr(SMLoc loc,ShapedType type)528 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
529   Type eltType = type.getElementType();
530 
531   // Check to see if we parse the literal from a hex string.
532   if (hexStorage &&
533       (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
534     return getHexAttr(loc, type);
535 
536   // Check that the parsed storage size has the same number of elements to the
537   // type, or is a known splat.
538   if (!shape.empty() && getShape() != type.getShape()) {
539     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
540                      << "]) does not match type ([" << type.getShape() << "])";
541     return nullptr;
542   }
543 
544   // Handle the case where no elements were parsed.
545   if (!hexStorage && storage.empty() && type.getNumElements()) {
546     p.emitError(loc) << "parsed zero elements, but type (" << type
547                      << ") expected at least 1";
548     return nullptr;
549   }
550 
551   // Handle complex types in the specific element type cases below.
552   bool isComplex = false;
553   if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
554     eltType = complexTy.getElementType();
555     isComplex = true;
556   }
557 
558   // Handle integer and index types.
559   if (eltType.isIntOrIndex()) {
560     std::vector<APInt> intValues;
561     if (failed(getIntAttrElements(loc, eltType, intValues)))
562       return nullptr;
563     if (isComplex) {
564       // If this is a complex, treat the parsed values as complex values.
565       auto complexData = llvm::makeArrayRef(
566           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
567           intValues.size() / 2);
568       return DenseElementsAttr::get(type, complexData);
569     }
570     return DenseElementsAttr::get(type, intValues);
571   }
572   // Handle floating point types.
573   if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
574     std::vector<APFloat> floatValues;
575     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
576       return nullptr;
577     if (isComplex) {
578       // If this is a complex, treat the parsed values as complex values.
579       auto complexData = llvm::makeArrayRef(
580           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
581           floatValues.size() / 2);
582       return DenseElementsAttr::get(type, complexData);
583     }
584     return DenseElementsAttr::get(type, floatValues);
585   }
586 
587   // Other types are assumed to be string representations.
588   return getStringAttr(loc, type, type.getElementType());
589 }
590 
591 /// Build a Dense Integer attribute for the given type.
592 ParseResult
getIntAttrElements(SMLoc loc,Type eltTy,std::vector<APInt> & intValues)593 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
594                                         std::vector<APInt> &intValues) {
595   intValues.reserve(storage.size());
596   bool isUintType = eltTy.isUnsignedInteger();
597   for (const auto &signAndToken : storage) {
598     bool isNegative = signAndToken.first;
599     const Token &token = signAndToken.second;
600     auto tokenLoc = token.getLoc();
601 
602     if (isNegative && isUintType) {
603       return p.emitError(tokenLoc)
604              << "expected unsigned integer elements, but parsed negative value";
605     }
606 
607     // Check to see if floating point values were parsed.
608     if (token.is(Token::floatliteral)) {
609       return p.emitError(tokenLoc)
610              << "expected integer elements, but parsed floating-point";
611     }
612 
613     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
614            "unexpected token type");
615     if (token.isAny(Token::kw_true, Token::kw_false)) {
616       if (!eltTy.isInteger(1)) {
617         return p.emitError(tokenLoc)
618                << "expected i1 type for 'true' or 'false' values";
619       }
620       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
621       intValues.push_back(apInt);
622       continue;
623     }
624 
625     // Create APInt values for each element with the correct bitwidth.
626     Optional<APInt> apInt =
627         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
628     if (!apInt)
629       return p.emitError(tokenLoc, "integer constant out of range for type");
630     intValues.push_back(*apInt);
631   }
632   return success();
633 }
634 
635 /// Build a Dense Float attribute for the given type.
636 ParseResult
getFloatAttrElements(SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)637 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
638                                           std::vector<APFloat> &floatValues) {
639   floatValues.reserve(storage.size());
640   for (const auto &signAndToken : storage) {
641     bool isNegative = signAndToken.first;
642     const Token &token = signAndToken.second;
643 
644     // Handle hexadecimal float literals.
645     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
646       Optional<APFloat> result;
647       if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
648                                                 eltTy.getFloatSemantics(),
649                                                 eltTy.getWidth())))
650         return failure();
651 
652       floatValues.push_back(*result);
653       continue;
654     }
655 
656     // Check to see if any decimal integers or booleans were parsed.
657     if (!token.is(Token::floatliteral))
658       return p.emitError()
659              << "expected floating-point elements, but parsed integer";
660 
661     // Build the float values from tokens.
662     auto val = token.getFloatingPointValue();
663     if (!val)
664       return p.emitError("floating point value too large for attribute");
665 
666     APFloat apVal(isNegative ? -*val : *val);
667     if (!eltTy.isF64()) {
668       bool unused;
669       apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
670                     &unused);
671     }
672     floatValues.push_back(apVal);
673   }
674   return success();
675 }
676 
677 /// Build a Dense String attribute for the given type.
getStringAttr(SMLoc loc,ShapedType type,Type eltTy)678 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
679                                                      Type eltTy) {
680   if (hexStorage.has_value()) {
681     auto stringValue = hexStorage.value().getStringValue();
682     return DenseStringElementsAttr::get(type, {stringValue});
683   }
684 
685   std::vector<std::string> stringValues;
686   std::vector<StringRef> stringRefValues;
687   stringValues.reserve(storage.size());
688   stringRefValues.reserve(storage.size());
689 
690   for (auto val : storage) {
691     stringValues.push_back(val.second.getStringValue());
692     stringRefValues.emplace_back(stringValues.back());
693   }
694 
695   return DenseStringElementsAttr::get(type, stringRefValues);
696 }
697 
698 /// Build a Dense attribute with hex data for the given type.
getHexAttr(SMLoc loc,ShapedType type)699 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
700   Type elementType = type.getElementType();
701   if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
702     p.emitError(loc)
703         << "expected floating-point, integer, or complex element type, got "
704         << elementType;
705     return nullptr;
706   }
707 
708   std::string data;
709   if (parseElementAttrHexValues(p, *hexStorage, data))
710     return nullptr;
711 
712   ArrayRef<char> rawData(data.data(), data.size());
713   bool detectedSplat = false;
714   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
715     p.emitError(loc) << "elements hex data size is invalid for provided type: "
716                      << type;
717     return nullptr;
718   }
719 
720   if (llvm::support::endian::system_endianness() ==
721       llvm::support::endianness::big) {
722     // Convert endianess in big-endian(BE) machines. `rawData` is
723     // little-endian(LE) because HEX in raw data of dense element attribute
724     // is always LE format. It is converted into BE here to be used in BE
725     // machines.
726     SmallVector<char, 64> outDataVec(rawData.size());
727     MutableArrayRef<char> convRawData(outDataVec);
728     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
729         rawData, convRawData, type);
730     return DenseElementsAttr::getFromRawBuffer(type, convRawData);
731   }
732 
733   return DenseElementsAttr::getFromRawBuffer(type, rawData);
734 }
735 
parseElement()736 ParseResult TensorLiteralParser::parseElement() {
737   switch (p.getToken().getKind()) {
738   // Parse a boolean element.
739   case Token::kw_true:
740   case Token::kw_false:
741   case Token::floatliteral:
742   case Token::integer:
743     storage.emplace_back(/*isNegative=*/false, p.getToken());
744     p.consumeToken();
745     break;
746 
747   // Parse a signed integer or a negative floating-point element.
748   case Token::minus:
749     p.consumeToken(Token::minus);
750     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
751       return p.emitError("expected integer or floating point literal");
752     storage.emplace_back(/*isNegative=*/true, p.getToken());
753     p.consumeToken();
754     break;
755 
756   case Token::string:
757     storage.emplace_back(/*isNegative=*/false, p.getToken());
758     p.consumeToken();
759     break;
760 
761   // Parse a complex element of the form '(' element ',' element ')'.
762   case Token::l_paren:
763     p.consumeToken(Token::l_paren);
764     if (parseElement() ||
765         p.parseToken(Token::comma, "expected ',' between complex elements") ||
766         parseElement() ||
767         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
768       return failure();
769     break;
770 
771   default:
772     return p.emitError("expected element literal of primitive type");
773   }
774 
775   return success();
776 }
777 
778 /// Parse a list of either lists or elements, returning the dimensions of the
779 /// parsed sub-tensors in dims. For example:
780 ///   parseList([1, 2, 3]) -> Success, [3]
781 ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
782 ///   parseList([[1, 2], 3]) -> Failure
783 ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)784 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
785   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
786                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
787     if (prevDims == newDims)
788       return success();
789     return p.emitError("tensor literal is invalid; ranks are not consistent "
790                        "between elements");
791   };
792 
793   bool first = true;
794   SmallVector<int64_t, 4> newDims;
795   unsigned size = 0;
796   auto parseOneElement = [&]() -> ParseResult {
797     SmallVector<int64_t, 4> thisDims;
798     if (p.getToken().getKind() == Token::l_square) {
799       if (parseList(thisDims))
800         return failure();
801     } else if (parseElement()) {
802       return failure();
803     }
804     ++size;
805     if (!first)
806       return checkDims(newDims, thisDims);
807     newDims = thisDims;
808     first = false;
809     return success();
810   };
811   if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
812     return failure();
813 
814   // Return the sublists' dimensions with 'size' prepended.
815   dims.clear();
816   dims.push_back(size);
817   dims.append(newDims.begin(), newDims.end());
818   return success();
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // ElementsAttr Parser
823 //===----------------------------------------------------------------------===//
824 
825 namespace {
826 /// This class provides an implementation of AsmParser, allowing to call back
827 /// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
828 /// in libMLIRIR.
829 class CustomAsmParser : public AsmParserImpl<AsmParser> {
830 public:
CustomAsmParser(Parser & parser)831   CustomAsmParser(Parser &parser)
832       : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
833 };
834 } // namespace
835 
836 /// Parse a dense array attribute.
parseDenseArrayAttr()837 Attribute Parser::parseDenseArrayAttr() {
838   auto typeLoc = getToken().getLoc();
839   auto type = parseType();
840   if (!type)
841     return {};
842   CustomAsmParser parser(*this);
843   Attribute result;
844   // Check for empty list.
845   bool isEmptyList = getToken().is(Token::r_square);
846 
847   if (auto intType = type.dyn_cast<IntegerType>()) {
848     switch (type.getIntOrFloatBitWidth()) {
849     case 8:
850       if (isEmptyList)
851         result = DenseI8ArrayAttr::get(parser.getContext(), {});
852       else
853         result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
854       break;
855     case 16:
856       if (isEmptyList)
857         result = DenseI16ArrayAttr::get(parser.getContext(), {});
858       else
859         result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
860       break;
861     case 32:
862       if (isEmptyList)
863         result = DenseI32ArrayAttr::get(parser.getContext(), {});
864       else
865         result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
866       break;
867     case 64:
868       if (isEmptyList)
869         result = DenseI64ArrayAttr::get(parser.getContext(), {});
870       else
871         result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
872       break;
873     default:
874       emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
875       return {};
876     }
877   } else if (auto floatType = type.dyn_cast<FloatType>()) {
878     switch (type.getIntOrFloatBitWidth()) {
879     case 32:
880       if (isEmptyList)
881         result = DenseF32ArrayAttr::get(parser.getContext(), {});
882       else
883         result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
884       break;
885     case 64:
886       if (isEmptyList)
887         result = DenseF64ArrayAttr::get(parser.getContext(), {});
888       else
889         result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
890       break;
891     default:
892       emitError(typeLoc, "expected f32 or f64 but got: ") << type;
893       return {};
894     }
895   } else {
896     emitError(typeLoc, "expected integer or float type, got: ") << type;
897     return {};
898   }
899   if (!consumeIf(Token::r_square)) {
900     emitError("expected ']' to close an array attribute");
901     return {};
902   }
903   return result;
904 }
905 
906 /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)907 Attribute Parser::parseDenseElementsAttr(Type attrType) {
908   auto attribLoc = getToken().getLoc();
909   consumeToken(Token::kw_dense);
910   if (parseToken(Token::less, "expected '<' after 'dense'"))
911     return nullptr;
912 
913   // Parse the literal data if necessary.
914   TensorLiteralParser literalParser(*this);
915   if (!consumeIf(Token::greater)) {
916     if (literalParser.parse(/*allowHex=*/true) ||
917         parseToken(Token::greater, "expected '>'"))
918       return nullptr;
919   }
920 
921   // If the type is specified `parseElementsLiteralType` will not parse a type.
922   // Use the attribute location as the location for error reporting in that
923   // case.
924   auto loc = attrType ? attribLoc : getToken().getLoc();
925   auto type = parseElementsLiteralType(attrType);
926   if (!type)
927     return nullptr;
928   return literalParser.getAttr(loc, type);
929 }
930 
931 /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)932 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
933   SMLoc loc = getToken().getLoc();
934   consumeToken(Token::kw_opaque);
935   if (parseToken(Token::less, "expected '<' after 'opaque'"))
936     return nullptr;
937 
938   if (getToken().isNot(Token::string))
939     return (emitError("expected dialect namespace"), nullptr);
940 
941   std::string name = getToken().getStringValue();
942   consumeToken(Token::string);
943 
944   if (parseToken(Token::comma, "expected ','"))
945     return nullptr;
946 
947   Token hexTok = getToken();
948   if (parseToken(Token::string, "elements hex string should start with '0x'") ||
949       parseToken(Token::greater, "expected '>'"))
950     return nullptr;
951   auto type = parseElementsLiteralType(attrType);
952   if (!type)
953     return nullptr;
954 
955   std::string data;
956   if (parseElementAttrHexValues(*this, hexTok, data))
957     return nullptr;
958   return getChecked<OpaqueElementsAttr>(loc, builder.getStringAttr(name), type,
959                                         data);
960 }
961 
962 /// Shaped type for elements attribute.
963 ///
964 ///   elements-literal-type ::= vector-type | ranked-tensor-type
965 ///
966 /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)967 ShapedType Parser::parseElementsLiteralType(Type type) {
968   // If the user didn't provide a type, parse the colon type for the literal.
969   if (!type) {
970     if (parseToken(Token::colon, "expected ':'"))
971       return nullptr;
972     if (!(type = parseType()))
973       return nullptr;
974   }
975 
976   if (!type.isa<RankedTensorType, VectorType>()) {
977     emitError("elements literal must be a ranked tensor or vector type");
978     return nullptr;
979   }
980 
981   auto sType = type.cast<ShapedType>();
982   if (!sType.hasStaticShape())
983     return (emitError("elements literal type must have static shape"), nullptr);
984 
985   return sType;
986 }
987 
988 /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)989 Attribute Parser::parseSparseElementsAttr(Type attrType) {
990   SMLoc loc = getToken().getLoc();
991   consumeToken(Token::kw_sparse);
992   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
993     return nullptr;
994 
995   // Check for the case where all elements are sparse. The indices are
996   // represented by a 2-dimensional shape where the second dimension is the rank
997   // of the type.
998   Type indiceEltType = builder.getIntegerType(64);
999   if (consumeIf(Token::greater)) {
1000     ShapedType type = parseElementsLiteralType(attrType);
1001     if (!type)
1002       return nullptr;
1003 
1004     // Construct the sparse elements attr using zero element indice/value
1005     // attributes.
1006     ShapedType indicesType =
1007         RankedTensorType::get({0, type.getRank()}, indiceEltType);
1008     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1009     return getChecked<SparseElementsAttr>(
1010         loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1011         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1012   }
1013 
1014   /// Parse the indices. We don't allow hex values here as we may need to use
1015   /// the inferred shape.
1016   auto indicesLoc = getToken().getLoc();
1017   TensorLiteralParser indiceParser(*this);
1018   if (indiceParser.parse(/*allowHex=*/false))
1019     return nullptr;
1020 
1021   if (parseToken(Token::comma, "expected ','"))
1022     return nullptr;
1023 
1024   /// Parse the values.
1025   auto valuesLoc = getToken().getLoc();
1026   TensorLiteralParser valuesParser(*this);
1027   if (valuesParser.parse(/*allowHex=*/true))
1028     return nullptr;
1029 
1030   if (parseToken(Token::greater, "expected '>'"))
1031     return nullptr;
1032 
1033   auto type = parseElementsLiteralType(attrType);
1034   if (!type)
1035     return nullptr;
1036 
1037   // If the indices are a splat, i.e. the literal parser parsed an element and
1038   // not a list, we set the shape explicitly. The indices are represented by a
1039   // 2-dimensional shape where the second dimension is the rank of the type.
1040   // Given that the parsed indices is a splat, we know that we only have one
1041   // indice and thus one for the first dimension.
1042   ShapedType indicesType;
1043   if (indiceParser.getShape().empty()) {
1044     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1045   } else {
1046     // Otherwise, set the shape to the one parsed by the literal parser.
1047     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1048   }
1049   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1050 
1051   // If the values are a splat, set the shape explicitly based on the number of
1052   // indices. The number of indices is encoded in the first dimension of the
1053   // indice shape type.
1054   auto valuesEltType = type.getElementType();
1055   ShapedType valuesType =
1056       valuesParser.getShape().empty()
1057           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1058           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1059   auto values = valuesParser.getAttr(valuesLoc, valuesType);
1060 
1061   // Build the sparse elements attribute by the indices and values.
1062   return getChecked<SparseElementsAttr>(loc, type, indices, values);
1063 }
1064